From 7b64ef1cbae8404cc9e040150d3423a300ce946d Mon Sep 17 00:00:00 2001 From: Douwe Osinga Date: Wed, 18 Mar 2026 04:26:12 -0400 Subject: [PATCH] Revert message flush & test (#7966) Co-authored-by: Douwe Osinga --- crates/goose/src/agents/agent.rs | 30 ++-- crates/goose/tests/agent.rs | 270 +++++++++++++++++++++++++++++++ 2 files changed, 284 insertions(+), 16 deletions(-) diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 67cb59b26a..222f148974 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1187,6 +1187,7 @@ impl Agent { ).await?; let mut no_tools_called = true; + let mut messages_to_add = Conversation::default(); let mut tools_updated = false; let mut did_recovery_compact_this_iteration = false; let mut exit_chat = false; @@ -1241,8 +1242,7 @@ impl Agent { if !text.is_empty() { last_assistant_text = text; } - session_manager.add_message(&session_config.id, &response).await?; - conversation.push(response); + messages_to_add.push(response); continue; } @@ -1438,8 +1438,7 @@ impl Agent { response.created, thinking_content, ).with_id(format!("msg_{}", Uuid::new_v4())); - session_manager.add_message(&session_config.id, &thinking_msg).await?; - conversation.push(thinking_msg); + messages_to_add.push(thinking_msg); } // Collect reasoning content to attach to tool request messages @@ -1467,14 +1466,11 @@ impl Agent { request.metadata.as_ref(), request.tool_meta.clone(), ); + messages_to_add.push(request_msg); let final_response = tool_response_messages[idx] .lock().await.clone(); - // Persist the tool request and response as a pair - session_manager.add_message(&session_config.id, &request_msg).await?; - session_manager.add_message(&session_config.id, &final_response).await?; - conversation.push(request_msg); - conversation.push(final_response.clone()); - yield AgentEvent::Message(final_response); + yield AgentEvent::Message(final_response.clone()); + messages_to_add.push(final_response); } else { error!( "Tool call could not be parsed: {}", @@ -1618,14 +1614,12 @@ impl Agent { Some(None) => { warn!("Final output tool has not been called yet. Continuing agent loop."); let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE); - session_manager.add_message(&session_config.id, &message).await?; - conversation.push(message.clone()); + messages_to_add.push(message.clone()); yield AgentEvent::Message(message); } Some(Some(output)) => { let message = Message::assistant().with_text(output); - session_manager.add_message(&session_config.id, &message).await?; - conversation.push(message.clone()); + messages_to_add.push(message.clone()); yield AgentEvent::Message(message); exit_chat = true; } @@ -1637,6 +1631,7 @@ impl Agent { Ok(should_retry) => { if should_retry { info!("Retry logic triggered, restarting agent loop"); + messages_to_add = Conversation::default(); session_manager.replace_conversation(&session_config.id, &conversation).await?; yield AgentEvent::HistoryReplaced(conversation.clone()); } else { @@ -1680,14 +1675,17 @@ impl Agent { }).await?; } conversation = Conversation::new_unvalidated(updated_messages); - session_manager.add_message(&session_config.id, &summary_msg).await?; - conversation.push(summary_msg); + messages_to_add.push(summary_msg); } else { warn!("Expected a tool request/reply pair, but found {} matching messages", matching.len()); } } + for msg in &messages_to_add { + session_manager.add_message(&session_config.id, msg).await?; + } + conversation.extend(messages_to_add); if exit_chat { break; } diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 515ab9b6e9..c08fe092c1 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -591,4 +591,274 @@ mod tests { ); } } + + #[cfg(test)] + mod streaming_persistence_tests { + use super::*; + use async_trait::async_trait; + use goose::agents::{AgentConfig, SessionConfig}; + use goose::config::permission::PermissionManager; + use goose::config::GooseMode; + use goose::conversation::message::Message; + use goose::model::ModelConfig; + use goose::providers::base::{ + MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, + }; + use goose::providers::errors::ProviderError; + use goose::session::session_manager::SessionType; + use goose::session::SessionManager; + use rmcp::model::{CallToolRequestParams, Role, Tool}; + use rmcp::object; + use std::path::PathBuf; + use std::sync::atomic::{AtomicUsize, Ordering}; + use tokio_util::sync::CancellationToken; + + struct MultiStepProvider { + call_count: AtomicUsize, + cancel_token: CancellationToken, + } + + impl MultiStepProvider { + fn new(cancel_token: CancellationToken) -> Self { + Self { + call_count: AtomicUsize::new(0), + cancel_token, + } + } + } + + impl ProviderDef for MultiStepProvider { + type Provider = Self; + + fn metadata() -> ProviderMetadata { + ProviderMetadata { + name: "multi-step-mock".to_string(), + display_name: "Multi-Step Mock".to_string(), + description: "Mock provider for streaming persistence tests".to_string(), + default_model: "mock-model".to_string(), + known_models: vec![], + model_doc_link: "".to_string(), + config_keys: vec![], + setup_steps: vec![], + } + } + + fn from_env( + _model: ModelConfig, + _extensions: Vec, + ) -> futures::future::BoxFuture<'static, anyhow::Result> { + unimplemented!() + } + } + + #[async_trait] + impl Provider for MultiStepProvider { + async fn stream( + &self, + _model_config: &ModelConfig, + _session_id: &str, + _system_prompt: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result { + let call = self.call_count.fetch_add(1, Ordering::SeqCst); + let usage = ProviderUsage::new( + "mock-model".to_string(), + Usage::new(Some(10), Some(5), Some(15)), + ); + + match call { + 0 => { + let tool_call = CallToolRequestParams::new("test_tool") + .with_arguments(object!({"param": "value"})); + let message = + Message::assistant().with_tool_request("call_1", Ok(tool_call)); + let stream = + futures::stream::once(async move { Ok((Some(message), Some(usage))) }); + Ok(Box::pin(stream)) + } + 1 => { + let msg_id = format!("msg_{}", uuid::Uuid::new_v4()); + let tokens = vec!["Hello", " world", ", how", " are", " you?"]; + let stream = futures::stream::iter(tokens.into_iter().enumerate().map( + move |(i, token)| { + let msg = Message::assistant() + .with_text(token) + .with_id(msg_id.clone()); + let u = if i == 4 { Some(usage.clone()) } else { None }; + Ok((Some(msg), u)) + }, + )); + Ok(Box::pin(stream)) + } + _ => { + let cancel = self.cancel_token.clone(); + let msg_id = format!("msg_{}", uuid::Uuid::new_v4()); + let tokens = vec!["This ", "should ", "be ", "cancelled ", "soon."]; + let stream = futures::stream::iter(tokens.into_iter().enumerate().map( + move |(i, token)| { + if i == 1 { + cancel.cancel(); + } + let msg = Message::assistant() + .with_text(token) + .with_id(msg_id.clone()); + let u = if i == 4 { Some(usage.clone()) } else { None }; + Ok((Some(msg), u)) + }, + )); + Ok(Box::pin(stream)) + } + } + } + + fn get_model_config(&self) -> ModelConfig { + ModelConfig::new("mock-model").unwrap() + } + + fn get_name(&self) -> &str { + "multi-step-mock" + } + } + + #[tokio::test] + async fn test_streaming_text_not_persisted_per_token() -> Result<()> { + let cancel_token = CancellationToken::new(); + let temp_dir = tempfile::tempdir()?; + let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf())); + let config = AgentConfig::new( + session_manager.clone(), + PermissionManager::instance(), + None, + GooseMode::Auto, + true, // disable session naming so it doesn't consume a provider call + GoosePlatform::GooseCli, + ); + let agent = Agent::with_config(config); + let provider = Arc::new(MultiStepProvider::new(cancel_token.clone())); + + let session = session_manager + .create_session( + PathBuf::default(), + "streaming-test".to_string(), + SessionType::Hidden, + GooseMode::default(), + ) + .await?; + + let session_id = session.id.clone(); + agent.update_provider(provider, &session_id).await?; + + // ── Single reply: tool call (call 0) → text stream (call 1) → cancelled text (call 2) + // max_turns=3 allows all three provider calls within one reply(). + // call 0: tool call → agent executes tool, loops + // call 1: 5 text deltas → no tools called, agent exits loop + // call 2: 5 text deltas, cancel token fired after 1st → agent interrupted + // + // Because call 1 ends the agent loop (no_tools_called=true → exit), + // call 2 is NOT reached in the same reply. We issue a second reply() + // with the cancel token so the provider triggers cancellation. + let session_config = SessionConfig { + id: session_id.clone(), + schedule_id: None, + max_turns: Some(2), + retry_config: None, + }; + + let reply_stream = agent + .reply( + Message::user().with_text("Do something then say hello"), + session_config, + None, + ) + .await?; + tokio::pin!(reply_stream); + + while let Some(event) = reply_stream.next().await { + match event { + Ok(AgentEvent::Message(_)) => {} + Ok(_) => {} + Err(e) => return Err(e), + } + } + + // ── Check persisted state after reply 1 ───────────────── + let reloaded = session_manager.get_session(&session_id, true).await?; + let messages = reloaded + .conversation + .expect("should have conversation") + .messages() + .to_vec(); + + let user_count = messages.iter().filter(|m| m.role == Role::User).count(); + let asst_count = messages + .iter() + .filter(|m| m.role == Role::Assistant) + .count(); + + // Expected: user(prompt) + assistant(tool-req) + user(tool-resp) + assistant(text) + assert_eq!( + user_count, 2, + "Expected 2 user messages (prompt + tool response), got {user_count}", + ); + assert_eq!( + asst_count, 2, + "Expected 2 assistant messages (tool request + text reply), got {asst_count} \ + — streaming text deltas are being persisted as separate messages", + ); + + // ── Reply 2: text stream with provider-triggered cancellation (call 2) + let session_config2 = SessionConfig { + id: session_id.clone(), + schedule_id: None, + max_turns: Some(2), + retry_config: None, + }; + + let reply_stream2 = agent + .reply( + Message::user().with_text("Tell me more"), + session_config2, + Some(cancel_token), + ) + .await?; + tokio::pin!(reply_stream2); + + while let Some(event) = reply_stream2.next().await { + match event { + Ok(_) => {} + Err(e) => return Err(e), + } + } + + // ── Check persisted state after cancellation ──────────── + let reloaded2 = session_manager.get_session(&session_id, true).await?; + let messages2 = reloaded2 + .conversation + .expect("should have conversation") + .messages() + .to_vec(); + + let user_count2 = messages2.iter().filter(|m| m.role == Role::User).count(); + let asst_count2 = messages2 + .iter() + .filter(|m| m.role == Role::Assistant) + .count(); + + // Reply 2 added 1 user message. The cancelled stream should + // have persisted at most 1 (partial) assistant message. + assert_eq!( + user_count2, 3, + "Expected 3 user messages (2 from reply 1 + follow-up), got {user_count2}", + ); + assert!( + asst_count2 <= 3, + "Expected at most 3 assistant messages (2 from reply 1 + at most 1 partial \ + from cancelled reply 2), got {asst_count2} \ + — streaming deltas are leaking into persistence", + ); + + Ok(()) + } + } }