This commit is contained in:
taggart_comet 2026-02-03 18:20:22 +02:00
parent bd21ae6296
commit 3fd019bd7a
35 changed files with 633 additions and 343 deletions

View file

@ -82,6 +82,38 @@ impl TracingFacade {
} }
} }
pub fn set_model_config(&mut self, name: impl AsRef<str>, config: HashMap<String, serde_json::Value>) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.model_config = Some(config);
}
}
}
pub fn set_usage(&mut self, name: impl AsRef<str>, input_tokens: u32, output_tokens: u32) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.usage = Some(crate::types::UsageData::new(input_tokens, output_tokens));
}
}
}
pub fn set_input_json(&mut self, name: impl AsRef<str>, input: serde_json::Value) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.input = Some(vec![input]);
}
}
}
pub fn set_output_json(&mut self, name: impl AsRef<str>, output: serde_json::Value) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.output = Some(vec![output]);
}
}
}
pub async fn end(&mut self) { pub async fn end(&mut self) {
for (_, mut span) in self.open_spans.drain() { for (_, mut span) in self.open_spans.drain() {
span.mark_ended(); span.mark_ended();

View file

@ -17,6 +17,7 @@ pub use session::{Session, SessionRequest};
pub use startup::StartupService; pub use startup::StartupService;
pub use user_settings::UserSettings; pub use user_settings::UserSettings;
pub use workflow::{CancellationToken, Chain}; pub use workflow::{CancellationToken, Chain};
#[allow(unused_imports)]
pub use todo::{TodoList, TodoItem}; pub use todo::{TodoList, TodoItem};
/// Model type enum matching the inference engine types /// Model type enum matching the inference engine types
@ -41,4 +42,4 @@ impl ModelType {
_ => None, _ => None,
} }
} }
} }

View file

@ -401,12 +401,17 @@ mod tests {
"read_only" "read_only"
} }
fn parse_input(&self, _input: String) -> Option<Error> { fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
None None
} }
fn work(&self, _request: &dyn Request) -> ToolResult { fn work(&self, _request: &dyn Request) -> ToolResult {
ToolResult::ok("read_only".to_string(), String::new(), String::new()) ToolResult::ok(
"read_only".to_string(),
String::new(),
String::new(),
String::new(),
)
} }
fn parameters(&self) -> serde_json::Value { fn parameters(&self) -> serde_json::Value {
@ -439,12 +444,17 @@ mod tests {
"write_tool" "write_tool"
} }
fn parse_input(&self, _input: String) -> Option<Error> { fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
None None
} }
fn work(&self, _request: &dyn Request) -> ToolResult { fn work(&self, _request: &dyn Request) -> ToolResult {
ToolResult::ok("write_tool".to_string(), String::new(), String::new()) ToolResult::ok(
"write_tool".to_string(),
String::new(),
String::new(),
String::new(),
)
} }
fn parameters(&self) -> serde_json::Value { fn parameters(&self) -> serde_json::Value {
@ -474,12 +484,17 @@ mod tests {
"command_tool" "command_tool"
} }
fn parse_input(&self, _input: String) -> Option<Error> { fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
None None
} }
fn work(&self, _request: &dyn Request) -> ToolResult { fn work(&self, _request: &dyn Request) -> ToolResult {
ToolResult::ok("command_tool".to_string(), String::new(), String::new()) ToolResult::ok(
"command_tool".to_string(),
String::new(),
String::new(),
String::new(),
)
} }
fn parameters(&self) -> serde_json::Value { fn parameters(&self) -> serde_json::Value {
@ -955,4 +970,4 @@ mod tests {
assert_eq!(decision, PermissionDecision::AlwaysAllow); assert_eq!(decision, PermissionDecision::AlwaysAllow);
} }
} }

View file

@ -1,7 +1,6 @@
use super::types::{Permission, PermissionDecision, PermissionScope}; use super::types::{Permission, PermissionDecision, PermissionScope};
use crate::infrastructure::db::DbPool; use crate::infrastructure::db::DbPool;
use rusqlite::params; use rusqlite::params;
use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -121,16 +120,19 @@ impl PermissionStore for SqlitePermissionStore {
) -> Result<Option<Permission>, StoreError> { ) -> Result<Option<Permission>, StoreError> {
// Build query dynamically to handle NULL values properly // Build query dynamically to handle NULL values properly
// In SQL, NULL != '', so we need to use IS NULL for empty strings // In SQL, NULL != '', so we need to use IS NULL for empty strings
let mut param_num = 3;
let command_clause = if command_pattern.is_empty() { let command_clause = if command_pattern.is_empty() {
"command_pattern IS NULL" "command_pattern IS NULL".to_string()
} else { } else {
"command_pattern = ?3" let clause = format!("command_pattern = ?{}", param_num);
param_num += 1;
clause
}; };
let resource_clause = if resource_pattern.is_empty() { let resource_clause = if resource_pattern.is_empty() {
"resource_pattern IS NULL" "resource_pattern IS NULL".to_string()
} else { } else {
"resource_pattern = ?4" format!("resource_pattern = ?{}", param_num)
}; };
let query = format!( let query = format!(
@ -151,32 +153,28 @@ impl PermissionStore for SqlitePermissionStore {
))) )))
})?; })?;
// Build params based on whether patterns are empty // Build params list dynamically to match the query
let result = if command_pattern.is_empty() && resource_pattern.is_empty() { let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = vec![
conn.query_row( Box::new(project_id),
&query, Box::new(tool.to_string()),
params![project_id, tool], ];
|row| self.row_to_permission(row),
) if !command_pattern.is_empty() {
} else if command_pattern.is_empty() { param_values.push(Box::new(command_pattern.to_string()));
conn.query_row( }
&query,
params![project_id, tool, resource_pattern], if !resource_pattern.is_empty() {
|row| self.row_to_permission(row), param_values.push(Box::new(resource_pattern.to_string()));
) }
} else if resource_pattern.is_empty() {
conn.query_row( let params_refs: Vec<&dyn rusqlite::ToSql> = param_values
&query, .iter()
params![project_id, tool, command_pattern], .map(|p| p.as_ref())
|row| self.row_to_permission(row), .collect();
)
} else { let result = conn.query_row(&query, params_refs.as_slice(), |row| {
conn.query_row( self.row_to_permission(row)
&query, });
params![project_id, tool, command_pattern, resource_pattern],
|row| self.row_to_permission(row),
)
};
match result { match result {
Ok(permission) => Ok(Some(permission)), Ok(permission) => Ok(Some(permission)),
@ -186,48 +184,140 @@ impl PermissionStore for SqlitePermissionStore {
} }
} }
impl SqlitePermissionStore { #[cfg(test)]
fn find_matching_command_permission( mod tests {
&self, use super::*;
rows: rusqlite::MappedRows< use r2d2_sqlite::SqliteConnectionManager;
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
>,
tool: &str,
command: &str,
) -> Result<Option<Permission>, StoreError> {
for row_result in rows {
match row_result {
Ok(permission) => {
if permission.matches(tool, Some(command), None::<&PathBuf>) {
return Ok(Some(permission));
}
}
Err(e) => return Err(StoreError::Database(e)),
}
}
Ok(None) fn setup_test_db() -> DbPool {
let manager = SqliteConnectionManager::memory();
let pool = r2d2::Pool::new(manager).unwrap();
let conn = pool.get().unwrap();
// Create permissions table
conn.execute(
"CREATE TABLE IF NOT EXISTS permissions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tool_name TEXT NOT NULL,
command_pattern TEXT,
resource_pattern TEXT,
decision TEXT NOT NULL,
scope TEXT NOT NULL,
project_id INTEGER,
created_at TEXT NOT NULL
)",
[],
).unwrap();
pool
} }
fn find_matching_path_permission( #[test]
&self, fn test_find_permission_with_null_patterns() {
rows: rusqlite::MappedRows< let pool = setup_test_db();
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>, let store = SqlitePermissionStore::new(pool.clone());
>,
tool: &str,
path: &PathBuf,
) -> Result<Option<Permission>, StoreError> {
for row_result in rows {
match row_result {
Ok(permission) => {
if permission.matches(tool, None, Some(path)) {
return Ok(Some(permission));
}
}
Err(e) => return Err(StoreError::Database(e)),
}
}
Ok(None) // Create permission with NULL command and resource patterns
let permission = Permission::new(
"test_tool".to_string(),
None,
None,
PermissionDecision::AlwaysAllow,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with empty strings
let result = store.find_permission("test_tool", 1, "", "").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
} }
}
#[test]
fn test_find_permission_with_command_only() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with command but NULL resource
let permission = Permission::new(
"test_tool".to_string(),
Some("echo test".to_string()),
None,
PermissionDecision::AlwaysDeny,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with command and empty resource
let result = store.find_permission("test_tool", 1, "echo test", "").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
}
#[test]
fn test_find_permission_with_resource_only() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with resource but NULL command
let permission = Permission::new(
"test_tool".to_string(),
None,
Some("/path/to/file".to_string()),
PermissionDecision::AlwaysAllow,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with empty command and resource
let result = store.find_permission("test_tool", 1, "", "/path/to/file").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
}
#[test]
fn test_find_permission_with_both_patterns() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with both patterns
let permission = Permission::new(
"test_tool".to_string(),
Some("rm -rf".to_string()),
Some("/etc/passwd".to_string()),
PermissionDecision::AlwaysDeny,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with both patterns
let result = store.find_permission("test_tool", 1, "rm -rf", "/etc/passwd").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
}
#[test]
fn test_find_permission_no_match() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with command
let permission = Permission::new(
"test_tool".to_string(),
Some("echo test".to_string()),
None,
PermissionDecision::AlwaysAllow,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should NOT find when searching with different command
let result = store.find_permission("test_tool", 1, "rm -rf", "").unwrap();
assert!(result.is_none());
}
}

View file

@ -29,6 +29,7 @@ pub struct Permission {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
} }
#[allow(dead_code)]
impl Permission { impl Permission {
pub fn new( pub fn new(
tool_name: String, tool_name: String,
@ -173,4 +174,4 @@ impl Default for PermissionConfig {
require_confirmation: true, require_confirmation: true,
} }
} }
} }

View file

@ -32,33 +32,38 @@ pub fn get_bt_tree_step_prompt(
pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String { pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String {
let (os_name, shell_name) = get_runtime_environment(); let (os_name, shell_name) = get_runtime_environment();
let mut system_prompt = "".to_string();
if agent_mode == AgentModeType::Plan { if agent_mode == AgentModeType::Plan {
system_prompt = _system_prompt_for_plan(model_type); let mut system_prompt = _system_prompt_for_plan(model_type);
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); if remaining_calls < 3 {
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
}
return system_prompt; return system_prompt;
} }
if agent_mode == AgentModeType::BuildFromPlan { if agent_mode == AgentModeType::BuildFromPlan {
system_prompt = _system_prompt_for_build_from_plan(model_type); let mut system_prompt = _system_prompt_for_build_from_plan(model_type);
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); if remaining_calls < 3 {
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
}
return system_prompt; return system_prompt;
} }
if model_type == ModelType::OpenAI { let mut system_prompt = if model_type == ModelType::OpenAI {
system_prompt = format!( format!(
"You are Drastis, a coding agent. \n\ "You are Drastis, a coding agent. \n\
Use the available tools to gather context and make changes. \ Use the available tools to gather context and make changes. \
When using tools, pass JSON arguments that match their parameters. \n\ When using tools, pass JSON arguments that match their parameters. \n\
Runtime: os={}, shell={}.", Runtime: os={}, shell={}.",
os_name, shell_name os_name, shell_name
); )
} else { } else {
system_prompt = format!( format!(
"You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.", "You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.",
os_name, shell_name os_name, shell_name
); )
};
if remaining_calls < 3 {
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
} }
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); system_prompt
return system_prompt;
} }
fn _system_prompt_for_plan(model_type: ModelType) -> String { fn _system_prompt_for_plan(model_type: ModelType) -> String {

View file

@ -8,6 +8,7 @@ use std::path::Path;
/// without being tightly coupled to the Session entity. /// without being tightly coupled to the Session entity.
pub trait Request { pub trait Request {
/// Get the history of previous requests /// Get the history of previous requests
#[allow(dead_code)]
fn history(&self) -> &[SessionRequest]; fn history(&self) -> &[SessionRequest];
/// Get the current request prompt /// Get the current request prompt
@ -40,4 +41,4 @@ pub trait Request {
/// Get the session's TODO list (plan) /// Get the session's TODO list (plan)
fn get_session_plan(&self) -> Option<crate::domain::todo::TodoList>; fn get_session_plan(&self) -> Option<crate::domain::todo::TodoList>;
} }

View file

@ -4,6 +4,7 @@ use crate::repository::SessionRequestRow;
/// Domain entity representing a user request within a session. /// Domain entity representing a user request within a session.
/// Each request contains the user's prompt and the resulting summary. /// Each request contains the user's prompt and the resulting summary.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SessionRequest { pub struct SessionRequest {
prompt: String, prompt: String,
result_summary: Option<String>, result_summary: Option<String>,
@ -32,4 +33,4 @@ impl SessionRequest {
mode: row.mode, mode: row.mode,
} }
} }
} }

View file

@ -1,6 +1,7 @@
use super::Error; use super::Error;
use crate::domain::prompting::session_naming_prompt; use crate::domain::prompting::session_naming_prompt;
use crate::domain::workflow::Chain; use crate::domain::workflow::Chain;
use crate::domain::workflow::ChainStep;
use crate::domain::{Project, Session, SessionRequest}; use crate::domain::{Project, Session, SessionRequest};
use crate::infrastructure::db::DbPool; use crate::infrastructure::db::DbPool;
use crate::infrastructure::inference::InferenceEngine; use crate::infrastructure::inference::InferenceEngine;
@ -83,38 +84,37 @@ impl StartupService {
let prompt_preview: String = first_prompt.chars().take(100).collect(); let prompt_preview: String = first_prompt.chars().take(100).collect();
let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview); let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview);
let chain = Chain::new(); let mut chain = Chain::new();
let session_name = chain
match self .steps
.engine .push(ChainStep::user_message(naming_prompt.parse().unwrap(), Vec::new()));
.generate("", &naming_prompt, 15, &[], &chain, &[], None) let session_name = match self.engine.generate(&[], &chain, &[], None) {
{ Ok(raw) => {
Ok(raw) => { log::debug!("Raw session name response: {:?}", raw.summary);
log::debug!("Raw session name response: {:?}", raw.summary); // Clean up the response
// Clean up the response let cleaned = raw
let cleaned = raw .summary
.summary .lines()
.lines() .next()
.next() .unwrap_or("")
.unwrap_or("") .trim()
.trim() .trim_matches('"')
.trim_matches('"') .trim_matches('*')
.trim_matches('*') .replace("<|im_end|>", "")
.replace("<|im_end|>", "") .trim()
.trim() .to_string();
.to_string();
if cleaned.is_empty() || cleaned.len() > 50 { if cleaned.is_empty() || cleaned.len() > 50 {
Self::fallback_session_name(first_prompt)
} else {
cleaned
}
}
Err(e) => {
log::warn!("Session naming failed: {}", e);
Self::fallback_session_name(first_prompt) Self::fallback_session_name(first_prompt)
} else {
cleaned
} }
}; }
Err(e) => {
log::warn!("Session naming failed: {}", e);
Self::fallback_session_name(first_prompt)
}
};
// Create session in database // Create session in database
let conn = self let conn = self
@ -225,4 +225,4 @@ mod tests {
assert_eq!(StartupService::fallback_session_name(""), "New Session"); assert_eq!(StartupService::fallback_session_name(""), "New Session");
} }
} }

View file

@ -15,6 +15,7 @@ pub struct DiscoverObjects {
struct DiscoverObjectsInput { struct DiscoverObjectsInput {
raw: String, raw: String,
full_path_to_file: String, full_path_to_file: String,
call_id: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -69,6 +70,7 @@ impl DiscoverObjects {
Ok(DiscoverObjectsInput { Ok(DiscoverObjectsInput {
raw: raw.to_string(), raw: raw.to_string(),
full_path_to_file: parsed.full_path_to_file, full_path_to_file: parsed.full_path_to_file,
call_id: String::new(),
}) })
} }
@ -86,12 +88,13 @@ impl Tool for DiscoverObjects {
"discover_objects" "discover_objects"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim(); let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed); let parsed = Self::parse_input_json(trimmed);
match parsed { match parsed {
Ok(parsed) => { Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed); *self.input.lock().unwrap() = Some(parsed);
None None
} }
@ -103,17 +106,18 @@ impl Tool for DiscoverObjects {
let input = match self.load_input() { let input = match self.load_input() {
Ok(input) => input, Ok(input) => input,
Err(e) => { Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
} }
}; };
match Self::parse_file(&input.full_path_to_file) { match Self::parse_file(&input.full_path_to_file) {
Ok((lang, objects)) => ToolResult::ok( Ok((lang, objects)) => ToolResult::ok(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
Self::format_output(lang, &objects), Self::format_output(lang, &objects),
input.call_id,
), ),
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()), Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
} }
} }

View file

@ -19,6 +19,7 @@ struct FindFilesInput {
query: String, query: String,
root: Option<String>, root: Option<String>,
max_results: usize, max_results: usize,
call_id: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -132,6 +133,7 @@ impl FindFiles {
query: parsed.query, query: parsed.query,
root: parsed.root, root: parsed.root,
max_results: parsed.max_results.unwrap_or(20), max_results: parsed.max_results.unwrap_or(20),
call_id: String::new(),
}) })
} }
@ -149,12 +151,13 @@ impl Tool for FindFiles {
"find_files" "find_files"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim(); let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed); let parsed = Self::parse_input_json(trimmed);
match parsed { match parsed {
Ok(parsed) => { Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed); *self.input.lock().unwrap() = Some(parsed);
None None
} }
@ -166,7 +169,7 @@ impl Tool for FindFiles {
let input = match self.load_input() { let input = match self.load_input() {
Ok(input) => input, Ok(input) => input,
Err(e) => { Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
} }
}; };
@ -174,15 +177,16 @@ impl Tool for FindFiles {
let search_root = let search_root =
match Self::resolve_search_root(input.root.as_deref(), request.project_root()) { match Self::resolve_search_root(input.root.as_deref(), request.project_root()) {
Ok(root) => root, Ok(root) => root,
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e), Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
}; };
// Check if search root exists // Check if search root exists
if !search_root.exists() { if !search_root.exists() {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Search root does not exist: {}", search_root.display()), format!("Search root does not exist: {}", search_root.display()),
input.call_id,
); );
} }
@ -204,7 +208,7 @@ impl Tool for FindFiles {
output output
}; };
ToolResult::ok(self.name().to_string(), input.raw, output) ToolResult::ok(self.name().to_string(), input.raw, output, input.call_id)
} }
fn parameters(&self) -> serde_json::Value { fn parameters(&self) -> serde_json::Value {

View file

@ -45,6 +45,7 @@ pub struct FileChange {
pub struct ToolResult { pub struct ToolResult {
tool_name: String, tool_name: String,
call_id: String,
pub(crate) input: String, pub(crate) input: String,
is_successful: bool, is_successful: bool,
output: String, output: String,
@ -53,7 +54,7 @@ pub struct ToolResult {
} }
impl ToolResult { impl ToolResult {
pub fn ok(tool_name: String, input: String, output: String) -> Self { pub fn ok(tool_name: String, input: String, output: String, call_id: String) -> Self {
Self { Self {
tool_name, tool_name,
input, input,
@ -61,10 +62,11 @@ impl ToolResult {
output, output,
error_message: String::new(), error_message: String::new(),
file_changes: None, file_changes: None,
call_id
} }
} }
pub fn error(tool_name: String, input: String, message: String) -> Self { pub fn error(tool_name: String, input: String, message: String, call_id: String) -> Self {
Self { Self {
tool_name, tool_name,
input, input,
@ -72,6 +74,7 @@ impl ToolResult {
output: String::new(), output: String::new(),
error_message: message.into(), error_message: message.into(),
file_changes: None, file_changes: None,
call_id
} }
} }
@ -128,11 +131,15 @@ impl ToolResult {
pub fn file_changes(&self) -> Option<&[FileChange]> { pub fn file_changes(&self) -> Option<&[FileChange]> {
self.file_changes.as_deref() self.file_changes.as_deref()
} }
pub(crate) fn call_id(&self) -> String {
self.call_id.clone()
}
} }
pub trait Tool: Send + Sync { pub trait Tool: Send + Sync {
fn name(&self) -> &'static str; fn name(&self) -> &'static str;
fn parse_input(&self, input: String) -> Option<Error>; fn parse_input(&self, input: String, call_id: String) -> Option<Error>;
fn work(&self, request: &dyn Request) -> ToolResult; fn work(&self, request: &dyn Request) -> ToolResult;
fn parameters(&self) -> Value; fn parameters(&self) -> Value;
fn desc(&self) -> String; fn desc(&self) -> String;

View file

@ -18,6 +18,7 @@ pub struct PatchFiles {
struct PatchFilesInput { struct PatchFilesInput {
raw: String, raw: String,
patch: String, patch: String,
call_id: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -41,6 +42,7 @@ impl PatchFiles {
Ok(PatchFilesInput { Ok(PatchFilesInput {
raw: raw.to_string(), raw: raw.to_string(),
patch: parsed.patch, patch: parsed.patch,
call_id: String::new(),
}) })
} }
@ -58,12 +60,13 @@ impl Tool for PatchFiles {
"patch_files" "patch_files"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim(); let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed); let parsed = Self::parse_input_json(trimmed);
match parsed { match parsed {
Ok(parsed) => { Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed); *self.input.lock().unwrap() = Some(parsed);
None None
} }
@ -75,7 +78,7 @@ impl Tool for PatchFiles {
let input = match self.load_input() { let input = match self.load_input() {
Ok(input) => input, Ok(input) => input,
Err(e) => { Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
} }
}; };
@ -84,8 +87,9 @@ impl Tool for PatchFiles {
Err(e) => { Err(e) => {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Failed to parse patch: {}", e), format!("Failed to parse patch: {}", e),
input.call_id.clone(),
) )
} }
}; };
@ -115,8 +119,9 @@ impl Tool for PatchFiles {
{ {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Invalid path outside project root: {}", path), format!("Invalid path outside project root: {}", path),
input.call_id.clone(),
); );
} }
let fs_path = project_root.join(rel_path); let fs_path = project_root.join(rel_path);
@ -129,8 +134,9 @@ impl Tool for PatchFiles {
Err(e) => { Err(e) => {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Failed to read file '{}': {}", fs_path.display(), e), format!("Failed to read file '{}': {}", fs_path.display(), e),
input.call_id.clone(),
) )
} }
} }
@ -142,8 +148,9 @@ impl Tool for PatchFiles {
Err(e) => { Err(e) => {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Failed to apply patch: {}", e), format!("Failed to apply patch: {}", e),
input.call_id.clone(),
) )
} }
}; };
@ -155,28 +162,31 @@ impl Tool for PatchFiles {
if let Err(e) = fs::create_dir_all(parent) { if let Err(e) = fs::create_dir_all(parent) {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!( format!(
"Failed to create parent directories for '{}': {}", "Failed to create parent directories for '{}': {}",
fs_path.display(), fs_path.display(),
e e
), ),
input.call_id.clone(),
); );
} }
} }
if let Err(e) = fs::write(&fs_path, content) { if let Err(e) = fs::write(&fs_path, content) {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Failed to write file '{}': {}", fs_path.display(), e), format!("Failed to write file '{}': {}", fs_path.display(), e),
input.call_id.clone(),
); );
} }
} else if fs_path.exists() { } else if fs_path.exists() {
if let Err(e) = fs::remove_file(&fs_path) { if let Err(e) = fs::remove_file(&fs_path) {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Failed to delete file '{}': {}", fs_path.display(), e), format!("Failed to delete file '{}': {}", fs_path.display(), e),
input.call_id.clone(),
); );
} }
} }
@ -224,6 +234,7 @@ impl Tool for PatchFiles {
self.name().to_string(), self.name().to_string(),
input.raw, input.raw,
"Patch applied successfully".to_string(), "Patch applied successfully".to_string(),
input.call_id,
); );
if !file_changes.is_empty() { if !file_changes.is_empty() {
@ -531,7 +542,7 @@ mod tests {
"{{\"patch\":\"{}\"}}", "{{\"patch\":\"{}\"}}",
patch.replace('\n', "\\n").replace('"', "\\\"") patch.replace('\n', "\\n").replace('"', "\\\"")
); );
assert!(tool.parse_input(input).is_none()); assert!(tool.parse_input(input, "call-id".to_string()).is_none());
let result = tool.work(&request); let result = tool.work(&request);
assert!( assert!(
@ -571,7 +582,7 @@ mod tests {
"{{\"patch\":\"{}\"}}", "{{\"patch\":\"{}\"}}",
patch.replace('\n', "\\n").replace('"', "\\\"") patch.replace('\n', "\\n").replace('"', "\\\"")
); );
assert!(tool.parse_input(input).is_none()); assert!(tool.parse_input(input, "call-id".to_string()).is_none());
let result = tool.work(&request); let result = tool.work(&request);
assert!( assert!(
@ -592,7 +603,7 @@ mod tests {
let tool = PatchFiles::new(); let tool = PatchFiles::new();
let input = "{\"patch\":\"not a patch\"}".to_string(); let input = "{\"patch\":\"not a patch\"}".to_string();
assert!(tool.parse_input(input).is_none()); assert!(tool.parse_input(input, "call-id".to_string()).is_none());
let result = tool.work(&request); let result = tool.work(&request);
assert!( assert!(
@ -601,4 +612,4 @@ mod tests {
result.output_string() result.output_string()
); );
} }
} }

View file

@ -18,6 +18,7 @@ struct ReadObjectsInput {
raw: String, raw: String,
full_path_to_file: String, full_path_to_file: String,
queries: Vec<ObjectQuery>, queries: Vec<ObjectQuery>,
call_id: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -138,6 +139,7 @@ impl ReadObjects {
raw: raw.to_string(), raw: raw.to_string(),
full_path_to_file: trimmed_path.to_string(), full_path_to_file: trimmed_path.to_string(),
queries, queries,
call_id: String::new(),
}) })
} }
@ -182,12 +184,13 @@ impl Tool for ReadObjects {
"read_objects" "read_objects"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim(); let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed); let parsed = Self::parse_input_json(trimmed);
match parsed { match parsed {
Ok(parsed) => { Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed); *self.input.lock().unwrap() = Some(parsed);
None None
} }
@ -199,17 +202,18 @@ impl Tool for ReadObjects {
let input = match self.load_input() { let input = match self.load_input() {
Ok(input) => input, Ok(input) => input,
Err(e) => { Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
} }
}; };
match Self::read_objects(&input.full_path_to_file, &input.queries) { match Self::read_objects(&input.full_path_to_file, &input.queries) {
Ok((lang, results)) => ToolResult::ok( Ok((lang, results)) => ToolResult::ok(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
Self::format_output(lang, results), Self::format_output(lang, results),
input.call_id,
), ),
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()), Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
} }
} }
@ -224,7 +228,7 @@ impl Tool for ReadObjects {
}, },
"query": { "query": {
"type": "string", "type": "string",
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\")", "description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\"), grepping is not working here, use exact names",
"minLength": 1 "minLength": 1
} }
}, },
@ -235,8 +239,8 @@ impl Tool for ReadObjects {
fn desc(&self) -> String { fn desc(&self) -> String {
format!( format!(
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties to use for `{}`, use the `discover_objects` tool first.", "Use the `{}` tool to read source code of specific objects from a file. To determine correct properties, use the `discover_objects` tool first.",
self.name(), self.name() self.name(),
) )
} }
@ -293,7 +297,7 @@ mod tests {
fn test_parse_input_valid() { fn test_parse_input_valid() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#; let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_none(), "Expected no error, got: {:?}", result); assert!(result.is_none(), "Expected no error, got: {:?}", result);
// Verify the parsed input // Verify the parsed input
@ -308,7 +312,7 @@ mod tests {
fn test_parse_input_empty_query() { fn test_parse_input_empty_query() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": ""}"#; let input = r#"{"path": "src/main.rs", "query": ""}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for empty query"); assert!(result.is_some(), "Expected error for empty query");
if let Some(Error::Parse(msg)) = result { if let Some(Error::Parse(msg)) = result {
@ -326,7 +330,7 @@ mod tests {
fn test_parse_input_whitespace_query() { fn test_parse_input_whitespace_query() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": " "}"#; let input = r#"{"path": "src/main.rs", "query": " "}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for whitespace query"); assert!(result.is_some(), "Expected error for whitespace query");
if let Some(Error::Parse(msg)) = result { if let Some(Error::Parse(msg)) = result {
@ -344,7 +348,7 @@ mod tests {
fn test_parse_input_comma_only_query() { fn test_parse_input_comma_only_query() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": ",,,"}"#; let input = r#"{"path": "src/main.rs", "query": ",,,"}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for comma-only query"); assert!(result.is_some(), "Expected error for comma-only query");
if let Some(Error::Parse(msg)) = result { if let Some(Error::Parse(msg)) = result {
@ -362,7 +366,7 @@ mod tests {
fn test_parse_input_missing_path() { fn test_parse_input_missing_path() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "", "query": "main"}"#; let input = r#"{"path": "", "query": "main"}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for empty path"); assert!(result.is_some(), "Expected error for empty path");
if let Some(Error::Parse(msg)) = result { if let Some(Error::Parse(msg)) = result {
@ -407,7 +411,7 @@ mod tests {
fn test_parse_input_malformed_json() { fn test_parse_input_malformed_json() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for malformed JSON"); assert!(result.is_some(), "Expected error for malformed JSON");
} }
@ -415,7 +419,7 @@ mod tests {
fn test_parse_input_path_with_whitespace() { fn test_parse_input_path_with_whitespace() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": " src/main.rs ", "query": "main"}"#; let input = r#"{"path": " src/main.rs ", "query": "main"}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!( assert!(
result.is_none(), result.is_none(),
"Expected no error for path with whitespace" "Expected no error for path with whitespace"
@ -432,7 +436,7 @@ mod tests {
fn test_parse_input_query_with_extra_whitespace() { fn test_parse_input_query_with_extra_whitespace() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#; let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!( assert!(
result.is_none(), result.is_none(),
"Expected no error for query with extra whitespace" "Expected no error for query with extra whitespace"
@ -448,7 +452,7 @@ mod tests {
fn test_parse_input_single_object() { fn test_parse_input_single_object() {
let tool = ReadObjects::new(); let tool = ReadObjects::new();
let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#; let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#;
let result = tool.parse_input(input.to_string()); let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!( assert!(
result.is_none(), result.is_none(),
"Expected no error for single object query" "Expected no error for single object query"
@ -481,4 +485,4 @@ mod tests {
}; };
assert!(query.matches(&obj_partial), "Should match partial name"); assert!(query.matches(&obj_partial), "Should match partial name");
} }
} }

View file

@ -25,6 +25,7 @@ struct ShellExecInputParsed {
raw: String, raw: String,
command: String, command: String,
working_dir: Option<String>, working_dir: Option<String>,
call_id: String,
} }
impl Tool for ShellExec { impl Tool for ShellExec {
@ -32,7 +33,7 @@ impl Tool for ShellExec {
"shell_exec" "shell_exec"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim(); let trimmed = input.trim();
let parsed = serde_json::from_str::<ShellExecInput>(trimmed) let parsed = serde_json::from_str::<ShellExecInput>(trimmed)
.map_err(|e| Error::Parse(e.to_string())); .map_err(|e| Error::Parse(e.to_string()));
@ -46,6 +47,7 @@ impl Tool for ShellExec {
raw: trimmed.to_string(), raw: trimmed.to_string(),
command: parsed.command, command: parsed.command,
working_dir: parsed.working_dir, working_dir: parsed.working_dir,
call_id,
}); });
None None
} }
@ -61,6 +63,7 @@ impl Tool for ShellExec {
self.name().to_string(), self.name().to_string(),
String::new(), String::new(),
"input not parsed".to_string(), "input not parsed".to_string(),
String::new(),
) )
} }
}; };
@ -72,15 +75,17 @@ impl Tool for ShellExec {
if !path.exists() { if !path.exists() {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Working directory does not exist: {}", dir), format!("Working directory does not exist: {}", dir),
input.call_id.clone(),
); );
} }
if !crate::utils::paths::is_within_root(path, request.project_root()) { if !crate::utils::paths::is_within_root(path, request.project_root()) {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
"Working directory is outside project root".to_string(), "Working directory is outside project root".to_string(),
input.call_id.clone(),
); );
} }
path.to_path_buf() path.to_path_buf()
@ -99,8 +104,9 @@ impl Tool for ShellExec {
Err(e) => { Err(e) => {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Failed to execute command: {}", e), format!("Failed to execute command: {}", e),
input.call_id.clone(),
) )
} }
}; };
@ -129,10 +135,10 @@ impl Tool for ShellExec {
if !stderr.is_empty() { if !stderr.is_empty() {
result.push_str(&format!("[stderr]: {}", stderr)); result.push_str(&format!("[stderr]: {}", stderr));
} }
return ToolResult::error(self.name().to_string(), input.raw, result); return ToolResult::error(self.name().to_string(), input.raw.clone(), result, input.call_id.clone());
}; };
ToolResult::ok(self.name().to_string(), input.raw, result) ToolResult::ok(self.name().to_string(), input.raw, result, input.call_id)
} }
fn parameters(&self) -> serde_json::Value { fn parameters(&self) -> serde_json::Value {
@ -155,8 +161,9 @@ impl Tool for ShellExec {
fn desc(&self) -> String { fn desc(&self) -> String {
format!( format!(
r#"Use `{}` tool to execute shell commands. "Use `{}` tool to execute shell commands.
Please DO NOT use it to read the full content of a file, this is not efficient, use `read_objects` tool for this."#, DO NOT use it to read code, this is not efficient, use `read_objects` tool for this. \n\
DO NOT use it to change files, use `patch_files` tool for this.",
self.name() self.name()
) )
} }
@ -339,7 +346,7 @@ mod tests {
let tool = ShellExec::new(); let tool = ShellExec::new();
assert!(tool assert!(tool
.parse_input(r#"{"command":"echo hello"}"#.to_string()) .parse_input(r#"{"command":"echo hello"}"#.to_string(), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -357,7 +364,7 @@ mod tests {
let tool = ShellExec::new(); let tool = ShellExec::new();
assert!(tool assert!(tool
.parse_input(r#"{"command":"pwd"}"#.to_string()) .parse_input(r#"{"command":"pwd"}"#.to_string(), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -384,7 +391,7 @@ mod tests {
.parse_input(format!( .parse_input(format!(
r#"{{"command":"pwd","working_dir":"{}"}}"#, r#"{{"command":"pwd","working_dir":"{}"}}"#,
subdir.display() subdir.display()
)) ), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -405,7 +412,7 @@ mod tests {
let tool = ShellExec::new(); let tool = ShellExec::new();
assert!(tool assert!(tool
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string()) .parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string(), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -424,7 +431,7 @@ mod tests {
let tool = ShellExec::new(); let tool = ShellExec::new();
assert!(tool assert!(tool
.parse_input(r#"{"command":"exit 1"}"#.to_string()) .parse_input(r#"{"command":"exit 1"}"#.to_string(), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -443,7 +450,7 @@ mod tests {
let tool = ShellExec::new(); let tool = ShellExec::new();
assert!(tool assert!(tool
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string()) .parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string(), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -457,7 +464,7 @@ mod tests {
let request = TestRequest::new(temp.path()); let request = TestRequest::new(temp.path());
let tool = ShellExec::new(); let tool = ShellExec::new();
let err = tool.parse_input(r#"{"command":""}"#.to_string()); let err = tool.parse_input(r#"{"command":""}"#.to_string(), "call-id".to_string());
assert!(err.is_some()); assert!(err.is_some());
let result = tool.work(&request); let result = tool.work(&request);
assert!(result.output_string().contains("Error")); assert!(result.output_string().contains("Error"));
@ -474,7 +481,7 @@ mod tests {
.parse_input(format!( .parse_input(format!(
r#"{{"command":"echo 'test content' > {}"}}"#, r#"{{"command":"echo 'test content' > {}"}}"#,
file_path.display() file_path.display()
)) ), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -496,7 +503,7 @@ mod tests {
let tool = ShellExec::new(); let tool = ShellExec::new();
assert!(tool assert!(tool
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string()) .parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string(), "call-id".to_string())
.is_none()); .is_none());
let result = tool.work(&request); let result = tool.work(&request);
@ -506,4 +513,4 @@ mod tests {
result.output_string() result.output_string()
); );
} }
} }

View file

@ -16,6 +16,7 @@ struct StructureInput {
raw: String, raw: String,
path: String, path: String,
max_depth: usize, max_depth: usize,
call_id: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -127,6 +128,7 @@ impl Structure {
raw: raw.to_string(), raw: raw.to_string(),
path: parsed.path.unwrap_or_else(|| ".".to_string()), path: parsed.path.unwrap_or_else(|| ".".to_string()),
max_depth: parsed.max_depth.unwrap_or(3), max_depth: parsed.max_depth.unwrap_or(3),
call_id: String::new(),
}) })
} }
@ -144,12 +146,13 @@ impl Tool for Structure {
"structure" "structure"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim(); let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed); let parsed = Self::parse_input_json(trimmed);
match parsed { match parsed {
Ok(parsed) => { Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed); *self.input.lock().unwrap() = Some(parsed);
None None
} }
@ -161,34 +164,36 @@ impl Tool for Structure {
let input = match self.load_input() { let input = match self.load_input() {
Ok(input) => input, Ok(input) => input,
Err(e) => { Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
} }
}; };
let target_path = match Self::resolve_path(&input.path, request.project_root()) { let target_path = match Self::resolve_path(&input.path, request.project_root()) {
Ok(p) => p, Ok(p) => p,
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e), Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
}; };
if !target_path.exists() { if !target_path.exists() {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Path does not exist: {}", target_path.display()), format!("Path does not exist: {}", target_path.display()),
input.call_id.clone(),
); );
} }
if !target_path.is_dir() { if !target_path.is_dir() {
return ToolResult::error( return ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw, input.raw.clone(),
format!("Path is not a directory: {}", target_path.display()), format!("Path is not a directory: {}", target_path.display()),
input.call_id,
); );
} }
let tree = Self::build_tree(&target_path, input.max_depth); let tree = Self::build_tree(&target_path, input.max_depth);
ToolResult::ok(self.name().to_string(), input.raw, tree) ToolResult::ok(self.name().to_string(), input.raw, tree, input.call_id)
} }
fn parameters(&self) -> serde_json::Value { fn parameters(&self) -> serde_json::Value {

View file

@ -18,6 +18,7 @@ pub struct UpdateTodoList {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct UpdateTodoListInput { struct UpdateTodoListInput {
raw: String, raw: String,
call_id: String,
} }
impl UpdateTodoList { impl UpdateTodoList {
@ -36,12 +37,12 @@ impl Tool for UpdateTodoList {
"update_todo_list" "update_todo_list"
} }
fn parse_input(&self, input: String) -> Option<Error> { fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let parsed: Result<TodoList, _> = serde_json::from_str(&input); let parsed: Result<TodoList, _> = serde_json::from_str(&input);
match parsed { match parsed {
Ok(_) => { Ok(_) => {
let mut lock = self.input.lock().unwrap(); let mut lock = self.input.lock().unwrap();
*lock = Some(UpdateTodoListInput { raw: input }); *lock = Some(UpdateTodoListInput { raw: input, call_id });
None None
} }
Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))), Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))),
@ -57,6 +58,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(), self.name().to_string(),
String::new(), String::new(),
"No input provided".to_string(), "No input provided".to_string(),
String::new(),
) )
} }
}; };
@ -70,6 +72,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
format!("Failed to get database connection: {}", e), format!("Failed to get database connection: {}", e),
input.call_id.clone(),
) )
} }
}; };
@ -83,6 +86,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
format!("Failed to get/create TODO list: {}", e), format!("Failed to get/create TODO list: {}", e),
input.call_id.clone(),
) )
} }
}; };
@ -92,6 +96,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
format!("Failed to update TODO list: {}", e), format!("Failed to update TODO list: {}", e),
input.call_id.clone(),
); );
} }
@ -106,6 +111,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(), self.name().to_string(),
input.raw, input.raw,
"TODO list updated successfully.".to_string(), "TODO list updated successfully.".to_string(),
input.call_id,
) )
} }

View file

@ -16,6 +16,7 @@ struct WebSearchInput {
raw: String, raw: String,
query: String, query: String,
max_results: u32, max_results: u32,
call_id: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -94,7 +95,7 @@ impl Tool for WebSearch {
"web_search" "web_search"
} }
fn parse_input(&self, input: String) -> Option<crate::domain::tools::Error> { fn parse_input(&self, input: String, call_id: String) -> Option<crate::domain::tools::Error> {
let trimmed = input.trim(); let trimmed = input.trim();
match serde_json::from_str::<WebSearchInputJson>(trimmed) { match serde_json::from_str::<WebSearchInputJson>(trimmed) {
Ok(parsed) => { Ok(parsed) => {
@ -103,6 +104,7 @@ impl Tool for WebSearch {
raw: input, raw: input,
query: parsed.query, query: parsed.query,
max_results, max_results,
call_id,
}; };
*self.input.lock().unwrap() = Some(parsed_input); *self.input.lock().unwrap() = Some(parsed_input);
None None
@ -118,7 +120,7 @@ impl Tool for WebSearch {
let input = match self.load_input() { let input = match self.load_input() {
Ok(input) => input, Ok(input) => input,
Err(e) => { Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e); return ToolResult::error(self.name().to_string(), String::new(), e, String::new());
} }
}; };
@ -130,6 +132,7 @@ impl Tool for WebSearch {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
"User settings not available".to_string(), "User settings not available".to_string(),
input.call_id.clone(),
); );
} }
}; };
@ -139,6 +142,7 @@ impl Tool for WebSearch {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
"Web search is not enabled in settings".to_string(), "Web search is not enabled in settings".to_string(),
input.call_id.clone(),
); );
} }
@ -149,6 +153,7 @@ impl Tool for WebSearch {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
"Brave API key not configured".to_string(), "Brave API key not configured".to_string(),
input.call_id.clone(),
); );
} }
}; };
@ -159,12 +164,13 @@ impl Tool for WebSearch {
let output = WebSearchOutput { results }; let output = WebSearchOutput { results };
match serde_json::to_string(&output) { match serde_json::to_string(&output) {
Ok(json_output) => { Ok(json_output) => {
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output) ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output, input.call_id)
} }
Err(e) => ToolResult::error( Err(e) => ToolResult::error(
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
format!("Failed to serialize output: {}", e), format!("Failed to serialize output: {}", e),
input.call_id.clone(),
), ),
} }
} }
@ -172,6 +178,7 @@ impl Tool for WebSearch {
self.name().to_string(), self.name().to_string(),
input.raw.clone(), input.raw.clone(),
format!("Search failed: {}", e), format!("Search failed: {}", e),
input.call_id,
), ),
} }
} }
@ -593,8 +600,10 @@ mod tests {
user_settings, user_settings,
}; };
let web_search = WebSearch::new(); let web_search = WebSearch::new();
let _ = let _ = web_search.parse_input(
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string()); r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
"call-id".to_string(),
);
let _allowed = checker.check(&web_search, &request, Some(1)).unwrap(); let _allowed = checker.check(&web_search, &request, Some(1)).unwrap();
@ -629,8 +638,10 @@ mod tests {
user_settings, user_settings,
}; };
let web_search = WebSearch::new(); let web_search = WebSearch::new();
let _ = let _ = web_search.parse_input(
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string()); r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
"call-id".to_string(),
);
let allowed = checker.check(&web_search, &request, Some(1)).unwrap(); let allowed = checker.check(&web_search, &request, Some(1)).unwrap();
@ -640,4 +651,4 @@ mod tests {
assert_eq!(created[0].tool_name, "web_search"); assert_eq!(created[0].tool_name, "web_search");
assert_eq!(created[0].resource_pattern, None); assert_eq!(created[0].resource_pattern, None);
} }
} }

View file

@ -17,6 +17,7 @@ pub struct Chain {
pub fail_reason: String, pub fail_reason: String,
#[serde(default)] #[serde(default)]
pub final_message: Option<String>, pub final_message: Option<String>,
pub system_prompt: String,
} }
impl Chain { impl Chain {
@ -28,6 +29,7 @@ impl Chain {
is_failed: false, is_failed: false,
fail_reason: String::new(), fail_reason: String::new(),
final_message: None, final_message: None,
system_prompt: String::new(),
} }
} }
@ -120,6 +122,10 @@ impl Chain {
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n") .join("\n")
} }
pub fn set_system_prompt(&mut self, system_prompt: String) {
self.system_prompt = system_prompt;
}
#[allow(dead_code)] #[allow(dead_code)]
pub fn total_payload_len_chars(&self) -> usize { pub fn total_payload_len_chars(&self) -> usize {

View file

@ -39,6 +39,8 @@ pub struct ChainStep {
#[serde(default)] #[serde(default)]
pub tool_name: Option<String>, pub tool_name: Option<String>,
#[serde(default)] #[serde(default)]
pub call_id: Option<String>,
#[serde(default)]
pub tool_output: Option<String>, pub tool_output: Option<String>,
#[serde(default)] #[serde(default)]
pub is_successful: Option<bool>, pub is_successful: Option<bool>,
@ -60,6 +62,7 @@ impl ChainStep {
let mut tool_output = None; let mut tool_output = None;
let mut is_successful = None; let mut is_successful = None;
let mut file_changes = None; let mut file_changes = None;
let mut call_id = None;
if let Some(tr) = tool_result { if let Some(tr) = tool_result {
summary = tr.summary(); summary = tr.summary();
context_payload = tr.output_string(); context_payload = tr.output_string();
@ -68,6 +71,7 @@ impl ChainStep {
tool_output = Some(tr.output_raw().to_string()); tool_output = Some(tr.output_raw().to_string());
is_successful = Some(tr.is_successful()); is_successful = Some(tr.is_successful());
file_changes = tr.file_changes().map(|fc| fc.to_vec()); file_changes = tr.file_changes().map(|fc| fc.to_vec());
call_id = Some(tr.call_id());
} }
Self { Self {
@ -76,6 +80,7 @@ impl ChainStep {
context_payload, context_payload,
input_payload, input_payload,
tool_name, tool_name,
call_id,
tool_output, tool_output,
is_successful, is_successful,
file_changes, file_changes,
@ -91,6 +96,7 @@ impl ChainStep {
context_payload: raw_output.clone(), context_payload: raw_output.clone(),
input_payload: String::new(), input_payload: String::new(),
tool_name: None, tool_name: None,
call_id: None,
tool_output: Some(raw_output), tool_output: Some(raw_output),
is_successful: Some(true), is_successful: Some(true),
file_changes: None, file_changes: None,
@ -118,6 +124,7 @@ impl ChainStep {
context_payload: prompt.clone(), context_payload: prompt.clone(),
input_payload: prompt, input_payload: prompt,
tool_name: None, tool_name: None,
call_id: None,
tool_output: None, tool_output: None,
is_successful: Some(true), is_successful: Some(true),
file_changes: None, file_changes: None,

View file

@ -41,6 +41,7 @@ impl ToolRunner {
tool_call.name.clone(), tool_call.name.clone(),
tool_call.arguments.clone(), tool_call.arguments.clone(),
error_msg, error_msg,
tool_call.call_id.clone(),
); );
} }
}; };
@ -73,11 +74,13 @@ impl ToolRunner {
tool.name().to_string(), tool.name().to_string(),
String::new(), String::new(),
"Permission denied".to_string(), "Permission denied".to_string(),
String::new(),
), ),
Err(err) => ToolResult::error( Err(err) => ToolResult::error(
tool.name().to_string(), tool.name().to_string(),
String::new(), String::new(),
format!("Permission check error: {}", err), format!("Permission check error: {}", err),
String::new(),
), ),
}; };

View file

@ -91,7 +91,7 @@ pub trait Toolset {
.map(|t| t.as_ref()) .map(|t| t.as_ref())
.ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?; .ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?;
if let Some(err) = tool.parse_input(tool_call.arguments.clone()) { if let Some(err) = tool.parse_input(tool_call.arguments.clone(), tool_call.call_id.clone()) {
return Err(err); return Err(err);
} }

View file

@ -5,6 +5,7 @@ use super::{
use crate::domain::bt::GeneralTree; use crate::domain::bt::GeneralTree;
use crate::domain::prompting; use crate::domain::prompting;
use crate::domain::session::Request; use crate::domain::session::Request;
use crate::domain::todo::TodoListStatus;
use crate::infrastructure::db::DbPool; use crate::infrastructure::db::DbPool;
use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase}; use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase};
use crate::infrastructure::inference::InferenceEngine; use crate::infrastructure::inference::InferenceEngine;
@ -16,7 +17,6 @@ use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use crossbeam_channel::Sender; use crossbeam_channel::Sender;
use crate::domain::AgentModeType; use crate::domain::AgentModeType;
use crate::domain::todo::TodoListStatus;
/// Main workflow orchestrator that runs LLM-driven coding tasks /// Main workflow orchestrator that runs LLM-driven coding tasks
/// Implements an eternal agent loop that: /// Implements an eternal agent loop that:
@ -94,7 +94,6 @@ impl Workflow {
self.chain.set_todo_list(request.get_session_plan()); self.chain.set_todo_list(request.get_session_plan());
let result = self._run(request, cancel, None, None, mode); let result = self._run(request, cancel, None, None, mode);
self.chain.steps.push(ChainStep::user_message(request.current_request().to_string(), request.images().to_vec()));
self._end_tracing(); self._end_tracing();
result result
} }
@ -145,6 +144,7 @@ impl Workflow {
context_payload: String::new(), context_payload: String::new(),
input_payload: step_user_prompt.to_string(), input_payload: step_user_prompt.to_string(),
tool_name: None, tool_name: None,
call_id: None,
tool_output: None, tool_output: None,
is_successful: Some(true), is_successful: Some(true),
file_changes: None, file_changes: None,
@ -171,6 +171,13 @@ impl Workflow {
mode: AgentModeType, mode: AgentModeType,
) -> Result<(), Error> { ) -> Result<(), Error> {
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
let user_prompt = user_prompt_override
.as_deref()
.unwrap_or(&base_user_prompt)
.to_string();
self.chain.steps.push(ChainStep::user_message(user_prompt.clone(), request.images().to_vec()));
// Get max_tool_calls from override (for BT mode) or user settings // Get max_tool_calls from override (for BT mode) or user settings
let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| { let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| {
request.user_settings() request.user_settings()
@ -198,11 +205,7 @@ impl Workflow {
// Get base system prompt and inject remaining count // Get base system prompt and inject remaining count
let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls); let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls);
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request); self.chain.set_system_prompt(system_prompt.clone());
let user_prompt = user_prompt_override
.as_deref()
.unwrap_or(&base_user_prompt)
.to_string();
// Switch to finishing toolset if approaching limit // Switch to finishing toolset if approaching limit
if !in_finishing_mode && tool_call_count >= finishing_threshold { if !in_finishing_mode && tool_call_count >= finishing_threshold {
@ -220,13 +223,9 @@ impl Workflow {
} }
self._emit_inference_progress(); self._emit_inference_progress();
self._trace_llm_start(user_prompt.clone());
// Ask LLM to choose next tool // Ask LLM to choose next tool
let llm_output = match self.engine.generate( let llm_output = match self.engine.generate(
&system_prompt,
&user_prompt,
1024,
&self.toolset.tool_refs(), &self.toolset.tool_refs(),
&self.chain, &self.chain,
request.images(), request.images(),
@ -238,8 +237,6 @@ impl Workflow {
} }
}; };
self._trace_llm_end(llm_output.raw_output.clone());
// Always capture the assistant's response in the chain // Always capture the assistant's response in the chain
if !llm_output.raw_output.is_empty() { if !llm_output.raw_output.is_empty() {
self.chain.steps.push(ChainStep::assistant_response( self.chain.steps.push(ChainStep::assistant_response(
@ -347,13 +344,6 @@ impl Workflow {
self.chain = Chain::new(); self.chain = Chain::new();
} }
fn _trace_llm_start(&mut self, prompt: String) {
if let Some(tracer) = &mut self.tracer {
tracer.start_span("LLM generation", SpanKind::Generation);
tracer.add_input("LLM generation", prompt);
}
}
fn _emit_inference_progress(&self) { fn _emit_inference_progress(&self) {
let options = [ let options = [
"Thinking.. well kinda..", "Thinking.. well kinda..",
@ -383,18 +373,10 @@ impl Workflow {
}); });
} }
fn _trace_llm_end(&mut self, output: String) {
if let Some(tracer) = &mut self.tracer {
tracer.add_output("LLM generation", output);
tracer.end_span("LLM generation");
}
}
/// Detects if the active TODO item has changed /// Detects if the active TODO item has changed
/// Returns true if the first non-completed item's title is different /// Returns true if the first non-completed item's title is different
fn _did_todo_item_change(&self, previous_todo: &Option<crate::domain::todo::TodoList>) -> bool { fn _did_todo_item_change(&self, previous_todo: &Option<crate::domain::todo::TodoList>) -> bool {
use crate::domain::todo::TodoListStatus;
// Get first non-completed item from previous TODO list // Get first non-completed item from previous TODO list
let prev_active = previous_todo.as_ref().and_then(|list| { let prev_active = previous_todo.as_ref().and_then(|list| {
list.items.iter() list.items.iter()

View file

@ -64,8 +64,6 @@ impl OpenAIClient {
pub fn call_responses_api( pub fn call_responses_api(
&self, &self,
system_prompt: &str,
user_prompt: &str,
tools: &[&dyn crate::domain::tools::Tool], tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain, chain: &crate::domain::workflow::Chain,
images: &[String], images: &[String],
@ -74,8 +72,6 @@ impl OpenAIClient {
let max_attempts = 3; let max_attempts = 3;
for attempt in 1..=max_attempts { for attempt in 1..=max_attempts {
match self.call_responses_api_inner( match self.call_responses_api_inner(
system_prompt,
user_prompt,
tools, tools,
chain, chain,
images, images,
@ -121,25 +117,31 @@ impl OpenAIClient {
fn call_responses_api_inner( fn call_responses_api_inner(
&self, &self,
system_prompt: &str,
user_prompt: &str,
tools: &[&dyn crate::domain::tools::Tool], tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain, chain: &crate::domain::workflow::Chain,
images: &[String], images: &[String],
tracer: Option<&mut openai_agents_tracing::TracingFacade>, mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, Box<dyn Error + Send + Sync>> { ) -> Result<LLMInferenceResult, Box<dyn Error + Send + Sync>> {
let url = "https://api.openai.com/v1/responses"; let url = "https://api.openai.com/v1/responses";
let request_body = build_request_dto( let request_body = build_request_dto(
&self.model, &self.model,
system_prompt,
user_prompt,
images, images,
tools, tools,
chain, chain,
tracer, tracer.as_deref_mut(),
); );
// Start span with model name and add request as JSON
if let Some(tracer) = &mut tracer {
tracer.start_span(&self.model, openai_agents_tracing::SpanKind::Generation);
// Convert request_body to JSON Value and set as input
if let Ok(request_json) = serde_json::to_value(&request_body) {
tracer.set_input_json(&self.model, request_json);
}
}
let response = self let response = self
.client .client
.post(url) .post(url)
@ -151,18 +153,35 @@ impl OpenAIClient {
let body = response.text()?; let body = response.text()?;
if !status.is_success() { if !status.is_success() {
if let Some(t) = tracer {
t.end_span(&self.model);
}
return Err(Box::new(OpenAIClientError::Api { status, body })); return Err(Box::new(OpenAIClientError::Api { status, body }));
} }
let dto = match serde_json::from_str::<ResponseDTO>(&body) { let dto = match serde_json::from_str::<ResponseDTO>(&body) {
Ok(v) => v, Ok(v) => v,
Err(e) => return Err(Box::new(OpenAIClientError::Deserialize { source: e, body })), Err(e) => {
if let Some(t) = tracer {
t.end_span(&self.model);
}
return Err(Box::new(OpenAIClientError::Deserialize { source: e, body }));
}
}; };
// Add response as JSON and end span
if let Some(tracer) = &mut tracer {
if let Ok(response_json) = serde_json::to_value(&dto) {
tracer.set_output_json(&self.model, response_json);
}
tracer.end_span(&self.model);
}
let result = build_llm_result(dto, tools); let result = build_llm_result(dto, tools);
if result.summary.is_empty() && result.tool_call.is_none() { if result.summary.is_empty() && result.tool_call.is_none() {
return Err(Box::new(OpenAIClientError::NoText { body })); return Err(Box::new(OpenAIClientError::NoText { body }));
} }
Ok(result) Ok(result)
} }
} }

View file

@ -3,13 +3,13 @@ use crate::domain::{Chain, ModelType};
use crate::domain::prompting::format_todo_list_message; use crate::domain::prompting::format_todo_list_message;
use serde::Serialize; use serde::Serialize;
use serde_json::Value; use serde_json::Value;
use crate::domain::workflow::step::StepType::{AssistantResponse, ToolCall, UserMessage}; use crate::domain::workflow::step::StepType;
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct RequestDTO { pub struct RequestDTO {
model: String, model: String,
instructions: String, instructions: String,
input: Vec<InputDto>, input: Vec<InputMessageDto>,
tools: Vec<ToolDto>, tools: Vec<ToolDto>,
tool_choice: String, tool_choice: String,
parallel_tool_calls: bool, parallel_tool_calls: bool,
@ -18,12 +18,37 @@ pub struct RequestDTO {
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub(super) struct InputDto { pub(super) struct MessageDto {
content: Vec<InputContent>, content: Vec<InputContent>,
role: String, #[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(rename = "type")] #[serde(rename = "type")]
kind: String, kind: String,
status: String, }
#[derive(Debug, Serialize)]
pub(super) struct FunctionOutputDto {
output: String,
#[serde(rename = "type")]
kind: String,
call_id: String,
}
#[derive(Debug, Serialize)]
pub(super) struct FunctionCallDto {
arguments: String,
name: String,
#[serde(rename = "type")]
kind: String,
call_id: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum InputMessageDto {
Message(MessageDto),
FunctionOutput(FunctionOutputDto),
FunctionCall(FunctionCallDto),
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -63,10 +88,25 @@ impl InputContent {
} }
} }
pub fn function(text: String) -> Self { }
Self::Text {
kind: "function".to_string(), impl FunctionOutputDto {
text, pub fn new(output: String, call_id: String) -> Self {
Self {
kind: "function_call_output".to_string(),
output,
call_id,
}
}
}
impl FunctionCallDto {
pub fn new(name: String, arguments: String, call_id: String) -> Self {
Self {
name,
kind: "function_call".to_string(),
arguments,
call_id,
} }
} }
} }
@ -83,28 +123,23 @@ pub(super) struct ToolDto {
const ROLE_USER: &str = "user"; const ROLE_USER: &str = "user";
const ROLE_SYSTEM: &str = "system"; const ROLE_SYSTEM: &str = "system";
const ROLE_ASSISTANT: &str = "assistant"; const ROLE_ASSISTANT: &str = "assistant";
const INPUT_STATUS_IN_PROGRESS: &str = "in_progress";
const INPUT_STATUS_COMPLETED: &str = "completed";
const INPUT_STATUS_FAILED: &str = "failed";
impl RequestDTO { impl RequestDTO {
pub(crate) fn new( pub(crate) fn new(
model: String, model: String,
system_prompt: String,
user_prompt: String,
tools: &[&dyn Tool], tools: &[&dyn Tool],
chain: &crate::domain::workflow::Chain, chain: &Chain,
) -> Self { ) -> Self {
// User request is now part of the chain, no need to add separately // User request is now part of the chain, no need to add separately
let input = InputDto::build(user_prompt, chain); let input = MessageDto::build(chain);
Self { Self {
model, model,
instructions: system_prompt, instructions: chain.system_prompt.clone(),
input, input,
tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(), tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(),
tool_choice: "auto".to_string(), tool_choice: "auto".to_string(),
parallel_tool_calls: true, parallel_tool_calls: false,
store: false, store: false,
stream: false, stream: false,
} }
@ -123,50 +158,46 @@ impl ToolDto {
} }
} }
impl InputDto { impl MessageDto {
fn build(user_prompt: String, chain: &Chain) -> Vec<Self> { fn build(chain: &Chain) -> Vec<InputMessageDto> {
let steps = chain.get_steps_with_history(); let steps = chain.get_steps_with_history();
let mut result: Vec<Self> = steps let mut result: Vec<InputMessageDto> = Vec::new();
.iter()
.enumerate()
.map(|(idx, step)| {
// Determine status
let is_user_message = step.step_type == UserMessage.as_str();
let status = if is_user_message || step.is_successful.unwrap_or(false) { for step in steps.iter() {
INPUT_STATUS_COMPLETED let is_user_message = step.step_type == StepType::UserMessage.as_str();
} else {
INPUT_STATUS_FAILED
};
let mut role = ROLE_ASSISTANT.to_string();
// Build content items if is_user_message {
let content_items = if is_user_message { // User message: text + optional images
role = ROLE_USER.to_string();
// For user messages, include text and images
let mut items = vec![InputContent::text(step.input_payload.clone())]; let mut items = vec![InputContent::text(step.input_payload.clone())];
// Add image content items if present
if let Some(ref images) = step.images { if let Some(ref images) = step.images {
for image_url in images { for image_url in images {
items.push(InputContent::image(image_url.clone())); items.push(InputContent::image(image_url.clone()));
} }
} }
items result.push(InputMessageDto::Message(Self {
} else { content: items,
vec![InputContent::output_text(step.get_output(ModelType::OpenAI))] role: Some(ROLE_USER.to_string()),
}; kind: "message".to_string(),
}));
} else if step.step_type == StepType::ToolCall.as_str() {
Self { let tool_name = step.tool_name.clone().unwrap();
content: content_items, let call_id = step.call_id.clone().unwrap();
role, // Tool call output is a separate DTO type
kind: "message".to_string(), result.push(InputMessageDto::FunctionCall(FunctionCallDto::new(tool_name, step.input_payload.clone(), call_id.clone())));
status: status.to_string(), result.push(InputMessageDto::FunctionOutput(FunctionOutputDto::new(step.get_output(ModelType::OpenAI), call_id)));
} else {
// Assistant message
result.push(InputMessageDto::Message(Self {
content: vec![InputContent::output_text(step.get_output(ModelType::OpenAI))],
role: Some(ROLE_ASSISTANT.to_string()),
kind: "message".to_string(),
}));
} }
}) }
.collect();
// Add the plan as system message at the beginning if it exists and is not completed // Add the plan as system message at the beginning if it exists and is not completed
if let Some(ref todo_list) = chain.todo_list { if let Some(ref todo_list) = chain.todo_list {
@ -178,22 +209,15 @@ impl InputDto {
let todo_input = Self { let todo_input = Self {
content: vec![InputContent::text(todo_message)], content: vec![InputContent::text(todo_message)],
role: ROLE_SYSTEM.to_string(), role: Some(ROLE_SYSTEM.to_string()),
kind: "message".to_string(), kind: "message".to_string(),
status: INPUT_STATUS_COMPLETED.to_string(),
}; };
result.push(todo_input);
// Put it first
result.insert(0, InputMessageDto::Message(todo_input));
} }
} }
// and adding the current user message at the end
result.push(Self {
content: vec![InputContent::text(user_prompt)],
role: ROLE_USER.to_string(),
kind: "message".to_string(),
status: INPUT_STATUS_IN_PROGRESS.to_string(),
});
result result
} }
} }

View file

@ -1,20 +1,21 @@
use serde::Deserialize; use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Serialize)]
pub struct ResponseDTO { pub struct ResponseDTO {
output: Vec<OpenAIOutputItem>, output: Vec<OpenAIOutputItem>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Serialize)]
struct OpenAIOutputItem { struct OpenAIOutputItem {
#[serde(rename = "type")] #[serde(rename = "type")]
kind: String, kind: String,
content: Option<Vec<OpenAIContentItem>>, content: Option<Vec<OpenAIContentItem>>,
name: Option<String>, name: Option<String>,
arguments: Option<String>, arguments: Option<String>,
call_id: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, Serialize)]
struct OpenAIContentItem { struct OpenAIContentItem {
#[serde(rename = "type")] #[serde(rename = "type")]
kind: String, kind: String,
@ -41,11 +42,12 @@ impl ResponseDTO {
} }
} }
} else if item.kind == "function_call" { } else if item.kind == "function_call" {
if let (Some(name), Some(arguments)) = (item.name.as_ref(), item.arguments.as_ref()) if let (Some(name), Some(arguments), Some(call_id)) = (item.name.as_ref(), item.arguments.as_ref(), item.call_id.as_ref())
{ {
tool_call = Some(FunctionCall { tool_call = Some(FunctionCall {
name: name.to_string(), name: name.to_string(),
arguments: arguments.to_string(), arguments: arguments.to_string(),
call_id: call_id.to_string(),
}); });
} }
} }
@ -59,4 +61,5 @@ impl ResponseDTO {
pub(crate) struct FunctionCall { pub(crate) struct FunctionCall {
pub name: String, pub name: String,
pub arguments: String, pub arguments: String,
pub call_id: String,
} }

View file

@ -3,8 +3,6 @@ use crate::infrastructure::inference::{LLMInferenceResult, ToolCall};
pub fn build_request_dto( pub fn build_request_dto(
model: &str, model: &str,
system_prompt: &str,
user_prompt: &str,
images: &[String], images: &[String],
tools: &[&dyn crate::domain::tools::Tool], tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain, chain: &crate::domain::workflow::Chain,
@ -25,8 +23,6 @@ pub fn build_request_dto(
let dto = RequestDTO::new( let dto = RequestDTO::new(
model.to_string(), model.to_string(),
system_prompt.to_string(),
user_prompt.to_string(),
tools, tools,
chain, chain,
); );
@ -54,9 +50,10 @@ pub fn build_llm_result(
let tool_call = if let Some(call) = tool_call_dto { let tool_call = if let Some(call) = tool_call_dto {
// Create ToolCall for the workflow to use // Create ToolCall for the workflow to use
Some(ToolCall { Some(ToolCall{
name: call.name.clone(), name: call.name.clone(),
arguments: call.arguments.clone(), arguments: call.arguments.clone(),
call_id: call.call_id.clone(),
}) })
} else { } else {
None None

View file

@ -51,4 +51,4 @@ pub fn render(frame: &mut Frame, area: Rect, state: &UiState, theme: Theme) {
let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line); let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line);
frame.set_cursor(x, y); frame.set_cursor(x, y);
} }
} }

View file

@ -159,6 +159,7 @@ fn handle_agent_event(
}); });
state.request_status = None; state.request_status = None;
state.request_progress = None; state.request_progress = None;
state.last_progress_update = None; // Reset timer for new request
state.file_changes = None; state.file_changes = None;
// Display the user's message // Display the user's message
@ -175,7 +176,22 @@ fn handle_agent_event(
summary, summary,
} => { } => {
if matches!(phase, StepPhase::Before) { if matches!(phase, StepPhase::Before) {
state.request_progress = Some(summary.clone()); // Check if enough time has passed since the last progress update
let can_update = state
.last_progress_update
.map(|last| {
std::time::Instant::now()
.duration_since(last)
.as_millis()
>= crate::infrastructure::cli::state::MIN_PROGRESS_DISPLAY_MS
})
.unwrap_or(true); // Update immediately if no previous update
if can_update {
state.request_progress = Some(summary.clone());
state.last_progress_update = Some(std::time::Instant::now());
}
// Otherwise skip this update - too soon since last one
} else { } else {
state.request_progress = None; state.request_progress = None;
} }
@ -207,6 +223,7 @@ fn handle_agent_event(
}); });
} }
state.request_progress = None; state.request_progress = None;
state.last_progress_update = None; // Reset timer when request finishes
// Update the user message color based on request status // Update the user message color based on request status
for entry in state.progress.iter_mut() { for entry in state.progress.iter_mut() {

View file

@ -33,8 +33,8 @@ impl AttachedImage {
} }
} }
pub const INPUT_MIN_HEIGHT: usize = 3; pub const INPUT_MIN_HEIGHT: usize = 1;
pub const INPUT_MAX_HEIGHT: usize = 6; pub const INPUT_MAX_HEIGHT: usize = 5;
pub const PROGRESS_HISTORY_LIMIT: usize = 200; pub const PROGRESS_HISTORY_LIMIT: usize = 200;
pub const MAIN_BODY_SCROLL_STEP: usize = 3; pub const MAIN_BODY_SCROLL_STEP: usize = 3;
@ -51,6 +51,7 @@ pub enum ProgressKind {
} }
pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2); pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2);
pub const MIN_PROGRESS_DISPLAY_MS: u128 = 1000; // 1 second minimum display time
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
#[allow(dead_code)] #[allow(dead_code)]
@ -260,6 +261,7 @@ pub struct UiState {
pub request_in_flight: Option<RequestIndicator>, pub request_in_flight: Option<RequestIndicator>,
pub request_status: Option<RequestStatusDisplay>, pub request_status: Option<RequestStatusDisplay>,
pub request_progress: Option<String>, pub request_progress: Option<String>,
pub last_progress_update: Option<Instant>,
pub file_changes: Option<FileChangesDisplay>, pub file_changes: Option<FileChangesDisplay>,
pub agent_mode: AgentModeType, pub agent_mode: AgentModeType,
pub todo_list: Option<TodoListDisplay>, pub todo_list: Option<TodoListDisplay>,
@ -300,6 +302,7 @@ impl UiState {
request_in_flight: None, request_in_flight: None,
request_status: None, request_status: None,
request_progress: None, request_progress: None,
last_progress_update: None,
file_changes: None, file_changes: None,
agent_mode: AgentModeType::Build, // Default to build mode agent_mode: AgentModeType::Build, // Default to build mode
todo_list: None, todo_list: None,
@ -353,4 +356,4 @@ impl UiState {
Some(crate::domain::ModelType::OpenAI) Some(crate::domain::ModelType::OpenAI)
) )
} }
} }

View file

@ -31,4 +31,4 @@ impl Theme {
} }
pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1); pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1);
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1); pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);

View file

@ -6,7 +6,7 @@ use crate::infrastructure::cli::helpers::{
centered_rect, cursor_position, list_state, panel_block, centered_rect, cursor_position, list_state, panel_block,
}; };
use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState}; use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState};
use crate::infrastructure::cli::theme::{Theme, PANEL_PADDING}; use crate::infrastructure::cli::theme::{Theme, INPUT_PADDING, PANEL_PADDING};
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect}; use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
use ratatui::style::{Modifier, Style}; use ratatui::style::{Modifier, Style};
use ratatui::text::{Line, Span, Text}; use ratatui::text::{Line, Span, Text};
@ -28,7 +28,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
input_lines, input_lines,
crate::infrastructure::cli::state::INPUT_MAX_HEIGHT, crate::infrastructure::cli::state::INPUT_MAX_HEIGHT,
); );
let mut input_box_height = (input_lines + 2) as u16; let mut input_box_height = (input_lines + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16;
let indicator_height = 1u16; let indicator_height = 1u16;
let attachment_indicator_height = if state.attached_images.is_empty() { let attachment_indicator_height = if state.attached_images.is_empty() {
@ -45,7 +45,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
let overflow = fixed_height - size.height; let overflow = fixed_height - size.height;
let reduced = input_box_height.saturating_sub(overflow); let reduced = input_box_height.saturating_sub(overflow);
input_box_height = input_box_height =
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + 2) as u16); reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16);
} }
let constraints = if state.attached_images.is_empty() { let constraints = if state.attached_images.is_empty() {
@ -557,4 +557,4 @@ pub fn openai_available_filtered(state: &UiState, filter: &str) -> Vec<String> {
.filter(|name| name.to_lowercase().contains(&query)) .filter(|name| name.to_lowercase().contains(&query))
.cloned() .cloned()
.collect() .collect()
} }

View file

@ -89,17 +89,39 @@ impl LocalEngine {
} }
impl InferenceEngine for LocalEngine { impl InferenceEngine for LocalEngine {
// so far very fucked up implementation, even user_prompt is not properly passed
// need to be refactored using proper request builder same way as for openai inference
fn generate( fn generate(
&self, &self,
system_prompt: &str,
user_prompt: &str,
max_tokens: u32,
_tools: &[&dyn crate::domain::tools::Tool], _tools: &[&dyn crate::domain::tools::Tool],
_chain: &crate::domain::workflow::Chain, chain: &crate::domain::workflow::Chain,
_images: &[String], _images: &[String],
_tracer: Option<&mut openai_agents_tracing::TracingFacade>, mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, InfaError> { ) -> Result<LLMInferenceResult, InfaError> {
let prompt = format!("{}\n\n{}", system_prompt, user_prompt); let model_name = "local";
// Start span with model name and add request as JSON
if let Some(tracer) = &mut tracer {
tracer.start_span(model_name, openai_agents_tracing::SpanKind::Generation);
// Set request as structured JSON input
let request_json = serde_json::json!({
"system_prompt": chain.system_prompt.clone(),
"user_prompt": chain.get_steps_with_history()[0].context_payload.clone(),
"max_tokens": 1000,
});
tracer.set_input_json(model_name, request_json);
// Set model configuration
let mut model_config = std::collections::HashMap::new();
model_config.insert("temperature".to_string(), serde_json::json!(self.params.temperature));
model_config.insert("top_p".to_string(), serde_json::json!(self.params.top_p));
model_config.insert("ctx_size".to_string(), serde_json::json!(self.params.ctx_size));
model_config.insert("threads".to_string(), serde_json::json!(self.params.threads));
tracer.set_model_config(model_name, model_config);
}
let prompt = format!("{}\n\n{}", chain.system_prompt.clone(), chain.get_steps_with_history()[0].context_payload.clone());
let to_error = |msg: String| -> InfaError { let to_error = |msg: String| -> InfaError {
std::io::Error::new(std::io::ErrorKind::Other, msg).into() std::io::Error::new(std::io::ErrorKind::Other, msg).into()
}; };
@ -145,7 +167,7 @@ impl InferenceEngine for LocalEngine {
let mut output = String::new(); let mut output = String::new();
let mut n_cur = tokens.len(); let mut n_cur = tokens.len();
for _ in 0..max_tokens { for _ in 0..2000 {
let new_token = sampler.sample(&ctx, -1); let new_token = sampler.sample(&ctx, -1);
if self.model.is_eog_token(new_token) { if self.model.is_eog_token(new_token) {
@ -169,6 +191,15 @@ impl InferenceEngine for LocalEngine {
n_cur += 1; n_cur += 1;
} }
// Add output as JSON and end span
if let Some(tracer) = tracer {
let response_json = serde_json::json!({
"text": &output,
});
tracer.set_output_json(model_name, response_json);
tracer.end_span(model_name);
}
Ok(LLMInferenceResult { Ok(LLMInferenceResult {
summary: output.trim().to_string(), summary: output.trim().to_string(),
raw_output: output, raw_output: output,

View file

@ -9,6 +9,7 @@ pub mod openai;
pub struct ToolCall { pub struct ToolCall {
pub name: String, pub name: String,
pub arguments: String, pub arguments: String,
pub call_id: String,
} }
pub struct LLMInferenceResult { pub struct LLMInferenceResult {
@ -22,9 +23,6 @@ pub trait InferenceEngine: Send + Sync {
/// Generate text without streaming output /// Generate text without streaming output
fn generate( fn generate(
&self, &self,
system_prompt: &str,
user_prompt: &str,
max_tokens: u32,
tools: &[&dyn Tool], tools: &[&dyn Tool],
chain: &Chain, chain: &Chain,
images: &[String], images: &[String],

View file

@ -66,15 +66,12 @@ impl OpenAIEngine {
impl InferenceEngine for OpenAIEngine { impl InferenceEngine for OpenAIEngine {
fn generate( fn generate(
&self, &self,
system_prompt: &str,
user_prompt: &str,
_max_tokens: u32,
tools: &[&dyn crate::domain::tools::Tool], tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain, chain: &crate::domain::workflow::Chain,
images: &[String], images: &[String],
tracer: Option<&mut openai_agents_tracing::TracingFacade>, tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, InfaError> { ) -> Result<LLMInferenceResult, InfaError> {
self.generate_with_responses_api(system_prompt, user_prompt, tools, chain, images, tracer) self.generate_with_responses_api(tools, chain, images, tracer)
} }
fn get_type(&self) -> ModelType { fn get_type(&self) -> ModelType {
ModelType::OpenAI ModelType::OpenAI
@ -85,15 +82,13 @@ impl OpenAIEngine {
/// Generate using the Responses API (for newer models like codex, o-series) /// Generate using the Responses API (for newer models like codex, o-series)
fn generate_with_responses_api( fn generate_with_responses_api(
&self, &self,
system_prompt: &str,
user_prompt: &str,
tools: &[&dyn crate::domain::tools::Tool], tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain, chain: &crate::domain::workflow::Chain,
images: &[String], images: &[String],
tracer: Option<&mut openai_agents_tracing::TracingFacade>, tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, InfaError> { ) -> Result<LLMInferenceResult, InfaError> {
self.responses_client self.responses_client
.call_responses_api(system_prompt, user_prompt, tools, chain, images, tracer) .call_responses_api(tools, chain, images, tracer)
.map_err(|e| e) .map_err(|e| e)
} }
} }