mirror of
https://github.com/block/goose.git
synced 2026-04-28 03:29:36 +00:00
Optimize tool summarization (#7938)
Some checks failed
Unused Dependencies / machete (push) Has been skipped
Canary / Prepare Version (push) Failing after 4s
Canary / Upload Install Script (push) Has been skipped
Live Provider Tests / check-fork (push) Successful in 2s
CI / changes (push) Failing after 3s
Live Provider Tests / changes (push) Failing after 3s
Canary / bundle-desktop (push) Has been skipped
Canary / bundle-desktop-intel (push) Has been skipped
Canary / bundle-desktop-linux (push) Has been skipped
Canary / bundle-desktop-windows (push) Has been skipped
Canary / build-cli (push) Has been skipped
Canary / Release (push) Has been skipped
Scorecard supply-chain security / Scorecard analysis (push) Has been skipped
Publish Docker Image / docker (push) Failing after 4s
CI / Check Rust Code Format (push) Has been skipped
CI / Build and Test Rust Project (push) Has been skipped
CI / Lint Rust Code (push) Has been skipped
CI / Check OpenAPI Schema is Up-to-Date (push) Has been skipped
Live Provider Tests / Build Binary (push) Has been skipped
Live Provider Tests / Smoke Tests (push) Has been skipped
Live Provider Tests / Smoke Tests (Code Execution) (push) Has been skipped
Live Provider Tests / Compaction Tests (push) Has been skipped
Live Provider Tests / goose server HTTP integration tests (push) Has been skipped
CI / Test and Lint Electron Desktop App (push) Has been cancelled
Some checks failed
Unused Dependencies / machete (push) Has been skipped
Canary / Prepare Version (push) Failing after 4s
Canary / Upload Install Script (push) Has been skipped
Live Provider Tests / check-fork (push) Successful in 2s
CI / changes (push) Failing after 3s
Live Provider Tests / changes (push) Failing after 3s
Canary / bundle-desktop (push) Has been skipped
Canary / bundle-desktop-intel (push) Has been skipped
Canary / bundle-desktop-linux (push) Has been skipped
Canary / bundle-desktop-windows (push) Has been skipped
Canary / build-cli (push) Has been skipped
Canary / Release (push) Has been skipped
Scorecard supply-chain security / Scorecard analysis (push) Has been skipped
Publish Docker Image / docker (push) Failing after 4s
CI / Check Rust Code Format (push) Has been skipped
CI / Build and Test Rust Project (push) Has been skipped
CI / Lint Rust Code (push) Has been skipped
CI / Check OpenAPI Schema is Up-to-Date (push) Has been skipped
Live Provider Tests / Build Binary (push) Has been skipped
Live Provider Tests / Smoke Tests (push) Has been skipped
Live Provider Tests / Smoke Tests (Code Execution) (push) Has been skipped
Live Provider Tests / Compaction Tests (push) Has been skipped
Live Provider Tests / goose server HTTP integration tests (push) Has been skipped
CI / Test and Lint Electron Desktop App (push) Has been cancelled
Co-authored-by: Michael Neale <michael.neale@gmail.com> Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
parent
493566dff2
commit
475968db64
7 changed files with 458 additions and 177 deletions
|
|
@ -363,15 +363,29 @@ impl Agent {
|
|||
self.tool_inspection_manager.apply_tool_annotations(&tools);
|
||||
}
|
||||
|
||||
let tool_call_cut_off = match Config::global().get_param::<usize>("GOOSE_TOOL_CALL_CUTOFF")
|
||||
{
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
let context_limit = self
|
||||
.provider()
|
||||
.await
|
||||
.map(|p| p.get_model_config().context_limit())
|
||||
.unwrap_or(crate::model::DEFAULT_CONTEXT_LIMIT);
|
||||
let compaction_threshold = Config::global()
|
||||
.get_param::<f64>("GOOSE_AUTO_COMPACT_THRESHOLD")
|
||||
.unwrap_or(crate::context_mgmt::DEFAULT_COMPACTION_THRESHOLD);
|
||||
crate::context_mgmt::compute_tool_call_cutoff(context_limit, compaction_threshold)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ReplyContext {
|
||||
conversation,
|
||||
tools,
|
||||
toolshim_tools,
|
||||
system_prompt,
|
||||
goose_mode,
|
||||
tool_call_cut_off: Config::global()
|
||||
.get_param::<usize>("GOOSE_TOOL_CALL_CUTOFF")
|
||||
.unwrap_or(10),
|
||||
tool_call_cut_off,
|
||||
initial_messages,
|
||||
})
|
||||
}
|
||||
|
|
@ -1127,6 +1141,15 @@ impl Agent {
|
|||
});
|
||||
}
|
||||
|
||||
// Count tool calls present before this reply — everything added during
|
||||
// the reply loop is part of the current turn and should not be summarized.
|
||||
let pre_turn_tool_count = conversation
|
||||
.messages()
|
||||
.iter()
|
||||
.flat_map(|m| m.content.iter())
|
||||
.filter(|c| matches!(c, MessageContent::ToolRequest(_)))
|
||||
.count();
|
||||
|
||||
let working_dir = session.working_dir.clone();
|
||||
let reply_stream_span = tracing::info_span!(target: "goose::agents::agent", "reply_stream", session.id = %session_config.id);
|
||||
let inner = Box::pin(async_stream::try_stream! {
|
||||
|
|
@ -1162,13 +1185,6 @@ impl Agent {
|
|||
break;
|
||||
}
|
||||
|
||||
let tool_pair_summarization_task = crate::context_mgmt::maybe_summarize_tool_pair(
|
||||
self.provider().await?,
|
||||
session_config.id.clone(),
|
||||
conversation.clone(),
|
||||
tool_call_cut_off,
|
||||
);
|
||||
|
||||
let conversation_with_moim = super::moim::inject_moim(
|
||||
&session_config.id,
|
||||
conversation.clone(),
|
||||
|
|
@ -1185,6 +1201,20 @@ impl Agent {
|
|||
&toolshim_tools,
|
||||
).await?;
|
||||
|
||||
let current_turn_tool_count = conversation.messages().iter()
|
||||
.flat_map(|m| m.content.iter())
|
||||
.filter(|c| matches!(c, MessageContent::ToolRequest(_)))
|
||||
.count()
|
||||
.saturating_sub(pre_turn_tool_count);
|
||||
|
||||
let tool_pair_summarization_task = crate::context_mgmt::maybe_summarize_tool_pairs(
|
||||
self.provider().await?,
|
||||
session_config.id.clone(),
|
||||
conversation.clone(),
|
||||
tool_call_cut_off,
|
||||
current_turn_tool_count,
|
||||
);
|
||||
|
||||
let mut no_tools_called = true;
|
||||
let mut messages_to_add = Conversation::default();
|
||||
let mut tools_updated = false;
|
||||
|
|
@ -1641,34 +1671,40 @@ impl Agent {
|
|||
}
|
||||
}
|
||||
|
||||
if let Ok(Some((summary_msg, tool_id))) = tool_pair_summarization_task.await {
|
||||
if is_token_cancelled(&cancel_token) {
|
||||
tool_pair_summarization_task.abort();
|
||||
}
|
||||
|
||||
if let Ok(summaries) = tool_pair_summarization_task.await {
|
||||
let mut updated_messages = conversation.messages().clone();
|
||||
|
||||
let matching: Vec<&mut Message> = updated_messages
|
||||
.iter_mut()
|
||||
.filter(|msg| {
|
||||
msg.id.is_some() && msg.content.iter().any(|c| match c {
|
||||
MessageContent::ToolRequest(req) => req.id == tool_id,
|
||||
MessageContent::ToolResponse(resp) => resp.id == tool_id,
|
||||
_ => false,
|
||||
for (summary_msg, tool_id) in summaries {
|
||||
let matching: Vec<&mut Message> = updated_messages
|
||||
.iter_mut()
|
||||
.filter(|msg| {
|
||||
msg.id.is_some() && msg.content.iter().any(|c| match c {
|
||||
MessageContent::ToolRequest(req) => req.id == tool_id,
|
||||
MessageContent::ToolResponse(resp) => resp.id == tool_id,
|
||||
_ => false,
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect();
|
||||
|
||||
if matching.len() == 2 {
|
||||
for msg in matching {
|
||||
let id = msg.id.as_ref().unwrap();
|
||||
msg.metadata = msg.metadata.with_agent_invisible();
|
||||
SessionManager::update_message_metadata(&session_config.id, id, |metadata| {
|
||||
metadata.with_agent_invisible()
|
||||
}).await?;
|
||||
if matching.len() == 2 {
|
||||
for msg in matching {
|
||||
let id = msg.id.as_ref().unwrap();
|
||||
msg.metadata = msg.metadata.with_agent_invisible();
|
||||
SessionManager::update_message_metadata(&session_config.id, id, |metadata| {
|
||||
metadata.with_agent_invisible()
|
||||
}).await?;
|
||||
}
|
||||
messages_to_add.push(summary_msg);
|
||||
} else {
|
||||
warn!("Expected a tool request/reply pair, but found {} matching messages",
|
||||
matching.len());
|
||||
}
|
||||
conversation = Conversation::new_unvalidated(updated_messages);
|
||||
messages_to_add.push(summary_msg);
|
||||
} else {
|
||||
warn!("Expected a tool request/reply pair, but found {} matching messages",
|
||||
matching.len());
|
||||
}
|
||||
conversation = Conversation::new_unvalidated(updated_messages);
|
||||
}
|
||||
|
||||
for msg in &messages_to_add {
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@ use tracing::log::warn;
|
|||
|
||||
pub const DEFAULT_COMPACTION_THRESHOLD: f64 = 0.8;
|
||||
|
||||
/// Feature flag to enable/disable tool pair summarization.
|
||||
/// Set to `false` to disable summarizing old tool call/response pairs.
|
||||
/// TODO: Re-enable once tool summarization stability issues are resolved.
|
||||
const ENABLE_TOOL_PAIR_SUMMARIZATION: bool = false;
|
||||
const TOOLCALL_SUMMARIZATION_BATCH_SIZE: usize = 10;
|
||||
|
||||
fn tool_pair_summarization_enabled() -> bool {
|
||||
Config::global()
|
||||
.get_param::<bool>("GOOSE_TOOL_PAIR_SUMMARIZATION")
|
||||
.unwrap_or(true)
|
||||
}
|
||||
|
||||
const CONVERSATION_CONTINUATION_TEXT: &str =
|
||||
"Your context was compacted. The previous message contains a summary of the conversation so far.
|
||||
|
|
@ -418,13 +421,24 @@ fn format_message_for_compacting(msg: &Message) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
/// Find the id of a tool call to summarize. We only do this if we have more than
|
||||
/// cutoff tool calls that aren't summarized yet
|
||||
pub fn tool_id_to_summarize(conversation: &Conversation, cutoff: usize) -> Option<String> {
|
||||
pub fn compute_tool_call_cutoff(context_limit: usize, compaction_threshold: f64) -> usize {
|
||||
let threshold = if compaction_threshold > 0.0 && compaction_threshold <= 1.0 {
|
||||
compaction_threshold
|
||||
} else {
|
||||
DEFAULT_COMPACTION_THRESHOLD
|
||||
};
|
||||
let effective_limit = (context_limit as f64 * threshold) as usize;
|
||||
(3 * effective_limit / 20_000).clamp(10, 500)
|
||||
}
|
||||
|
||||
pub fn tool_ids_to_summarize(
|
||||
conversation: &Conversation,
|
||||
cutoff: usize,
|
||||
protect_last_n: usize,
|
||||
) -> Vec<String> {
|
||||
let messages = conversation.messages();
|
||||
|
||||
let mut tool_call_count = 0;
|
||||
let mut first_tool_call_id = None;
|
||||
let mut tool_call_ids: Vec<String> = Vec::new();
|
||||
|
||||
for msg in messages.iter() {
|
||||
if !msg.is_agent_visible() {
|
||||
|
|
@ -433,17 +447,21 @@ pub fn tool_id_to_summarize(conversation: &Conversation, cutoff: usize) -> Optio
|
|||
|
||||
for content in &msg.content {
|
||||
if let MessageContent::ToolRequest(req) = content {
|
||||
if first_tool_call_id.is_none() {
|
||||
first_tool_call_id = Some(req.id.clone());
|
||||
}
|
||||
tool_call_count += 1;
|
||||
if tool_call_count > cutoff {
|
||||
return first_tool_call_id;
|
||||
}
|
||||
tool_call_ids.push(req.id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
|
||||
// Never summarize the last N tool calls (current turn)
|
||||
let eligible = tool_call_ids.len().saturating_sub(protect_last_n);
|
||||
if eligible <= cutoff + TOOLCALL_SUMMARIZATION_BATCH_SIZE {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
tool_call_ids
|
||||
.into_iter()
|
||||
.take(TOOLCALL_SUMMARIZATION_BATCH_SIZE)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn summarize_tool_call(
|
||||
|
|
@ -482,17 +500,16 @@ pub async fn summarize_tool_call(
|
|||
let summarization_request = vec![user_message];
|
||||
|
||||
let system_prompt = indoc! {r#"
|
||||
Your task is to summarize a tool call & response pair to save tokens
|
||||
Your task is to summarize a tool call & response pair to save tokens.
|
||||
|
||||
reply with a single message that describe what happened. Typically a toolcall
|
||||
is asks for something using a bunch of parameters and then the result is also some
|
||||
Reply with a single message that describes what happened. Typically a tool call
|
||||
asks for something using a bunch of parameters and then the result is also some
|
||||
structured output. So the tool might ask to look up something on github and the
|
||||
reply might be a json document. So you could reply with something like:
|
||||
|
||||
"A call to github was made to get the project status"
|
||||
|
||||
if that is what it was.
|
||||
|
||||
"#};
|
||||
|
||||
let (mut response, _) = provider
|
||||
|
|
@ -506,31 +523,30 @@ pub async fn summarize_tool_call(
|
|||
Ok(response.with_generated_id())
|
||||
}
|
||||
|
||||
pub fn maybe_summarize_tool_pair(
|
||||
pub fn maybe_summarize_tool_pairs(
|
||||
provider: Arc<dyn Provider>,
|
||||
session_id: String,
|
||||
conversation: Conversation,
|
||||
cutoff: usize,
|
||||
) -> JoinHandle<Option<(Message, String)>> {
|
||||
protect_last_n: usize,
|
||||
) -> JoinHandle<Vec<(Message, String)>> {
|
||||
tokio::spawn(async move {
|
||||
// Tool pair summarization is currently disabled via feature flag.
|
||||
// See ENABLE_TOOL_PAIR_SUMMARIZATION constant above.
|
||||
if !ENABLE_TOOL_PAIR_SUMMARIZATION {
|
||||
return None;
|
||||
if !tool_pair_summarization_enabled() || provider.manages_own_context() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
if let Some(tool_id) = tool_id_to_summarize(&conversation, cutoff) {
|
||||
let tool_ids = tool_ids_to_summarize(&conversation, cutoff, protect_last_n);
|
||||
let mut results = Vec::new();
|
||||
for tool_id in tool_ids {
|
||||
match summarize_tool_call(provider.as_ref(), &session_id, &conversation, &tool_id).await
|
||||
{
|
||||
Ok(summary) => Some((summary, tool_id)),
|
||||
Ok(summary) => results.push((summary, tool_id)),
|
||||
Err(e) => {
|
||||
warn!("Failed to summarize tool pair: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
results
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -544,6 +560,30 @@ mod tests {
|
|||
use async_trait::async_trait;
|
||||
use rmcp::model::{AnnotateAble, CallToolRequestParams, RawContent, Tool};
|
||||
|
||||
fn create_tool_pair(
|
||||
call_id: &str,
|
||||
response_id: &str,
|
||||
tool_name: &str,
|
||||
response_text: &str,
|
||||
) -> Vec<Message> {
|
||||
vec![
|
||||
Message::assistant()
|
||||
.with_tool_request(
|
||||
call_id,
|
||||
Ok(CallToolRequestParams::new(tool_name.to_string())),
|
||||
)
|
||||
.with_id(call_id),
|
||||
Message::user()
|
||||
.with_tool_response(
|
||||
call_id,
|
||||
Ok(rmcp::model::CallToolResult::success(vec![
|
||||
RawContent::text(response_text).no_annotation(),
|
||||
])),
|
||||
)
|
||||
.with_id(response_id),
|
||||
]
|
||||
}
|
||||
|
||||
struct MockProvider {
|
||||
message: Message,
|
||||
config: ModelConfig,
|
||||
|
|
@ -677,125 +717,86 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_pair_summarization_workflow() {
|
||||
fn create_tool_pair(
|
||||
call_id: &str,
|
||||
response_id: &str,
|
||||
tool_name: &str,
|
||||
response_text: &str,
|
||||
) -> Vec<Message> {
|
||||
vec![
|
||||
Message::assistant()
|
||||
.with_tool_request(
|
||||
call_id,
|
||||
Ok(CallToolRequestParams::new(tool_name.to_string())),
|
||||
)
|
||||
.with_id(call_id),
|
||||
Message::user()
|
||||
.with_tool_response(
|
||||
call_id,
|
||||
Ok(rmcp::model::CallToolResult::success(vec![
|
||||
RawContent::text(response_text).no_annotation(),
|
||||
])),
|
||||
)
|
||||
.with_id(response_id),
|
||||
]
|
||||
#[test]
|
||||
fn test_compute_tool_call_cutoff_scales_with_context() {
|
||||
// Default threshold (0.8)
|
||||
assert_eq!(compute_tool_call_cutoff(128_000, 0.8), 15); // 102K effective
|
||||
assert_eq!(compute_tool_call_cutoff(200_000, 0.8), 24); // 160K effective
|
||||
assert_eq!(compute_tool_call_cutoff(1_000_000, 0.8), 120); // 800K effective
|
||||
// Clamp at minimum
|
||||
assert_eq!(compute_tool_call_cutoff(50_000, 0.8), 10);
|
||||
assert_eq!(compute_tool_call_cutoff(10_000, 0.8), 10);
|
||||
// Clamp at maximum (500)
|
||||
assert_eq!(compute_tool_call_cutoff(10_000_000, 0.8), 500);
|
||||
// Lower compaction threshold means earlier summarization
|
||||
assert_eq!(compute_tool_call_cutoff(200_000, 0.3), 10); // 60K effective
|
||||
assert_eq!(compute_tool_call_cutoff(1_000_000, 0.5), 75); // 500K effective
|
||||
// Invalid threshold falls back to default 0.8
|
||||
assert_eq!(compute_tool_call_cutoff(200_000, 0.0), 24); // falls back to 0.8
|
||||
assert_eq!(compute_tool_call_cutoff(200_000, -1.0), 24); // falls back to 0.8
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_ids_to_summarize_triggers_at_cutoff_plus_batch() {
|
||||
// cutoff=5, so we need >5+10=15 to trigger. 15 exactly should NOT trigger.
|
||||
let mut messages = vec![Message::user().with_text("hello")];
|
||||
for i in 0..15 {
|
||||
messages.extend(create_tool_pair(
|
||||
&format!("call{}", i),
|
||||
&format!("resp{}", i),
|
||||
"read_file",
|
||||
"content",
|
||||
));
|
||||
}
|
||||
let conversation = Conversation::new_unvalidated(messages);
|
||||
let result = tool_ids_to_summarize(&conversation, 5, 0);
|
||||
assert!(result.is_empty(), "Exactly cutoff+batch should not trigger");
|
||||
|
||||
let summary_response = Message::assistant()
|
||||
.with_text("Tool call to list files and response with file listing");
|
||||
let provider = MockProvider::new(summary_response, 1000);
|
||||
|
||||
let mut messages = vec![Message::user().with_text("list files").with_id("msg_1")];
|
||||
messages.extend(create_tool_pair(
|
||||
"call1",
|
||||
"response1",
|
||||
"shell",
|
||||
"file1.txt\nfile2.txt",
|
||||
));
|
||||
messages.extend(create_tool_pair(
|
||||
"call2",
|
||||
"response2",
|
||||
"read_file",
|
||||
"content of file1",
|
||||
));
|
||||
messages.extend(create_tool_pair(
|
||||
"call3",
|
||||
"response3",
|
||||
"read_file",
|
||||
"content of file2",
|
||||
));
|
||||
// 16 tool calls: now exceeds cutoff+10, should return a batch of 10
|
||||
let mut messages = vec![Message::user().with_text("hello")];
|
||||
for i in 0..16 {
|
||||
messages.extend(create_tool_pair(
|
||||
&format!("call{}", i),
|
||||
&format!("resp{}", i),
|
||||
"read_file",
|
||||
"content",
|
||||
));
|
||||
}
|
||||
let conversation = Conversation::new_unvalidated(messages);
|
||||
let result = tool_ids_to_summarize(&conversation, 5, 0);
|
||||
assert_eq!(result.len(), TOOLCALL_SUMMARIZATION_BATCH_SIZE);
|
||||
assert_eq!(result[0], "call0");
|
||||
assert_eq!(result[9], "call9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_ids_to_summarize_protects_current_turn() {
|
||||
// 20 tool pairs, cutoff=2 → 20 > 12, would normally trigger
|
||||
let mut messages = vec![Message::user().with_text("hello")];
|
||||
for i in 0..20 {
|
||||
messages.extend(create_tool_pair(
|
||||
&format!("call{}", i),
|
||||
&format!("resp{}", i),
|
||||
"read_file",
|
||||
"content",
|
||||
));
|
||||
}
|
||||
let conversation = Conversation::new_unvalidated(messages);
|
||||
|
||||
let result = tool_id_to_summarize(&conversation, 2);
|
||||
// No protection: 20 eligible, 20 > 12 → batch of 10
|
||||
let result = tool_ids_to_summarize(&conversation, 2, 0);
|
||||
assert_eq!(result.len(), TOOLCALL_SUMMARIZATION_BATCH_SIZE);
|
||||
|
||||
// Protect last 8: 12 eligible, 12 <= 12 → nothing
|
||||
let result = tool_ids_to_summarize(&conversation, 2, 8);
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"Should return a pair to summarize when tool calls exceed cutoff"
|
||||
result.is_empty(),
|
||||
"Should not summarize when protected count leaves eligible <= cutoff + batch"
|
||||
);
|
||||
|
||||
let tool_call_id = result.unwrap();
|
||||
assert_eq!(tool_call_id, "call1");
|
||||
|
||||
let summary = summarize_tool_call(&provider, "test-session", &conversation, &tool_call_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(summary.role, Role::User);
|
||||
assert!(summary.metadata.agent_visible);
|
||||
assert!(!summary.metadata.user_visible);
|
||||
|
||||
let mut updated_messages = conversation.messages().clone();
|
||||
for msg in updated_messages.iter_mut() {
|
||||
let has_matching_content = msg.content.iter().any(|c| match c {
|
||||
MessageContent::ToolRequest(req) => req.id == tool_call_id,
|
||||
MessageContent::ToolResponse(resp) => resp.id == tool_call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if has_matching_content {
|
||||
msg.metadata = msg.metadata.with_agent_invisible();
|
||||
}
|
||||
}
|
||||
|
||||
updated_messages.push(summary);
|
||||
|
||||
let updated_conversation = Conversation::new_unvalidated(updated_messages);
|
||||
let messages = updated_conversation.messages();
|
||||
|
||||
let call1_msg = messages
|
||||
.iter()
|
||||
.find(|m| m.id.as_deref() == Some("call1"))
|
||||
.unwrap();
|
||||
assert!(
|
||||
!call1_msg.is_agent_visible(),
|
||||
"Original call should not be agent visible"
|
||||
);
|
||||
|
||||
let response1_msg = messages
|
||||
.iter()
|
||||
.find(|m| m.id.as_deref() == Some("response1"))
|
||||
.unwrap();
|
||||
assert!(
|
||||
!response1_msg.is_agent_visible(),
|
||||
"Original response should not be agent visible"
|
||||
);
|
||||
|
||||
let summary_msg = messages
|
||||
.iter()
|
||||
.find(|m| {
|
||||
m.metadata.agent_visible
|
||||
&& !m.metadata.user_visible
|
||||
&& m.as_concat_text().contains("Tool call")
|
||||
})
|
||||
.unwrap();
|
||||
assert!(
|
||||
!summary_msg.is_user_visible(),
|
||||
"Summary should not be user visible"
|
||||
);
|
||||
|
||||
let result = tool_id_to_summarize(&updated_conversation, 3);
|
||||
assert!(result.is_none(), "Nothing left to summarize");
|
||||
// Protect last 7: 13 eligible, 13 > 12 → batch of 10
|
||||
let result = tool_ids_to_summarize(&conversation, 2, 7);
|
||||
assert_eq!(result.len(), TOOLCALL_SUMMARIZATION_BATCH_SIZE);
|
||||
assert_eq!(result[0], "call0");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use std::collections::HashMap;
|
|||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
const DEFAULT_CONTEXT_LIMIT: usize = 128_000;
|
||||
pub const DEFAULT_CONTEXT_LIMIT: usize = 128_000;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct PredefinedModel {
|
||||
|
|
|
|||
|
|
@ -622,6 +622,14 @@ pub trait Provider: Send + Sync {
|
|||
false
|
||||
}
|
||||
|
||||
/// Whether the provider manages its own conversation context (e.g. CLI
|
||||
/// wrappers like Claude Code or Gemini CLI). When true, goose-side
|
||||
/// context management such as tool-pair summarization is skipped because
|
||||
/// the provider's internal state is the source of truth.
|
||||
fn manages_own_context(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn supports_cache_control(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -637,6 +637,10 @@ impl Provider for ClaudeCodeProvider {
|
|||
&self.name
|
||||
}
|
||||
|
||||
fn manages_own_context(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -185,6 +185,10 @@ impl Provider for GeminiCliProvider {
|
|||
&self.name
|
||||
}
|
||||
|
||||
fn manages_own_context(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
self.model.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -495,6 +495,234 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tool_pair_summarization_tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use goose::agents::SessionConfig;
|
||||
use goose::config::base::Config;
|
||||
use goose::config::GooseMode;
|
||||
use goose::conversation::message::Message;
|
||||
use goose::model::ModelConfig;
|
||||
use goose::providers::base::{
|
||||
stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata,
|
||||
ProviderUsage, Usage,
|
||||
};
|
||||
use goose::providers::errors::ProviderError;
|
||||
use goose::session::session_manager::SessionType;
|
||||
use rmcp::model::{AnnotateAble, CallToolRequestParams, CallToolResult, RawContent, Tool};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Mock provider that returns text for the main reply and summaries for
|
||||
/// summarization calls. Distinguishes by checking if tools are empty
|
||||
/// (summarization calls pass no tools).
|
||||
struct SummarizationTestProvider {
|
||||
summary_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl SummarizationTestProvider {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
summary_count: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderDef for SummarizationTestProvider {
|
||||
type Provider = Self;
|
||||
|
||||
fn metadata() -> ProviderMetadata {
|
||||
ProviderMetadata {
|
||||
name: "mock-summarization".to_string(),
|
||||
display_name: "Mock Summarization Provider".to_string(),
|
||||
description: "Mock provider for summarization tests".to_string(),
|
||||
default_model: "mock-model".to_string(),
|
||||
known_models: vec![],
|
||||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
fn from_env(
|
||||
_model: ModelConfig,
|
||||
_extensions: Vec<goose::config::ExtensionConfig>,
|
||||
) -> futures::future::BoxFuture<'static, anyhow::Result<Self>> {
|
||||
Box::pin(async { Ok(Self::new()) })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for SummarizationTestProvider {
|
||||
async fn stream(
|
||||
&self,
|
||||
_model_config: &ModelConfig,
|
||||
_session_id: &str,
|
||||
_system_prompt: &str,
|
||||
_messages: &[Message],
|
||||
tools: &[Tool],
|
||||
) -> Result<MessageStream, ProviderError> {
|
||||
let message = if tools.is_empty() {
|
||||
// Summarization call — return a unique summary
|
||||
let n = self.summary_count.fetch_add(1, Ordering::SeqCst);
|
||||
Message::assistant().with_text(format!("Summary of tool call #{}", n))
|
||||
} else {
|
||||
// Main agent reply — return plain text so the loop exits
|
||||
Message::assistant().with_text("Done processing.")
|
||||
};
|
||||
|
||||
let usage = ProviderUsage::new(
|
||||
"mock-model".to_string(),
|
||||
Usage::new(Some(10), Some(5), Some(15)),
|
||||
);
|
||||
Ok(stream_from_single_message(message, usage))
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
ModelConfig::new("mock-model").unwrap()
|
||||
}
|
||||
|
||||
fn get_name(&self) -> &str {
|
||||
"mock-summarization"
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that batch tool pair summarization preserves all summaries.
|
||||
///
|
||||
/// Pre-populates a session with enough tool call/response pairs to trigger
|
||||
/// batch summarization, runs agent.reply(), then verifies:
|
||||
/// - All 10 summaries are present in the final conversation
|
||||
/// - The original tool pairs are marked invisible
|
||||
#[tokio::test]
|
||||
async fn test_batch_summarization_preserves_all_summaries() -> Result<()> {
|
||||
// Set a low cutoff so we don't need hundreds of tool pairs.
|
||||
// cutoff=2 means we need >2+10=12 visible tool pairs to trigger.
|
||||
Config::global()
|
||||
.set_param("GOOSE_TOOL_CALL_CUTOFF", 2)
|
||||
.unwrap();
|
||||
|
||||
let agent = Agent::new();
|
||||
let session_manager = agent.config.session_manager.clone();
|
||||
let provider = Arc::new(SummarizationTestProvider::new());
|
||||
|
||||
let session = session_manager
|
||||
.create_session(
|
||||
PathBuf::from("."),
|
||||
"summarization-test".to_string(),
|
||||
SessionType::Hidden,
|
||||
GooseMode::Auto,
|
||||
)
|
||||
.await?;
|
||||
|
||||
agent.update_provider(provider, &session.id).await?;
|
||||
|
||||
// Pre-populate: start with a user message, then 13 tool call/response pairs
|
||||
// (need > cutoff + 10 = 12 to trigger batch summarization)
|
||||
let initial_msg = Message::user().with_text("help me read some files");
|
||||
session_manager
|
||||
.add_message(&session.id, &initial_msg)
|
||||
.await?;
|
||||
|
||||
for i in 0..13 {
|
||||
let call_id = format!("precall_{}", i);
|
||||
let req_msg = Message::assistant()
|
||||
.with_tool_request(&call_id, Ok(CallToolRequestParams::new("read_file")))
|
||||
.with_generated_id();
|
||||
session_manager.add_message(&session.id, &req_msg).await?;
|
||||
|
||||
let resp_msg = Message::user()
|
||||
.with_tool_response(
|
||||
&call_id,
|
||||
Ok(CallToolResult::success(vec![RawContent::text(format!(
|
||||
"content of file {}",
|
||||
i
|
||||
))
|
||||
.no_annotation()])),
|
||||
)
|
||||
.with_generated_id();
|
||||
session_manager.add_message(&session.id, &resp_msg).await?;
|
||||
}
|
||||
|
||||
// Send a user message to trigger the reply loop
|
||||
let user_message = Message::user().with_text("summarize what you found");
|
||||
|
||||
let session_config = SessionConfig {
|
||||
id: session.id.clone(),
|
||||
schedule_id: None,
|
||||
max_turns: Some(1),
|
||||
retry_config: None,
|
||||
};
|
||||
|
||||
let reply_stream = agent.reply(user_message, session_config, None).await?;
|
||||
tokio::pin!(reply_stream);
|
||||
|
||||
// Drain the stream
|
||||
while let Some(event) = reply_stream.next().await {
|
||||
match event {
|
||||
Ok(AgentEvent::Message(_)) => {}
|
||||
Ok(_) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// Load the final session and inspect the conversation
|
||||
let final_session = session_manager.get_session(&session.id, true).await?;
|
||||
let conversation = final_session
|
||||
.conversation
|
||||
.expect("Session should have a conversation");
|
||||
let messages = conversation.messages();
|
||||
|
||||
// Count summaries: messages that are agent-visible, not user-visible,
|
||||
// and contain our summary text pattern
|
||||
let summaries: Vec<&Message> = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.metadata.agent_visible
|
||||
&& !m.metadata.user_visible
|
||||
&& m.as_concat_text().starts_with("Summary of tool call #")
|
||||
})
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
summaries.len(),
|
||||
10,
|
||||
"Expected 10 summaries (one full batch), got {}. Summary texts: {:?}",
|
||||
summaries.len(),
|
||||
summaries
|
||||
.iter()
|
||||
.map(|m| m.as_concat_text())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
// Verify each summary is unique
|
||||
let summary_texts: std::collections::HashSet<String> =
|
||||
summaries.iter().map(|m| m.as_concat_text()).collect();
|
||||
assert_eq!(summary_texts.len(), 10, "All 10 summaries should be unique");
|
||||
|
||||
// Count invisible tool pairs: original pairs that were summarized
|
||||
// should have agent_visible=false
|
||||
let invisible_tool_msgs: Vec<&Message> = messages
|
||||
.iter()
|
||||
.filter(|m| !m.metadata.agent_visible && (m.is_tool_call() || m.is_tool_response()))
|
||||
.collect();
|
||||
|
||||
// Each summarized pair = 2 invisible messages (request + response)
|
||||
assert_eq!(
|
||||
invisible_tool_msgs.len(),
|
||||
20, // 10 pairs × 2 messages
|
||||
"Expected 20 invisible tool messages (10 summarized pairs), got {}",
|
||||
invisible_tool_msgs.len()
|
||||
);
|
||||
|
||||
// Clean up the config override
|
||||
Config::global().delete("GOOSE_TOOL_CALL_CUTOFF").unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod extension_manager_tests {
|
||||
use super::*;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue