feat(gemini-cli): use stream-json output and re-use session (#7118)

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
Rabi Mishra 2026-02-13 07:22:15 +05:30 committed by GitHub
parent ae98e503ff
commit 860c7d7b97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,9 +1,10 @@
use anyhow::Result;
use async_trait::async_trait;
use serde_json::json;
use serde_json::{json, Value};
use std::path::PathBuf;
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, BufReader};
use std::sync::OnceLock;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use tokio::process::Command;
use super::base::{Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage};
@ -30,12 +31,37 @@ pub const GEMINI_CLI_KNOWN_MODELS: &[&str] = &[
pub const GEMINI_CLI_DOC_URL: &str = "https://ai.google.dev/gemini-api/docs";
fn extract_usage_from_stats(stats: &Value) -> Usage {
let get = |key: &str| {
stats
.get(key)
.and_then(|v| v.as_i64())
.and_then(|v| i32::try_from(v).ok())
};
Usage::new(
get("input_tokens"),
get("output_tokens"),
get("total_tokens"),
)
}
fn error_from_event(parsed: &Value) -> ProviderError {
let error_msg = parsed
.get("error")
.and_then(|e| e.as_str())
.or_else(|| parsed.get("message").and_then(|m| m.as_str()))
.unwrap_or("Unknown error");
ProviderError::RequestFailed(format!("Gemini CLI error: {error_msg}"))
}
#[derive(Debug, serde::Serialize)]
pub struct GeminiCliProvider {
command: PathBuf,
model: ModelConfig,
#[serde(skip)]
name: String,
#[serde(skip)]
cli_session_id: OnceLock<String>,
}
impl GeminiCliProvider {
@ -48,154 +74,35 @@ impl GeminiCliProvider {
command: resolved_command,
model,
name: GEMINI_CLI_PROVIDER_NAME.to_string(),
cli_session_id: OnceLock::new(),
})
}
/// Execute gemini CLI command with simple text prompt
async fn execute_command(
&self,
system: &str,
messages: &[Message],
_tools: &[Tool],
) -> Result<Vec<String>, ProviderError> {
// Create a simple prompt combining system + conversation
let mut full_prompt = String::new();
let filtered_system = filter_extensions_from_system_prompt(system);
full_prompt.push_str(&filtered_system);
full_prompt.push_str("\n\n");
// Add conversation history
for message in messages.iter().filter(|m| m.is_agent_visible()) {
let role_prefix = match message.role {
Role::User => "Human: ",
Role::Assistant => "Assistant: ",
};
full_prompt.push_str(role_prefix);
for content in &message.content {
if let MessageContent::Text(text_content) = content {
full_prompt.push_str(&text_content.text);
full_prompt.push('\n');
}
}
full_prompt.push('\n');
}
full_prompt.push_str("Assistant: ");
if std::env::var("GOOSE_GEMINI_CLI_DEBUG").is_ok() {
println!("=== GEMINI CLI PROVIDER DEBUG ===");
println!("Command: {:?}", self.command);
println!("Full prompt: {}", full_prompt);
println!("================================");
}
let mut cmd = Command::new(&self.command);
configure_subprocess(&mut cmd);
if let Ok(path) = SearchPaths::builder().with_npm().path() {
cmd.env("PATH", path);
}
// Only pass model parameter if it's in the known models list
if GEMINI_CLI_KNOWN_MODELS.contains(&self.model.model_name.as_str()) {
cmd.arg("-m").arg(&self.model.model_name);
}
if cfg!(windows) {
let sanitized_prompt = full_prompt.replace("\r\n", "\\n").replace('\n', "\\n");
cmd.arg("-p").arg(&sanitized_prompt).arg("--yolo");
} else {
cmd.arg("-p").arg(&full_prompt).arg("--yolo");
}
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
let mut child = cmd.spawn().map_err(|e| {
ProviderError::RequestFailed(format!(
"Failed to spawn Gemini CLI command '{:?}': {}. \
Make sure the Gemini CLI is installed and available in the configured search paths.",
self.command, e
))
})?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?;
let mut reader = BufReader::new(stdout);
let mut lines = Vec::new();
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break, // EOF
Ok(_) => {
let trimmed = line.trim();
if !trimmed.is_empty() && !trimmed.starts_with("Loaded cached credentials") {
lines.push(trimmed.to_string());
}
}
Err(e) => {
return Err(ProviderError::RequestFailed(format!(
"Failed to read output: {}",
e
)));
}
}
}
let exit_status = child.wait().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to wait for command: {}", e))
})?;
if !exit_status.success() {
return Err(ProviderError::RequestFailed(format!(
"Command failed with exit code: {:?}",
exit_status.code()
)));
}
tracing::debug!(
"Gemini CLI executed successfully, got {} lines",
lines.len()
);
Ok(lines)
fn session_id(&self) -> Option<&str> {
self.cli_session_id.get().map(|s| s.as_str())
}
/// Parse simple text response
fn parse_response(&self, lines: &[String]) -> Result<(Message, Usage), ProviderError> {
// Join all lines into a single response
let response_text = lines.join("\n");
if response_text.trim().is_empty() {
return Err(ProviderError::RequestFailed(
"Empty response from gemini command".to_string(),
));
}
let message = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::text(response_text)],
);
let usage = Usage::default(); // No usage info available for gemini CLI
Ok((message, usage))
fn set_session_id(&self, sid: String) {
let _ = self.cli_session_id.set(sid);
}
fn last_user_message_text(messages: &[Message]) -> String {
messages
.iter()
.rev()
.find(|m| m.role == Role::User)
.map(|m| m.as_concat_text())
.unwrap_or_default()
}
fn is_session_description_request(system: &str) -> bool {
system.contains("four words or less") || system.contains("4 words or less")
}
/// Generate a simple session description without calling subprocess
fn generate_simple_session_description(
&self,
messages: &[Message],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Extract the first user message text
let description = messages
.iter()
.find(|m| m.role == Role::User)
@ -206,7 +113,6 @@ impl GeminiCliProvider {
})
})
.map(|text| {
// Take first few words, limit to 4 words
text.split_whitespace()
.take(4)
.collect::<Vec<_>>()
@ -214,25 +120,210 @@ impl GeminiCliProvider {
})
.unwrap_or_else(|| "Simple task".to_string());
if std::env::var("GOOSE_GEMINI_CLI_DEBUG").is_ok() {
println!("=== GEMINI CLI PROVIDER DEBUG ===");
println!("Generated simple session description: {}", description);
println!("Skipped subprocess call for session description");
println!("================================");
tracing::debug!(
description = %description,
"Generated simple session description, skipped subprocess"
);
let message = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::text(description)],
);
Ok((
message,
ProviderUsage::new(self.model.model_name.clone(), Usage::default()),
))
}
/// Build the prompt for the CLI invocation. When resuming a session the CLI
/// maintains conversation context internally, so only the latest user
/// message is needed. On the first turn (no session yet) the system prompt
/// is prepended — there is typically only one user message at that point.
fn build_prompt(&self, system: &str, messages: &[Message]) -> String {
let user_text = Self::last_user_message_text(messages);
if self.session_id().is_some() {
user_text
} else {
let filtered_system = filter_extensions_from_system_prompt(system);
if filtered_system.is_empty() {
user_text
} else {
format!("{filtered_system}\n\n{user_text}")
}
}
}
fn build_command(&self, prompt: &str, model_name: &str) -> Command {
let mut cmd = Command::new(&self.command);
configure_subprocess(&mut cmd);
if let Ok(path) = SearchPaths::builder().with_npm().path() {
cmd.env("PATH", path);
}
cmd.arg("-m").arg(model_name);
if let Some(sid) = self.session_id() {
cmd.arg("-r").arg(sid);
}
cmd.arg("-p")
.arg(prompt)
.arg("--output-format")
.arg("stream-json")
.arg("--yolo");
cmd.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
cmd
}
async fn execute_command(
&self,
system: &str,
messages: &[Message],
_tools: &[Tool],
model_name: &str,
) -> Result<Vec<Value>, ProviderError> {
let prompt = self.build_prompt(system, messages);
tracing::debug!(command = ?self.command, "Executing Gemini CLI command");
let mut cmd = self.build_command(&prompt, model_name);
let mut child = cmd.kill_on_drop(true).spawn().map_err(|e| {
ProviderError::RequestFailed(format!(
"Failed to spawn Gemini CLI command '{}': {e}. \
Make sure the Gemini CLI is installed and available in the configured search paths.",
self.command.display()
))
})?;
let stdout = child
.stdout
.take()
.ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?;
// Drain stderr concurrently to avoid pipe deadlock
let stderr_task = tokio::spawn(async move {
let mut buf = String::new();
if let Some(mut stderr) = child.stderr.take() {
let _ = stderr.read_to_string(&mut buf).await;
}
(child, buf)
});
let mut reader = BufReader::new(stdout);
let mut events = Vec::new();
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break,
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
match serde_json::from_str::<Value>(trimmed) {
Ok(parsed) => {
if parsed.get("type").and_then(|t| t.as_str()) == Some("init") {
if let Some(sid) = parsed.get("session_id").and_then(|s| s.as_str())
{
self.set_session_id(sid.to_string());
}
}
events.push(parsed);
}
Err(_) => {
tracing::warn!(line = trimmed, "Non-JSON line in stream-json output");
}
}
}
Err(e) => {
return Err(ProviderError::RequestFailed(format!(
"Failed to read output: {e}"
)));
}
}
}
let (mut child, stderr_text) = stderr_task
.await
.map_err(|e| ProviderError::RequestFailed(format!("Failed to read stderr: {e}")))?;
let exit_status = child.wait().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to wait for command: {e}"))
})?;
if !exit_status.success() {
let stderr_snippet = stderr_text.trim();
let detail = if stderr_snippet.is_empty() {
format!("exit code {:?}", exit_status.code())
} else {
format!("exit code {:?}: {stderr_snippet}", exit_status.code())
};
return Err(ProviderError::RequestFailed(format!(
"Gemini CLI command failed ({detail})"
)));
}
tracing::debug!(
"Gemini CLI executed successfully, got {} events",
events.len()
);
Ok(events)
}
fn parse_stream_json_response(events: &[Value]) -> Result<(Message, Usage), ProviderError> {
let mut all_text_content = Vec::new();
let mut usage = Usage::default();
for parsed in events {
match parsed.get("type").and_then(|t| t.as_str()) {
Some("message") => {
if parsed.get("role").and_then(|r| r.as_str()) == Some("assistant") {
if let Some(content) = parsed.get("content").and_then(|c| c.as_str()) {
if !content.is_empty() {
all_text_content.push(content.to_string());
}
}
}
}
Some("result") => {
if let Some(stats) = parsed.get("stats") {
usage = extract_usage_from_stats(stats);
}
}
Some("error") => {
return Err(error_from_event(parsed));
}
_ => {}
}
}
let combined_text = all_text_content.join("");
if combined_text.is_empty() {
return Err(ProviderError::RequestFailed(
"No text content found in response".to_string(),
));
}
let message = Message::new(
Role::Assistant,
chrono::Utc::now().timestamp(),
vec![MessageContent::text(description.clone())],
vec![MessageContent::text(combined_text)],
);
let usage = Usage::default();
Ok((
message,
ProviderUsage::new(self.model.model_name.clone(), usage),
))
Ok((message, usage))
}
}
@ -267,7 +358,6 @@ impl Provider for GeminiCliProvider {
}
fn get_model_config(&self) -> ModelConfig {
// Return the model config with appropriate context limit for Gemini models
self.model.clone()
}
@ -279,50 +369,117 @@ impl Provider for GeminiCliProvider {
}
#[tracing::instrument(
skip(self, _model_config, system, messages, tools),
skip(self, model_config, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete_with_model(
&self,
_session_id: Option<&str>, // CLI has no external session-id flag to propagate.
_model_config: &ModelConfig,
_session_id: Option<&str>,
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
// Check if this is a session description request (short system prompt asking for 4 words or less)
if system.contains("four words or less") || system.contains("4 words or less") {
if Self::is_session_description_request(system) {
return self.generate_simple_session_description(messages);
}
// Create a dummy payload for debug tracing
let payload = json!({
"command": self.command,
"model": self.model.model_name,
"model": model_config.model_name,
"system": system,
"messages": messages.len()
});
let mut log = RequestLog::start(&self.model, &payload).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to start request log: {}", e))
let mut log = RequestLog::start(model_config, &payload).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to start request log: {e}"))
})?;
let lines = self.execute_command(system, messages, tools).await?;
let (message, usage) = self.parse_response(&lines)?;
let events = self
.execute_command(system, messages, tools, &model_config.model_name)
.await?;
let (message, usage) = Self::parse_stream_json_response(&events)?;
let response = json!({
"lines": lines.len(),
"events": events.len(),
"usage": usage
});
log.write(&response, Some(&usage)).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to write request log: {}", e))
ProviderError::RequestFailed(format!("Failed to write request log: {e}"))
})?;
Ok((
message,
ProviderUsage::new(self.model.model_name.clone(), usage),
ProviderUsage::new(model_config.model_name.clone(), usage),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_provider() -> GeminiCliProvider {
GeminiCliProvider {
command: PathBuf::from("gemini"),
model: ModelConfig::new("gemini-2.5-pro").unwrap(),
name: "gemini-cli".to_string(),
cli_session_id: OnceLock::new(),
}
}
#[test]
fn test_parse_stream_json_response() {
let events = vec![
json!({"type":"init","session_id":"abc","model":"gemini-2.5-pro"}),
json!({"type":"message","role":"user","content":"Hi"}),
json!({"type":"message","role":"assistant","content":"Hello ","delta":true}),
json!({"type":"message","role":"assistant","content":"there!","delta":true}),
json!({"type":"result","status":"success","stats":{"input_tokens":20,"output_tokens":5,"total_tokens":25}}),
];
let (message, usage) = GeminiCliProvider::parse_stream_json_response(&events).unwrap();
assert_eq!(message.role, Role::Assistant);
assert_eq!(message.as_concat_text(), "Hello there!");
assert_eq!(usage.input_tokens, Some(20));
assert_eq!(usage.output_tokens, Some(5));
let error_events = vec![
json!({"type":"init","session_id":"abc"}),
json!({"type":"error","error":"Rate limit exceeded"}),
];
let err = GeminiCliProvider::parse_stream_json_response(&error_events).unwrap_err();
assert!(err.to_string().contains("Rate limit exceeded"));
let empty: Vec<Value> = vec![];
assert!(GeminiCliProvider::parse_stream_json_response(&empty).is_err());
}
#[test]
fn test_build_prompt_first_and_resume() {
let provider = make_provider();
let messages = vec![Message::new(
Role::User,
0,
vec![MessageContent::text("Hello")],
)];
let prompt = provider.build_prompt("You are helpful.", &messages);
assert!(prompt.contains("You are helpful."));
assert!(prompt.contains("Hello"));
provider.set_session_id("session-123".to_string());
let messages = vec![
Message::new(Role::User, 0, vec![MessageContent::text("Hello")]),
Message::new(Role::Assistant, 0, vec![MessageContent::text("Hi!")]),
Message::new(
Role::User,
0,
vec![MessageContent::text("Follow up question")],
),
];
let prompt = provider.build_prompt("You are helpful.", &messages);
assert_eq!(prompt, "Follow up question");
}
}