mirror of
https://github.com/zed-industries/zed.git
synced 2026-05-27 08:34:11 +00:00
agent: Improve the subagent task structure (#49629)
Removes tool filtering since this was throwing off certain models, and also allows for more generic task prompts that don't always require summaries. Since the models usually provide a wrap-up message, we don't have to wait for another turn. This also sets us up to allow the agent to re-interact with an existing subagent thread. Release Notes: - N/A --------- Co-authored-by: Jakub Konka <kubkon@jakubkonka.com>
This commit is contained in:
parent
bc31ad4a8c
commit
85c23d0d0b
5 changed files with 64 additions and 597 deletions
|
|
@ -342,7 +342,7 @@ impl NativeAgent {
|
|||
fn register_session(
|
||||
&mut self,
|
||||
thread_handle: Entity<Thread>,
|
||||
allowed_tool_names: Option<Vec<&str>>,
|
||||
allowed_tool_names: Option<Vec<SharedString>>,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Entity<AcpThread> {
|
||||
let connection = Rc::new(NativeAgentConnection(cx.entity()));
|
||||
|
|
@ -1590,7 +1590,6 @@ impl NativeThreadEnvironment {
|
|||
label: String,
|
||||
initial_prompt: String,
|
||||
timeout: Option<Duration>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
cx: &mut App,
|
||||
) -> Result<Rc<dyn SubagentHandle>> {
|
||||
let parent_thread = parent_thread_entity.read(cx);
|
||||
|
|
@ -1602,20 +1601,7 @@ impl NativeThreadEnvironment {
|
|||
MAX_SUBAGENT_DEPTH
|
||||
));
|
||||
}
|
||||
|
||||
let allowed_tools = match allowed_tools {
|
||||
Some(tools) => {
|
||||
let parent_tool_names: std::collections::HashSet<&str> =
|
||||
parent_thread.tools.keys().map(|s| s.as_str()).collect();
|
||||
Some(
|
||||
tools
|
||||
.into_iter()
|
||||
.filter(|t| parent_tool_names.contains(t.as_str()))
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
None => Some(parent_thread.tools.keys().map(|s| s.to_string()).collect()),
|
||||
};
|
||||
let allowed_tool_names = Some(parent_thread.tools.keys().cloned().collect::<Vec<_>>());
|
||||
|
||||
let subagent_thread: Entity<Thread> = cx.new(|cx| {
|
||||
let mut thread = Thread::new_subagent(&parent_thread_entity, cx);
|
||||
|
|
@ -1626,13 +1612,7 @@ impl NativeThreadEnvironment {
|
|||
let session_id = subagent_thread.read(cx).id().clone();
|
||||
|
||||
let acp_thread = agent.update(cx, |agent, cx| {
|
||||
agent.register_session(
|
||||
subagent_thread.clone(),
|
||||
allowed_tools
|
||||
.as_ref()
|
||||
.map(|v| v.iter().map(|s| s.as_str()).collect()),
|
||||
cx,
|
||||
)
|
||||
agent.register_session(subagent_thread.clone(), allowed_tool_names, cx)
|
||||
})?;
|
||||
|
||||
parent_thread_entity.update(cx, |parent_thread, _cx| {
|
||||
|
|
@ -1676,7 +1656,6 @@ impl NativeThreadEnvironment {
|
|||
session_id,
|
||||
subagent_thread,
|
||||
parent_thread: parent_thread_entity.downgrade(),
|
||||
acp_thread,
|
||||
wait_for_prompt_to_complete,
|
||||
}) as _)
|
||||
}
|
||||
|
|
@ -1722,7 +1701,6 @@ impl ThreadEnvironment for NativeThreadEnvironment {
|
|||
label: String,
|
||||
initial_prompt: String,
|
||||
timeout: Option<Duration>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
cx: &mut App,
|
||||
) -> Result<Rc<dyn SubagentHandle>> {
|
||||
Self::create_subagent_thread(
|
||||
|
|
@ -1731,7 +1709,6 @@ impl ThreadEnvironment for NativeThreadEnvironment {
|
|||
label,
|
||||
initial_prompt,
|
||||
timeout,
|
||||
allowed_tools,
|
||||
cx,
|
||||
)
|
||||
}
|
||||
|
|
@ -1748,7 +1725,6 @@ pub struct NativeSubagentHandle {
|
|||
session_id: acp::SessionId,
|
||||
parent_thread: WeakEntity<Thread>,
|
||||
subagent_thread: Entity<Thread>,
|
||||
acp_thread: Entity<AcpThread>,
|
||||
wait_for_prompt_to_complete: Shared<Task<SubagentInitialPromptResult>>,
|
||||
}
|
||||
|
||||
|
|
@ -1757,51 +1733,35 @@ impl SubagentHandle for NativeSubagentHandle {
|
|||
self.session_id.clone()
|
||||
}
|
||||
|
||||
fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
|
||||
fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
|
||||
let thread = self.subagent_thread.clone();
|
||||
let acp_thread = self.acp_thread.clone();
|
||||
let wait_for_prompt = self.wait_for_prompt_to_complete.clone();
|
||||
|
||||
let wait_for_summary_task = cx.spawn(async move |cx| {
|
||||
let timed_out = match wait_for_prompt.await {
|
||||
SubagentInitialPromptResult::Completed => false,
|
||||
SubagentInitialPromptResult::Timeout => true,
|
||||
let subagent_session_id = self.session_id.clone();
|
||||
let parent_thread = self.parent_thread.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
match wait_for_prompt.await {
|
||||
SubagentInitialPromptResult::Completed => {}
|
||||
SubagentInitialPromptResult::Timeout => {
|
||||
return Err(anyhow!("The time to complete the task was exceeded."));
|
||||
}
|
||||
SubagentInitialPromptResult::Cancelled => return Err(anyhow!("User cancelled")),
|
||||
};
|
||||
|
||||
let summary_prompt = if timed_out {
|
||||
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
||||
format!("{}\n{}", "The time to complete the task was exceeded. Stop with the task and follow the directions below:", summary_prompt)
|
||||
} else {
|
||||
summary_prompt
|
||||
};
|
||||
|
||||
let response = acp_thread
|
||||
.update(cx, |thread, cx| thread.send(vec![summary_prompt.into()], cx))
|
||||
.await?;
|
||||
|
||||
let was_canceled = response.is_some_and(|r| r.stop_reason == acp::StopReason::Cancelled);
|
||||
if was_canceled {
|
||||
return Err(anyhow!("User cancelled"));
|
||||
}
|
||||
|
||||
thread.read_with(cx, |thread, _cx| {
|
||||
let result = thread.read_with(cx, |thread, _cx| {
|
||||
thread
|
||||
.last_message()
|
||||
.map(|m| m.to_markdown())
|
||||
.context("No response from subagent")
|
||||
})
|
||||
});
|
||||
});
|
||||
|
||||
let subagent_session_id = self.session_id.clone();
|
||||
let parent_thread = self.parent_thread.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
let result = wait_for_summary_task.await;
|
||||
parent_thread
|
||||
.update(cx, |parent_thread, cx| {
|
||||
parent_thread.unregister_running_subagent(&subagent_session_id, cx)
|
||||
})
|
||||
.ok();
|
||||
|
||||
result
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ impl SubagentHandle for FakeSubagentHandle {
|
|||
self.session_id.clone()
|
||||
}
|
||||
|
||||
fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
|
||||
fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
|
||||
let task = self.wait_for_summary_task.clone();
|
||||
cx.background_spawn(async move { Ok(task.await) })
|
||||
}
|
||||
|
|
@ -208,7 +208,6 @@ impl crate::ThreadEnvironment for FakeThreadEnvironment {
|
|||
_label: String,
|
||||
_initial_prompt: String,
|
||||
_timeout_ms: Option<Duration>,
|
||||
_allowed_tools: Option<Vec<String>>,
|
||||
_cx: &mut App,
|
||||
) -> Result<Rc<dyn SubagentHandle>> {
|
||||
Ok(self
|
||||
|
|
@ -255,7 +254,6 @@ impl crate::ThreadEnvironment for MultiTerminalEnvironment {
|
|||
_label: String,
|
||||
_initial_prompt: String,
|
||||
_timeout: Option<Duration>,
|
||||
_allowed_tools: Option<Vec<String>>,
|
||||
_cx: &mut App,
|
||||
) -> Result<Rc<dyn SubagentHandle>> {
|
||||
unimplemented!()
|
||||
|
|
@ -4234,10 +4232,8 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
|
|||
model.send_last_completion_stream_text_chunk("spawning subagent");
|
||||
let subagent_tool_input = SubagentToolInput {
|
||||
label: "label".to_string(),
|
||||
task_prompt: "subagent task prompt".to_string(),
|
||||
summary_prompt: "subagent summary prompt".to_string(),
|
||||
timeout_ms: None,
|
||||
allowed_tools: None,
|
||||
prompt: "subagent task prompt".to_string(),
|
||||
timeout: None,
|
||||
};
|
||||
let subagent_tool_use = LanguageModelToolUse {
|
||||
id: "subagent_1".into(),
|
||||
|
|
@ -4276,11 +4272,6 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
|
|||
|
||||
cx.run_until_parked();
|
||||
|
||||
model.send_last_completion_stream_text_chunk("subagent summary response");
|
||||
model.end_last_completion_stream();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
assert_eq!(
|
||||
subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
|
||||
indoc! {"
|
||||
|
|
@ -4292,14 +4283,6 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
|
|||
|
||||
subagent task response
|
||||
|
||||
## User
|
||||
|
||||
subagent summary prompt
|
||||
|
||||
## Assistant
|
||||
|
||||
subagent summary response
|
||||
|
||||
"}
|
||||
);
|
||||
|
||||
|
|
@ -4325,8 +4308,8 @@ async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
|
|||
|
||||
```json
|
||||
{{
|
||||
"subagent_session_id": "{}",
|
||||
"summary": "subagent summary response\n"
|
||||
"session_id": "{}",
|
||||
"output": "subagent task response\n"
|
||||
}}
|
||||
```
|
||||
|
||||
|
|
@ -4399,10 +4382,8 @@ async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAp
|
|||
model.send_last_completion_stream_text_chunk("spawning subagent");
|
||||
let subagent_tool_input = SubagentToolInput {
|
||||
label: "label".to_string(),
|
||||
task_prompt: "subagent task prompt".to_string(),
|
||||
summary_prompt: "subagent summary prompt".to_string(),
|
||||
timeout_ms: None,
|
||||
allowed_tools: None,
|
||||
prompt: "subagent task prompt".to_string(),
|
||||
timeout: None,
|
||||
};
|
||||
let subagent_tool_use = LanguageModelToolUse {
|
||||
id: "subagent_1".into(),
|
||||
|
|
@ -4479,153 +4460,6 @@ async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAp
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subagent_tool_call_cancellation_during_summary_prompt(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
cx.update(|cx| {
|
||||
LanguageModelRegistry::test(cx);
|
||||
});
|
||||
cx.update(|cx| {
|
||||
cx.update_flags(true, vec!["subagents".to_string()]);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/",
|
||||
json!({
|
||||
"a": {
|
||||
"b.md": "Lorem"
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
|
||||
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
||||
let agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
thread_store.clone(),
|
||||
Templates::new(),
|
||||
None,
|
||||
fs.clone(),
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let connection = Rc::new(NativeAgentConnection(agent.clone()));
|
||||
|
||||
let acp_thread = cx
|
||||
.update(|cx| {
|
||||
connection
|
||||
.clone()
|
||||
.new_session(project.clone(), Path::new(""), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
|
||||
let thread = agent.read_with(cx, |agent, _| {
|
||||
agent.sessions.get(&session_id).unwrap().thread.clone()
|
||||
});
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
|
||||
// Ensure empty threads are not saved, even if they get mutated.
|
||||
thread.update(cx, |thread, cx| {
|
||||
thread.set_model(model.clone(), cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
|
||||
let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
|
||||
cx.run_until_parked();
|
||||
model.send_last_completion_stream_text_chunk("spawning subagent");
|
||||
let subagent_tool_input = SubagentToolInput {
|
||||
label: "label".to_string(),
|
||||
task_prompt: "subagent task prompt".to_string(),
|
||||
summary_prompt: "subagent summary prompt".to_string(),
|
||||
timeout_ms: None,
|
||||
allowed_tools: None,
|
||||
};
|
||||
let subagent_tool_use = LanguageModelToolUse {
|
||||
id: "subagent_1".into(),
|
||||
name: SubagentTool::NAME.into(),
|
||||
raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
|
||||
input: serde_json::to_value(&subagent_tool_input).unwrap(),
|
||||
is_input_complete: true,
|
||||
thought_signature: None,
|
||||
};
|
||||
model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
|
||||
subagent_tool_use,
|
||||
));
|
||||
model.end_last_completion_stream();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let subagent_session_id = thread.read_with(cx, |thread, cx| {
|
||||
thread
|
||||
.running_subagent_ids(cx)
|
||||
.get(0)
|
||||
.expect("subagent thread should be running")
|
||||
.clone()
|
||||
});
|
||||
let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
|
||||
agent
|
||||
.sessions
|
||||
.get(&subagent_session_id)
|
||||
.expect("subagent session should exist")
|
||||
.acp_thread
|
||||
.clone()
|
||||
});
|
||||
|
||||
model.send_last_completion_stream_text_chunk("subagent task response");
|
||||
model.end_last_completion_stream();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
send.await.unwrap();
|
||||
|
||||
acp_thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.status(), ThreadStatus::Idle);
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
Prompt
|
||||
|
||||
## Assistant
|
||||
|
||||
spawning subagent
|
||||
|
||||
**Tool Call: label**
|
||||
Status: Canceled
|
||||
|
||||
"}
|
||||
);
|
||||
});
|
||||
subagent_acp_thread.read_with(cx, |thread, cx| {
|
||||
assert_eq!(thread.status(), ThreadStatus::Idle);
|
||||
assert_eq!(
|
||||
thread.to_markdown(cx),
|
||||
indoc! {"
|
||||
## User
|
||||
|
||||
subagent task prompt
|
||||
|
||||
## Assistant
|
||||
|
||||
subagent task response
|
||||
|
||||
## User
|
||||
|
||||
subagent summary prompt
|
||||
|
||||
"}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
|
@ -4818,105 +4652,6 @@ async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
|
|||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
always_allow_tools(cx);
|
||||
|
||||
cx.update(|cx| {
|
||||
cx.update_flags(true, vec!["subagents".to_string()]);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
||||
let native_agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
thread_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let parent_thread = cx.new(|cx| {
|
||||
Thread::new(
|
||||
project.clone(),
|
||||
project_context,
|
||||
context_server_registry,
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
let subagent_handle = cx
|
||||
.update(|cx| {
|
||||
NativeThreadEnvironment::create_subagent_thread(
|
||||
native_agent.downgrade(),
|
||||
parent_thread.clone(),
|
||||
"some title".to_string(),
|
||||
"task prompt".to_string(),
|
||||
Some(Duration::from_millis(10)),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.expect("Failed to create subagent");
|
||||
|
||||
let summary_task =
|
||||
subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
{
|
||||
let messages = model.pending_completions().last().unwrap().messages.clone();
|
||||
// Ensure that model received a system prompt
|
||||
assert_eq!(messages[0].role, Role::System);
|
||||
// Ensure that model received a task prompt
|
||||
assert_eq!(messages[1].role, Role::User);
|
||||
assert_eq!(
|
||||
messages[1].content,
|
||||
vec![MessageContent::Text("task prompt".to_string())]
|
||||
);
|
||||
}
|
||||
|
||||
model.send_last_completion_stream_text_chunk("Some task response...");
|
||||
model.end_last_completion_stream();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
{
|
||||
let messages = model.pending_completions().last().unwrap().messages.clone();
|
||||
assert_eq!(messages[2].role, Role::Assistant);
|
||||
assert_eq!(
|
||||
messages[2].content,
|
||||
vec![MessageContent::Text("Some task response...".to_string())]
|
||||
);
|
||||
// Ensure that model received a summary prompt
|
||||
assert_eq!(messages[3].role, Role::User);
|
||||
assert_eq!(
|
||||
messages[3].content,
|
||||
vec![MessageContent::Text("summary prompt".to_string())]
|
||||
);
|
||||
}
|
||||
|
||||
model.send_last_completion_stream_text_chunk("Some summary...");
|
||||
model.end_last_completion_stream();
|
||||
|
||||
let result = summary_task.await;
|
||||
assert_eq!(result.unwrap(), "Some summary...\n");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
|
||||
cx: &mut TestAppContext,
|
||||
|
|
@ -4967,15 +4702,13 @@ async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceede
|
|||
parent_thread.clone(),
|
||||
"some title".to_string(),
|
||||
"task prompt".to_string(),
|
||||
Some(Duration::from_millis(100)),
|
||||
None,
|
||||
Some(Duration::from_secs(1)),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.expect("Failed to create subagent");
|
||||
|
||||
let summary_task =
|
||||
subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
|
||||
let summary_task = subagent_handle.wait_for_output(&cx.to_async());
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
|
|
@ -4991,29 +4724,16 @@ async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceede
|
|||
}
|
||||
|
||||
// Don't complete the initial model stream — let the timeout expire instead.
|
||||
cx.executor().advance_clock(Duration::from_millis(200));
|
||||
cx.executor().advance_clock(Duration::from_secs(2));
|
||||
cx.run_until_parked();
|
||||
|
||||
// After the timeout fires, the thread is cancelled and context_low_prompt is sent
|
||||
// instead of the summary_prompt.
|
||||
{
|
||||
let messages = model.pending_completions().last().unwrap().messages.clone();
|
||||
let last_user_message = messages
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|m| m.role == Role::User)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
last_user_message.content,
|
||||
vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
|
||||
);
|
||||
}
|
||||
|
||||
model.send_last_completion_stream_text_chunk("Some context low response...");
|
||||
model.end_last_completion_stream();
|
||||
|
||||
let result = summary_task.await;
|
||||
assert_eq!(result.unwrap(), "Some context low response...\n");
|
||||
let error = summary_task.await.unwrap_err();
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"The time to complete the task was exceeded."
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
|
@ -5068,7 +4788,6 @@ async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
|
|||
"some title".to_string(),
|
||||
"task prompt".to_string(),
|
||||
Some(Duration::from_millis(10)),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
|
|
@ -5089,77 +4808,6 @@ async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
|
|||
assert!(tools.contains(&"list_directory".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
always_allow_tools(cx);
|
||||
|
||||
cx.update(|cx| {
|
||||
cx.update_flags(true, vec!["subagents".to_string()]);
|
||||
});
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||
cx.update(LanguageModelRegistry::test);
|
||||
let model = Arc::new(FakeLanguageModel::default());
|
||||
let thread_store = cx.new(|cx| ThreadStore::new(cx));
|
||||
let native_agent = NativeAgent::new(
|
||||
project.clone(),
|
||||
thread_store,
|
||||
Templates::new(),
|
||||
None,
|
||||
fs,
|
||||
&mut cx.to_async(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let parent_thread = cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
project.clone(),
|
||||
project_context,
|
||||
context_server_registry,
|
||||
Templates::new(),
|
||||
Some(model.clone()),
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(ListDirectoryTool::new(project.clone()), None);
|
||||
thread.add_tool(GrepTool::new(project.clone()), None);
|
||||
thread
|
||||
});
|
||||
|
||||
let _subagent_handle = cx
|
||||
.update(|cx| {
|
||||
NativeThreadEnvironment::create_subagent_thread(
|
||||
native_agent.downgrade(),
|
||||
parent_thread.clone(),
|
||||
"some title".to_string(),
|
||||
"task prompt".to_string(),
|
||||
Some(Duration::from_millis(10)),
|
||||
Some(vec!["grep".to_string()]),
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.expect("Failed to create subagent");
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let tools = model
|
||||
.pending_completions()
|
||||
.last()
|
||||
.unwrap()
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(tools, vec!["grep"]);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
|
|
|||
|
|
@ -601,7 +601,7 @@ pub trait TerminalHandle {
|
|||
|
||||
pub trait SubagentHandle {
|
||||
fn id(&self) -> acp::SessionId;
|
||||
fn wait_for_summary(&self, summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>>;
|
||||
fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>>;
|
||||
}
|
||||
|
||||
pub trait ThreadEnvironment {
|
||||
|
|
@ -619,7 +619,6 @@ pub trait ThreadEnvironment {
|
|||
label: String,
|
||||
initial_prompt: String,
|
||||
timeout: Option<Duration>,
|
||||
allowed_tools: Option<Vec<String>>,
|
||||
cx: &mut App,
|
||||
) -> Result<Rc<dyn SubagentHandle>>;
|
||||
}
|
||||
|
|
@ -1327,7 +1326,7 @@ impl Thread {
|
|||
|
||||
pub fn add_default_tools(
|
||||
&mut self,
|
||||
allowed_tool_names: Option<Vec<&str>>,
|
||||
allowed_tool_names: Option<Vec<SharedString>>,
|
||||
environment: Rc<dyn ThreadEnvironment>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
|
|
@ -1421,8 +1420,14 @@ impl Thread {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn add_tool<T: AgentTool>(&mut self, tool: T, allowed_tool_names: Option<&Vec<&str>>) {
|
||||
if allowed_tool_names.is_some_and(|tool_names| !tool_names.contains(&T::NAME)) {
|
||||
pub fn add_tool<T: AgentTool>(
|
||||
&mut self,
|
||||
tool: T,
|
||||
allowed_tool_names: Option<&Vec<SharedString>>,
|
||||
) {
|
||||
if allowed_tool_names
|
||||
.is_some_and(|tool_names| !tool_names.iter().any(|x| x.as_str() == T::NAME))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use acp_thread::SUBAGENT_SESSION_ID_META_KEY;
|
||||
use agent_client_protocol as acp;
|
||||
use anyhow::{Result, anyhow};
|
||||
use gpui::{App, Entity, SharedString, Task, WeakEntity};
|
||||
use gpui::{App, SharedString, Task, WeakEntity};
|
||||
use language_model::LanguageModelToolResultContent;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -10,64 +10,42 @@ use std::{rc::Rc, time::Duration};
|
|||
|
||||
use crate::{AgentTool, Thread, ThreadEnvironment, ToolCallEventStream};
|
||||
|
||||
/// Spawns a subagent with its own context window to perform a delegated task.
|
||||
/// Spawns an agent to perform a delegated task.
|
||||
///
|
||||
/// Use this tool when you want to do any of the following:
|
||||
/// - Perform an investigation where all you need to know is the outcome, not the research that led to that outcome.
|
||||
/// - Complete a self-contained task where you need to know if it succeeded or failed (and how), but none of its intermediate output.
|
||||
/// - Run multiple tasks in parallel that would take significantly longer to run sequentially.
|
||||
///
|
||||
/// You control what the subagent does by providing:
|
||||
/// 1. A task prompt describing what the subagent should do
|
||||
/// 2. A summary prompt that tells the subagent how to summarize its work when done
|
||||
/// 3. A "context running out" prompt for when the subagent is low on tokens
|
||||
/// You control what the agent does by providing a prompt describing what the agent should do. The agent has access to the same tools you do.
|
||||
///
|
||||
/// Each subagent has access to the same tools you do. You can optionally restrict
|
||||
/// which tools each subagent can use.
|
||||
/// You will receive the agent's final message.
|
||||
///
|
||||
/// Note:
|
||||
/// - Maximum 8 subagents can run in parallel
|
||||
/// - Subagents cannot use tools you don't have access to
|
||||
/// - If spawning multiple subagents that might write to the filesystem, provide
|
||||
/// guidance on how to avoid conflicts (e.g. assign each to different directories)
|
||||
/// - Instruct subagents to be concise in their summaries to conserve your context
|
||||
/// - Agents cannot use tools you don't have access to.
|
||||
/// - If spawning multiple agents that might write to the filesystem, provide guidance on how to avoid conflicts (e.g. assign each to different directories)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SubagentToolInput {
|
||||
/// Short label displayed in the UI while the subagent runs (e.g., "Researching alternatives")
|
||||
/// Short label displayed in the UI while the agent runs (e.g., "Researching alternatives")
|
||||
pub label: String,
|
||||
|
||||
/// The initial prompt that tells the subagent what task to perform.
|
||||
/// Be specific about what you want the subagent to accomplish.
|
||||
pub task_prompt: String,
|
||||
|
||||
/// The prompt sent to the subagent when it completes its task, asking it
|
||||
/// to summarize what it did and return results. This summary becomes the
|
||||
/// tool result you receive.
|
||||
///
|
||||
/// Example: "Summarize what you found, listing the top 3 alternatives with pros/cons."
|
||||
pub summary_prompt: String,
|
||||
|
||||
/// Optional: Maximum runtime in milliseconds. If exceeded, the subagent is
|
||||
/// asked to summarize and return. No timeout by default.
|
||||
/// The prompt that tells the agent what task to perform. Be specific about what you want the agent to accomplish.
|
||||
pub prompt: String,
|
||||
/// Optional: Maximum runtime in seconds. No timeout by default.
|
||||
#[serde(default)]
|
||||
pub timeout_ms: Option<u64>,
|
||||
|
||||
/// Optional: List of tool names the subagent is allowed to use.
|
||||
/// If not provided, the subagent can use all tools available to the parent.
|
||||
/// Tools listed here must be a subset of the parent's available tools.
|
||||
#[serde(default)]
|
||||
pub allowed_tools: Option<Vec<String>>,
|
||||
pub timeout: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct SubagentToolOutput {
|
||||
pub subagent_session_id: acp::SessionId,
|
||||
pub summary: String,
|
||||
pub session_id: acp::SessionId,
|
||||
pub output: String,
|
||||
}
|
||||
|
||||
impl From<SubagentToolOutput> for LanguageModelToolResultContent {
|
||||
fn from(output: SubagentToolOutput) -> Self {
|
||||
output.summary.into()
|
||||
serde_json::to_string(&output)
|
||||
.expect("Failed to serialize SubagentToolOutput")
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,32 +62,6 @@ impl SubagentTool {
|
|||
environment,
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_allowed_tools(
|
||||
allowed_tools: &Option<Vec<String>>,
|
||||
parent_thread: &Entity<Thread>,
|
||||
cx: &App,
|
||||
) -> Result<()> {
|
||||
let Some(allowed_tools) = allowed_tools else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let thread = parent_thread.read(cx);
|
||||
let invalid_tools: Vec<_> = allowed_tools
|
||||
.iter()
|
||||
.filter(|tool| !thread.tools.contains_key(tool.as_str()))
|
||||
.map(|s| format!("'{s}'"))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if !invalid_tools.is_empty() {
|
||||
return Err(anyhow!(
|
||||
"The following tools do not exist: {}",
|
||||
invalid_tools.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentTool for SubagentTool {
|
||||
|
|
@ -142,18 +94,11 @@ impl AgentTool for SubagentTool {
|
|||
return Task::ready(Err(anyhow!("Parent thread no longer exists")));
|
||||
};
|
||||
|
||||
if let Err(e) =
|
||||
Self::validate_allowed_tools(&input.allowed_tools, &parent_thread_entity, cx)
|
||||
{
|
||||
return Task::ready(Err(e));
|
||||
}
|
||||
|
||||
let subagent = match self.environment.create_subagent(
|
||||
parent_thread_entity,
|
||||
input.label,
|
||||
input.task_prompt,
|
||||
input.timeout_ms.map(|ms| Duration::from_millis(ms)),
|
||||
input.allowed_tools,
|
||||
input.prompt,
|
||||
input.timeout.map(|secs| Duration::from_secs(secs)),
|
||||
cx,
|
||||
) {
|
||||
Ok(subagent) => subagent,
|
||||
|
|
@ -170,10 +115,10 @@ impl AgentTool for SubagentTool {
|
|||
event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta));
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let summary = subagent.wait_for_summary(input.summary_prompt, cx).await?;
|
||||
let output = subagent.wait_for_output(cx).await?;
|
||||
Ok(SubagentToolOutput {
|
||||
subagent_session_id,
|
||||
summary,
|
||||
session_id: subagent_session_id,
|
||||
output,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
@ -185,102 +130,12 @@ impl AgentTool for SubagentTool {
|
|||
event_stream: ToolCallEventStream,
|
||||
_cx: &mut App,
|
||||
) -> Result<()> {
|
||||
event_stream.subagent_spawned(output.subagent_session_id.clone());
|
||||
event_stream.subagent_spawned(output.session_id.clone());
|
||||
let meta = acp::Meta::from_iter([(
|
||||
SUBAGENT_SESSION_ID_META_KEY.into(),
|
||||
output.subagent_session_id.to_string().into(),
|
||||
output.session_id.to_string().into(),
|
||||
)]);
|
||||
event_stream.update_fields_with_meta(acp::ToolCallUpdateFields::new(), Some(meta));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ContextServerRegistry, Templates, Thread};
|
||||
use fs::FakeFs;
|
||||
use gpui::{AppContext as _, TestAppContext};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
async fn create_thread_with_tools(cx: &mut TestAppContext) -> Entity<Thread> {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
});
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(path!("/test"), json!({})).await;
|
||||
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
|
||||
let project_context = cx.new(|_cx| ProjectContext::default());
|
||||
let context_server_store =
|
||||
project.read_with(cx, |project, _| project.context_server_store());
|
||||
let context_server_registry =
|
||||
cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
|
||||
|
||||
cx.new(|cx| {
|
||||
let mut thread = Thread::new(
|
||||
project,
|
||||
project_context,
|
||||
context_server_registry,
|
||||
Templates::new(),
|
||||
None,
|
||||
cx,
|
||||
);
|
||||
thread.add_tool(crate::NowTool, None);
|
||||
thread.add_tool(crate::WebSearchTool, None);
|
||||
thread
|
||||
})
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_validate_allowed_tools_succeeds_for_valid_tools(cx: &mut TestAppContext) {
|
||||
let thread = create_thread_with_tools(cx).await;
|
||||
|
||||
cx.update(|cx| {
|
||||
assert!(SubagentTool::validate_allowed_tools(&None, &thread, cx).is_ok());
|
||||
|
||||
let valid_tools = Some(vec!["now".to_string()]);
|
||||
assert!(SubagentTool::validate_allowed_tools(&valid_tools, &thread, cx).is_ok());
|
||||
|
||||
let both_tools = Some(vec!["now".to_string(), "web_search".to_string()]);
|
||||
assert!(SubagentTool::validate_allowed_tools(&both_tools, &thread, cx).is_ok());
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_validate_allowed_tools_fails_for_unknown_tools(cx: &mut TestAppContext) {
|
||||
let thread = create_thread_with_tools(cx).await;
|
||||
|
||||
cx.update(|cx| {
|
||||
let unknown_tools = Some(vec!["nonexistent_tool".to_string()]);
|
||||
let result = SubagentTool::validate_allowed_tools(&unknown_tools, &thread, cx);
|
||||
assert!(result.is_err());
|
||||
let error_message = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_message.contains("'nonexistent_tool'"),
|
||||
"Expected error to mention the invalid tool name, got: {error_message}"
|
||||
);
|
||||
|
||||
let mixed_tools = Some(vec![
|
||||
"now".to_string(),
|
||||
"fake_tool_a".to_string(),
|
||||
"fake_tool_b".to_string(),
|
||||
]);
|
||||
let result = SubagentTool::validate_allowed_tools(&mixed_tools, &thread, cx);
|
||||
assert!(result.is_err());
|
||||
let error_message = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_message.contains("'fake_tool_a'") && error_message.contains("'fake_tool_b'"),
|
||||
"Expected error to mention both invalid tool names, got: {error_message}"
|
||||
);
|
||||
assert!(
|
||||
!error_message.contains("'now'"),
|
||||
"Expected error to not mention valid tool 'now', got: {error_message}"
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -686,7 +686,6 @@ impl agent::ThreadEnvironment for EvalThreadEnvironment {
|
|||
_label: String,
|
||||
_initial_prompt: String,
|
||||
_timeout_ms: Option<Duration>,
|
||||
_allowed_tools: Option<Vec<String>>,
|
||||
_cx: &mut App,
|
||||
) -> Result<Rc<dyn agent::SubagentHandle>> {
|
||||
unimplemented!()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue