From 35d06eb236493dba4bd62aad368576b646c19f91 Mon Sep 17 00:00:00 2001 From: "zed-zippy[bot]" <234243425+zed-zippy[bot]@users.noreply.github.com> Date: Thu, 21 May 2026 10:43:39 +0000 Subject: [PATCH] google: Add Google thinking level support (#57358) (cherry-pick to preview) (#57377) Cherry-pick of #57358 to preview ---- Also makes sure we are properly catching and processing thinking events. Self-Review Checklist: - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - google: Support thinking levels for Google models. Co-authored-by: Ben Brandt --- crates/google_ai/src/completion.rs | 381 ++++++++++++++++-- crates/google_ai/src/google_ai.rs | 110 ++++- crates/language_models/src/provider/google.rs | 23 +- 3 files changed, 452 insertions(+), 62 deletions(-) diff --git a/crates/google_ai/src/completion.rs b/crates/google_ai/src/completion.rs index b96679cfbc1..862e9f08fb7 100644 --- a/crates/google_ai/src/completion.rs +++ b/crates/google_ai/src/completion.rs @@ -12,8 +12,8 @@ use std::sync::atomic::{self, AtomicU64}; use crate::{ Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration, GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode, - InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig, - UsageMetadata, + InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ThinkingLevel, + ToolConfig, UsageMetadata, }; pub fn into_google( @@ -27,19 +27,24 @@ pub fn into_google( .flat_map(|content| match content { MessageContent::Text(text) => { if !text.is_empty() { - vec![Part::TextPart(TextPart { text })] + vec![Part::TextPart(TextPart { + text, + thought: false, + thought_signature: None, + })] } else { vec![] } } MessageContent::Thinking { - text: _, + text, signature: Some(signature), } => { if !signature.is_empty() { - vec![Part::ThoughtPart(crate::ThoughtPart { + vec![Part::TextPart(TextPart { + text, thought: true, - thought_signature: signature, + thought_signature: Some(signature), })] } else { vec![] @@ -110,6 +115,8 @@ pub fn into_google( .collect() } + let thinking_config = thinking_config_for_request(&request, &model_id, mode); + let system_instructions = if request .messages .first() @@ -150,32 +157,7 @@ pub fn into_google( stop_sequences: Some(request.stop), max_output_tokens: None, temperature: request.temperature.map(|t| t as f64), - thinking_config: match (request.thinking_allowed, mode) { - (true, GoogleModelMode::Thinking { budget_tokens }) => { - let effort = request.thinking_effort.as_deref().map(|s| s.to_lowercase()); - let thinking_level = match effort.as_deref() { - Some("high") => Some(crate::ThinkingLevel::High), - Some("medium") => Some(crate::ThinkingLevel::Medium), - Some("low") => Some(crate::ThinkingLevel::Low), - Some("minimal") => Some(crate::ThinkingLevel::Minimal), - _ => None, - }; - - Some(ThinkingConfig { - thinking_budget: budget_tokens, - thinking_level, - }) - .filter( - |ThinkingConfig { - thinking_budget, - thinking_level, - }| { - thinking_level.is_some() || thinking_budget.is_some() - }, - ) - } - _ => None, - }, + thinking_config, top_p: None, top_k: None, }), @@ -206,6 +188,76 @@ pub fn into_google( } } +fn thinking_config_for_request( + request: &LanguageModelRequest, + model_id: &str, + mode: GoogleModelMode, +) -> Option { + let supports_thinking = + matches!(mode, GoogleModelMode::Thinking { .. }) || is_google_thinking_model(model_id); + if !supports_thinking { + return None; + } + + let mut config = ThinkingConfig::default(); + + if request.thinking_allowed { + config.include_thoughts = Some(true); + config.thinking_level = request + .thinking_effort + .as_deref() + .and_then(ThinkingLevel::from_effort); + + if config.thinking_level.is_none() + && let GoogleModelMode::Thinking { + budget_tokens: Some(budget_tokens), + } = mode + { + config.thinking_budget = Some(budget_tokens); + } + } else if let Some(thinking_level) = disabled_thinking_level(model_id) { + config.thinking_level = Some(thinking_level); + } else if supports_thinking_budget_disable(model_id) { + config.thinking_budget = Some(0); + } + + (!config.is_empty()).then_some(config) +} + +impl ThinkingConfig { + fn is_empty(&self) -> bool { + self.thinking_budget.is_none() + && self.thinking_level.is_none() + && self.include_thoughts.is_none() + } +} + +fn is_google_thinking_model(model_id: &str) -> bool { + model_id.starts_with("gemini-2.5-") || model_id.starts_with("gemini-3") +} + +fn disabled_thinking_level(model_id: &str) -> Option { + match model_id { + model_id if model_id.starts_with("gemini-3") && model_id.contains("-pro") => { + Some(ThinkingLevel::Low) + } + model_id if model_id.starts_with("gemini-3") => Some(ThinkingLevel::Minimal), + _ => None, + } +} + +fn supports_thinking_budget_disable(model_id: &str) -> bool { + matches!( + model_id, + "gemini-2.5-flash" + | "gemini-2.5-flash-lite" + | "gemini-2.5-flash-preview-latest" + | "gemini-2.5-flash-preview-04-17" + | "gemini-2.5-flash-preview-05-20" + | "gemini-2.5-flash-lite-preview-06-17" + ) +} + pub struct GoogleEventMapper { usage: UsageMetadata, stop_reason: StopReason, @@ -276,6 +328,23 @@ impl GoogleEventMapper { self.stop_reason = match finish_reason { "STOP" => StopReason::EndTurn, "MAX_TOKENS" => StopReason::MaxTokens, + "SAFETY" + | "RECITATION" + | "LANGUAGE" + | "OTHER" + | "BLOCKLIST" + | "PROHIBITED_CONTENT" + | "SPII" + | "MALFORMED_FUNCTION_CALL" + | "IMAGE_SAFETY" + | "IMAGE_PROHIBITED_CONTENT" + | "IMAGE_OTHER" + | "NO_IMAGE" + | "IMAGE_RECITATION" + | "UNEXPECTED_TOOL_CALL" + | "TOO_MANY_TOOL_CALLS" + | "MISSING_THOUGHT_SIGNATURE" + | "MALFORMED_RESPONSE" => StopReason::Refusal, _ => { log::error!("Unexpected google finish_reason: {finish_reason}"); StopReason::EndTurn @@ -288,7 +357,28 @@ impl GoogleEventMapper { .into_iter() .for_each(|part| match part { Part::TextPart(text_part) => { - events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) + let thought_signature = + text_part.thought_signature.filter(|s| !s.is_empty()); + if text_part.thought { + if !text_part.text.is_empty() || thought_signature.is_some() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: text_part.text, + signature: thought_signature, + })) + } + } else { + if let Some(thought_signature) = thought_signature { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text: String::new(), + signature: Some(thought_signature), + })); + } + if !text_part.text.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Text( + text_part.text, + ))); + } + } } Part::InlineDataPart(_) => {} Part::FunctionCallPart(function_call_part) => { @@ -320,12 +410,6 @@ impl GoogleEventMapper { ))); } Part::FunctionResponsePart(_) => {} - Part::ThoughtPart(part) => { - events.push(Ok(LanguageModelCompletionEvent::Thinking { - text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries? - signature: Some(part.thought_signature), - })); - } }); } } @@ -382,8 +466,227 @@ mod tests { Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, Part, Role as GoogleRole, }; + use language_model_core::LanguageModelRequestMessage; use serde_json::json; + fn text_request() -> LanguageModelRequest { + LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text("Hello".to_string())], + cache: false, + reasoning_details: None, + }], + ..Default::default() + } + } + + #[test] + fn into_google_requests_thought_summaries_and_thinking_level() { + let mut request = text_request(); + request.thinking_allowed = true; + request.thinking_effort = Some("low".to_string()); + + let request = into_google( + request, + "gemini-3.5-flash".to_string(), + GoogleModelMode::Thinking { + budget_tokens: None, + }, + ); + + let thinking_config = request.generation_config.unwrap().thinking_config.unwrap(); + assert_eq!(thinking_config.include_thoughts, Some(true)); + assert_eq!(thinking_config.thinking_level, Some(ThinkingLevel::Low)); + + let serialized = serde_json::to_value(thinking_config).unwrap(); + assert_eq!(serialized["thinkingLevel"], "LOW"); + assert_eq!(serialized["includeThoughts"], true); + } + + #[test] + fn into_google_turns_off_budget_thinking_when_supported() { + let mut request = text_request(); + request.thinking_allowed = false; + + let request = into_google( + request, + "gemini-2.5-flash".to_string(), + GoogleModelMode::Thinking { + budget_tokens: None, + }, + ); + + let thinking_config = request.generation_config.unwrap().thinking_config.unwrap(); + assert_eq!(thinking_config.thinking_budget, Some(0)); + assert_eq!(thinking_config.include_thoughts, None); + } + + #[test] + fn into_google_uses_minimal_level_when_gemini_3_flash_thinking_is_off() { + let mut request = text_request(); + request.thinking_allowed = false; + + let request = into_google( + request, + "gemini-3.5-flash".to_string(), + GoogleModelMode::Thinking { + budget_tokens: None, + }, + ); + + let thinking_config = request.generation_config.unwrap().thinking_config.unwrap(); + assert_eq!(thinking_config.thinking_level, Some(ThinkingLevel::Minimal)); + assert_eq!(thinking_config.include_thoughts, None); + } + + #[test] + fn into_google_replays_signed_thinking_as_thought_text_part() { + let request = LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::Thinking { + text: "summary".to_string(), + signature: Some("signature".to_string()), + }], + cache: false, + reasoning_details: None, + }], + ..Default::default() + }; + + let request = into_google( + request, + "gemini-3.5-flash".to_string(), + GoogleModelMode::Thinking { + budget_tokens: None, + }, + ); + + let Part::TextPart(text_part) = &request.contents[0].parts[0] else { + panic!("expected text part"); + }; + assert_eq!(text_part.text, "summary"); + assert!(text_part.thought); + assert_eq!(text_part.thought_signature.as_deref(), Some("signature")); + } + + #[test] + fn thought_text_part_deserializes_and_maps_to_thinking_event() { + let part: Part = serde_json::from_value(json!({ + "text": "checking the constraints", + "thought": true, + "thoughtSignature": "thought-signature" + })) + .unwrap(); + + let mut mapper = GoogleEventMapper::new(); + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![part], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 1); + assert!(matches!( + &events[0], + Ok(LanguageModelCompletionEvent::Thinking { text, signature }) + if text == "checking the constraints" + && signature.as_deref() == Some("thought-signature") + )); + } + + #[test] + fn signed_non_thought_text_part_preserves_signature() { + let part: Part = serde_json::from_value(json!({ + "text": "visible text", + "thoughtSignature": "visible-signature" + })) + .unwrap(); + + let Part::TextPart(text_part) = part else { + panic!("expected text part"); + }; + assert_eq!(text_part.text, "visible text"); + assert!(!text_part.thought); + assert_eq!( + text_part.thought_signature.as_deref(), + Some("visible-signature") + ); + } + + #[test] + fn signed_non_thought_text_part_maps_signature_carrier() { + let part: Part = serde_json::from_value(json!({ + "text": "visible text", + "thoughtSignature": "visible-signature" + })) + .unwrap(); + + let mut mapper = GoogleEventMapper::new(); + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![part], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + assert_eq!(events.len(), 2); + assert!(matches!( + &events[0], + Ok(LanguageModelCompletionEvent::Thinking { text, signature }) + if text.is_empty() && signature.as_deref() == Some("visible-signature") + )); + assert!(matches!( + &events[1], + Ok(LanguageModelCompletionEvent::Text(text)) if text == "visible text" + )); + } + + #[test] + fn safety_finish_reason_is_refusal() { + let mut mapper = GoogleEventMapper::new(); + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: Vec::new(), + role: GoogleRole::Model, + }, + finish_reason: Some("SAFETY".to_string()), + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + mapper.map_event(response); + assert_eq!(mapper.stop_reason, StopReason::Refusal); + } + #[test] fn test_function_call_with_signature_creates_tool_use_with_signature() { let mut mapper = GoogleEventMapper::new(); diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 305f43ab3a3..56ec48c83a8 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -166,17 +166,24 @@ pub enum Role { #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum Part { - TextPart(TextPart), - InlineDataPart(InlineDataPart), FunctionCallPart(FunctionCallPart), FunctionResponsePart(FunctionResponsePart), - ThoughtPart(ThoughtPart), + InlineDataPart(InlineDataPart), + TextPart(TextPart), +} + +fn is_false(value: &bool) -> bool { + !*value } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct TextPart { pub text: String, + #[serde(default, skip_serializing_if = "is_false")] + pub thought: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub thought_signature: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -208,13 +215,6 @@ pub struct FunctionResponsePart { pub function_response: FunctionResponse, } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ThoughtPart { - pub thought: bool, - pub thought_signature: String, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { @@ -261,16 +261,18 @@ pub struct UsageMetadata { pub total_token_count: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] #[serde(rename_all = "camelCase")] pub struct ThinkingConfig { #[serde(skip_serializing_if = "Option::is_none")] pub thinking_budget: Option, #[serde(skip_serializing_if = "Option::is_none")] pub thinking_level: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_thoughts: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "UPPERCASE")] pub enum ThinkingLevel { Minimal, @@ -279,6 +281,36 @@ pub enum ThinkingLevel { High, } +impl ThinkingLevel { + pub fn from_effort(effort: &str) -> Option { + match effort.to_lowercase().as_str() { + "minimal" => Some(Self::Minimal), + "low" => Some(Self::Low), + "medium" => Some(Self::Medium), + "high" => Some(Self::High), + _ => None, + } + } + + pub fn name(self) -> &'static str { + match self { + Self::Minimal => "Minimal", + Self::Low => "Low", + Self::Medium => "Medium", + Self::High => "High", + } + } + + pub fn value(self) -> &'static str { + match self { + Self::Minimal => "minimal", + Self::Low => "low", + Self::Medium => "medium", + Self::High => "high", + } + } +} + #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { @@ -579,6 +611,50 @@ impl Model { true } + pub fn supports_thinking(&self) -> bool { + matches!( + self, + Self::Gemini25FlashLite + | Self::Gemini25Flash + | Self::Gemini25Pro + | Self::Gemini31FlashLite + | Self::Gemini3Flash + | Self::Gemini35Flash + | Self::Gemini31Pro + | Self::Custom { + mode: GoogleModelMode::Thinking { .. }, + .. + } + ) + } + + pub fn supported_thinking_levels(&self) -> &'static [ThinkingLevel] { + match self { + Self::Gemini31FlashLite | Self::Gemini3Flash | Self::Gemini35Flash => &[ + ThinkingLevel::Minimal, + ThinkingLevel::Low, + ThinkingLevel::Medium, + ThinkingLevel::High, + ], + Self::Gemini31Pro => &[ + ThinkingLevel::Low, + ThinkingLevel::Medium, + ThinkingLevel::High, + ], + _ => &[], + } + } + + pub fn default_thinking_level(&self) -> Option { + match self { + Self::Gemini31FlashLite => Some(ThinkingLevel::Minimal), + Self::Gemini3Flash => Some(ThinkingLevel::High), + Self::Gemini35Flash => Some(ThinkingLevel::Medium), + Self::Gemini31Pro => Some(ThinkingLevel::High), + _ => None, + } + } + pub fn mode(&self) -> GoogleModelMode { match self { Self::Gemini25FlashLite | Self::Gemini25Flash | Self::Gemini25Pro => { @@ -588,12 +664,10 @@ impl Model { budget_tokens: None, } } - Self::Gemini3Flash => GoogleModelMode::Default, - Self::Gemini31FlashLite => GoogleModelMode::Default, - Self::Gemini35Flash => GoogleModelMode::Thinking { - budget_tokens: None, - }, - Self::Gemini31Pro => GoogleModelMode::Thinking { + Self::Gemini31FlashLite + | Self::Gemini3Flash + | Self::Gemini35Flash + | Self::Gemini31Pro => GoogleModelMode::Thinking { budget_tokens: None, }, Self::Custom { mode, .. } => *mode, diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index d5b47bf4583..774de74a6b2 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -2,8 +2,8 @@ use anyhow::{Context as _, Result}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use futures::{FutureExt, StreamExt, future::BoxFuture}; +use google_ai::GenerateContentResponse; pub use google_ai::completion::{GoogleEventMapper, into_google}; -use google_ai::{GenerateContentResponse, GoogleModelMode}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, TaskExt, Window}; use http_client::HttpClient; use language_model::{ @@ -11,9 +11,9 @@ use language_model::{ LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat, }; use language_model::{ - GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, RateLimiter, + GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelEffortLevel, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -300,7 +300,20 @@ impl LanguageModel for GoogleLanguageModel { } fn supports_thinking(&self) -> bool { - matches!(self.model.mode(), GoogleModelMode::Thinking { .. }) + self.model.supports_thinking() + } + + fn supported_effort_levels(&self) -> Vec { + let default_level = self.model.default_thinking_level(); + self.model + .supported_thinking_levels() + .iter() + .map(|level| LanguageModelEffortLevel { + name: level.name().into(), + value: level.value().into(), + is_default: Some(*level) == default_level, + }) + .collect() } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {