diff --git a/crates/agent/src/db.rs b/crates/agent/src/db.rs index 0ed03ed5170..a34290742ad 100644 --- a/crates/agent/src/db.rs +++ b/crates/agent/src/db.rs @@ -261,7 +261,7 @@ impl DbThread { tool_use_id: tool_result.tool_use_id, tool_name: name.into(), is_error: tool_result.is_error, - content: tool_result.content, + content: vec![tool_result.content], output: tool_result.output, }, ); diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index c1c2886f84e..7e4f314afd0 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -1156,7 +1156,7 @@ fn tool_result( tool_use_id: LanguageModelToolUseId::from(id.into()), tool_name: name.into(), is_error: false, - content: LanguageModelToolResultContent::Text(result.into()), + content: vec![LanguageModelToolResultContent::Text(result.into())], output: None, }) } diff --git a/crates/agent/src/tests/edit_file_thread_test.rs b/crates/agent/src/tests/edit_file_thread_test.rs index b5ce6441e79..3efd7753740 100644 --- a/crates/agent/src/tests/edit_file_thread_test.rs +++ b/crates/agent/src/tests/edit_file_thread_test.rs @@ -387,10 +387,7 @@ async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes( "Tool result should succeed, got: {:?}", tool_result ); - let content_text = match &tool_result.content { - language_model::LanguageModelToolResultContent::Text(t) => t.to_string(), - other => panic!("Expected text content, got: {:?}", other), - }; + let content_text = tool_result.text_contents(); assert!( !content_text.contains("file has been modified since you last read it"), "Did not expect a stale last-read error, got: {content_text}" diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index f8d74e0df95..996e753952b 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -494,7 +494,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) { assert_eq!(pending_completion.messages[0].role, Role::System); let system_message = &pending_completion.messages[0]; - let system_prompt = system_message.content[0].to_str().unwrap(); + let MessageContent::Text(system_prompt) = &system_message.content[0] else { + panic!("Expected text content"); + }; assert!( system_prompt.contains("test-shell"), "unexpected system message: {:?}", @@ -530,7 +532,9 @@ async fn test_system_prompt_without_tools(cx: &mut TestAppContext) { assert_eq!(pending_completion.messages[0].role, Role::System); let system_message = &pending_completion.messages[0]; - let system_prompt = system_message.content[0].to_str().unwrap(); + let MessageContent::Text(system_prompt) = &system_message.content[0] else { + panic!("Expected text content"); + }; assert!( !system_prompt.contains("## Tool Use"), "unexpected system message: {:?}", @@ -637,7 +641,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { tool_use_id: "tool_1".into(), tool_name: EchoTool::NAME.into(), is_error: false, - content: "test".into(), + content: vec!["test".into()], output: Some("test".into()), }; assert_eq!( @@ -866,14 +870,14 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(), tool_name: ToolRequiringPermission::NAME.into(), is_error: false, - content: "Allowed".into(), + content: vec!["Allowed".into()], output: Some("Allowed".into()) }), language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(), tool_name: ToolRequiringPermission::NAME.into(), is_error: true, - content: "Permission to run tool denied by user".into(), + content: vec!["Permission to run tool denied by user".into()], output: Some("Permission to run tool denied by user".into()) }) ] @@ -912,7 +916,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(), tool_name: ToolRequiringPermission::NAME.into(), is_error: false, - content: "Allowed".into(), + content: vec!["Allowed".into()], output: Some("Allowed".into()) } )] @@ -940,7 +944,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { tool_use_id: "tool_id_4".into(), tool_name: ToolRequiringPermission::NAME.into(), is_error: false, - content: "Allowed".into(), + content: vec!["Allowed".into()], output: Some("Allowed".into()) } )] @@ -1562,14 +1566,14 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { tool_use_id: "tool_3".into(), tool_name: "echo".into(), is_error: false, - content: "native".into(), + content: vec!["native".into()], output: Some("native".into()), },), MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: "tool_2".into(), tool_name: "test_server_echo".into(), is_error: false, - content: "mcp".into(), + content: vec!["mcp".into()], output: Some("mcp".into()), },), ] @@ -1578,6 +1582,126 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { events.collect::>().await; } +#[gpui::test] +async fn test_mcp_tool_multi_content_response(cx: &mut TestAppContext) { + let ThreadTest { + model, + thread, + context_server_store, + fs, + .. + } = setup(cx, TestModel::Fake).await; + let fake_model = model.as_fake(); + fake_model.set_supports_images(true); + + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "tool_permissions": { "default": "allow" }, + "profiles": { + "test": { + "name": "Test Profile", + "enable_all_context_servers": true, + "tools": {} + }, + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + cx.run_until_parked(); + thread.update(cx, |thread, cx| { + thread.set_profile(AgentProfileId("test".into()), cx) + }); + + let mut mcp_tool_calls = setup_context_server( + "screenshot_server", + vec![context_server::types::Tool { + name: "screenshot".into(), + description: None, + input_schema: json!({"type": "object", "properties": {}}), + output_schema: None, + annotations: None, + }], + &context_server_store, + cx, + ); + + let events = thread.update(cx, |thread, cx| { + thread + .send(UserMessageId::new(), ["Take a screenshot"], cx) + .unwrap() + }); + cx.run_until_parked(); + + let completion = fake_model.pending_completions().pop().unwrap(); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: "tool_1".into(), + name: "screenshot".into(), + raw_input: json!({}).to_string(), + input: json!({}), + is_input_complete: true, + thought_signature: None, + }, + )); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + let _ = completion; + + let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap(); + assert_eq!(tool_call_params.name, "screenshot"); + tool_call_response + .send(context_server::types::CallToolResponse { + content: vec![ + context_server::types::ToolResponseContent::Text { + text: "Some text".into(), + }, + context_server::types::ToolResponseContent::Image { + data: "aGVsbG8=".into(), + mime_type: "image/png".into(), + }, + context_server::types::ToolResponseContent::Text { + text: "Some more text".into(), + }, + ], + is_error: None, + meta: None, + structured_content: None, + }) + .unwrap(); + cx.run_until_parked(); + + // Verify the tool result round-trips back to the model as a multi-part Vec. + let completion = fake_model.pending_completions().pop().unwrap(); + let tool_result = completion + .messages + .last() + .unwrap() + .content + .iter() + .find_map(|c| match c { + MessageContent::ToolResult(r) => Some(r.clone()), + _ => None, + }) + .expect("expected a tool result"); + assert_eq!(tool_result.tool_use_id, "tool_1".into()); + assert_eq!(tool_result.content.len(), 2); + assert_eq!( + tool_result.content[0], + language_model::LanguageModelToolResultContent::Text(Arc::from("Some text")) + ); + assert_eq!( + tool_result.content[1], + language_model::LanguageModelToolResultContent::Text(Arc::from("Some more text")) + ); + fake_model.end_last_completion_stream(); + events.collect::>().await; +} + #[gpui::test] async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) { let ThreadTest { @@ -2106,10 +2230,7 @@ async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext .get(&tool_use.id) .expect("expected tool result"); - let result_text = match &tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), - _ => panic!("expected text content in tool result"), - }; + let result_text = tool_result.text_contents(); // "partial output" comes from FakeTerminalHandle's output field assert!( @@ -2571,10 +2692,7 @@ async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppCon .get(&tool_use.id) .expect("expected tool result"); - let result_text = match &tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), - _ => panic!("expected text content in tool result"), - }; + let result_text = tool_result.text_contents(); assert!( result_text.contains("The user stopped this command"), @@ -2666,10 +2784,7 @@ async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) { .get(&tool_use.id) .expect("expected tool result"); - let result_text = match &tool_result.content { - language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), - _ => panic!("expected text content in tool result"), - }; + let result_text = tool_result.text_contents(); assert!( result_text.contains("timed out"), @@ -3290,7 +3405,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { tool_use_id: echo_tool_use.id.clone(), tool_name: echo_tool_use.name, is_error: false, - content: "test".into(), + content: vec!["test".into()], output: Some("test".into()) })], cache: false, @@ -3776,7 +3891,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { tool_use_id: tool_use_1.id.clone(), tool_name: tool_use_1.name.clone(), is_error: false, - content: "test".into(), + content: vec!["test".into()], output: Some("test".into()) } )], @@ -3936,8 +4051,10 @@ async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input( tool_use_id: tool_use.id.clone(), tool_name: tool_use.name, is_error: true, - content: "Failed to receive tool input: tool input was not fully received" - .into(), + content: vec![ + "Failed to receive tool input: tool input was not fully received" + .into(), + ], output: Some( "Failed to receive tool input: tool input was not fully received" .into() @@ -4044,10 +4161,7 @@ async fn test_streaming_tool_json_parse_error_is_forwarded_to_running_tool( let result = tool_results[0]; assert!(result.is_error); - let content_text = match &result.content { - language_model::LanguageModelToolResultContent::Text(text) => text.to_string(), - other => panic!("Expected text content, got {:?}", other), - }; + let content_text = result.text_contents(); assert!( content_text.contains("Saw partial text 'partial' before invalid JSON"), "Expected tool-enriched partial context, got: {content_text}" @@ -7069,7 +7183,7 @@ async fn test_streaming_tool_error_breaks_stream_loop_immediately(cx: &mut TestA tool_use_id: tool_use.id.clone(), tool_name: tool_use.name, is_error: true, - content: "failed".into(), + content: vec!["failed".into()], output: Some("failed".into()), } )], @@ -7180,14 +7294,14 @@ async fn test_streaming_tool_error_waits_for_prior_tools_to_complete(cx: &mut Te tool_use_id: second_tool_use.id.clone(), tool_name: second_tool_use.name, is_error: true, - content: "failed".into(), + content: vec!["failed".into()], output: Some("failed".into()), }), language_model::MessageContent::ToolResult(LanguageModelToolResult { tool_use_id: first_tool_use.id.clone(), tool_name: first_tool_use.name, is_error: false, - content: "hello world".into(), + content: vec!["hello world".into()], output: Some("hello world".into()), }), ], diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 89b3b0eb251..2a8a6a5b3cb 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -518,12 +518,14 @@ impl AgentMessage { markdown.push_str("**ERROR:**\n"); } - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}\n").ok(); - } - LanguageModelToolResultContent::Image(_) => { - writeln!(markdown, "\n").ok(); + for part in &tool_result.content { + match part { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); + } + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); + } } } @@ -588,8 +590,8 @@ impl AgentMessage { let mut tool_result = tool_result.clone(); // Surprisingly, the API fails if we return an empty string here. // It thinks we are sending a tool use without a tool result. - if tool_result.content.is_empty() { - tool_result.content = "".into(); + if tool_result.is_content_empty() { + tool_result.content = vec!["".into()]; } user_message .content @@ -2332,7 +2334,7 @@ impl Thread { let Some(tool) = tool else { let content = format!("No tool named {} exists", tool_use.name); return Some(Task::ready(LanguageModelToolResult { - content: LanguageModelToolResultContent::Text(Arc::from(content)), + content: vec![LanguageModelToolResultContent::Text(Arc::from(content))], tool_use_id: tool_use.id, tool_name: tool_use.name, is_error: true, @@ -2418,13 +2420,39 @@ impl Thread { cx.foreground_executor().spawn(async move { let (is_error, output) = match tool_result.await { Ok(mut output) => { - if let LanguageModelToolResultContent::Image(_) = &output.llm_output - && !supports_images - { - output = AgentToolOutput::from_error( - "Attempted to read an image, but this model doesn't support it.", - ); - (true, output) + let contains_image = output + .llm_output + .iter() + .any(|part| matches!(part, LanguageModelToolResultContent::Image(_))); + if contains_image && !supports_images { + // Replace each image part with an inline placeholder so + // any accompanying text is still presented to the model. + // If there's nothing else in the output, surface an error + // to match the pre-multi-part behavior for image-only + // tool results. + let placeholder = LanguageModelToolResultContent::Text(Arc::from( + "[Tool responded with an image, but this model doesn't support images]", + )); + let has_non_image = output + .llm_output + .iter() + .any(|part| !matches!(part, LanguageModelToolResultContent::Image(_))); + if has_non_image { + output.llm_output = output + .llm_output + .into_iter() + .map(|part| match part { + LanguageModelToolResultContent::Image(_) => placeholder.clone(), + other => other, + }) + .collect(); + (false, output) + } else { + let output = AgentToolOutput::from_error( + "Attempted to read an image, but this model doesn't support it.", + ); + (true, output) + } } else { (false, output) } @@ -2472,7 +2500,7 @@ impl Thread { let Some(tool) = tool else { let content = format!("No tool named {} exists", tool_use.name); return Some(Task::ready(LanguageModelToolResult { - content: LanguageModelToolResultContent::Text(Arc::from(content)), + content: vec![LanguageModelToolResultContent::Text(Arc::from(content))], tool_use_id: tool_use.id, tool_name: tool_use.name, is_error: true, @@ -2743,7 +2771,9 @@ impl Thread { tool_use_id: tool_use.id.clone(), tool_name: tool_use.name.clone(), is_error: true, - content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), + content: vec![LanguageModelToolResultContent::Text( + TOOL_CANCELED_MESSAGE.into(), + )], output: None, }, ); @@ -3392,14 +3422,16 @@ where pub struct Erased(T); pub struct AgentToolOutput { - pub llm_output: LanguageModelToolResultContent, + pub llm_output: Vec, pub raw_output: serde_json::Value, } impl AgentToolOutput { pub fn from_error(message: impl Into) -> Self { let message = message.into(); - let llm_output = LanguageModelToolResultContent::Text(Arc::from(message.as_str())); + let llm_output = vec![LanguageModelToolResultContent::Text(Arc::from( + message.as_str(), + ))]; Self { raw_output: serde_json::Value::String(message), llm_output, @@ -3484,7 +3516,7 @@ where AgentToolOutput::from_error(format!("Failed to serialize tool output: {e}")) })?; Ok(AgentToolOutput { - llm_output: output.into(), + llm_output: vec![output.into()], raw_output, }) } @@ -3494,7 +3526,7 @@ where serde_json::Value::Null }); Err(AgentToolOutput { - llm_output: error_output.into(), + llm_output: vec![error_output.into()], raw_output, }) } @@ -4518,8 +4550,8 @@ mod tests { assert_eq!(result.tool_use_id, tool_use_id); assert_eq!(result.tool_name, tool_name); assert!(matches!( - result.content, - LanguageModelToolResultContent::Text(_) + result.content.as_slice(), + [LanguageModelToolResultContent::Text(_)] )); thread.update(cx, |thread, _cx| { diff --git a/crates/agent/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs index 65b5df8abfe..c5476d6343d 100644 --- a/crates/agent/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -5,6 +5,7 @@ use collections::{BTreeMap, HashMap}; use context_server::{ContextServerId, client::NotificationSubscription}; use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task}; +use language_model::LanguageModelToolResultContent; use project::context_server_store::{ContextServerStatus, ContextServerStore}; use std::sync::Arc; use util::ResultExt; @@ -389,11 +390,13 @@ impl AnyAgentTool for ContextServerTool { return Err(AgentToolOutput::from_error(error_message)); } - let mut result = String::new(); + let mut llm_output = Vec::new(); + let mut concatenated_text = String::new(); for content in response.content { match content { context_server::types::ToolResponseContent::Text { text } => { - result.push_str(&text); + concatenated_text.push_str(&text); + llm_output.push(LanguageModelToolResultContent::Text(text.into())); } context_server::types::ToolResponseContent::Image { .. } => { log::warn!("Ignoring image content from tool response"); @@ -406,9 +409,10 @@ impl AnyAgentTool for ContextServerTool { } } } + let raw_output = serde_json::Value::String(concatenated_text); Ok(AgentToolOutput { - raw_output: result.clone().into(), - llm_output: result.into(), + raw_output, + llm_output, }) }) } diff --git a/crates/agent/src/tools/evals/streaming_edit_file.rs b/crates/agent/src/tools/evals/streaming_edit_file.rs index 3156fd25397..c82f652daca 100644 --- a/crates/agent/src/tools/evals/streaming_edit_file.rs +++ b/crates/agent/src/tools/evals/streaming_edit_file.rs @@ -666,7 +666,7 @@ fn tool_result( tool_use_id: LanguageModelToolUseId::from(id.into()), tool_name: name.into(), is_error: false, - content: LanguageModelToolResultContent::Text(result.into()), + content: vec![LanguageModelToolResultContent::Text(result.into())], output: None, }) } diff --git a/crates/anthropic/src/completion.rs b/crates/anthropic/src/completion.rs index 7bb4821cc78..48eed580d68 100644 --- a/crates/anthropic/src/completion.rs +++ b/crates/anthropic/src/completion.rs @@ -70,25 +70,38 @@ fn to_anthropic_content(content: MessageContent) -> Option { input: tool_use.input, cache_control: None, }), - MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult { - tool_use_id: tool_result.tool_use_id.to_string(), - is_error: tool_result.is_error, - content: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { + MessageContent::ToolResult(tool_result) => { + let content = match tool_result.content.as_slice() { + [LanguageModelToolResultContent::Text(text)] => { ToolResultContent::Plain(text.to_string()) } - LanguageModelToolResultContent::Image(image) => { - ToolResultContent::Multipart(vec![ToolResultPart::Image { - source: ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }]) + _ => { + let parts = tool_result + .content + .into_iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => ToolResultPart::Text { + text: text.to_string(), + }, + LanguageModelToolResultContent::Image(image) => ToolResultPart::Image { + source: ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }, + }) + .collect(); + ToolResultContent::Multipart(parts) } - }, - cache_control: None, - }), + }; + Some(RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content, + cache_control: None, + }) + } } } diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs index efbd1dc9ff7..b546682c552 100644 --- a/crates/google_ai/src/completion.rs +++ b/crates/google_ai/src/completion.rs @@ -70,38 +70,39 @@ pub fn into_google( })] } MessageContent::ToolResult(tool_result) => { - match tool_result.content { - language_model_core::LanguageModelToolResultContent::Text(text) => { - vec![Part::FunctionResponsePart(crate::FunctionResponsePart { - function_response: crate::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": text - }), - }, - })] - } - language_model_core::LanguageModelToolResultContent::Image(image) => { - vec![ - Part::FunctionResponsePart(crate::FunctionResponsePart { - function_response: crate::FunctionResponse { - name: tool_result.tool_name.to_string(), - // The API expects a valid JSON object - response: serde_json::json!({ - "output": "Tool responded with an image" - }), - }, - }), - Part::InlineDataPart(InlineDataPart { + let mut text_output = String::new(); + let mut images: Vec = Vec::new(); + for part in tool_result.content { + match part { + language_model_core::LanguageModelToolResultContent::Text(text) => { + text_output.push_str(&text); + } + language_model_core::LanguageModelToolResultContent::Image(image) => { + images.push(InlineDataPart { inline_data: GenerativeContentBlob { mime_type: "image/png".to_string(), data: image.source.to_string(), }, - }), - ] + }); + } } } + let output = if text_output.is_empty() && !images.is_empty() { + "Tool responded with an image".to_string() + } else { + text_output + }; + let mut parts = vec![Part::FunctionResponsePart(crate::FunctionResponsePart { + function_response: crate::FunctionResponse { + name: tool_result.tool_name.to_string(), + // The API expects a valid JSON object + response: serde_json::json!({ + "output": output + }), + }, + })]; + parts.extend(images.into_iter().map(Part::InlineDataPart)); + parts } }) .collect() diff --git a/crates/language_model/src/fake_provider.rs b/crates/language_model/src/fake_provider.rs index 4466a3f2762..dfef78b5fce 100644 --- a/crates/language_model/src/fake_provider.rs +++ b/crates/language_model/src/fake_provider.rs @@ -125,6 +125,7 @@ pub struct FakeLanguageModel { forbid_requests: AtomicBool, supports_thinking: AtomicBool, supports_streaming_tools: AtomicBool, + supports_images: AtomicBool, } impl Default for FakeLanguageModel { @@ -138,6 +139,7 @@ impl Default for FakeLanguageModel { forbid_requests: AtomicBool::new(false), supports_thinking: AtomicBool::new(false), supports_streaming_tools: AtomicBool::new(false), + supports_images: AtomicBool::new(false), } } } @@ -174,6 +176,10 @@ impl FakeLanguageModel { self.supports_streaming_tools.store(supports, SeqCst); } + pub fn set_supports_images(&self, supports: bool) { + self.supports_images.store(supports, SeqCst); + } + pub fn pending_completions(&self) -> Vec { self.current_completion_txs .lock() @@ -280,7 +286,7 @@ impl LanguageModel for FakeLanguageModel { } fn supports_images(&self) -> bool { - false + self.supports_images.load(SeqCst) } fn supports_thinking(&self) -> bool { diff --git a/crates/language_model_core/src/request.rs b/crates/language_model_core/src/request.rs index a35f4883389..f352ce16d22 100644 --- a/crates/language_model_core/src/request.rs +++ b/crates/language_model_core/src/request.rs @@ -102,12 +102,74 @@ pub struct LanguageModelToolResult { pub tool_use_id: LanguageModelToolUseId, pub tool_name: Arc, pub is_error: bool, - /// The tool output formatted for presenting to the model - pub content: LanguageModelToolResultContent, + #[serde(with = "tool_result_content_vec")] + pub content: Vec, /// The raw tool output, if available, often for debugging or extra state for replay pub output: Option, } +impl LanguageModelToolResult { + /// Concatenates all `Text` parts of the content, ignoring non-text parts. + pub fn text_contents(&self) -> String { + let mut buffer = String::new(); + for part in &self.content { + if let LanguageModelToolResultContent::Text(text) = part { + buffer.push_str(text); + } + } + buffer + } + + /// Returns true when there are no content parts, or every part is empty. + pub fn is_content_empty(&self) -> bool { + self.content.iter().all(|part| part.is_empty()) + } +} + +/// Serde helper that accepts both the legacy single-value shape and the new +/// array shape for `LanguageModelToolResult::content`, and normalizes both to +/// `Vec`. +mod tool_result_content_vec { + use super::LanguageModelToolResultContent; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize( + value: &Vec, + serializer: S, + ) -> Result + where + S: Serializer, + { + value.serialize(serializer) + } + + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = serde_json::Value::deserialize(deserializer)?; + match value { + serde_json::Value::Array(items) => { + let mut out = Vec::with_capacity(items.len()); + for item in items { + out.push( + serde_json::from_value::(item) + .map_err(serde::de::Error::custom)?, + ); + } + Ok(out) + } + other => { + let single = serde_json::from_value::(other) + .map_err(serde::de::Error::custom)?; + Ok(vec![single]) + } + } + } +} + #[derive(Debug, Clone, Serialize, Eq, PartialEq, Hash)] pub enum LanguageModelToolResultContent { Text(Arc), @@ -231,21 +293,11 @@ pub enum MessageContent { } impl MessageContent { - pub fn to_str(&self) -> Option<&str> { - match self { - MessageContent::Text(text) => Some(text.as_str()), - MessageContent::Thinking { text, .. } => Some(text.as_str()), - MessageContent::RedactedThinking(_) => None, - MessageContent::ToolResult(tool_result) => tool_result.content.to_str(), - MessageContent::ToolUse(_) | MessageContent::Image(_) => None, - } - } - pub fn is_empty(&self) -> bool { match self { MessageContent::Text(text) => text.chars().all(|c| c.is_whitespace()), MessageContent::Thinking { text, .. } => text.chars().all(|c| c.is_whitespace()), - MessageContent::ToolResult(tool_result) => tool_result.content.is_empty(), + MessageContent::ToolResult(tool_result) => tool_result.is_content_empty(), MessageContent::RedactedThinking(_) | MessageContent::ToolUse(_) | MessageContent::Image(_) => false, @@ -277,8 +329,25 @@ pub struct LanguageModelRequestMessage { impl LanguageModelRequestMessage { pub fn string_contents(&self) -> String { let mut buffer = String::new(); - for string in self.content.iter().filter_map(|content| content.to_str()) { - buffer.push_str(string); + for content in &self.content { + match content { + MessageContent::Text(text) => { + buffer.push_str(text); + } + MessageContent::Thinking { text, .. } => { + buffer.push_str(text); + } + MessageContent::ToolResult(tool_result) => { + for part in &tool_result.content { + if let LanguageModelToolResultContent::Text(text) = part { + buffer.push_str(text); + } + } + } + MessageContent::RedactedThinking(_) + | MessageContent::ToolUse(_) + | MessageContent::Image(_) => {} + } } buffer } @@ -462,4 +531,90 @@ mod tests { _ => panic!("Expected Image variant"), } } + + #[test] + fn test_language_model_tool_result_content_vec_deserialization() { + // Legacy single-value shape is normalized to a Vec. + let json = serde_json::json!({ + "tool_use_id": "abc", + "tool_name": "echo", + "is_error": false, + "content": "hello", + "output": null, + }); + let result: LanguageModelToolResult = serde_json::from_value(json).unwrap(); + assert_eq!( + result.content, + vec![LanguageModelToolResultContent::Text(Arc::from("hello"))] + ); + + // Legacy wrapped single-value shape also works. + let json = serde_json::json!({ + "tool_use_id": "abc", + "tool_name": "echo", + "is_error": false, + "content": {"type": "text", "text": "hello"}, + "output": null, + }); + let result: LanguageModelToolResult = serde_json::from_value(json).unwrap(); + assert_eq!( + result.content, + vec![LanguageModelToolResultContent::Text(Arc::from("hello"))] + ); + + // New array shape with text + image deserializes into a Vec. + let json = serde_json::json!({ + "tool_use_id": "abc", + "tool_name": "echo", + "is_error": false, + "content": [ + {"type": "text", "text": "foo"}, + {"source": "data", "size": {"width": 1, "height": 2}} + ], + "output": null, + }); + let result: LanguageModelToolResult = serde_json::from_value(json).unwrap(); + assert_eq!(result.content.len(), 2); + assert_eq!( + result.content[0], + LanguageModelToolResultContent::Text(Arc::from("foo")) + ); + match &result.content[1] { + LanguageModelToolResultContent::Image(image) => { + assert_eq!(image.source.as_ref(), "data"); + } + _ => panic!("Expected Image variant"), + } + + // Round-tripping preserves multi-part content. + let roundtripped: LanguageModelToolResult = + serde_json::from_value(serde_json::to_value(&result).unwrap()).unwrap(); + assert_eq!(roundtripped, result); + } + + #[test] + fn test_string_contents_includes_all_tool_result_text_parts() { + let tool_result = LanguageModelToolResult { + tool_use_id: LanguageModelToolUseId::from("id".to_string()), + tool_name: Arc::from("tool"), + is_error: false, + content: vec![ + LanguageModelToolResultContent::Text(Arc::from("first ")), + LanguageModelToolResultContent::Image(LanguageModelImage::empty()), + LanguageModelToolResultContent::Text(Arc::from("second")), + ], + output: None, + }; + let message = LanguageModelRequestMessage { + role: Role::User, + content: vec![ + MessageContent::Text("prefix ".to_string()), + MessageContent::ToolResult(tool_result), + MessageContent::Text(" suffix".to_string()), + ], + cache: false, + reasoning_details: None, + }; + assert_eq!(message.string_contents(), "prefix first second suffix"); + } } diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 1a8d477192e..fb48e7d73a2 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -918,9 +918,10 @@ pub fn into_bedrock( } MessageContent::ToolResult(tool_result) => { messages_contain_tool_content = true; - BedrockToolResultBlock::builder() - .tool_use_id(tool_result.tool_use_id.to_string()) - .content(match tool_result.content { + let mut builder = BedrockToolResultBlock::builder() + .tool_use_id(tool_result.tool_use_id.to_string()); + for part in tool_result.content { + let block = match part { LanguageModelToolResultContent::Text(text) => { BedrockToolResultContentBlock::Text(text.to_string()) } @@ -961,7 +962,10 @@ pub fn into_bedrock( } } } - }) + }; + builder = builder.content(block); + } + builder .status({ if tool_result.is_error { BedrockToolResultStatus::Error diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index db50f5161e3..1fc1dc3ce4a 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -868,23 +868,40 @@ fn into_copilot_chat( Role::User => { for content in &message.content { if let MessageContent::ToolResult(tool_result) = content { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string().into(), - LanguageModelToolResultContent::Image(image) => { - if model.supports_vision() { - ChatMessageContent::Multipart(vec![ChatMessagePart::Image { - image_url: ImageUrl { - url: image.to_base64_url(), - }, - }]) - } else { - debug_panic!( - "This should be caught at {} level", - tool_result.tool_name - ); - "[Tool responded with an image, but this model does not support vision]".to_string().into() + let parts: Vec = tool_result + .content + .iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => { + ChatMessagePart::Text { + text: text.to_string(), + } } + LanguageModelToolResultContent::Image(image) => { + if model.supports_vision() { + ChatMessagePart::Image { + image_url: ImageUrl { + url: image.to_base64_url(), + }, + } + } else { + debug_panic!( + "This should be caught at {} level", + tool_result.tool_name + ); + ChatMessagePart::Text { + text: "[Tool responded with an image, but this model does not support vision]".to_string(), + } + } + } + }) + .collect(); + + let content = match parts.as_slice() { + [ChatMessagePart::Text { text }] => { + ChatMessageContent::Plain(text.clone()) } + _ => ChatMessageContent::Multipart(parts), }; messages.push(ChatMessage::Tool { @@ -1088,27 +1105,39 @@ fn into_copilot_responses( Role::User => { for content in &message.content { if let MessageContent::ToolResult(tool_result) = content { - let output = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { + let output = match tool_result.content.as_slice() { + [LanguageModelToolResultContent::Text(text)] => { responses::ResponseFunctionOutput::Text(text.to_string()) } - LanguageModelToolResultContent::Image(image) => { - if model.supports_vision() { - responses::ResponseFunctionOutput::Content(vec![ - responses::ResponseInputContent::InputImage { - image_url: Some(image.to_base64_url()), - detail: Default::default(), - }, - ]) - } else { - debug_panic!( - "This should be caught at {} level", - tool_result.tool_name - ); - responses::ResponseFunctionOutput::Text( - "[Tool responded with an image, but this model does not support vision]".into(), - ) - } + _ => { + let parts = tool_result + .content + .iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => { + responses::ResponseInputContent::InputText { + text: text.to_string(), + } + } + LanguageModelToolResultContent::Image(image) => { + if model.supports_vision() { + responses::ResponseInputContent::InputImage { + image_url: Some(image.to_base64_url()), + detail: Default::default(), + } + } else { + debug_panic!( + "This should be caught at {} level", + tool_result.tool_name + ); + responses::ResponseInputContent::InputText { + text: "[Tool responded with an image, but this model does not support vision]".to_string(), + } + } + } + }) + .collect(); + responses::ResponseFunctionOutput::Content(parts) } }; diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index dfc8521154e..a08cc25c7b5 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -378,15 +378,26 @@ pub fn into_deepseek( } } MessageContent::ToolResult(tool_result) => { - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - messages.push(deepseek::RequestMessage::Tool { - content: text.to_string(), - tool_call_id: tool_result.tool_use_id.to_string(), - }); + let mut text_parts: Vec = Vec::new(); + for part in &tool_result.content { + match part { + LanguageModelToolResultContent::Text(text) => { + text_parts.push(text.to_string()); + } + LanguageModelToolResultContent::Image(_) => { + text_parts.push("[Tool responded with an image]".to_string()); + } } - LanguageModelToolResultContent::Image(_) => {} + } + let content = if text_parts.is_empty() { + "".to_string() + } else { + text_parts.join("\n") }; + messages.push(deepseek::RequestMessage::Tool { + content, + tool_call_id: tool_result.tool_use_id.to_string(), + }); } } } diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index f035e765f07..50ac1286524 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -380,21 +380,25 @@ impl LmStudioLanguageModel { } } MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![lmstudio::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![lmstudio::MessagePart::Image { - image_url: lmstudio::ImageUrl { - url: image.to_base64_url(), - detail: None, - }, - }] - } - }; + let content: Vec = tool_result + .content + .iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => { + lmstudio::MessagePart::Text { + text: text.to_string(), + } + } + LanguageModelToolResultContent::Image(image) => { + lmstudio::MessagePart::Image { + image_url: lmstudio::ImageUrl { + url: image.to_base64_url(), + detail: None, + }, + } + } + }) + .collect(); messages.push(lmstudio::ChatMessage::Tool { content: content.into(), diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index cce5448b993..403d94e9832 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -390,14 +390,19 @@ pub fn into_mistral( // Tool use is not supported in User messages for Mistral } MessageContent::ToolResult(tool_result) => { - let tool_content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => text.to_string(), - LanguageModelToolResultContent::Image(_) => { - "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string() + let mut text_parts: Vec = Vec::new(); + for part in &tool_result.content { + match part { + LanguageModelToolResultContent::Text(text) => { + text_parts.push(text.to_string()); + } + LanguageModelToolResultContent::Image(_) => { + text_parts.push("[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()); + } } - }; + } messages.push(mistral::RequestMessage::Tool { - content: tool_content, + content: text_parts.join("\n"), tool_call_id: tool_result.tool_use_id.to_string(), }); } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 229b59e2bfd..f38321b7c88 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -363,7 +363,7 @@ impl OllamaLanguageModel { MessageContent::ToolResult(tool_result) => { messages.push(ChatMessage::Tool { tool_name: tool_result.tool_name.to_string(), - content: tool_result.content.to_str().unwrap_or("").to_string(), + content: tool_result.text_contents(), }) } _ => unreachable!("Only tool result should be extracted"), diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 6562d9de085..bc4fbcc9aa7 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -465,18 +465,22 @@ pub fn into_open_router( } } MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![open_router::MessagePart::Text { - text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![open_router::MessagePart::Image { - image_url: image.to_base64_url(), - }] - } - }; + let content: Vec = tool_result + .content + .iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => { + open_router::MessagePart::Text { + text: text.to_string(), + } + } + LanguageModelToolResultContent::Image(image) => { + open_router::MessagePart::Image { + image_url: image.to_base64_url(), + } + } + }) + .collect(); messages.push(open_router::RequestMessage::Tool { content: content.into(), diff --git a/crates/open_ai/src/completion.rs b/crates/open_ai/src/completion.rs index 3068f57f582..4abc752c4d5 100644 --- a/crates/open_ai/src/completion.rs +++ b/crates/open_ai/src/completion.rs @@ -104,21 +104,21 @@ pub fn into_open_ai( } } MessageContent::ToolResult(tool_result) => { - let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - vec![MessagePart::Text { + let content: Vec = tool_result + .content + .iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => MessagePart::Text { text: text.to_string(), - }] - } - LanguageModelToolResultContent::Image(image) => { - vec![MessagePart::Image { + }, + LanguageModelToolResultContent::Image(image) => MessagePart::Image { image_url: ImageUrl { url: image.to_base64_url(), detail: None, }, - }] - } - }; + }, + }) + .collect(); messages.push(crate::RequestMessage::Tool { content: content.into(), @@ -270,21 +270,34 @@ fn append_message_to_response_items( } MessageContent::ToolResult(tool_result) => { flush_response_parts(&message.role, index, &mut content_parts, input_items); + let output = match tool_result.content.as_slice() { + [LanguageModelToolResultContent::Text(text)] => { + ResponseFunctionCallOutputContent::Text(text.to_string()) + } + _ => { + let parts = tool_result + .content + .into_iter() + .map(|part| match part { + LanguageModelToolResultContent::Text(text) => { + ResponseInputContent::Text { + text: text.to_string(), + } + } + LanguageModelToolResultContent::Image(image) => { + ResponseInputContent::Image { + image_url: image.to_base64_url(), + } + } + }) + .collect(); + ResponseFunctionCallOutputContent::List(parts) + } + }; input_items.push(ResponseInputItem::FunctionCallOutput( ResponseFunctionCallOutputItem { call_id: tool_result.tool_use_id.to_string(), - output: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ResponseFunctionCallOutputContent::Text(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ResponseFunctionCallOutputContent::List(vec![ - ResponseInputContent::Image { - image_url: image.to_base64_url(), - }, - ]) - } - }, + output, }, )); } @@ -933,7 +946,7 @@ mod tests { tool_use_id: tool_call_id, tool_name: Arc::from("get_weather"), is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from("Sunny")), + content: vec![LanguageModelToolResultContent::Text(Arc::from("Sunny"))], output: Some(json!({ "forecast": "Sunny" })), }; let user_image = LanguageModelImage { @@ -1634,7 +1647,7 @@ mod tests { tool_use_id: tool_use_id, tool_name: Arc::from("search"), is_error: false, - content: LanguageModelToolResultContent::Text(Arc::from("result")), + content: vec![LanguageModelToolResultContent::Text(Arc::from("result"))], output: None, }; let request = LanguageModelRequest {