diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index e4134dd79d..ff398f95d0 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1550,6 +1550,16 @@ impl Agent { ); break; } + Err(ref provider_err @ ProviderError::NetworkError(_)) => { + crate::posthog::emit_error(provider_err.telemetry_type(), &provider_err.to_string()); + error!("Error: {}", provider_err); + yield AgentEvent::Message( + Message::assistant().with_text( + format!("{provider_err}\n\nPlease resend your message to try again.") + ) + ); + break; + } Err(ref provider_err) => { crate::posthog::emit_error(provider_err.telemetry_type(), &provider_err.to_string()); error!("Error: {}", provider_err); diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 5e6613a57e..2748011178 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -16,6 +16,7 @@ use super::formats::anthropic::{ }; use super::openai_compatible::handle_status_openai_compat; use super::openai_compatible::map_http_error_to_provider_error; +use super::retry::ProviderRetry; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -226,19 +227,22 @@ impl Provider for AnthropicProvider { .unwrap() .insert("stream".to_string(), Value::Bool(true)); - let mut request = self.api_client.request(Some(session_id), "v1/messages"); + let conditional_headers = self.get_conditional_headers(); let mut log = RequestLog::start(model_config, &payload)?; - for (key, value) in self.get_conditional_headers() { - request = request.header(key, value)?; - } - - let resp = request.response_post(&payload).await.inspect_err(|e| { - let _ = log.error(e); - })?; - let response = handle_status_openai_compat(resp).await.inspect_err(|e| { - let _ = log.error(e); - })?; + let response = self + .with_retry(|| async { + let mut request = self.api_client.request(Some(session_id), "v1/messages"); + for (key, value) in &conditional_headers { + request = request.header(key, value)?; + } + let resp = request.response_post(&payload).await?; + handle_status_openai_compat(resp).await + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; let stream = response.bytes_stream().map_err(io::Error::other); diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 71a99bbefa..214a14837d 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -19,6 +19,9 @@ pub enum ProviderError { #[error("Server error: {0}")] ServerError(String), + #[error("Network error: {0}")] + NetworkError(String), + #[error("Request failed: {0}")] RequestFailed(String), @@ -45,6 +48,7 @@ impl ProviderError { ProviderError::ContextLengthExceeded(_) => "context_length", ProviderError::RateLimitExceeded { .. } => "rate_limit", ProviderError::ServerError(_) => "server", + ProviderError::NetworkError(_) => "network", ProviderError::RequestFailed(_) => "request", ProviderError::ExecutionError(_) => "execution", ProviderError::UsageError(_) => "usage", @@ -54,38 +58,51 @@ impl ProviderError { } } +fn is_network_error(err: &reqwest::Error) -> bool { + err.is_connect() || err.is_timeout() || (err.status().is_none() && err.is_request()) +} + +fn provider_error_from_reqwest(error: &reqwest::Error) -> ProviderError { + if is_network_error(error) { + let msg = if error.is_timeout() { + "Request timed out — check your network connection and try again.".to_string() + } else if error.is_connect() { + if let Some(url) = error.url() { + if let Some(host) = url.host_str() { + let port_info = url.port().map(|p| format!(":{}", p)).unwrap_or_default(); + format!( + "Could not connect to {}{} — check your network connection and try again.", + host, port_info + ) + } else { + "Could not connect to the provider — check your network connection and try again.".to_string() + } + } else { + "Could not connect to the provider — check your network connection and try again." + .to_string() + } + } else { + "Network error — check your network connection and try again.".to_string() + }; + return ProviderError::NetworkError(msg); + } + + let mut details = vec![]; + if let Some(status) = error.status() { + details.push(format!("status: {}", status)); + } + let msg = if details.is_empty() { + error.to_string() + } else { + format!("{} ({})", error, details.join(", ")) + }; + ProviderError::RequestFailed(msg) +} + impl From for ProviderError { fn from(error: anyhow::Error) -> Self { if let Some(reqwest_err) = error.downcast_ref::() { - let mut details = vec![]; - - if let Some(status) = reqwest_err.status() { - details.push(format!("status: {}", status)); - } - if reqwest_err.is_timeout() { - details.push("timeout".to_string()); - } - if reqwest_err.is_connect() { - if let Some(url) = reqwest_err.url() { - if let Some(host) = url.host_str() { - let port_info = url.port().map(|p| format!(":{}", p)).unwrap_or_default(); - - details.push(format!("failed to connect to {}{}", host, port_info)); - - if url.port().is_some() { - details.push("check that the port is correct".to_string()); - } - } - } else { - details.push("connection failed".to_string()); - } - } - let msg = if details.is_empty() { - reqwest_err.to_string() - } else { - format!("{} ({})", reqwest_err, details.join(", ")) - }; - return ProviderError::RequestFailed(msg); + return provider_error_from_reqwest(reqwest_err); } ProviderError::ExecutionError(error.to_string()) } @@ -93,7 +110,7 @@ impl From for ProviderError { impl From for ProviderError { fn from(error: reqwest::Error) -> Self { - ProviderError::RequestFailed(error.to_string()) + provider_error_from_reqwest(&error) } } diff --git a/crates/goose/src/providers/retry.rs b/crates/goose/src/providers/retry.rs index ca1a6f7261..4a8a17a92f 100644 --- a/crates/goose/src/providers/retry.rs +++ b/crates/goose/src/providers/retry.rs @@ -76,6 +76,7 @@ pub fn should_retry(error: &ProviderError) -> bool { error, ProviderError::RateLimitExceeded { .. } | ProviderError::ServerError(_) + | ProviderError::NetworkError(_) | ProviderError::RequestFailed(_) ) }