From 860c7d7b972e41b97539dcf6a74f043c2c667b8e Mon Sep 17 00:00:00 2001 From: Rabi Mishra Date: Fri, 13 Feb 2026 07:22:15 +0530 Subject: [PATCH] feat(gemini-cli): use stream-json output and re-use session (#7118) Signed-off-by: rabi --- crates/goose/src/providers/gemini_cli.rs | 493 +++++++++++++++-------- 1 file changed, 325 insertions(+), 168 deletions(-) diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 7d8a28797d..8a4fdc77bc 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -1,9 +1,10 @@ use anyhow::Result; use async_trait::async_trait; -use serde_json::json; +use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; -use tokio::io::{AsyncBufReadExt, BufReader}; +use std::sync::OnceLock; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader}; use tokio::process::Command; use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; @@ -30,12 +31,37 @@ pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &[ pub const GEMINI_CLI_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs"; +fn extract_usage_from_stats(stats: &Value) -> Usage { + let get = |key: &str| { + stats + .get(key) + .and_then(|v| v.as_i64()) + .and_then(|v| i32::try_from(v).ok()) + }; + Usage::new( + get("input_tokens"), + get("output_tokens"), + get("total_tokens"), + ) +} + +fn error_from_event(parsed: &Value) -> ProviderError { + let error_msg = parsed + .get("error") + .and_then(|e| e.as_str()) + .or_else(|| parsed.get("message").and_then(|m| m.as_str())) + .unwrap_or("Unknown error"); + ProviderError::RequestFailed(format!("Gemini CLI error: {error_msg}")) +} + #[derive(Debug, serde::Serialize)] pub struct GeminiCliProvider { command: PathBuf, model: ModelConfig, #[serde(skip)] name: String, + #[serde(skip)] + cli_session_id: OnceLock, } impl GeminiCliProvider { @@ -48,154 +74,35 @@ impl GeminiCliProvider { command: resolved_command, model, name: GEMINI_CLI_PROVIDER_NAME.to_string(), + cli_session_id: OnceLock::new(), }) } - /// Execute gemini CLI command with simple text prompt - async fn execute_command( - &self, - system: &str, - messages: &[Message], - _tools: &[Tool], - ) -> Result, ProviderError> { - // Create a simple prompt combining system + conversation - let mut full_prompt = String::new(); - - let filtered_system = filter_extensions_from_system_prompt(system); - full_prompt.push_str(&filtered_system); - full_prompt.push_str("\n\n"); - - // Add conversation history - for message in messages.iter().filter(|m| m.is_agent_visible()) { - let role_prefix = match message.role { - Role::User => "Human: ", - Role::Assistant => "Assistant: ", - }; - full_prompt.push_str(role_prefix); - - for content in &message.content { - if let MessageContent::Text(text_content) = content { - full_prompt.push_str(&text_content.text); - full_prompt.push('\n'); - } - } - full_prompt.push('\n'); - } - - full_prompt.push_str("Assistant: "); - - if std::env::var("GOOSE_GEMINI_CLI_DEBUG").is_ok() { - println!("=== GEMINI CLI PROVIDER DEBUG ==="); - println!("Command: {:?}", self.command); - println!("Full prompt: {}", full_prompt); - println!("================================"); - } - - let mut cmd = Command::new(&self.command); - configure_subprocess(&mut cmd); - - if let Ok(path) = SearchPaths::builder().with_npm().path() { - cmd.env("PATH", path); - } - - // Only pass model parameter if it's in the known models list - if GEMINI_CLI_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { - cmd.arg("-m").arg(&self.model.model_name); - } - - if cfg!(windows) { - let sanitized_prompt = full_prompt.replace("\r\n", "\\n").replace('\n', "\\n"); - - cmd.arg("-p").arg(&sanitized_prompt).arg("--yolo"); - } else { - cmd.arg("-p").arg(&full_prompt).arg("--yolo"); - } - - cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); - - let mut child = cmd.spawn().map_err(|e| { - ProviderError::RequestFailed(format!( - "Failed to spawn Gemini CLI command '{:?}': {}. \ - Make sure the Gemini CLI is installed and available in the configured search paths.", - self.command, e - )) - })?; - - let stdout = child - .stdout - .take() - .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; - - let mut reader = BufReader::new(stdout); - let mut lines = Vec::new(); - let mut line = String::new(); - - loop { - line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, // EOF - Ok(_) => { - let trimmed = line.trim(); - if !trimmed.is_empty() && !trimmed.starts_with("Loaded cached credentials") { - lines.push(trimmed.to_string()); - } - } - Err(e) => { - return Err(ProviderError::RequestFailed(format!( - "Failed to read output: {}", - e - ))); - } - } - } - - let exit_status = child.wait().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to wait for command: {}", e)) - })?; - - if !exit_status.success() { - return Err(ProviderError::RequestFailed(format!( - "Command failed with exit code: {:?}", - exit_status.code() - ))); - } - - tracing::debug!( - "Gemini CLI executed successfully, got {} lines", - lines.len() - ); - - Ok(lines) + fn session_id(&self) -> Option<&str> { + self.cli_session_id.get().map(|s| s.as_str()) } - /// Parse simple text response - fn parse_response(&self, lines: &[String]) -> Result<(Message, Usage), ProviderError> { - // Join all lines into a single response - let response_text = lines.join("\n"); - - if response_text.trim().is_empty() { - return Err(ProviderError::RequestFailed( - "Empty response from gemini command".to_string(), - )); - } - - let message = Message::new( - Role::Assistant, - chrono::Utc::now().timestamp(), - vec![MessageContent::text(response_text)], - ); - - let usage = Usage::default(); // No usage info available for gemini CLI - - Ok((message, usage)) + fn set_session_id(&self, sid: String) { + let _ = self.cli_session_id.set(sid); + } + + fn last_user_message_text(messages: &[Message]) -> String { + messages + .iter() + .rev() + .find(|m| m.role == Role::User) + .map(|m| m.as_concat_text()) + .unwrap_or_default() + } + + fn is_session_description_request(system: &str) -> bool { + system.contains("four words or less") || system.contains("4 words or less") } - /// Generate a simple session description without calling subprocess fn generate_simple_session_description( &self, messages: &[Message], ) -> Result<(Message, ProviderUsage), ProviderError> { - // Extract the first user message text let description = messages .iter() .find(|m| m.role == Role::User) @@ -206,7 +113,6 @@ impl GeminiCliProvider { }) }) .map(|text| { - // Take first few words, limit to 4 words text.split_whitespace() .take(4) .collect::>() @@ -214,25 +120,210 @@ impl GeminiCliProvider { }) .unwrap_or_else(|| "Simple task".to_string()); - if std::env::var("GOOSE_GEMINI_CLI_DEBUG").is_ok() { - println!("=== GEMINI CLI PROVIDER DEBUG ==="); - println!("Generated simple session description: {}", description); - println!("Skipped subprocess call for session description"); - println!("================================"); + tracing::debug!( + description = %description, + "Generated simple session description, skipped subprocess" + ); + + let message = Message::new( + Role::Assistant, + chrono::Utc::now().timestamp(), + vec![MessageContent::text(description)], + ); + + Ok(( + message, + ProviderUsage::new(self.model.model_name.clone(), Usage::default()), + )) + } + + /// Build the prompt for the CLI invocation. When resuming a session the CLI + /// maintains conversation context internally, so only the latest user + /// message is needed. On the first turn (no session yet) the system prompt + /// is prepended — there is typically only one user message at that point. + fn build_prompt(&self, system: &str, messages: &[Message]) -> String { + let user_text = Self::last_user_message_text(messages); + + if self.session_id().is_some() { + user_text + } else { + let filtered_system = filter_extensions_from_system_prompt(system); + if filtered_system.is_empty() { + user_text + } else { + format!("{filtered_system}\n\n{user_text}") + } + } + } + + fn build_command(&self, prompt: &str, model_name: &str) -> Command { + let mut cmd = Command::new(&self.command); + configure_subprocess(&mut cmd); + + if let Ok(path) = SearchPaths::builder().with_npm().path() { + cmd.env("PATH", path); + } + + cmd.arg("-m").arg(model_name); + + if let Some(sid) = self.session_id() { + cmd.arg("-r").arg(sid); + } + + cmd.arg("-p") + .arg(prompt) + .arg("--output-format") + .arg("stream-json") + .arg("--yolo"); + + cmd.stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + cmd + } + + async fn execute_command( + &self, + system: &str, + messages: &[Message], + _tools: &[Tool], + model_name: &str, + ) -> Result, ProviderError> { + let prompt = self.build_prompt(system, messages); + + tracing::debug!(command = ?self.command, "Executing Gemini CLI command"); + + let mut cmd = self.build_command(&prompt, model_name); + + let mut child = cmd.kill_on_drop(true).spawn().map_err(|e| { + ProviderError::RequestFailed(format!( + "Failed to spawn Gemini CLI command '{}': {e}. \ + Make sure the Gemini CLI is installed and available in the configured search paths.", + self.command.display() + )) + })?; + + let stdout = child + .stdout + .take() + .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; + + // Drain stderr concurrently to avoid pipe deadlock + let stderr_task = tokio::spawn(async move { + let mut buf = String::new(); + if let Some(mut stderr) = child.stderr.take() { + let _ = stderr.read_to_string(&mut buf).await; + } + (child, buf) + }); + + let mut reader = BufReader::new(stdout); + let mut events = Vec::new(); + let mut line = String::new(); + + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + + match serde_json::from_str::(trimmed) { + Ok(parsed) => { + if parsed.get("type").and_then(|t| t.as_str()) == Some("init") { + if let Some(sid) = parsed.get("session_id").and_then(|s| s.as_str()) + { + self.set_session_id(sid.to_string()); + } + } + events.push(parsed); + } + Err(_) => { + tracing::warn!(line = trimmed, "Non-JSON line in stream-json output"); + } + } + } + Err(e) => { + return Err(ProviderError::RequestFailed(format!( + "Failed to read output: {e}" + ))); + } + } + } + + let (mut child, stderr_text) = stderr_task + .await + .map_err(|e| ProviderError::RequestFailed(format!("Failed to read stderr: {e}")))?; + + let exit_status = child.wait().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to wait for command: {e}")) + })?; + + if !exit_status.success() { + let stderr_snippet = stderr_text.trim(); + let detail = if stderr_snippet.is_empty() { + format!("exit code {:?}", exit_status.code()) + } else { + format!("exit code {:?}: {stderr_snippet}", exit_status.code()) + }; + return Err(ProviderError::RequestFailed(format!( + "Gemini CLI command failed ({detail})" + ))); + } + + tracing::debug!( + "Gemini CLI executed successfully, got {} events", + events.len() + ); + + Ok(events) + } + + fn parse_stream_json_response(events: &[Value]) -> Result<(Message, Usage), ProviderError> { + let mut all_text_content = Vec::new(); + let mut usage = Usage::default(); + + for parsed in events { + match parsed.get("type").and_then(|t| t.as_str()) { + Some("message") => { + if parsed.get("role").and_then(|r| r.as_str()) == Some("assistant") { + if let Some(content) = parsed.get("content").and_then(|c| c.as_str()) { + if !content.is_empty() { + all_text_content.push(content.to_string()); + } + } + } + } + Some("result") => { + if let Some(stats) = parsed.get("stats") { + usage = extract_usage_from_stats(stats); + } + } + Some("error") => { + return Err(error_from_event(parsed)); + } + _ => {} + } + } + + let combined_text = all_text_content.join(""); + if combined_text.is_empty() { + return Err(ProviderError::RequestFailed( + "No text content found in response".to_string(), + )); } let message = Message::new( Role::Assistant, chrono::Utc::now().timestamp(), - vec![MessageContent::text(description.clone())], + vec![MessageContent::text(combined_text)], ); - let usage = Usage::default(); - - Ok(( - message, - ProviderUsage::new(self.model.model_name.clone(), usage), - )) + Ok((message, usage)) } } @@ -267,7 +358,6 @@ impl Provider for GeminiCliProvider { } fn get_model_config(&self) -> ModelConfig { - // Return the model config with appropriate context limit for Gemini models self.model.clone() } @@ -279,50 +369,117 @@ impl Provider for GeminiCliProvider { } #[tracing::instrument( - skip(self, _model_config, system, messages, tools), + skip(self, model_config, system, messages, tools), fields(model_config, input, output, input_tokens, output_tokens, total_tokens) )] async fn complete_with_model( &self, - _session_id: Option<&str>, // CLI has no external session-id flag to propagate. - _model_config: &ModelConfig, + _session_id: Option<&str>, + model_config: &ModelConfig, system: &str, messages: &[Message], tools: &[Tool], ) -> Result<(Message, ProviderUsage), ProviderError> { - // Check if this is a session description request (short system prompt asking for 4 words or less) - if system.contains("four words or less") || system.contains("4 words or less") { + if Self::is_session_description_request(system) { return self.generate_simple_session_description(messages); } - // Create a dummy payload for debug tracing let payload = json!({ "command": self.command, - "model": self.model.model_name, + "model": model_config.model_name, "system": system, "messages": messages.len() }); - let mut log = RequestLog::start(&self.model, &payload).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to start request log: {}", e)) + let mut log = RequestLog::start(model_config, &payload).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to start request log: {e}")) })?; - let lines = self.execute_command(system, messages, tools).await?; - - let (message, usage) = self.parse_response(&lines)?; + let events = self + .execute_command(system, messages, tools, &model_config.model_name) + .await?; + let (message, usage) = Self::parse_stream_json_response(&events)?; let response = json!({ - "lines": lines.len(), + "events": events.len(), "usage": usage }); log.write(&response, Some(&usage)).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to write request log: {}", e)) + ProviderError::RequestFailed(format!("Failed to write request log: {e}")) })?; Ok(( message, - ProviderUsage::new(self.model.model_name.clone(), usage), + ProviderUsage::new(model_config.model_name.clone(), usage), )) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn make_provider() -> GeminiCliProvider { + GeminiCliProvider { + command: PathBuf::from("gemini"), + model: ModelConfig::new("gemini-2.5-pro").unwrap(), + name: "gemini-cli".to_string(), + cli_session_id: OnceLock::new(), + } + } + + #[test] + fn test_parse_stream_json_response() { + let events = vec![ + json!({"type":"init","session_id":"abc","model":"gemini-2.5-pro"}), + json!({"type":"message","role":"user","content":"Hi"}), + json!({"type":"message","role":"assistant","content":"Hello ","delta":true}), + json!({"type":"message","role":"assistant","content":"there!","delta":true}), + json!({"type":"result","status":"success","stats":{"input_tokens":20,"output_tokens":5,"total_tokens":25}}), + ]; + let (message, usage) = GeminiCliProvider::parse_stream_json_response(&events).unwrap(); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.as_concat_text(), "Hello there!"); + assert_eq!(usage.input_tokens, Some(20)); + assert_eq!(usage.output_tokens, Some(5)); + + let error_events = vec![ + json!({"type":"init","session_id":"abc"}), + json!({"type":"error","error":"Rate limit exceeded"}), + ]; + let err = GeminiCliProvider::parse_stream_json_response(&error_events).unwrap_err(); + assert!(err.to_string().contains("Rate limit exceeded")); + + let empty: Vec = vec![]; + assert!(GeminiCliProvider::parse_stream_json_response(&empty).is_err()); + } + + #[test] + fn test_build_prompt_first_and_resume() { + let provider = make_provider(); + let messages = vec![Message::new( + Role::User, + 0, + vec![MessageContent::text("Hello")], + )]; + + let prompt = provider.build_prompt("You are helpful.", &messages); + assert!(prompt.contains("You are helpful.")); + assert!(prompt.contains("Hello")); + + provider.set_session_id("session-123".to_string()); + let messages = vec![ + Message::new(Role::User, 0, vec![MessageContent::text("Hello")]), + Message::new(Role::Assistant, 0, vec![MessageContent::text("Hi!")]), + Message::new( + Role::User, + 0, + vec![MessageContent::text("Follow up question")], + ), + ]; + let prompt = provider.build_prompt("You are helpful.", &messages); + assert_eq!(prompt, "Follow up question"); + } +}