mirror of
https://github.com/block/goose.git
synced 2026-04-28 03:29:36 +00:00
Revert message flush & test (#7966)
All checks were successful
Unused Dependencies / machete (push) Has been skipped
All checks were successful
Unused Dependencies / machete (push) Has been skipped
Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
parent
d26e41bea0
commit
7b64ef1cba
2 changed files with 284 additions and 16 deletions
|
|
@ -1187,6 +1187,7 @@ impl Agent {
|
|||
).await?;
|
||||
|
||||
let mut no_tools_called = true;
|
||||
let mut messages_to_add = Conversation::default();
|
||||
let mut tools_updated = false;
|
||||
let mut did_recovery_compact_this_iteration = false;
|
||||
let mut exit_chat = false;
|
||||
|
|
@ -1241,8 +1242,7 @@ impl Agent {
|
|||
if !text.is_empty() {
|
||||
last_assistant_text = text;
|
||||
}
|
||||
session_manager.add_message(&session_config.id, &response).await?;
|
||||
conversation.push(response);
|
||||
messages_to_add.push(response);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -1438,8 +1438,7 @@ impl Agent {
|
|||
response.created,
|
||||
thinking_content,
|
||||
).with_id(format!("msg_{}", Uuid::new_v4()));
|
||||
session_manager.add_message(&session_config.id, &thinking_msg).await?;
|
||||
conversation.push(thinking_msg);
|
||||
messages_to_add.push(thinking_msg);
|
||||
}
|
||||
|
||||
// Collect reasoning content to attach to tool request messages
|
||||
|
|
@ -1467,14 +1466,11 @@ impl Agent {
|
|||
request.metadata.as_ref(),
|
||||
request.tool_meta.clone(),
|
||||
);
|
||||
messages_to_add.push(request_msg);
|
||||
let final_response = tool_response_messages[idx]
|
||||
.lock().await.clone();
|
||||
// Persist the tool request and response as a pair
|
||||
session_manager.add_message(&session_config.id, &request_msg).await?;
|
||||
session_manager.add_message(&session_config.id, &final_response).await?;
|
||||
conversation.push(request_msg);
|
||||
conversation.push(final_response.clone());
|
||||
yield AgentEvent::Message(final_response);
|
||||
yield AgentEvent::Message(final_response.clone());
|
||||
messages_to_add.push(final_response);
|
||||
} else {
|
||||
error!(
|
||||
"Tool call could not be parsed: {}",
|
||||
|
|
@ -1618,14 +1614,12 @@ impl Agent {
|
|||
Some(None) => {
|
||||
warn!("Final output tool has not been called yet. Continuing agent loop.");
|
||||
let message = Message::user().with_text(FINAL_OUTPUT_CONTINUATION_MESSAGE);
|
||||
session_manager.add_message(&session_config.id, &message).await?;
|
||||
conversation.push(message.clone());
|
||||
messages_to_add.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
}
|
||||
Some(Some(output)) => {
|
||||
let message = Message::assistant().with_text(output);
|
||||
session_manager.add_message(&session_config.id, &message).await?;
|
||||
conversation.push(message.clone());
|
||||
messages_to_add.push(message.clone());
|
||||
yield AgentEvent::Message(message);
|
||||
exit_chat = true;
|
||||
}
|
||||
|
|
@ -1637,6 +1631,7 @@ impl Agent {
|
|||
Ok(should_retry) => {
|
||||
if should_retry {
|
||||
info!("Retry logic triggered, restarting agent loop");
|
||||
messages_to_add = Conversation::default();
|
||||
session_manager.replace_conversation(&session_config.id, &conversation).await?;
|
||||
yield AgentEvent::HistoryReplaced(conversation.clone());
|
||||
} else {
|
||||
|
|
@ -1680,14 +1675,17 @@ impl Agent {
|
|||
}).await?;
|
||||
}
|
||||
conversation = Conversation::new_unvalidated(updated_messages);
|
||||
session_manager.add_message(&session_config.id, &summary_msg).await?;
|
||||
conversation.push(summary_msg);
|
||||
messages_to_add.push(summary_msg);
|
||||
} else {
|
||||
warn!("Expected a tool request/reply pair, but found {} matching messages",
|
||||
matching.len());
|
||||
}
|
||||
}
|
||||
|
||||
for msg in &messages_to_add {
|
||||
session_manager.add_message(&session_config.id, msg).await?;
|
||||
}
|
||||
conversation.extend(messages_to_add);
|
||||
if exit_chat {
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -591,4 +591,274 @@ mod tests {
|
|||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod streaming_persistence_tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use goose::agents::{AgentConfig, SessionConfig};
|
||||
use goose::config::permission::PermissionManager;
|
||||
use goose::config::GooseMode;
|
||||
use goose::conversation::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::{
|
||||
MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage,
|
||||
};
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::session::session_manager::SessionType;
|
||||
use goose::session::SessionManager;
|
||||
use rmcp::model::{CallToolRequestParams, Role, Tool};
|
||||
use rmcp::object;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
struct MultiStepProvider {
|
||||
call_count: AtomicUsize,
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl MultiStepProvider {
|
||||
fn new(cancel_token: CancellationToken) -> Self {
|
||||
Self {
|
||||
call_count: AtomicUsize::new(0),
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderDef for MultiStepProvider {
|
||||
type Provider = Self;
|
||||
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata {
|
||||
name: "multi-step-mock".to_string(),
|
||||
display_name: "Multi-Step Mock".to_string(),
|
||||
description: "Mock provider for streaming persistence tests".to_string(),
|
||||
default_model: "mock-model".to_string(),
|
||||
known_models: vec![],
|
||||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn from_env(
|
||||
_model: ModelConfig,
|
||||
_extensions: Vec<goose::config::ExtensionConfig>,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<Self>> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MultiStepProvider {
|
||||
async fn stream(
|
||||
&self,
|
||||
_model_config: &ModelConfig,
|
||||
_session_id: &str,
|
||||
_system_prompt: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let call = self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
let usage = ProviderUsage::new(
|
||||
"mock-model".to_string(),
|
||||
Usage::new(Some(10), Some(5), Some(15)),
|
||||
);
|
||||
|
||||
match call {
|
||||
0 => {
|
||||
let tool_call = CallToolRequestParams::new("test_tool")
|
||||
.with_arguments(object!({"param": "value"}));
|
||||
let message =
|
||||
Message::assistant().with_tool_request("call_1", Ok(tool_call));
|
||||
let stream =
|
||||
futures::stream::once(async move { Ok((Some(message), Some(usage))) });
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
1 => {
|
||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4());
|
||||
let tokens = vec!["Hello", " world", ", how", " are", " you?"];
|
||||
let stream = futures::stream::iter(tokens.into_iter().enumerate().map(
|
||||
move |(i, token)| {
|
||||
let msg = Message::assistant()
|
||||
.with_text(token)
|
||||
.with_id(msg_id.clone());
|
||||
let u = if i == 4 { Some(usage.clone()) } else { None };
|
||||
Ok((Some(msg), u))
|
||||
},
|
||||
));
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
_ => {
|
||||
let cancel = self.cancel_token.clone();
|
||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4());
|
||||
let tokens = vec!["This ", "should ", "be ", "cancelled ", "soon."];
|
||||
let stream = futures::stream::iter(tokens.into_iter().enumerate().map(
|
||||
move |(i, token)| {
|
||||
if i == 1 {
|
||||
cancel.cancel();
|
||||
}
|
||||
let msg = Message::assistant()
|
||||
.with_text(token)
|
||||
.with_id(msg_id.clone());
|
||||
let u = if i == 4 { Some(usage.clone()) } else { None };
|
||||
Ok((Some(msg), u))
|
||||
},
|
||||
));
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
ModelConfig::new("mock-model").unwrap()
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
"multi-step-mock"
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_text_not_persisted_per_token() -> Result<()> {
|
||||
let cancel_token = CancellationToken::new();
|
||||
let temp_dir = tempfile::tempdir()?;
|
||||
let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf()));
|
||||
let config = AgentConfig::new(
|
||||
session_manager.clone(),
|
||||
PermissionManager::instance(),
|
||||
None,
|
||||
GooseMode::Auto,
|
||||
true, // disable session naming so it doesn't consume a provider call
|
||||
GoosePlatform::GooseCli,
|
||||
);
|
||||
let agent = Agent::with_config(config);
|
||||
let provider = Arc::new(MultiStepProvider::new(cancel_token.clone()));
|
||||
|
||||
let session = session_manager
|
||||
.create_session(
|
||||
PathBuf::default(),
|
||||
"streaming-test".to_string(),
|
||||
SessionType::Hidden,
|
||||
GooseMode::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let session_id = session.id.clone();
|
||||
agent.update_provider(provider, &session_id).await?;
|
||||
|
||||
// ── Single reply: tool call (call 0) → text stream (call 1) → cancelled text (call 2)
|
||||
// max_turns=3 allows all three provider calls within one reply().
|
||||
// call 0: tool call → agent executes tool, loops
|
||||
// call 1: 5 text deltas → no tools called, agent exits loop
|
||||
// call 2: 5 text deltas, cancel token fired after 1st → agent interrupted
|
||||
//
|
||||
// Because call 1 ends the agent loop (no_tools_called=true → exit),
|
||||
// call 2 is NOT reached in the same reply. We issue a second reply()
|
||||
// with the cancel token so the provider triggers cancellation.
|
||||
let session_config = SessionConfig {
|
||||
id: session_id.clone(),
|
||||
schedule_id: None,
|
||||
max_turns: Some(2),
|
||||
retry_config: None,
|
||||
};
|
||||
|
||||
let reply_stream = agent
|
||||
.reply(
|
||||
Message::user().with_text("Do something then say hello"),
|
||||
session_config,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
while let Some(event) = reply_stream.next().await {
|
||||
match event {
|
||||
Ok(AgentEvent::Message(_)) => {}
|
||||
Ok(_) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Check persisted state after reply 1 ─────────────────
|
||||
let reloaded = session_manager.get_session(&session_id, true).await?;
|
||||
let messages = reloaded
|
||||
.conversation
|
||||
.expect("should have conversation")
|
||||
.messages()
|
||||
.to_vec();
|
||||
|
||||
let user_count = messages.iter().filter(|m| m.role == Role::User).count();
|
||||
let asst_count = messages
|
||||
.iter()
|
||||
.filter(|m| m.role == Role::Assistant)
|
||||
.count();
|
||||
|
||||
// Expected: user(prompt) + assistant(tool-req) + user(tool-resp) + assistant(text)
|
||||
assert_eq!(
|
||||
user_count, 2,
|
||||
"Expected 2 user messages (prompt + tool response), got {user_count}",
|
||||
);
|
||||
assert_eq!(
|
||||
asst_count, 2,
|
||||
"Expected 2 assistant messages (tool request + text reply), got {asst_count} \
|
||||
— streaming text deltas are being persisted as separate messages",
|
||||
);
|
||||
|
||||
// ── Reply 2: text stream with provider-triggered cancellation (call 2)
|
||||
let session_config2 = SessionConfig {
|
||||
id: session_id.clone(),
|
||||
schedule_id: None,
|
||||
max_turns: Some(2),
|
||||
retry_config: None,
|
||||
};
|
||||
|
||||
let reply_stream2 = agent
|
||||
.reply(
|
||||
Message::user().with_text("Tell me more"),
|
||||
session_config2,
|
||||
Some(cancel_token),
|
||||
)
|
||||
.await?;
|
||||
tokio::pin!(reply_stream2);
|
||||
|
||||
while let Some(event) = reply_stream2.next().await {
|
||||
match event {
|
||||
Ok(_) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Check persisted state after cancellation ────────────
|
||||
let reloaded2 = session_manager.get_session(&session_id, true).await?;
|
||||
let messages2 = reloaded2
|
||||
.conversation
|
||||
.expect("should have conversation")
|
||||
.messages()
|
||||
.to_vec();
|
||||
|
||||
let user_count2 = messages2.iter().filter(|m| m.role == Role::User).count();
|
||||
let asst_count2 = messages2
|
||||
.iter()
|
||||
.filter(|m| m.role == Role::Assistant)
|
||||
.count();
|
||||
|
||||
// Reply 2 added 1 user message. The cancelled stream should
|
||||
// have persisted at most 1 (partial) assistant message.
|
||||
assert_eq!(
|
||||
user_count2, 3,
|
||||
"Expected 3 user messages (2 from reply 1 + follow-up), got {user_count2}",
|
||||
);
|
||||
assert!(
|
||||
asst_count2 <= 3,
|
||||
"Expected at most 3 assistant messages (2 from reply 1 + at most 1 partial \
|
||||
from cancelled reply 2), got {asst_count2} \
|
||||
— streaming deltas are leaking into persistence",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue