feat(acp): support loading sessions in acp (#5942)

This commit is contained in:
Prem Pillai 2025-12-02 22:20:21 +11:00 committed by GitHub
parent 82686f4465
commit bf188cd9e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 299 additions and 61 deletions

View file

@ -11,7 +11,7 @@ use goose::mcp_utils::ToolResult;
use goose::providers::create; use goose::providers::create;
use goose::session::session_manager::SessionType; use goose::session::session_manager::SessionType;
use goose::session::SessionManager; use goose::session::SessionManager;
use rmcp::model::{Content, RawContent, ResourceContents}; use rmcp::model::{Content, RawContent, ResourceContents, Role};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fs; use std::fs;
use std::sync::Arc; use std::sync::Arc;
@ -570,7 +570,7 @@ impl acp::Agent for GooseAcpAgent {
// Advertise Goose's capabilities // Advertise Goose's capabilities
let agent_capabilities = acp::AgentCapabilities { let agent_capabilities = acp::AgentCapabilities {
load_session: false, // TODO: Implement session persistence load_session: true,
prompt_capabilities: acp::PromptCapabilities { prompt_capabilities: acp::PromptCapabilities {
image: true, // Goose supports image inputs via providers image: true, // Goose supports image inputs via providers
audio: false, // TODO: Add audio support when providers support it audio: false, // TODO: Add audio support when providers support it
@ -638,19 +638,108 @@ impl acp::Agent for GooseAcpAgent {
args: acp::LoadSessionRequest, args: acp::LoadSessionRequest,
) -> Result<acp::LoadSessionResponse, acp::Error> { ) -> Result<acp::LoadSessionResponse, acp::Error> {
info!("ACP: Received load session request {:?}", args); info!("ACP: Received load session request {:?}", args);
// For now, will start a new session. We could use goose session storage as an enhancement
// we would need to map ACP session IDs to goose session ids (which by default are auto generated)
// normal goose session restore in CLI doesn't load conversation visually.
//
// Example flow:
// - Load session file by session_id (might need to map ACP session IDs to Goose session paths)
// - For each message in history:
// - If user message: send user_message_chunk notification
// - If assistant message: send agent_message_chunk notification
// - If tool calls/responses: send appropriate notifications
// For now, we don't support loading previous sessions let session_id = args.session_id.0.to_string();
Err(acp::Error::method_not_found())
let goose_session = SessionManager::get_session(&session_id, true)
.await
.map_err(|e| {
error!("Failed to load session {}: {}", session_id, e);
acp::Error::invalid_params()
})?;
let conversation = goose_session.conversation.ok_or_else(|| {
error!("Session {} has no conversation data", session_id);
acp::Error::internal_error()
})?;
SessionManager::update_session(&session_id)
.working_dir(args.cwd.clone())
.apply()
.await
.map_err(|e| {
error!("Failed to update session working directory: {}", e);
acp::Error::internal_error()
})?;
let mut session = GooseAcpSession {
messages: conversation.clone(),
tool_call_ids: HashMap::new(),
tool_requests: HashMap::new(),
cancel_token: None,
};
// Replay conversation history to client
for message in conversation.messages() {
// Only replay user-visible messages
if !message.metadata.user_visible {
continue;
}
for content_item in &message.content {
match content_item {
MessageContent::Text(text) => {
let update = match message.role {
Role::User => acp::SessionUpdate::UserMessageChunk {
content: text.text.clone().into(),
},
Role::Assistant => acp::SessionUpdate::AgentMessageChunk {
content: text.text.clone().into(),
},
};
let (tx, rx) = oneshot::channel();
self.session_update_tx
.send((
SessionNotification {
session_id: args.session_id.clone(),
update,
meta: None,
},
tx,
))
.map_err(|_| acp::Error::internal_error())?;
rx.await.map_err(|_| acp::Error::internal_error())?;
}
MessageContent::ToolRequest(tool_request) => {
self.handle_tool_request(tool_request, &args.session_id, &mut session)
.await?;
}
MessageContent::ToolResponse(tool_response) => {
self.handle_tool_response(tool_response, &args.session_id, &mut session)
.await?;
}
MessageContent::Thinking(thinking) => {
let (tx, rx) = oneshot::channel();
self.session_update_tx
.send((
SessionNotification {
session_id: args.session_id.clone(),
update: acp::SessionUpdate::AgentThoughtChunk {
content: thinking.thinking.clone().into(),
},
meta: None,
},
tx,
))
.map_err(|_| acp::Error::internal_error())?;
rx.await.map_err(|_| acp::Error::internal_error())?;
}
_ => {
// Ignore other content types
}
}
}
}
let mut sessions = self.sessions.lock().await;
sessions.insert(session_id.clone(), session);
info!("Loaded ACP session {}", session_id);
Ok(acp::LoadSessionResponse {
modes: None,
meta: None,
})
} }
async fn prompt(&self, args: acp::PromptRequest) -> Result<acp::PromptResponse, acp::Error> { async fn prompt(&self, args: acp::PromptRequest) -> Result<acp::PromptResponse, acp::Error> {

View file

@ -2,14 +2,23 @@
""" """
Simple ACP client to test the goose ACP agent. Simple ACP client to test the goose ACP agent.
Connects to goose acp running on stdio. Connects to goose acp running on stdio.
Tests:
1. Initialize - Establish connection and verify capabilities
2. session/new - Create a new session
3. session/prompt - Send a prompt to the session
4. session/load - Load an existing session (new feature)
""" """
import subprocess import subprocess
import json import json
import os
import sys
import time
class AcpClient: class AcpClient:
def __init__(self): def __init__(self):
# Start the goose acp process
self.process = subprocess.Popen( self.process = subprocess.Popen(
['cargo', 'run', '-p', 'goose-cli', '--', 'acp'], ['cargo', 'run', '-p', 'goose-cli', '--', 'acp'],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
@ -19,8 +28,19 @@ class AcpClient:
bufsize=0 bufsize=0
) )
self.request_id = 0 self.request_id = 0
def send_request(self, method, params=None): def send_request(self, method, params=None, collect_notifications=False):
"""Send a request and wait for the response.
Args:
method: The JSON-RPC method name
params: Optional parameters for the request
collect_notifications: If True, collect notifications until response arrives
Returns:
Tuple of (response, notifications) if collect_notifications is True,
otherwise just the response.
"""
self.request_id += 1 self.request_id += 1
request = { request = {
"jsonrpc": "2.0", "jsonrpc": "2.0",
@ -29,22 +49,42 @@ class AcpClient:
} }
if params: if params:
request["params"] = params request["params"] = params
# Send the request
request_str = json.dumps(request) request_str = json.dumps(request)
print(f">>> Sending: {request_str}") print(f">>> Sending: {request_str}")
self.process.stdin.write(request_str + '\n') self.process.stdin.write(request_str + '\n')
self.process.stdin.flush() self.process.stdin.flush()
# Read response notifications = []
response_line = self.process.stdout.readline()
if not response_line: # Read responses until we get one with our request ID
return None while True:
response_line = self.process.stdout.readline()
print(f"<<< Response: {response_line}") if not response_line:
return json.loads(response_line) if collect_notifications:
return None, notifications
return None
response = json.loads(response_line)
# Check if this is a notification (has 'method' but no 'id')
if 'method' in response and 'id' not in response:
print(f"<<< Notification: {response['method']}: {response.get('params', {}).get('update', {}).get('sessionUpdate', 'unknown')}")
if collect_notifications:
notifications.append(response)
continue
if response.get('id') == self.request_id:
print(f"<<< Response: {response_line.strip()}")
if collect_notifications:
return response, notifications
return response
else:
# Response for a different request ID, skip
print(f"<<< Unexpected response ID: {response}")
def initialize(self): def initialize(self):
"""Initialize the ACP connection and verify capabilities."""
return self.send_request("initialize", { return self.send_request("initialize", {
"protocolVersion": "v1", "protocolVersion": "v1",
"clientCapabilities": {}, "clientCapabilities": {},
@ -53,14 +93,33 @@ class AcpClient:
"version": "1.0.0" "version": "1.0.0"
} }
}) })
def new_session(self): def new_session(self, cwd=None):
return self.send_request("newSession", { """Create a new session (session/new)."""
"context": {} params = {
}) "mcpServers": [],
"cwd": cwd or os.getcwd()
}
return self.send_request("session/new", params)
def load_session(self, session_id, cwd=None):
"""Load an existing session (session/load).
Returns: (response, notifications) tuple with session history notifications.
"""
params = {
"sessionId": session_id,
"mcpServers": [],
"cwd": cwd or os.getcwd()
}
return self.send_request("session/load", params, collect_notifications=True)
def prompt(self, session_id, text): def prompt(self, session_id, text):
return self.send_request("prompt", { """Send a prompt to the session (session/prompt).
Returns: (response, notifications) tuple with streaming notifications.
"""
return self.send_request("session/prompt", {
"sessionId": session_id, "sessionId": session_id,
"prompt": [ "prompt": [
{ {
@ -68,48 +127,138 @@ class AcpClient:
"text": text "text": text
} }
] ]
}) }, collect_notifications=True)
def close(self): def close(self):
if self.process: if self.process:
self.process.terminate() self.process.terminate()
self.process.wait() self.process.wait()
def test_new_session(client):
"""Test creating a new session and sending a prompt."""
print("\n" + "="*60)
print("TEST: New Session Flow")
print("="*60)
print("\n2. Creating new session (session/new)...")
session_response = client.new_session()
if session_response and 'result' in session_response:
session_id = session_response['result']['sessionId']
print(f" ✓ Created session: {session_id}")
return session_id
else:
print(f" ✗ Failed to create session: {session_response}")
return None
def test_load_session(client, session_id):
"""Test loading an existing session."""
print("\n" + "="*60)
print("TEST: Load Session Flow")
print("="*60)
print(f"\n4. Loading existing session (session/load) with ID: {session_id}")
load_response, notifications = client.load_session(session_id)
# Show notifications received (these are the session history)
if notifications:
print(f" 📝 Received {len(notifications)} notification(s) (session history replay):")
for n in notifications:
update = n.get('params', {}).get('update', {})
update_type = update.get('sessionUpdate', 'unknown')
content = update.get('content', {})
if isinstance(content, dict):
text = content.get('text', '')[:50]
else:
text = str(content)[:50]
print(f" - {update_type}: {text}...")
if load_response and 'result' in load_response:
print(f" ✓ Session loaded successfully")
print(f" Response: {load_response['result']}")
return True
else:
print(f" ✗ Failed to load session: {load_response}")
return False
def main(): def main():
print("Starting ACP client test...") print("="*60)
print("ACP Client Test Suite")
print("="*60)
print("\nStarting ACP client test...")
client = AcpClient() client = AcpClient()
try: try:
# Initialize the agent
print("\n1. Initializing agent...") print("\n1. Initializing agent...")
init_response = client.initialize() init_response = client.initialize()
if init_response and 'result' in init_response: if init_response and 'result' in init_response:
print(f" Initialized successfully: {init_response['result']}") capabilities = init_response['result'].get('agentCapabilities', {})
print(f" ✓ Initialized successfully")
print(f" - loadSession capability: {capabilities.get('loadSession', False)}")
print(f" - promptCapabilities: {capabilities.get('promptCapabilities', {})}")
if not capabilities.get('loadSession'):
print(" ⚠ Warning: loadSession capability is not advertised")
else: else:
print(f" Failed to initialize: {init_response}") print(f" ✗ Failed to initialize: {init_response}")
return return 1
# Create a new session session_id = test_new_session(client)
print("\n2. Creating new session...") if not session_id:
session_response = client.new_session() return 1
if session_response and 'result' in session_response:
session_id = session_response['result']['sessionId'] print("\n3. Sending prompt (session/prompt)...")
print(f" Created session: {session_id}") prompt_response, notifications = client.prompt(session_id, "Hello! Say 'test successful' if you can hear me.")
if notifications:
print(f" 📝 Received {len(notifications)} streaming notification(s)")
if prompt_response and 'result' in prompt_response:
print(f" ✓ Got response: {prompt_response['result']}")
elif prompt_response and 'error' in prompt_response:
print(f" ✗ Error: {prompt_response['error']}")
else: else:
print(f" Failed to create session: {session_response}") print(f" ✗ Failed to get prompt response: {prompt_response}")
return
# Close the client and start a new one to simulate reconnection
# Send a prompt print("\n--- Simulating client restart ---")
print("\n3. Sending prompt...") client.close()
prompt_response = client.prompt(session_id, "Hello! What is 2 + 2?") time.sleep(1)
if prompt_response:
print(f" Got response: {prompt_response}") client = AcpClient()
print("\n5. Re-initializing after restart...")
init_response = client.initialize()
if init_response and 'result' in init_response:
print(f" ✓ Re-initialized successfully")
else: else:
print(" Failed to get prompt response") print(f" ✗ Failed to re-initialize: {init_response}")
return 1
if not test_load_session(client, session_id):
return 1
print("\n6. Sending prompt to loaded session...")
prompt_response, notifications = client.prompt(session_id, "What was my previous message?")
if notifications:
print(f" 📝 Received {len(notifications)} streaming notification(s)")
if prompt_response and 'result' in prompt_response:
print(f" ✓ Got response: {prompt_response['result']}")
elif prompt_response and 'error' in prompt_response:
print(f" ✗ Error: {prompt_response['error']}")
else:
print(f" ✗ Failed to get prompt response: {prompt_response}")
print("\n" + "="*60)
print("All tests completed!")
print("="*60)
return 0
finally: finally:
client.close() client.close()
print("\nTest complete.") print("\nTest complete.")
if __name__ == "__main__": if __name__ == "__main__":
main() sys.exit(main())