Revert message flush & test (#7966)
All checks were successful
Unused Dependencies / machete (push) Has been skipped

Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
Douwe Osinga 2026-03-18 04:26:12 -04:00 committed by jh-block
parent d26e41bea0
commit 7b64ef1cba
2 changed files with 284 additions and 16 deletions

View file

@ -1187,6 +1187,7 @@ impl Agent {
).await?;
let mut no_tools_called = true;
let mut messages_to_add = Conversation::default();
let mut tools_updated = false;
let mut did_recovery_compact_this_iteration = false;
let mut exit_chat = false;
@ -1241,8 +1242,7 @@ impl Agent {
if !text.is_empty() {
last_assistant_text = text;
}
session_manager.add_message(&session_config.id, &response).await?;
conversation.push(response);
messages_to_add.push(response);
continue;
}
@ -1438,8 +1438,7 @@ impl Agent {
response.created,
thinking_content,
).with_id(format!("msg_{}", Uuid::new_v4()));
session_manager.add_message(&session_config.id, &thinking_msg).await?;
conversation.push(thinking_msg);
messages_to_add.push(thinking_msg);
}
// Collect reasoning content to attach to tool request messages
@ -1467,14 +1466,11 @@ impl Agent {
request.metadata.as_ref(),
request.tool_meta.clone(),
);
messages_to_add.push(request_msg);
let final_response = tool_response_messages[idx]
.lock().await.clone();
// Persist the tool request and response as a pair
session_manager.add_message(&session_config.id, &request_msg).await?;
session_manager.add_message(&session_config.id, &final_response).await?;
conversation.push(request_msg);
conversation.push(final_response.clone());
yield AgentEvent::Message(final_response);
yield AgentEvent::Message(final_response.clone());
messages_to_add.push(final_response);
} else {
error!(
"Tool call could not be parsed: {}",
@ -1618,14 +1614,12 @@ impl Agent {
Some(None) => {
warn!("Final output tool has not been called yet. Continuing agent loop.");
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
session_manager.add_message(&session_config.id, &message).await?;
conversation.push(message.clone());
messages_to_add.push(message.clone());
yield AgentEvent::Message(message);
}
Some(Some(output)) => {
let message = Message::assistant().with_text(output);
session_manager.add_message(&session_config.id, &message).await?;
conversation.push(message.clone());
messages_to_add.push(message.clone());
yield AgentEvent::Message(message);
exit_chat = true;
}
@ -1637,6 +1631,7 @@ impl Agent {
Ok(should_retry) => {
if should_retry {
info!("Retry logic triggered, restarting agent loop");
messages_to_add = Conversation::default();
session_manager.replace_conversation(&session_config.id, &conversation).await?;
yield AgentEvent::HistoryReplaced(conversation.clone());
} else {
@ -1680,14 +1675,17 @@ impl Agent {
}).await?;
}
conversation = Conversation::new_unvalidated(updated_messages);
session_manager.add_message(&session_config.id, &summary_msg).await?;
conversation.push(summary_msg);
messages_to_add.push(summary_msg);
} else {
warn!("Expected a tool request/reply pair, but found {} matching messages",
matching.len());
}
}
for msg in &messages_to_add {
session_manager.add_message(&session_config.id, msg).await?;
}
conversation.extend(messages_to_add);
if exit_chat {
break;
}

View file

@ -591,4 +591,274 @@ mod tests {
);
}
}
#[cfg(test)]
mod streaming_persistence_tests {
use super::*;
use async_trait::async_trait;
use goose::agents::{AgentConfig, SessionConfig};
use goose::config::permission::PermissionManager;
use goose::config::GooseMode;
use goose::conversation::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::{
MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage,
};
use goose::providers::errors::ProviderError;
use goose::session::session_manager::SessionType;
use goose::session::SessionManager;
use rmcp::model::{CallToolRequestParams, Role, Tool};
use rmcp::object;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio_util::sync::CancellationToken;
struct MultiStepProvider {
call_count: AtomicUsize,
cancel_token: CancellationToken,
}
impl MultiStepProvider {
fn new(cancel_token: CancellationToken) -> Self {
Self {
call_count: AtomicUsize::new(0),
cancel_token,
}
}
}
impl ProviderDef for MultiStepProvider {
type Provider = Self;
fn metadata() -> ProviderMetadata {
ProviderMetadata {
name: "multi-step-mock".to_string(),
display_name: "Multi-Step Mock".to_string(),
description: "Mock provider for streaming persistence tests".to_string(),
default_model: "mock-model".to_string(),
known_models: vec![],
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
}
}
fn from_env(
_model: ModelConfig,
_extensions: Vec<goose::config::ExtensionConfig>,
) -> futures::future::BoxFuture<'static, anyhow::Result<Self>> {
unimplemented!()
}
}
#[async_trait]
impl Provider for MultiStepProvider {
async fn stream(
&self,
_model_config: &ModelConfig,
_session_id: &str,
_system_prompt: &str,
_messages: &[Message],
_tools: &[Tool],
) -> Result<MessageStream, ProviderError> {
let call = self.call_count.fetch_add(1, Ordering::SeqCst);
let usage = ProviderUsage::new(
"mock-model".to_string(),
Usage::new(Some(10), Some(5), Some(15)),
);
match call {
0 => {
let tool_call = CallToolRequestParams::new("test_tool")
.with_arguments(object!({"param": "value"}));
let message =
Message::assistant().with_tool_request("call_1", Ok(tool_call));
let stream =
futures::stream::once(async move { Ok((Some(message), Some(usage))) });
Ok(Box::pin(stream))
}
1 => {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4());
let tokens = vec!["Hello", " world", ", how", " are", " you?"];
let stream = futures::stream::iter(tokens.into_iter().enumerate().map(
move |(i, token)| {
let msg = Message::assistant()
.with_text(token)
.with_id(msg_id.clone());
let u = if i == 4 { Some(usage.clone()) } else { None };
Ok((Some(msg), u))
},
));
Ok(Box::pin(stream))
}
_ => {
let cancel = self.cancel_token.clone();
let msg_id = format!("msg_{}", uuid::Uuid::new_v4());
let tokens = vec!["This ", "should ", "be ", "cancelled ", "soon."];
let stream = futures::stream::iter(tokens.into_iter().enumerate().map(
move |(i, token)| {
if i == 1 {
cancel.cancel();
}
let msg = Message::assistant()
.with_text(token)
.with_id(msg_id.clone());
let u = if i == 4 { Some(usage.clone()) } else { None };
Ok((Some(msg), u))
},
));
Ok(Box::pin(stream))
}
}
}
fn get_model_config(&self) -> ModelConfig {
ModelConfig::new("mock-model").unwrap()
}
fn get_name(&self) -> &str {
"multi-step-mock"
}
}
#[tokio::test]
async fn test_streaming_text_not_persisted_per_token() -> Result<()> {
let cancel_token = CancellationToken::new();
let temp_dir = tempfile::tempdir()?;
let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf()));
let config = AgentConfig::new(
session_manager.clone(),
PermissionManager::instance(),
None,
GooseMode::Auto,
true, // disable session naming so it doesn't consume a provider call
GoosePlatform::GooseCli,
);
let agent = Agent::with_config(config);
let provider = Arc::new(MultiStepProvider::new(cancel_token.clone()));
let session = session_manager
.create_session(
PathBuf::default(),
"streaming-test".to_string(),
SessionType::Hidden,
GooseMode::default(),
)
.await?;
let session_id = session.id.clone();
agent.update_provider(provider, &session_id).await?;
// ── Single reply: tool call (call 0) → text stream (call 1) → cancelled text (call 2)
// max_turns=3 allows all three provider calls within one reply().
// call 0: tool call → agent executes tool, loops
// call 1: 5 text deltas → no tools called, agent exits loop
// call 2: 5 text deltas, cancel token fired after 1st → agent interrupted
//
// Because call 1 ends the agent loop (no_tools_called=true → exit),
// call 2 is NOT reached in the same reply. We issue a second reply()
// with the cancel token so the provider triggers cancellation.
let session_config = SessionConfig {
id: session_id.clone(),
schedule_id: None,
max_turns: Some(2),
retry_config: None,
};
let reply_stream = agent
.reply(
Message::user().with_text("Do something then say hello"),
session_config,
None,
)
.await?;
tokio::pin!(reply_stream);
while let Some(event) = reply_stream.next().await {
match event {
Ok(AgentEvent::Message(_)) => {}
Ok(_) => {}
Err(e) => return Err(e),
}
}
// ── Check persisted state after reply 1 ─────────────────
let reloaded = session_manager.get_session(&session_id, true).await?;
let messages = reloaded
.conversation
.expect("should have conversation")
.messages()
.to_vec();
let user_count = messages.iter().filter(|m| m.role == Role::User).count();
let asst_count = messages
.iter()
.filter(|m| m.role == Role::Assistant)
.count();
// Expected: user(prompt) + assistant(tool-req) + user(tool-resp) + assistant(text)
assert_eq!(
user_count, 2,
"Expected 2 user messages (prompt + tool response), got {user_count}",
);
assert_eq!(
asst_count, 2,
"Expected 2 assistant messages (tool request + text reply), got {asst_count} \
streaming text deltas are being persisted as separate messages",
);
// ── Reply 2: text stream with provider-triggered cancellation (call 2)
let session_config2 = SessionConfig {
id: session_id.clone(),
schedule_id: None,
max_turns: Some(2),
retry_config: None,
};
let reply_stream2 = agent
.reply(
Message::user().with_text("Tell me more"),
session_config2,
Some(cancel_token),
)
.await?;
tokio::pin!(reply_stream2);
while let Some(event) = reply_stream2.next().await {
match event {
Ok(_) => {}
Err(e) => return Err(e),
}
}
// ── Check persisted state after cancellation ────────────
let reloaded2 = session_manager.get_session(&session_id, true).await?;
let messages2 = reloaded2
.conversation
.expect("should have conversation")
.messages()
.to_vec();
let user_count2 = messages2.iter().filter(|m| m.role == Role::User).count();
let asst_count2 = messages2
.iter()
.filter(|m| m.role == Role::Assistant)
.count();
// Reply 2 added 1 user message. The cancelled stream should
// have persisted at most 1 (partial) assistant message.
assert_eq!(
user_count2, 3,
"Expected 3 user messages (2 from reply 1 + follow-up), got {user_count2}",
);
assert!(
asst_count2 <= 3,
"Expected at most 3 assistant messages (2 from reply 1 + at most 1 partial \
from cancelled reply 2), got {asst_count2} \
streaming deltas are leaking into persistence",
);
Ok(())
}
}
}