diff --git a/crates/openai-agents-tracing-rust/src/facade.rs b/crates/openai-agents-tracing-rust/src/facade.rs index 2a48c1a..704b227 100644 --- a/crates/openai-agents-tracing-rust/src/facade.rs +++ b/crates/openai-agents-tracing-rust/src/facade.rs @@ -82,6 +82,38 @@ impl TracingFacade { } } + pub fn set_model_config(&mut self, name: impl AsRef, config: HashMap) { + if let Some(span) = self.open_spans.get_mut(name.as_ref()) { + if let SpanData::Generation(ref mut data) = span.span_data { + data.model_config = Some(config); + } + } + } + + pub fn set_usage(&mut self, name: impl AsRef, input_tokens: u32, output_tokens: u32) { + if let Some(span) = self.open_spans.get_mut(name.as_ref()) { + if let SpanData::Generation(ref mut data) = span.span_data { + data.usage = Some(crate::types::UsageData::new(input_tokens, output_tokens)); + } + } + } + + pub fn set_input_json(&mut self, name: impl AsRef, input: serde_json::Value) { + if let Some(span) = self.open_spans.get_mut(name.as_ref()) { + if let SpanData::Generation(ref mut data) = span.span_data { + data.input = Some(vec![input]); + } + } + } + + pub fn set_output_json(&mut self, name: impl AsRef, output: serde_json::Value) { + if let Some(span) = self.open_spans.get_mut(name.as_ref()) { + if let SpanData::Generation(ref mut data) = span.span_data { + data.output = Some(vec![output]); + } + } + } + pub async fn end(&mut self) { for (_, mut span) in self.open_spans.drain() { span.mark_ended(); diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 412dd03..3013724 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -17,6 +17,7 @@ pub use session::{Session, SessionRequest}; pub use startup::StartupService; pub use user_settings::UserSettings; pub use workflow::{CancellationToken, Chain}; +#[allow(unused_imports)] pub use todo::{TodoList, TodoItem}; /// Model type enum matching the inference engine types @@ -41,4 +42,4 @@ impl ModelType { _ => None, } } -} +} \ No newline at end of file diff --git a/src/domain/permissions/checker.rs b/src/domain/permissions/checker.rs index b982c04..955012c 100644 --- a/src/domain/permissions/checker.rs +++ b/src/domain/permissions/checker.rs @@ -401,12 +401,17 @@ mod tests { "read_only" } - fn parse_input(&self, _input: String) -> Option { + fn parse_input(&self, _input: String, _call_id: String) -> Option { None } fn work(&self, _request: &dyn Request) -> ToolResult { - ToolResult::ok("read_only".to_string(), String::new(), String::new()) + ToolResult::ok( + "read_only".to_string(), + String::new(), + String::new(), + String::new(), + ) } fn parameters(&self) -> serde_json::Value { @@ -439,12 +444,17 @@ mod tests { "write_tool" } - fn parse_input(&self, _input: String) -> Option { + fn parse_input(&self, _input: String, _call_id: String) -> Option { None } fn work(&self, _request: &dyn Request) -> ToolResult { - ToolResult::ok("write_tool".to_string(), String::new(), String::new()) + ToolResult::ok( + "write_tool".to_string(), + String::new(), + String::new(), + String::new(), + ) } fn parameters(&self) -> serde_json::Value { @@ -474,12 +484,17 @@ mod tests { "command_tool" } - fn parse_input(&self, _input: String) -> Option { + fn parse_input(&self, _input: String, _call_id: String) -> Option { None } fn work(&self, _request: &dyn Request) -> ToolResult { - ToolResult::ok("command_tool".to_string(), String::new(), String::new()) + ToolResult::ok( + "command_tool".to_string(), + String::new(), + String::new(), + String::new(), + ) } fn parameters(&self) -> serde_json::Value { @@ -955,4 +970,4 @@ mod tests { assert_eq!(decision, PermissionDecision::AlwaysAllow); } -} +} \ No newline at end of file diff --git a/src/domain/permissions/store.rs b/src/domain/permissions/store.rs index 2268e40..515f504 100644 --- a/src/domain/permissions/store.rs +++ b/src/domain/permissions/store.rs @@ -1,7 +1,6 @@ use super::types::{Permission, PermissionDecision, PermissionScope}; use crate::infrastructure::db::DbPool; use rusqlite::params; -use std::path::PathBuf; use thiserror::Error; #[derive(Debug, Error)] @@ -121,16 +120,19 @@ impl PermissionStore for SqlitePermissionStore { ) -> Result, StoreError> { // Build query dynamically to handle NULL values properly // In SQL, NULL != '', so we need to use IS NULL for empty strings + let mut param_num = 3; let command_clause = if command_pattern.is_empty() { - "command_pattern IS NULL" + "command_pattern IS NULL".to_string() } else { - "command_pattern = ?3" + let clause = format!("command_pattern = ?{}", param_num); + param_num += 1; + clause }; let resource_clause = if resource_pattern.is_empty() { - "resource_pattern IS NULL" + "resource_pattern IS NULL".to_string() } else { - "resource_pattern = ?4" + format!("resource_pattern = ?{}", param_num) }; let query = format!( @@ -151,32 +153,28 @@ impl PermissionStore for SqlitePermissionStore { ))) })?; - // Build params based on whether patterns are empty - let result = if command_pattern.is_empty() && resource_pattern.is_empty() { - conn.query_row( - &query, - params![project_id, tool], - |row| self.row_to_permission(row), - ) - } else if command_pattern.is_empty() { - conn.query_row( - &query, - params![project_id, tool, resource_pattern], - |row| self.row_to_permission(row), - ) - } else if resource_pattern.is_empty() { - conn.query_row( - &query, - params![project_id, tool, command_pattern], - |row| self.row_to_permission(row), - ) - } else { - conn.query_row( - &query, - params![project_id, tool, command_pattern, resource_pattern], - |row| self.row_to_permission(row), - ) - }; + // Build params list dynamically to match the query + let mut param_values: Vec> = vec![ + Box::new(project_id), + Box::new(tool.to_string()), + ]; + + if !command_pattern.is_empty() { + param_values.push(Box::new(command_pattern.to_string())); + } + + if !resource_pattern.is_empty() { + param_values.push(Box::new(resource_pattern.to_string())); + } + + let params_refs: Vec<&dyn rusqlite::ToSql> = param_values + .iter() + .map(|p| p.as_ref()) + .collect(); + + let result = conn.query_row(&query, params_refs.as_slice(), |row| { + self.row_to_permission(row) + }); match result { Ok(permission) => Ok(Some(permission)), @@ -186,48 +184,140 @@ impl PermissionStore for SqlitePermissionStore { } } -impl SqlitePermissionStore { - fn find_matching_command_permission( - &self, - rows: rusqlite::MappedRows< - impl FnMut(&rusqlite::Row) -> Result, - >, - tool: &str, - command: &str, - ) -> Result, StoreError> { - for row_result in rows { - match row_result { - Ok(permission) => { - if permission.matches(tool, Some(command), None::<&PathBuf>) { - return Ok(Some(permission)); - } - } - Err(e) => return Err(StoreError::Database(e)), - } - } +#[cfg(test)] +mod tests { + use super::*; + use r2d2_sqlite::SqliteConnectionManager; - Ok(None) + fn setup_test_db() -> DbPool { + let manager = SqliteConnectionManager::memory(); + let pool = r2d2::Pool::new(manager).unwrap(); + let conn = pool.get().unwrap(); + + // Create permissions table + conn.execute( + "CREATE TABLE IF NOT EXISTS permissions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tool_name TEXT NOT NULL, + command_pattern TEXT, + resource_pattern TEXT, + decision TEXT NOT NULL, + scope TEXT NOT NULL, + project_id INTEGER, + created_at TEXT NOT NULL + )", + [], + ).unwrap(); + + pool } - fn find_matching_path_permission( - &self, - rows: rusqlite::MappedRows< - impl FnMut(&rusqlite::Row) -> Result, - >, - tool: &str, - path: &PathBuf, - ) -> Result, StoreError> { - for row_result in rows { - match row_result { - Ok(permission) => { - if permission.matches(tool, None, Some(path)) { - return Ok(Some(permission)); - } - } - Err(e) => return Err(StoreError::Database(e)), - } - } + #[test] + fn test_find_permission_with_null_patterns() { + let pool = setup_test_db(); + let store = SqlitePermissionStore::new(pool.clone()); - Ok(None) + // Create permission with NULL command and resource patterns + let permission = Permission::new( + "test_tool".to_string(), + None, + None, + PermissionDecision::AlwaysAllow, + PermissionScope::Project, + Some(1), + ); + store.create_permission(permission).unwrap(); + + // Should find the permission when searching with empty strings + let result = store.find_permission("test_tool", 1, "", "").unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow); } -} + + #[test] + fn test_find_permission_with_command_only() { + let pool = setup_test_db(); + let store = SqlitePermissionStore::new(pool.clone()); + + // Create permission with command but NULL resource + let permission = Permission::new( + "test_tool".to_string(), + Some("echo test".to_string()), + None, + PermissionDecision::AlwaysDeny, + PermissionScope::Project, + Some(1), + ); + store.create_permission(permission).unwrap(); + + // Should find the permission when searching with command and empty resource + let result = store.find_permission("test_tool", 1, "echo test", "").unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny); + } + + #[test] + fn test_find_permission_with_resource_only() { + let pool = setup_test_db(); + let store = SqlitePermissionStore::new(pool.clone()); + + // Create permission with resource but NULL command + let permission = Permission::new( + "test_tool".to_string(), + None, + Some("/path/to/file".to_string()), + PermissionDecision::AlwaysAllow, + PermissionScope::Project, + Some(1), + ); + store.create_permission(permission).unwrap(); + + // Should find the permission when searching with empty command and resource + let result = store.find_permission("test_tool", 1, "", "/path/to/file").unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow); + } + + #[test] + fn test_find_permission_with_both_patterns() { + let pool = setup_test_db(); + let store = SqlitePermissionStore::new(pool.clone()); + + // Create permission with both patterns + let permission = Permission::new( + "test_tool".to_string(), + Some("rm -rf".to_string()), + Some("/etc/passwd".to_string()), + PermissionDecision::AlwaysDeny, + PermissionScope::Project, + Some(1), + ); + store.create_permission(permission).unwrap(); + + // Should find the permission when searching with both patterns + let result = store.find_permission("test_tool", 1, "rm -rf", "/etc/passwd").unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny); + } + + #[test] + fn test_find_permission_no_match() { + let pool = setup_test_db(); + let store = SqlitePermissionStore::new(pool.clone()); + + // Create permission with command + let permission = Permission::new( + "test_tool".to_string(), + Some("echo test".to_string()), + None, + PermissionDecision::AlwaysAllow, + PermissionScope::Project, + Some(1), + ); + store.create_permission(permission).unwrap(); + + // Should NOT find when searching with different command + let result = store.find_permission("test_tool", 1, "rm -rf", "").unwrap(); + assert!(result.is_none()); + } +} \ No newline at end of file diff --git a/src/domain/permissions/types.rs b/src/domain/permissions/types.rs index f0d032e..01083c2 100644 --- a/src/domain/permissions/types.rs +++ b/src/domain/permissions/types.rs @@ -29,6 +29,7 @@ pub struct Permission { pub created_at: DateTime, } +#[allow(dead_code)] impl Permission { pub fn new( tool_name: String, @@ -173,4 +174,4 @@ impl Default for PermissionConfig { require_confirmation: true, } } -} +} \ No newline at end of file diff --git a/src/domain/prompting/general.rs b/src/domain/prompting/general.rs index a387fb7..3e92dc2 100644 --- a/src/domain/prompting/general.rs +++ b/src/domain/prompting/general.rs @@ -32,33 +32,38 @@ pub fn get_bt_tree_step_prompt( pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String { let (os_name, shell_name) = get_runtime_environment(); - let mut system_prompt = "".to_string(); if agent_mode == AgentModeType::Plan { - system_prompt = _system_prompt_for_plan(model_type); - system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); + let mut system_prompt = _system_prompt_for_plan(model_type); + if remaining_calls < 3 { + system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); + } return system_prompt; } if agent_mode == AgentModeType::BuildFromPlan { - system_prompt = _system_prompt_for_build_from_plan(model_type); - system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); + let mut system_prompt = _system_prompt_for_build_from_plan(model_type); + if remaining_calls < 3 { + system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); + } return system_prompt; } - if model_type == ModelType::OpenAI { - system_prompt = format!( + let mut system_prompt = if model_type == ModelType::OpenAI { + format!( "You are Drastis, a coding agent. \n\ Use the available tools to gather context and make changes. \ When using tools, pass JSON arguments that match their parameters. \n\ Runtime: os={}, shell={}.", os_name, shell_name - ); + ) } else { - system_prompt = format!( + format!( "You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.", os_name, shell_name - ); + ) + }; + if remaining_calls < 3 { + system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); } - system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls)); - return system_prompt; + system_prompt } fn _system_prompt_for_plan(model_type: ModelType) -> String { diff --git a/src/domain/session/request.rs b/src/domain/session/request.rs index 0a94a97..0910558 100644 --- a/src/domain/session/request.rs +++ b/src/domain/session/request.rs @@ -8,6 +8,7 @@ use std::path::Path; /// without being tightly coupled to the Session entity. pub trait Request { /// Get the history of previous requests + #[allow(dead_code)] fn history(&self) -> &[SessionRequest]; /// Get the current request prompt @@ -40,4 +41,4 @@ pub trait Request { /// Get the session's TODO list (plan) fn get_session_plan(&self) -> Option; -} +} \ No newline at end of file diff --git a/src/domain/session/session_request.rs b/src/domain/session/session_request.rs index 1879b5d..7893c67 100644 --- a/src/domain/session/session_request.rs +++ b/src/domain/session/session_request.rs @@ -4,6 +4,7 @@ use crate::repository::SessionRequestRow; /// Domain entity representing a user request within a session. /// Each request contains the user's prompt and the resulting summary. #[derive(Debug, Clone)] +#[allow(dead_code)] pub struct SessionRequest { prompt: String, result_summary: Option, @@ -32,4 +33,4 @@ impl SessionRequest { mode: row.mode, } } -} +} \ No newline at end of file diff --git a/src/domain/startup/service.rs b/src/domain/startup/service.rs index 4d25a15..53d5964 100644 --- a/src/domain/startup/service.rs +++ b/src/domain/startup/service.rs @@ -1,6 +1,7 @@ use super::Error; use crate::domain::prompting::session_naming_prompt; use crate::domain::workflow::Chain; +use crate::domain::workflow::ChainStep; use crate::domain::{Project, Session, SessionRequest}; use crate::infrastructure::db::DbPool; use crate::infrastructure::inference::InferenceEngine; @@ -83,38 +84,37 @@ impl StartupService { let prompt_preview: String = first_prompt.chars().take(100).collect(); let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview); - let chain = Chain::new(); - let session_name = - match self - .engine - .generate("", &naming_prompt, 15, &[], &chain, &[], None) - { - Ok(raw) => { - log::debug!("Raw session name response: {:?}", raw.summary); - // Clean up the response - let cleaned = raw - .summary - .lines() - .next() - .unwrap_or("") - .trim() - .trim_matches('"') - .trim_matches('*') - .replace("<|im_end|>", "") - .trim() - .to_string(); + let mut chain = Chain::new(); + chain + .steps + .push(ChainStep::user_message(naming_prompt.parse().unwrap(), Vec::new())); + let session_name = match self.engine.generate(&[], &chain, &[], None) { + Ok(raw) => { + log::debug!("Raw session name response: {:?}", raw.summary); + // Clean up the response + let cleaned = raw + .summary + .lines() + .next() + .unwrap_or("") + .trim() + .trim_matches('"') + .trim_matches('*') + .replace("<|im_end|>", "") + .trim() + .to_string(); - if cleaned.is_empty() || cleaned.len() > 50 { - Self::fallback_session_name(first_prompt) - } else { - cleaned - } - } - Err(e) => { - log::warn!("Session naming failed: {}", e); + if cleaned.is_empty() || cleaned.len() > 50 { Self::fallback_session_name(first_prompt) + } else { + cleaned } - }; + } + Err(e) => { + log::warn!("Session naming failed: {}", e); + Self::fallback_session_name(first_prompt) + } + }; // Create session in database let conn = self @@ -225,4 +225,4 @@ mod tests { assert_eq!(StartupService::fallback_session_name(""), "New Session"); } -} +} \ No newline at end of file diff --git a/src/domain/tools/discover_objects.rs b/src/domain/tools/discover_objects.rs index a8d547f..f06664c 100644 --- a/src/domain/tools/discover_objects.rs +++ b/src/domain/tools/discover_objects.rs @@ -15,6 +15,7 @@ pub struct DiscoverObjects { struct DiscoverObjectsInput { raw: String, full_path_to_file: String, + call_id: String, } #[derive(Debug, Deserialize)] @@ -69,6 +70,7 @@ impl DiscoverObjects { Ok(DiscoverObjectsInput { raw: raw.to_string(), full_path_to_file: parsed.full_path_to_file, + call_id: String::new(), }) } @@ -86,12 +88,13 @@ impl Tool for DiscoverObjects { "discover_objects" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); let parsed = Self::parse_input_json(trimmed); match parsed { - Ok(parsed) => { + Ok(mut parsed) => { + parsed.call_id = call_id; *self.input.lock().unwrap() = Some(parsed); None } @@ -103,17 +106,18 @@ impl Tool for DiscoverObjects { let input = match self.load_input() { Ok(input) => input, Err(e) => { - return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) + return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new()) } }; match Self::parse_file(&input.full_path_to_file) { Ok((lang, objects)) => ToolResult::ok( self.name().to_string(), - input.raw, + input.raw.clone(), Self::format_output(lang, &objects), + input.call_id, ), - Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()), + Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id), } } diff --git a/src/domain/tools/find_files.rs b/src/domain/tools/find_files.rs index be4ea57..b033da3 100644 --- a/src/domain/tools/find_files.rs +++ b/src/domain/tools/find_files.rs @@ -19,6 +19,7 @@ struct FindFilesInput { query: String, root: Option, max_results: usize, + call_id: String, } #[derive(Debug, Deserialize)] @@ -132,6 +133,7 @@ impl FindFiles { query: parsed.query, root: parsed.root, max_results: parsed.max_results.unwrap_or(20), + call_id: String::new(), }) } @@ -149,12 +151,13 @@ impl Tool for FindFiles { "find_files" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); let parsed = Self::parse_input_json(trimmed); match parsed { - Ok(parsed) => { + Ok(mut parsed) => { + parsed.call_id = call_id; *self.input.lock().unwrap() = Some(parsed); None } @@ -166,7 +169,7 @@ impl Tool for FindFiles { let input = match self.load_input() { Ok(input) => input, Err(e) => { - return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) + return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new()) } }; @@ -174,15 +177,16 @@ impl Tool for FindFiles { let search_root = match Self::resolve_search_root(input.root.as_deref(), request.project_root()) { Ok(root) => root, - Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e), + Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()), }; // Check if search root exists if !search_root.exists() { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Search root does not exist: {}", search_root.display()), + input.call_id, ); } @@ -204,7 +208,7 @@ impl Tool for FindFiles { output }; - ToolResult::ok(self.name().to_string(), input.raw, output) + ToolResult::ok(self.name().to_string(), input.raw, output, input.call_id) } fn parameters(&self) -> serde_json::Value { diff --git a/src/domain/tools/mod.rs b/src/domain/tools/mod.rs index 39dcb34..01080a4 100644 --- a/src/domain/tools/mod.rs +++ b/src/domain/tools/mod.rs @@ -45,6 +45,7 @@ pub struct FileChange { pub struct ToolResult { tool_name: String, + call_id: String, pub(crate) input: String, is_successful: bool, output: String, @@ -53,7 +54,7 @@ pub struct ToolResult { } impl ToolResult { - pub fn ok(tool_name: String, input: String, output: String) -> Self { + pub fn ok(tool_name: String, input: String, output: String, call_id: String) -> Self { Self { tool_name, input, @@ -61,10 +62,11 @@ impl ToolResult { output, error_message: String::new(), file_changes: None, + call_id } } - pub fn error(tool_name: String, input: String, message: String) -> Self { + pub fn error(tool_name: String, input: String, message: String, call_id: String) -> Self { Self { tool_name, input, @@ -72,6 +74,7 @@ impl ToolResult { output: String::new(), error_message: message.into(), file_changes: None, + call_id } } @@ -128,11 +131,15 @@ impl ToolResult { pub fn file_changes(&self) -> Option<&[FileChange]> { self.file_changes.as_deref() } + + pub(crate) fn call_id(&self) -> String { + self.call_id.clone() + } } pub trait Tool: Send + Sync { fn name(&self) -> &'static str; - fn parse_input(&self, input: String) -> Option; + fn parse_input(&self, input: String, call_id: String) -> Option; fn work(&self, request: &dyn Request) -> ToolResult; fn parameters(&self) -> Value; fn desc(&self) -> String; diff --git a/src/domain/tools/patch_files.rs b/src/domain/tools/patch_files.rs index 3896326..ff138df 100644 --- a/src/domain/tools/patch_files.rs +++ b/src/domain/tools/patch_files.rs @@ -18,6 +18,7 @@ pub struct PatchFiles { struct PatchFilesInput { raw: String, patch: String, + call_id: String, } #[derive(Debug, Deserialize)] @@ -41,6 +42,7 @@ impl PatchFiles { Ok(PatchFilesInput { raw: raw.to_string(), patch: parsed.patch, + call_id: String::new(), }) } @@ -58,12 +60,13 @@ impl Tool for PatchFiles { "patch_files" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); let parsed = Self::parse_input_json(trimmed); match parsed { - Ok(parsed) => { + Ok(mut parsed) => { + parsed.call_id = call_id; *self.input.lock().unwrap() = Some(parsed); None } @@ -75,7 +78,7 @@ impl Tool for PatchFiles { let input = match self.load_input() { Ok(input) => input, Err(e) => { - return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) + return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new()) } }; @@ -84,8 +87,9 @@ impl Tool for PatchFiles { Err(e) => { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Failed to parse patch: {}", e), + input.call_id.clone(), ) } }; @@ -115,8 +119,9 @@ impl Tool for PatchFiles { { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Invalid path outside project root: {}", path), + input.call_id.clone(), ); } let fs_path = project_root.join(rel_path); @@ -129,8 +134,9 @@ impl Tool for PatchFiles { Err(e) => { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Failed to read file '{}': {}", fs_path.display(), e), + input.call_id.clone(), ) } } @@ -142,8 +148,9 @@ impl Tool for PatchFiles { Err(e) => { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Failed to apply patch: {}", e), + input.call_id.clone(), ) } }; @@ -155,28 +162,31 @@ impl Tool for PatchFiles { if let Err(e) = fs::create_dir_all(parent) { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!( "Failed to create parent directories for '{}': {}", fs_path.display(), e ), + input.call_id.clone(), ); } } if let Err(e) = fs::write(&fs_path, content) { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Failed to write file '{}': {}", fs_path.display(), e), + input.call_id.clone(), ); } } else if fs_path.exists() { if let Err(e) = fs::remove_file(&fs_path) { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Failed to delete file '{}': {}", fs_path.display(), e), + input.call_id.clone(), ); } } @@ -224,6 +234,7 @@ impl Tool for PatchFiles { self.name().to_string(), input.raw, "Patch applied successfully".to_string(), + input.call_id, ); if !file_changes.is_empty() { @@ -531,7 +542,7 @@ mod tests { "{{\"patch\":\"{}\"}}", patch.replace('\n', "\\n").replace('"', "\\\"") ); - assert!(tool.parse_input(input).is_none()); + assert!(tool.parse_input(input, "call-id".to_string()).is_none()); let result = tool.work(&request); assert!( @@ -571,7 +582,7 @@ mod tests { "{{\"patch\":\"{}\"}}", patch.replace('\n', "\\n").replace('"', "\\\"") ); - assert!(tool.parse_input(input).is_none()); + assert!(tool.parse_input(input, "call-id".to_string()).is_none()); let result = tool.work(&request); assert!( @@ -592,7 +603,7 @@ mod tests { let tool = PatchFiles::new(); let input = "{\"patch\":\"not a patch\"}".to_string(); - assert!(tool.parse_input(input).is_none()); + assert!(tool.parse_input(input, "call-id".to_string()).is_none()); let result = tool.work(&request); assert!( @@ -601,4 +612,4 @@ mod tests { result.output_string() ); } -} +} \ No newline at end of file diff --git a/src/domain/tools/read_objects.rs b/src/domain/tools/read_objects.rs index 211dd59..ef9e66f 100644 --- a/src/domain/tools/read_objects.rs +++ b/src/domain/tools/read_objects.rs @@ -18,6 +18,7 @@ struct ReadObjectsInput { raw: String, full_path_to_file: String, queries: Vec, + call_id: String, } #[derive(Debug, Deserialize)] @@ -138,6 +139,7 @@ impl ReadObjects { raw: raw.to_string(), full_path_to_file: trimmed_path.to_string(), queries, + call_id: String::new(), }) } @@ -182,12 +184,13 @@ impl Tool for ReadObjects { "read_objects" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); let parsed = Self::parse_input_json(trimmed); match parsed { - Ok(parsed) => { + Ok(mut parsed) => { + parsed.call_id = call_id; *self.input.lock().unwrap() = Some(parsed); None } @@ -199,17 +202,18 @@ impl Tool for ReadObjects { let input = match self.load_input() { Ok(input) => input, Err(e) => { - return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) + return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new()) } }; match Self::read_objects(&input.full_path_to_file, &input.queries) { Ok((lang, results)) => ToolResult::ok( self.name().to_string(), - input.raw, + input.raw.clone(), Self::format_output(lang, results), + input.call_id, ), - Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()), + Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id), } } @@ -224,7 +228,7 @@ impl Tool for ReadObjects { }, "query": { "type": "string", - "description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\")", + "description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\"), grepping is not working here, use exact names", "minLength": 1 } }, @@ -235,8 +239,8 @@ impl Tool for ReadObjects { fn desc(&self) -> String { format!( - "Use the `{}` tool to read source code of specific objects from a file. To determine correct properties to use for `{}`, use the `discover_objects` tool first.", - self.name(), self.name() + "Use the `{}` tool to read source code of specific objects from a file. To determine correct properties, use the `discover_objects` tool first.", + self.name(), ) } @@ -293,7 +297,7 @@ mod tests { fn test_parse_input_valid() { let tool = ReadObjects::new(); let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!(result.is_none(), "Expected no error, got: {:?}", result); // Verify the parsed input @@ -308,7 +312,7 @@ mod tests { fn test_parse_input_empty_query() { let tool = ReadObjects::new(); let input = r#"{"path": "src/main.rs", "query": ""}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!(result.is_some(), "Expected error for empty query"); if let Some(Error::Parse(msg)) = result { @@ -326,7 +330,7 @@ mod tests { fn test_parse_input_whitespace_query() { let tool = ReadObjects::new(); let input = r#"{"path": "src/main.rs", "query": " "}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!(result.is_some(), "Expected error for whitespace query"); if let Some(Error::Parse(msg)) = result { @@ -344,7 +348,7 @@ mod tests { fn test_parse_input_comma_only_query() { let tool = ReadObjects::new(); let input = r#"{"path": "src/main.rs", "query": ",,,"}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!(result.is_some(), "Expected error for comma-only query"); if let Some(Error::Parse(msg)) = result { @@ -362,7 +366,7 @@ mod tests { fn test_parse_input_missing_path() { let tool = ReadObjects::new(); let input = r#"{"path": "", "query": "main"}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!(result.is_some(), "Expected error for empty path"); if let Some(Error::Parse(msg)) = result { @@ -407,7 +411,7 @@ mod tests { fn test_parse_input_malformed_json() { let tool = ReadObjects::new(); let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!(result.is_some(), "Expected error for malformed JSON"); } @@ -415,7 +419,7 @@ mod tests { fn test_parse_input_path_with_whitespace() { let tool = ReadObjects::new(); let input = r#"{"path": " src/main.rs ", "query": "main"}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!( result.is_none(), "Expected no error for path with whitespace" @@ -432,7 +436,7 @@ mod tests { fn test_parse_input_query_with_extra_whitespace() { let tool = ReadObjects::new(); let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!( result.is_none(), "Expected no error for query with extra whitespace" @@ -448,7 +452,7 @@ mod tests { fn test_parse_input_single_object() { let tool = ReadObjects::new(); let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#; - let result = tool.parse_input(input.to_string()); + let result = tool.parse_input(input.to_string(), "call-id".to_string()); assert!( result.is_none(), "Expected no error for single object query" @@ -481,4 +485,4 @@ mod tests { }; assert!(query.matches(&obj_partial), "Should match partial name"); } -} +} \ No newline at end of file diff --git a/src/domain/tools/shell_exec.rs b/src/domain/tools/shell_exec.rs index 681d0da..387e7ff 100644 --- a/src/domain/tools/shell_exec.rs +++ b/src/domain/tools/shell_exec.rs @@ -25,6 +25,7 @@ struct ShellExecInputParsed { raw: String, command: String, working_dir: Option, + call_id: String, } impl Tool for ShellExec { @@ -32,7 +33,7 @@ impl Tool for ShellExec { "shell_exec" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); let parsed = serde_json::from_str::(trimmed) .map_err(|e| Error::Parse(e.to_string())); @@ -46,6 +47,7 @@ impl Tool for ShellExec { raw: trimmed.to_string(), command: parsed.command, working_dir: parsed.working_dir, + call_id, }); None } @@ -61,6 +63,7 @@ impl Tool for ShellExec { self.name().to_string(), String::new(), "input not parsed".to_string(), + String::new(), ) } }; @@ -72,15 +75,17 @@ impl Tool for ShellExec { if !path.exists() { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Working directory does not exist: {}", dir), + input.call_id.clone(), ); } if !crate::utils::paths::is_within_root(path, request.project_root()) { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), "Working directory is outside project root".to_string(), + input.call_id.clone(), ); } path.to_path_buf() @@ -99,8 +104,9 @@ impl Tool for ShellExec { Err(e) => { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Failed to execute command: {}", e), + input.call_id.clone(), ) } }; @@ -129,10 +135,10 @@ impl Tool for ShellExec { if !stderr.is_empty() { result.push_str(&format!("[stderr]: {}", stderr)); } - return ToolResult::error(self.name().to_string(), input.raw, result); + return ToolResult::error(self.name().to_string(), input.raw.clone(), result, input.call_id.clone()); }; - ToolResult::ok(self.name().to_string(), input.raw, result) + ToolResult::ok(self.name().to_string(), input.raw, result, input.call_id) } fn parameters(&self) -> serde_json::Value { @@ -155,8 +161,9 @@ impl Tool for ShellExec { fn desc(&self) -> String { format!( - r#"Use `{}` tool to execute shell commands. -Please DO NOT use it to read the full content of a file, this is not efficient, use `read_objects` tool for this."#, + "Use `{}` tool to execute shell commands. +DO NOT use it to read code, this is not efficient, use `read_objects` tool for this. \n\ +DO NOT use it to change files, use `patch_files` tool for this.", self.name() ) } @@ -339,7 +346,7 @@ mod tests { let tool = ShellExec::new(); assert!(tool - .parse_input(r#"{"command":"echo hello"}"#.to_string()) + .parse_input(r#"{"command":"echo hello"}"#.to_string(), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -357,7 +364,7 @@ mod tests { let tool = ShellExec::new(); assert!(tool - .parse_input(r#"{"command":"pwd"}"#.to_string()) + .parse_input(r#"{"command":"pwd"}"#.to_string(), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -384,7 +391,7 @@ mod tests { .parse_input(format!( r#"{{"command":"pwd","working_dir":"{}"}}"#, subdir.display() - )) + ), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -405,7 +412,7 @@ mod tests { let tool = ShellExec::new(); assert!(tool - .parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string()) + .parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string(), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -424,7 +431,7 @@ mod tests { let tool = ShellExec::new(); assert!(tool - .parse_input(r#"{"command":"exit 1"}"#.to_string()) + .parse_input(r#"{"command":"exit 1"}"#.to_string(), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -443,7 +450,7 @@ mod tests { let tool = ShellExec::new(); assert!(tool - .parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string()) + .parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string(), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -457,7 +464,7 @@ mod tests { let request = TestRequest::new(temp.path()); let tool = ShellExec::new(); - let err = tool.parse_input(r#"{"command":""}"#.to_string()); + let err = tool.parse_input(r#"{"command":""}"#.to_string(), "call-id".to_string()); assert!(err.is_some()); let result = tool.work(&request); assert!(result.output_string().contains("Error")); @@ -474,7 +481,7 @@ mod tests { .parse_input(format!( r#"{{"command":"echo 'test content' > {}"}}"#, file_path.display() - )) + ), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -496,7 +503,7 @@ mod tests { let tool = ShellExec::new(); assert!(tool - .parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string()) + .parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string(), "call-id".to_string()) .is_none()); let result = tool.work(&request); @@ -506,4 +513,4 @@ mod tests { result.output_string() ); } -} +} \ No newline at end of file diff --git a/src/domain/tools/structure.rs b/src/domain/tools/structure.rs index 6b0f53a..82f6c16 100644 --- a/src/domain/tools/structure.rs +++ b/src/domain/tools/structure.rs @@ -16,6 +16,7 @@ struct StructureInput { raw: String, path: String, max_depth: usize, + call_id: String, } #[derive(Debug, Deserialize)] @@ -127,6 +128,7 @@ impl Structure { raw: raw.to_string(), path: parsed.path.unwrap_or_else(|| ".".to_string()), max_depth: parsed.max_depth.unwrap_or(3), + call_id: String::new(), }) } @@ -144,12 +146,13 @@ impl Tool for Structure { "structure" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); let parsed = Self::parse_input_json(trimmed); match parsed { - Ok(parsed) => { + Ok(mut parsed) => { + parsed.call_id = call_id; *self.input.lock().unwrap() = Some(parsed); None } @@ -161,34 +164,36 @@ impl Tool for Structure { let input = match self.load_input() { Ok(input) => input, Err(e) => { - return ToolResult::error(self.name().to_string(), String::new(), e.to_string()) + return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new()) } }; let target_path = match Self::resolve_path(&input.path, request.project_root()) { Ok(p) => p, - Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e), + Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()), }; if !target_path.exists() { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Path does not exist: {}", target_path.display()), + input.call_id.clone(), ); } if !target_path.is_dir() { return ToolResult::error( self.name().to_string(), - input.raw, + input.raw.clone(), format!("Path is not a directory: {}", target_path.display()), + input.call_id, ); } let tree = Self::build_tree(&target_path, input.max_depth); - ToolResult::ok(self.name().to_string(), input.raw, tree) + ToolResult::ok(self.name().to_string(), input.raw, tree, input.call_id) } fn parameters(&self) -> serde_json::Value { diff --git a/src/domain/tools/update_todo_list.rs b/src/domain/tools/update_todo_list.rs index 89ab92e..6c2e6d1 100644 --- a/src/domain/tools/update_todo_list.rs +++ b/src/domain/tools/update_todo_list.rs @@ -18,6 +18,7 @@ pub struct UpdateTodoList { #[derive(Debug, Clone)] struct UpdateTodoListInput { raw: String, + call_id: String, } impl UpdateTodoList { @@ -36,12 +37,12 @@ impl Tool for UpdateTodoList { "update_todo_list" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let parsed: Result = serde_json::from_str(&input); match parsed { Ok(_) => { let mut lock = self.input.lock().unwrap(); - *lock = Some(UpdateTodoListInput { raw: input }); + *lock = Some(UpdateTodoListInput { raw: input, call_id }); None } Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))), @@ -57,6 +58,7 @@ impl Tool for UpdateTodoList { self.name().to_string(), String::new(), "No input provided".to_string(), + String::new(), ) } }; @@ -70,6 +72,7 @@ impl Tool for UpdateTodoList { self.name().to_string(), input.raw.clone(), format!("Failed to get database connection: {}", e), + input.call_id.clone(), ) } }; @@ -83,6 +86,7 @@ impl Tool for UpdateTodoList { self.name().to_string(), input.raw.clone(), format!("Failed to get/create TODO list: {}", e), + input.call_id.clone(), ) } }; @@ -92,6 +96,7 @@ impl Tool for UpdateTodoList { self.name().to_string(), input.raw.clone(), format!("Failed to update TODO list: {}", e), + input.call_id.clone(), ); } @@ -106,6 +111,7 @@ impl Tool for UpdateTodoList { self.name().to_string(), input.raw, "TODO list updated successfully.".to_string(), + input.call_id, ) } diff --git a/src/domain/tools/web_search.rs b/src/domain/tools/web_search.rs index b125c61..de177d0 100644 --- a/src/domain/tools/web_search.rs +++ b/src/domain/tools/web_search.rs @@ -16,6 +16,7 @@ struct WebSearchInput { raw: String, query: String, max_results: u32, + call_id: String, } #[derive(Debug, Deserialize)] @@ -94,7 +95,7 @@ impl Tool for WebSearch { "web_search" } - fn parse_input(&self, input: String) -> Option { + fn parse_input(&self, input: String, call_id: String) -> Option { let trimmed = input.trim(); match serde_json::from_str::(trimmed) { Ok(parsed) => { @@ -103,6 +104,7 @@ impl Tool for WebSearch { raw: input, query: parsed.query, max_results, + call_id, }; *self.input.lock().unwrap() = Some(parsed_input); None @@ -118,7 +120,7 @@ impl Tool for WebSearch { let input = match self.load_input() { Ok(input) => input, Err(e) => { - return ToolResult::error(self.name().to_string(), String::new(), e); + return ToolResult::error(self.name().to_string(), String::new(), e, String::new()); } }; @@ -130,6 +132,7 @@ impl Tool for WebSearch { self.name().to_string(), input.raw.clone(), "User settings not available".to_string(), + input.call_id.clone(), ); } }; @@ -139,6 +142,7 @@ impl Tool for WebSearch { self.name().to_string(), input.raw.clone(), "Web search is not enabled in settings".to_string(), + input.call_id.clone(), ); } @@ -149,6 +153,7 @@ impl Tool for WebSearch { self.name().to_string(), input.raw.clone(), "Brave API key not configured".to_string(), + input.call_id.clone(), ); } }; @@ -159,12 +164,13 @@ impl Tool for WebSearch { let output = WebSearchOutput { results }; match serde_json::to_string(&output) { Ok(json_output) => { - ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output) + ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output, input.call_id) } Err(e) => ToolResult::error( self.name().to_string(), input.raw.clone(), format!("Failed to serialize output: {}", e), + input.call_id.clone(), ), } } @@ -172,6 +178,7 @@ impl Tool for WebSearch { self.name().to_string(), input.raw.clone(), format!("Search failed: {}", e), + input.call_id, ), } } @@ -593,8 +600,10 @@ mod tests { user_settings, }; let web_search = WebSearch::new(); - let _ = - web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string()); + let _ = web_search.parse_input( + r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(), + "call-id".to_string(), + ); let _allowed = checker.check(&web_search, &request, Some(1)).unwrap(); @@ -629,8 +638,10 @@ mod tests { user_settings, }; let web_search = WebSearch::new(); - let _ = - web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string()); + let _ = web_search.parse_input( + r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(), + "call-id".to_string(), + ); let allowed = checker.check(&web_search, &request, Some(1)).unwrap(); @@ -640,4 +651,4 @@ mod tests { assert_eq!(created[0].tool_name, "web_search"); assert_eq!(created[0].resource_pattern, None); } -} +} \ No newline at end of file diff --git a/src/domain/workflow/chain.rs b/src/domain/workflow/chain.rs index c9732d6..be07c0a 100644 --- a/src/domain/workflow/chain.rs +++ b/src/domain/workflow/chain.rs @@ -17,6 +17,7 @@ pub struct Chain { pub fail_reason: String, #[serde(default)] pub final_message: Option, + pub system_prompt: String, } impl Chain { @@ -28,6 +29,7 @@ impl Chain { is_failed: false, fail_reason: String::new(), final_message: None, + system_prompt: String::new(), } } @@ -120,6 +122,10 @@ impl Chain { .collect::>() .join("\n") } + + pub fn set_system_prompt(&mut self, system_prompt: String) { + self.system_prompt = system_prompt; + } #[allow(dead_code)] pub fn total_payload_len_chars(&self) -> usize { diff --git a/src/domain/workflow/step.rs b/src/domain/workflow/step.rs index 7038a9c..656f7a6 100644 --- a/src/domain/workflow/step.rs +++ b/src/domain/workflow/step.rs @@ -39,6 +39,8 @@ pub struct ChainStep { #[serde(default)] pub tool_name: Option, #[serde(default)] + pub call_id: Option, + #[serde(default)] pub tool_output: Option, #[serde(default)] pub is_successful: Option, @@ -60,6 +62,7 @@ impl ChainStep { let mut tool_output = None; let mut is_successful = None; let mut file_changes = None; + let mut call_id = None; if let Some(tr) = tool_result { summary = tr.summary(); context_payload = tr.output_string(); @@ -68,6 +71,7 @@ impl ChainStep { tool_output = Some(tr.output_raw().to_string()); is_successful = Some(tr.is_successful()); file_changes = tr.file_changes().map(|fc| fc.to_vec()); + call_id = Some(tr.call_id()); } Self { @@ -76,6 +80,7 @@ impl ChainStep { context_payload, input_payload, tool_name, + call_id, tool_output, is_successful, file_changes, @@ -91,6 +96,7 @@ impl ChainStep { context_payload: raw_output.clone(), input_payload: String::new(), tool_name: None, + call_id: None, tool_output: Some(raw_output), is_successful: Some(true), file_changes: None, @@ -118,6 +124,7 @@ impl ChainStep { context_payload: prompt.clone(), input_payload: prompt, tool_name: None, + call_id: None, tool_output: None, is_successful: Some(true), file_changes: None, diff --git a/src/domain/workflow/tool_runner.rs b/src/domain/workflow/tool_runner.rs index af2af9b..6feffd8 100644 --- a/src/domain/workflow/tool_runner.rs +++ b/src/domain/workflow/tool_runner.rs @@ -41,6 +41,7 @@ impl ToolRunner { tool_call.name.clone(), tool_call.arguments.clone(), error_msg, + tool_call.call_id.clone(), ); } }; @@ -73,11 +74,13 @@ impl ToolRunner { tool.name().to_string(), String::new(), "Permission denied".to_string(), + String::new(), ), Err(err) => ToolResult::error( tool.name().to_string(), String::new(), format!("Permission check error: {}", err), + String::new(), ), }; diff --git a/src/domain/workflow/toolset/mod.rs b/src/domain/workflow/toolset/mod.rs index 3f8267f..ff1c934 100644 --- a/src/domain/workflow/toolset/mod.rs +++ b/src/domain/workflow/toolset/mod.rs @@ -91,7 +91,7 @@ pub trait Toolset { .map(|t| t.as_ref()) .ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?; - if let Some(err) = tool.parse_input(tool_call.arguments.clone()) { + if let Some(err) = tool.parse_input(tool_call.arguments.clone(), tool_call.call_id.clone()) { return Err(err); } diff --git a/src/domain/workflow/workflow.rs b/src/domain/workflow/workflow.rs index 70b1436..6d03a7a 100644 --- a/src/domain/workflow/workflow.rs +++ b/src/domain/workflow/workflow.rs @@ -5,6 +5,7 @@ use super::{ use crate::domain::bt::GeneralTree; use crate::domain::prompting; use crate::domain::session::Request; +use crate::domain::todo::TodoListStatus; use crate::infrastructure::db::DbPool; use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase}; use crate::infrastructure::inference::InferenceEngine; @@ -16,7 +17,6 @@ use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use crossbeam_channel::Sender; use crate::domain::AgentModeType; -use crate::domain::todo::TodoListStatus; /// Main workflow orchestrator that runs LLM-driven coding tasks /// Implements an eternal agent loop that: @@ -94,7 +94,6 @@ impl Workflow { self.chain.set_todo_list(request.get_session_plan()); let result = self._run(request, cancel, None, None, mode); - self.chain.steps.push(ChainStep::user_message(request.current_request().to_string(), request.images().to_vec())); self._end_tracing(); result } @@ -145,6 +144,7 @@ impl Workflow { context_payload: String::new(), input_payload: step_user_prompt.to_string(), tool_name: None, + call_id: None, tool_output: None, is_successful: Some(true), file_changes: None, @@ -171,6 +171,13 @@ impl Workflow { mode: AgentModeType, ) -> Result<(), Error> { + let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request); + let user_prompt = user_prompt_override + .as_deref() + .unwrap_or(&base_user_prompt) + .to_string(); + self.chain.steps.push(ChainStep::user_message(user_prompt.clone(), request.images().to_vec())); + // Get max_tool_calls from override (for BT mode) or user settings let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| { request.user_settings() @@ -198,11 +205,7 @@ impl Workflow { // Get base system prompt and inject remaining count let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls); - let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request); - let user_prompt = user_prompt_override - .as_deref() - .unwrap_or(&base_user_prompt) - .to_string(); + self.chain.set_system_prompt(system_prompt.clone()); // Switch to finishing toolset if approaching limit if !in_finishing_mode && tool_call_count >= finishing_threshold { @@ -220,13 +223,9 @@ impl Workflow { } self._emit_inference_progress(); - self._trace_llm_start(user_prompt.clone()); // Ask LLM to choose next tool let llm_output = match self.engine.generate( - &system_prompt, - &user_prompt, - 1024, &self.toolset.tool_refs(), &self.chain, request.images(), @@ -238,8 +237,6 @@ impl Workflow { } }; - self._trace_llm_end(llm_output.raw_output.clone()); - // Always capture the assistant's response in the chain if !llm_output.raw_output.is_empty() { self.chain.steps.push(ChainStep::assistant_response( @@ -347,13 +344,6 @@ impl Workflow { self.chain = Chain::new(); } - fn _trace_llm_start(&mut self, prompt: String) { - if let Some(tracer) = &mut self.tracer { - tracer.start_span("LLM generation", SpanKind::Generation); - tracer.add_input("LLM generation", prompt); - } - } - fn _emit_inference_progress(&self) { let options = [ "Thinking.. well kinda..", @@ -383,18 +373,10 @@ impl Workflow { }); } - fn _trace_llm_end(&mut self, output: String) { - if let Some(tracer) = &mut self.tracer { - tracer.add_output("LLM generation", output); - tracer.end_span("LLM generation"); - } - } - /// Detects if the active TODO item has changed /// Returns true if the first non-completed item's title is different fn _did_todo_item_change(&self, previous_todo: &Option) -> bool { - use crate::domain::todo::TodoListStatus; - + // Get first non-completed item from previous TODO list let prev_active = previous_todo.as_ref().and_then(|list| { list.items.iter() diff --git a/src/infrastructure/api_clients/openai/client.rs b/src/infrastructure/api_clients/openai/client.rs index 0f2ec05..d2eb83f 100644 --- a/src/infrastructure/api_clients/openai/client.rs +++ b/src/infrastructure/api_clients/openai/client.rs @@ -64,8 +64,6 @@ impl OpenAIClient { pub fn call_responses_api( &self, - system_prompt: &str, - user_prompt: &str, tools: &[&dyn crate::domain::tools::Tool], chain: &crate::domain::workflow::Chain, images: &[String], @@ -74,8 +72,6 @@ impl OpenAIClient { let max_attempts = 3; for attempt in 1..=max_attempts { match self.call_responses_api_inner( - system_prompt, - user_prompt, tools, chain, images, @@ -121,25 +117,31 @@ impl OpenAIClient { fn call_responses_api_inner( &self, - system_prompt: &str, - user_prompt: &str, tools: &[&dyn crate::domain::tools::Tool], chain: &crate::domain::workflow::Chain, images: &[String], - tracer: Option<&mut openai_agents_tracing::TracingFacade>, + mut tracer: Option<&mut openai_agents_tracing::TracingFacade>, ) -> Result> { let url = "https://api.openai.com/v1/responses"; let request_body = build_request_dto( &self.model, - system_prompt, - user_prompt, images, tools, chain, - tracer, + tracer.as_deref_mut(), ); + // Start span with model name and add request as JSON + if let Some(tracer) = &mut tracer { + tracer.start_span(&self.model, openai_agents_tracing::SpanKind::Generation); + + // Convert request_body to JSON Value and set as input + if let Ok(request_json) = serde_json::to_value(&request_body) { + tracer.set_input_json(&self.model, request_json); + } + } + let response = self .client .post(url) @@ -151,18 +153,35 @@ impl OpenAIClient { let body = response.text()?; if !status.is_success() { + if let Some(t) = tracer { + t.end_span(&self.model); + } return Err(Box::new(OpenAIClientError::Api { status, body })); } let dto = match serde_json::from_str::(&body) { Ok(v) => v, - Err(e) => return Err(Box::new(OpenAIClientError::Deserialize { source: e, body })), + Err(e) => { + if let Some(t) = tracer { + t.end_span(&self.model); + } + return Err(Box::new(OpenAIClientError::Deserialize { source: e, body })); + } }; + // Add response as JSON and end span + if let Some(tracer) = &mut tracer { + if let Ok(response_json) = serde_json::to_value(&dto) { + tracer.set_output_json(&self.model, response_json); + } + tracer.end_span(&self.model); + } + let result = build_llm_result(dto, tools); if result.summary.is_empty() && result.tool_call.is_none() { return Err(Box::new(OpenAIClientError::NoText { body })); } + Ok(result) } } diff --git a/src/infrastructure/api_clients/openai/dto/request.rs b/src/infrastructure/api_clients/openai/dto/request.rs index cacaa74..78b89ad 100644 --- a/src/infrastructure/api_clients/openai/dto/request.rs +++ b/src/infrastructure/api_clients/openai/dto/request.rs @@ -3,13 +3,13 @@ use crate::domain::{Chain, ModelType}; use crate::domain::prompting::format_todo_list_message; use serde::Serialize; use serde_json::Value; -use crate::domain::workflow::step::StepType::{AssistantResponse, ToolCall, UserMessage}; +use crate::domain::workflow::step::StepType; #[derive(Debug, Serialize)] pub struct RequestDTO { model: String, instructions: String, - input: Vec, + input: Vec, tools: Vec, tool_choice: String, parallel_tool_calls: bool, @@ -18,12 +18,37 @@ pub struct RequestDTO { } #[derive(Debug, Serialize)] -pub(super) struct InputDto { +pub(super) struct MessageDto { content: Vec, - role: String, + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, #[serde(rename = "type")] kind: String, - status: String, +} + +#[derive(Debug, Serialize)] +pub(super) struct FunctionOutputDto { + output: String, + #[serde(rename = "type")] + kind: String, + call_id: String, +} + +#[derive(Debug, Serialize)] +pub(super) struct FunctionCallDto { + arguments: String, + name: String, + #[serde(rename = "type")] + kind: String, + call_id: String, +} + +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum InputMessageDto { + Message(MessageDto), + FunctionOutput(FunctionOutputDto), + FunctionCall(FunctionCallDto), } #[derive(Debug, Serialize)] @@ -63,10 +88,25 @@ impl InputContent { } } - pub fn function(text: String) -> Self { - Self::Text { - kind: "function".to_string(), - text, +} + +impl FunctionOutputDto { + pub fn new(output: String, call_id: String) -> Self { + Self { + kind: "function_call_output".to_string(), + output, + call_id, + } + } +} + +impl FunctionCallDto { + pub fn new(name: String, arguments: String, call_id: String) -> Self { + Self { + name, + kind: "function_call".to_string(), + arguments, + call_id, } } } @@ -83,28 +123,23 @@ pub(super) struct ToolDto { const ROLE_USER: &str = "user"; const ROLE_SYSTEM: &str = "system"; const ROLE_ASSISTANT: &str = "assistant"; -const INPUT_STATUS_IN_PROGRESS: &str = "in_progress"; -const INPUT_STATUS_COMPLETED: &str = "completed"; -const INPUT_STATUS_FAILED: &str = "failed"; impl RequestDTO { pub(crate) fn new( model: String, - system_prompt: String, - user_prompt: String, tools: &[&dyn Tool], - chain: &crate::domain::workflow::Chain, + chain: &Chain, ) -> Self { // User request is now part of the chain, no need to add separately - let input = InputDto::build(user_prompt, chain); + let input = MessageDto::build(chain); Self { model, - instructions: system_prompt, + instructions: chain.system_prompt.clone(), input, tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(), tool_choice: "auto".to_string(), - parallel_tool_calls: true, + parallel_tool_calls: false, store: false, stream: false, } @@ -123,50 +158,46 @@ impl ToolDto { } } -impl InputDto { - fn build(user_prompt: String, chain: &Chain) -> Vec { +impl MessageDto { + fn build(chain: &Chain) -> Vec { let steps = chain.get_steps_with_history(); - let mut result: Vec = steps - .iter() - .enumerate() - .map(|(idx, step)| { - // Determine status - let is_user_message = step.step_type == UserMessage.as_str(); + let mut result: Vec = Vec::new(); - let status = if is_user_message || step.is_successful.unwrap_or(false) { - INPUT_STATUS_COMPLETED - } else { - INPUT_STATUS_FAILED - }; - let mut role = ROLE_ASSISTANT.to_string(); + for step in steps.iter() { + let is_user_message = step.step_type == StepType::UserMessage.as_str(); - // Build content items - let content_items = if is_user_message { - role = ROLE_USER.to_string(); - // For user messages, include text and images + if is_user_message { + // User message: text + optional images let mut items = vec![InputContent::text(step.input_payload.clone())]; - // Add image content items if present if let Some(ref images) = step.images { for image_url in images { items.push(InputContent::image(image_url.clone())); } } - items - } else { - vec![InputContent::output_text(step.get_output(ModelType::OpenAI))] - }; + result.push(InputMessageDto::Message(Self { + content: items, + role: Some(ROLE_USER.to_string()), + kind: "message".to_string(), + })); + } else if step.step_type == StepType::ToolCall.as_str() { - Self { - content: content_items, - role, - kind: "message".to_string(), - status: status.to_string(), + let tool_name = step.tool_name.clone().unwrap(); + let call_id = step.call_id.clone().unwrap(); + // Tool call output is a separate DTO type + result.push(InputMessageDto::FunctionCall(FunctionCallDto::new(tool_name, step.input_payload.clone(), call_id.clone()))); + result.push(InputMessageDto::FunctionOutput(FunctionOutputDto::new(step.get_output(ModelType::OpenAI), call_id))); + } else { + // Assistant message + result.push(InputMessageDto::Message(Self { + content: vec![InputContent::output_text(step.get_output(ModelType::OpenAI))], + role: Some(ROLE_ASSISTANT.to_string()), + kind: "message".to_string(), + })); } - }) - .collect(); + } // Add the plan as system message at the beginning if it exists and is not completed if let Some(ref todo_list) = chain.todo_list { @@ -178,22 +209,15 @@ impl InputDto { let todo_input = Self { content: vec![InputContent::text(todo_message)], - role: ROLE_SYSTEM.to_string(), + role: Some(ROLE_SYSTEM.to_string()), kind: "message".to_string(), - status: INPUT_STATUS_COMPLETED.to_string(), }; - result.push(todo_input); + + // Put it first + result.insert(0, InputMessageDto::Message(todo_input)); } } - // and adding the current user message at the end - result.push(Self { - content: vec![InputContent::text(user_prompt)], - role: ROLE_USER.to_string(), - kind: "message".to_string(), - status: INPUT_STATUS_IN_PROGRESS.to_string(), - }); - result } -} +} \ No newline at end of file diff --git a/src/infrastructure/api_clients/openai/dto/response.rs b/src/infrastructure/api_clients/openai/dto/response.rs index 0ce62fb..dcba7d3 100644 --- a/src/infrastructure/api_clients/openai/dto/response.rs +++ b/src/infrastructure/api_clients/openai/dto/response.rs @@ -1,20 +1,21 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct ResponseDTO { output: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] struct OpenAIOutputItem { #[serde(rename = "type")] kind: String, content: Option>, name: Option, arguments: Option, + call_id: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] struct OpenAIContentItem { #[serde(rename = "type")] kind: String, @@ -41,11 +42,12 @@ impl ResponseDTO { } } } else if item.kind == "function_call" { - if let (Some(name), Some(arguments)) = (item.name.as_ref(), item.arguments.as_ref()) + if let (Some(name), Some(arguments), Some(call_id)) = (item.name.as_ref(), item.arguments.as_ref(), item.call_id.as_ref()) { tool_call = Some(FunctionCall { name: name.to_string(), arguments: arguments.to_string(), + call_id: call_id.to_string(), }); } } @@ -59,4 +61,5 @@ impl ResponseDTO { pub(crate) struct FunctionCall { pub name: String, pub arguments: String, + pub call_id: String, } diff --git a/src/infrastructure/api_clients/openai/translator.rs b/src/infrastructure/api_clients/openai/translator.rs index 058fb41..7dafe99 100644 --- a/src/infrastructure/api_clients/openai/translator.rs +++ b/src/infrastructure/api_clients/openai/translator.rs @@ -3,8 +3,6 @@ use crate::infrastructure::inference::{LLMInferenceResult, ToolCall}; pub fn build_request_dto( model: &str, - system_prompt: &str, - user_prompt: &str, images: &[String], tools: &[&dyn crate::domain::tools::Tool], chain: &crate::domain::workflow::Chain, @@ -25,8 +23,6 @@ pub fn build_request_dto( let dto = RequestDTO::new( model.to_string(), - system_prompt.to_string(), - user_prompt.to_string(), tools, chain, ); @@ -54,9 +50,10 @@ pub fn build_llm_result( let tool_call = if let Some(call) = tool_call_dto { // Create ToolCall for the workflow to use - Some(ToolCall { + Some(ToolCall{ name: call.name.clone(), arguments: call.arguments.clone(), + call_id: call.call_id.clone(), }) } else { None diff --git a/src/infrastructure/cli/components/input.rs b/src/infrastructure/cli/components/input.rs index 0555498..cf6e716 100644 --- a/src/infrastructure/cli/components/input.rs +++ b/src/infrastructure/cli/components/input.rs @@ -51,4 +51,4 @@ pub fn render(frame: &mut Frame, area: Rect, state: &UiState, theme: Theme) { let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line); frame.set_cursor(x, y); } -} +} \ No newline at end of file diff --git a/src/infrastructure/cli/repl.rs b/src/infrastructure/cli/repl.rs index 7928d2e..6a1647b 100644 --- a/src/infrastructure/cli/repl.rs +++ b/src/infrastructure/cli/repl.rs @@ -159,6 +159,7 @@ fn handle_agent_event( }); state.request_status = None; state.request_progress = None; + state.last_progress_update = None; // Reset timer for new request state.file_changes = None; // Display the user's message @@ -175,7 +176,22 @@ fn handle_agent_event( summary, } => { if matches!(phase, StepPhase::Before) { - state.request_progress = Some(summary.clone()); + // Check if enough time has passed since the last progress update + let can_update = state + .last_progress_update + .map(|last| { + std::time::Instant::now() + .duration_since(last) + .as_millis() + >= crate::infrastructure::cli::state::MIN_PROGRESS_DISPLAY_MS + }) + .unwrap_or(true); // Update immediately if no previous update + + if can_update { + state.request_progress = Some(summary.clone()); + state.last_progress_update = Some(std::time::Instant::now()); + } + // Otherwise skip this update - too soon since last one } else { state.request_progress = None; } @@ -207,6 +223,7 @@ fn handle_agent_event( }); } state.request_progress = None; + state.last_progress_update = None; // Reset timer when request finishes // Update the user message color based on request status for entry in state.progress.iter_mut() { diff --git a/src/infrastructure/cli/state.rs b/src/infrastructure/cli/state.rs index c9cec88..60a5a37 100644 --- a/src/infrastructure/cli/state.rs +++ b/src/infrastructure/cli/state.rs @@ -33,8 +33,8 @@ impl AttachedImage { } } -pub const INPUT_MIN_HEIGHT: usize = 3; -pub const INPUT_MAX_HEIGHT: usize = 6; +pub const INPUT_MIN_HEIGHT: usize = 1; +pub const INPUT_MAX_HEIGHT: usize = 5; pub const PROGRESS_HISTORY_LIMIT: usize = 200; pub const MAIN_BODY_SCROLL_STEP: usize = 3; @@ -51,6 +51,7 @@ pub enum ProgressKind { } pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2); +pub const MIN_PROGRESS_DISPLAY_MS: u128 = 1000; // 1 second minimum display time #[derive(Debug, Clone)] #[allow(dead_code)] @@ -260,6 +261,7 @@ pub struct UiState { pub request_in_flight: Option, pub request_status: Option, pub request_progress: Option, + pub last_progress_update: Option, pub file_changes: Option, pub agent_mode: AgentModeType, pub todo_list: Option, @@ -300,6 +302,7 @@ impl UiState { request_in_flight: None, request_status: None, request_progress: None, + last_progress_update: None, file_changes: None, agent_mode: AgentModeType::Build, // Default to build mode todo_list: None, @@ -353,4 +356,4 @@ impl UiState { Some(crate::domain::ModelType::OpenAI) ) } -} +} \ No newline at end of file diff --git a/src/infrastructure/cli/theme.rs b/src/infrastructure/cli/theme.rs index 8249996..071c4d6 100644 --- a/src/infrastructure/cli/theme.rs +++ b/src/infrastructure/cli/theme.rs @@ -31,4 +31,4 @@ impl Theme { } pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1); -pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1); +pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1); \ No newline at end of file diff --git a/src/infrastructure/cli/views/main_view.rs b/src/infrastructure/cli/views/main_view.rs index 668c630..d178cb3 100644 --- a/src/infrastructure/cli/views/main_view.rs +++ b/src/infrastructure/cli/views/main_view.rs @@ -6,7 +6,7 @@ use crate::infrastructure::cli::helpers::{ centered_rect, cursor_position, list_state, panel_block, }; use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState}; -use crate::infrastructure::cli::theme::{Theme, PANEL_PADDING}; +use crate::infrastructure::cli::theme::{Theme, INPUT_PADDING, PANEL_PADDING}; use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect}; use ratatui::style::{Modifier, Style}; use ratatui::text::{Line, Span, Text}; @@ -28,7 +28,7 @@ pub fn render(frame: &mut Frame, state: &UiState) { input_lines, crate::infrastructure::cli::state::INPUT_MAX_HEIGHT, ); - let mut input_box_height = (input_lines + 2) as u16; + let mut input_box_height = (input_lines + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16; let indicator_height = 1u16; let attachment_indicator_height = if state.attached_images.is_empty() { @@ -45,7 +45,7 @@ pub fn render(frame: &mut Frame, state: &UiState) { let overflow = fixed_height - size.height; let reduced = input_box_height.saturating_sub(overflow); input_box_height = - reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + 2) as u16); + reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16); } let constraints = if state.attached_images.is_empty() { @@ -557,4 +557,4 @@ pub fn openai_available_filtered(state: &UiState, filter: &str) -> Vec { .filter(|name| name.to_lowercase().contains(&query)) .cloned() .collect() -} +} \ No newline at end of file diff --git a/src/infrastructure/inference/local.rs b/src/infrastructure/inference/local.rs index 0234a1b..8901486 100644 --- a/src/infrastructure/inference/local.rs +++ b/src/infrastructure/inference/local.rs @@ -89,17 +89,39 @@ impl LocalEngine { } impl InferenceEngine for LocalEngine { + // so far very fucked up implementation, even user_prompt is not properly passed + // need to be refactored using proper request builder same way as for openai inference fn generate( &self, - system_prompt: &str, - user_prompt: &str, - max_tokens: u32, _tools: &[&dyn crate::domain::tools::Tool], - _chain: &crate::domain::workflow::Chain, + chain: &crate::domain::workflow::Chain, _images: &[String], - _tracer: Option<&mut openai_agents_tracing::TracingFacade>, + mut tracer: Option<&mut openai_agents_tracing::TracingFacade>, ) -> Result { - let prompt = format!("{}\n\n{}", system_prompt, user_prompt); + let model_name = "local"; + + // Start span with model name and add request as JSON + if let Some(tracer) = &mut tracer { + tracer.start_span(model_name, openai_agents_tracing::SpanKind::Generation); + + // Set request as structured JSON input + let request_json = serde_json::json!({ + "system_prompt": chain.system_prompt.clone(), + "user_prompt": chain.get_steps_with_history()[0].context_payload.clone(), + "max_tokens": 1000, + }); + tracer.set_input_json(model_name, request_json); + + // Set model configuration + let mut model_config = std::collections::HashMap::new(); + model_config.insert("temperature".to_string(), serde_json::json!(self.params.temperature)); + model_config.insert("top_p".to_string(), serde_json::json!(self.params.top_p)); + model_config.insert("ctx_size".to_string(), serde_json::json!(self.params.ctx_size)); + model_config.insert("threads".to_string(), serde_json::json!(self.params.threads)); + tracer.set_model_config(model_name, model_config); + } + + let prompt = format!("{}\n\n{}", chain.system_prompt.clone(), chain.get_steps_with_history()[0].context_payload.clone()); let to_error = |msg: String| -> InfaError { std::io::Error::new(std::io::ErrorKind::Other, msg).into() }; @@ -145,7 +167,7 @@ impl InferenceEngine for LocalEngine { let mut output = String::new(); let mut n_cur = tokens.len(); - for _ in 0..max_tokens { + for _ in 0..2000 { let new_token = sampler.sample(&ctx, -1); if self.model.is_eog_token(new_token) { @@ -169,6 +191,15 @@ impl InferenceEngine for LocalEngine { n_cur += 1; } + // Add output as JSON and end span + if let Some(tracer) = tracer { + let response_json = serde_json::json!({ + "text": &output, + }); + tracer.set_output_json(model_name, response_json); + tracer.end_span(model_name); + } + Ok(LLMInferenceResult { summary: output.trim().to_string(), raw_output: output, diff --git a/src/infrastructure/inference/mod.rs b/src/infrastructure/inference/mod.rs index d23fd19..bf2e49e 100644 --- a/src/infrastructure/inference/mod.rs +++ b/src/infrastructure/inference/mod.rs @@ -9,6 +9,7 @@ pub mod openai; pub struct ToolCall { pub name: String, pub arguments: String, + pub call_id: String, } pub struct LLMInferenceResult { @@ -22,9 +23,6 @@ pub trait InferenceEngine: Send + Sync { /// Generate text without streaming output fn generate( &self, - system_prompt: &str, - user_prompt: &str, - max_tokens: u32, tools: &[&dyn Tool], chain: &Chain, images: &[String], diff --git a/src/infrastructure/inference/openai.rs b/src/infrastructure/inference/openai.rs index 643b78c..6761e29 100644 --- a/src/infrastructure/inference/openai.rs +++ b/src/infrastructure/inference/openai.rs @@ -66,15 +66,12 @@ impl OpenAIEngine { impl InferenceEngine for OpenAIEngine { fn generate( &self, - system_prompt: &str, - user_prompt: &str, - _max_tokens: u32, tools: &[&dyn crate::domain::tools::Tool], chain: &crate::domain::workflow::Chain, images: &[String], tracer: Option<&mut openai_agents_tracing::TracingFacade>, ) -> Result { - self.generate_with_responses_api(system_prompt, user_prompt, tools, chain, images, tracer) + self.generate_with_responses_api(tools, chain, images, tracer) } fn get_type(&self) -> ModelType { ModelType::OpenAI @@ -85,15 +82,13 @@ impl OpenAIEngine { /// Generate using the Responses API (for newer models like codex, o-series) fn generate_with_responses_api( &self, - system_prompt: &str, - user_prompt: &str, tools: &[&dyn crate::domain::tools::Tool], chain: &crate::domain::workflow::Chain, images: &[String], tracer: Option<&mut openai_agents_tracing::TracingFacade>, ) -> Result { self.responses_client - .call_responses_api(system_prompt, user_prompt, tools, chain, images, tracer) + .call_responses_api(tools, chain, images, tracer) .map_err(|e| e) } }