This commit is contained in:
taggart_comet 2026-02-03 18:20:22 +02:00
parent bd21ae6296
commit 3fd019bd7a
35 changed files with 633 additions and 343 deletions

View file

@ -82,6 +82,38 @@ impl TracingFacade {
}
}
pub fn set_model_config(&mut self, name: impl AsRef<str>, config: HashMap<String, serde_json::Value>) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.model_config = Some(config);
}
}
}
pub fn set_usage(&mut self, name: impl AsRef<str>, input_tokens: u32, output_tokens: u32) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.usage = Some(crate::types::UsageData::new(input_tokens, output_tokens));
}
}
}
pub fn set_input_json(&mut self, name: impl AsRef<str>, input: serde_json::Value) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.input = Some(vec![input]);
}
}
}
pub fn set_output_json(&mut self, name: impl AsRef<str>, output: serde_json::Value) {
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
if let SpanData::Generation(ref mut data) = span.span_data {
data.output = Some(vec![output]);
}
}
}
pub async fn end(&mut self) {
for (_, mut span) in self.open_spans.drain() {
span.mark_ended();

View file

@ -17,6 +17,7 @@ pub use session::{Session, SessionRequest};
pub use startup::StartupService;
pub use user_settings::UserSettings;
pub use workflow::{CancellationToken, Chain};
#[allow(unused_imports)]
pub use todo::{TodoList, TodoItem};
/// Model type enum matching the inference engine types
@ -41,4 +42,4 @@ impl ModelType {
_ => None,
}
}
}
}

View file

@ -401,12 +401,17 @@ mod tests {
"read_only"
}
fn parse_input(&self, _input: String) -> Option<Error> {
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
None
}
fn work(&self, _request: &dyn Request) -> ToolResult {
ToolResult::ok("read_only".to_string(), String::new(), String::new())
ToolResult::ok(
"read_only".to_string(),
String::new(),
String::new(),
String::new(),
)
}
fn parameters(&self) -> serde_json::Value {
@ -439,12 +444,17 @@ mod tests {
"write_tool"
}
fn parse_input(&self, _input: String) -> Option<Error> {
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
None
}
fn work(&self, _request: &dyn Request) -> ToolResult {
ToolResult::ok("write_tool".to_string(), String::new(), String::new())
ToolResult::ok(
"write_tool".to_string(),
String::new(),
String::new(),
String::new(),
)
}
fn parameters(&self) -> serde_json::Value {
@ -474,12 +484,17 @@ mod tests {
"command_tool"
}
fn parse_input(&self, _input: String) -> Option<Error> {
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
None
}
fn work(&self, _request: &dyn Request) -> ToolResult {
ToolResult::ok("command_tool".to_string(), String::new(), String::new())
ToolResult::ok(
"command_tool".to_string(),
String::new(),
String::new(),
String::new(),
)
}
fn parameters(&self) -> serde_json::Value {
@ -955,4 +970,4 @@ mod tests {
assert_eq!(decision, PermissionDecision::AlwaysAllow);
}
}
}

View file

@ -1,7 +1,6 @@
use super::types::{Permission, PermissionDecision, PermissionScope};
use crate::infrastructure::db::DbPool;
use rusqlite::params;
use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Error)]
@ -121,16 +120,19 @@ impl PermissionStore for SqlitePermissionStore {
) -> Result<Option<Permission>, StoreError> {
// Build query dynamically to handle NULL values properly
// In SQL, NULL != '', so we need to use IS NULL for empty strings
let mut param_num = 3;
let command_clause = if command_pattern.is_empty() {
"command_pattern IS NULL"
"command_pattern IS NULL".to_string()
} else {
"command_pattern = ?3"
let clause = format!("command_pattern = ?{}", param_num);
param_num += 1;
clause
};
let resource_clause = if resource_pattern.is_empty() {
"resource_pattern IS NULL"
"resource_pattern IS NULL".to_string()
} else {
"resource_pattern = ?4"
format!("resource_pattern = ?{}", param_num)
};
let query = format!(
@ -151,32 +153,28 @@ impl PermissionStore for SqlitePermissionStore {
)))
})?;
// Build params based on whether patterns are empty
let result = if command_pattern.is_empty() && resource_pattern.is_empty() {
conn.query_row(
&query,
params![project_id, tool],
|row| self.row_to_permission(row),
)
} else if command_pattern.is_empty() {
conn.query_row(
&query,
params![project_id, tool, resource_pattern],
|row| self.row_to_permission(row),
)
} else if resource_pattern.is_empty() {
conn.query_row(
&query,
params![project_id, tool, command_pattern],
|row| self.row_to_permission(row),
)
} else {
conn.query_row(
&query,
params![project_id, tool, command_pattern, resource_pattern],
|row| self.row_to_permission(row),
)
};
// Build params list dynamically to match the query
let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = vec![
Box::new(project_id),
Box::new(tool.to_string()),
];
if !command_pattern.is_empty() {
param_values.push(Box::new(command_pattern.to_string()));
}
if !resource_pattern.is_empty() {
param_values.push(Box::new(resource_pattern.to_string()));
}
let params_refs: Vec<&dyn rusqlite::ToSql> = param_values
.iter()
.map(|p| p.as_ref())
.collect();
let result = conn.query_row(&query, params_refs.as_slice(), |row| {
self.row_to_permission(row)
});
match result {
Ok(permission) => Ok(Some(permission)),
@ -186,48 +184,140 @@ impl PermissionStore for SqlitePermissionStore {
}
}
impl SqlitePermissionStore {
fn find_matching_command_permission(
&self,
rows: rusqlite::MappedRows<
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
>,
tool: &str,
command: &str,
) -> Result<Option<Permission>, StoreError> {
for row_result in rows {
match row_result {
Ok(permission) => {
if permission.matches(tool, Some(command), None::<&PathBuf>) {
return Ok(Some(permission));
}
}
Err(e) => return Err(StoreError::Database(e)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use r2d2_sqlite::SqliteConnectionManager;
Ok(None)
fn setup_test_db() -> DbPool {
let manager = SqliteConnectionManager::memory();
let pool = r2d2::Pool::new(manager).unwrap();
let conn = pool.get().unwrap();
// Create permissions table
conn.execute(
"CREATE TABLE IF NOT EXISTS permissions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tool_name TEXT NOT NULL,
command_pattern TEXT,
resource_pattern TEXT,
decision TEXT NOT NULL,
scope TEXT NOT NULL,
project_id INTEGER,
created_at TEXT NOT NULL
)",
[],
).unwrap();
pool
}
fn find_matching_path_permission(
&self,
rows: rusqlite::MappedRows<
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
>,
tool: &str,
path: &PathBuf,
) -> Result<Option<Permission>, StoreError> {
for row_result in rows {
match row_result {
Ok(permission) => {
if permission.matches(tool, None, Some(path)) {
return Ok(Some(permission));
}
}
Err(e) => return Err(StoreError::Database(e)),
}
}
#[test]
fn test_find_permission_with_null_patterns() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
Ok(None)
// Create permission with NULL command and resource patterns
let permission = Permission::new(
"test_tool".to_string(),
None,
None,
PermissionDecision::AlwaysAllow,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with empty strings
let result = store.find_permission("test_tool", 1, "", "").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
}
}
#[test]
fn test_find_permission_with_command_only() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with command but NULL resource
let permission = Permission::new(
"test_tool".to_string(),
Some("echo test".to_string()),
None,
PermissionDecision::AlwaysDeny,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with command and empty resource
let result = store.find_permission("test_tool", 1, "echo test", "").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
}
#[test]
fn test_find_permission_with_resource_only() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with resource but NULL command
let permission = Permission::new(
"test_tool".to_string(),
None,
Some("/path/to/file".to_string()),
PermissionDecision::AlwaysAllow,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with empty command and resource
let result = store.find_permission("test_tool", 1, "", "/path/to/file").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
}
#[test]
fn test_find_permission_with_both_patterns() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with both patterns
let permission = Permission::new(
"test_tool".to_string(),
Some("rm -rf".to_string()),
Some("/etc/passwd".to_string()),
PermissionDecision::AlwaysDeny,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should find the permission when searching with both patterns
let result = store.find_permission("test_tool", 1, "rm -rf", "/etc/passwd").unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
}
#[test]
fn test_find_permission_no_match() {
let pool = setup_test_db();
let store = SqlitePermissionStore::new(pool.clone());
// Create permission with command
let permission = Permission::new(
"test_tool".to_string(),
Some("echo test".to_string()),
None,
PermissionDecision::AlwaysAllow,
PermissionScope::Project,
Some(1),
);
store.create_permission(permission).unwrap();
// Should NOT find when searching with different command
let result = store.find_permission("test_tool", 1, "rm -rf", "").unwrap();
assert!(result.is_none());
}
}

View file

@ -29,6 +29,7 @@ pub struct Permission {
pub created_at: DateTime<Utc>,
}
#[allow(dead_code)]
impl Permission {
pub fn new(
tool_name: String,
@ -173,4 +174,4 @@ impl Default for PermissionConfig {
require_confirmation: true,
}
}
}
}

View file

@ -32,33 +32,38 @@ pub fn get_bt_tree_step_prompt(
pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String {
let (os_name, shell_name) = get_runtime_environment();
let mut system_prompt = "".to_string();
if agent_mode == AgentModeType::Plan {
system_prompt = _system_prompt_for_plan(model_type);
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
let mut system_prompt = _system_prompt_for_plan(model_type);
if remaining_calls < 3 {
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
}
return system_prompt;
}
if agent_mode == AgentModeType::BuildFromPlan {
system_prompt = _system_prompt_for_build_from_plan(model_type);
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
let mut system_prompt = _system_prompt_for_build_from_plan(model_type);
if remaining_calls < 3 {
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
}
return system_prompt;
}
if model_type == ModelType::OpenAI {
system_prompt = format!(
let mut system_prompt = if model_type == ModelType::OpenAI {
format!(
"You are Drastis, a coding agent. \n\
Use the available tools to gather context and make changes. \
When using tools, pass JSON arguments that match their parameters. \n\
Runtime: os={}, shell={}.",
os_name, shell_name
);
)
} else {
system_prompt = format!(
format!(
"You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.",
os_name, shell_name
);
)
};
if remaining_calls < 3 {
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
}
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
return system_prompt;
system_prompt
}
fn _system_prompt_for_plan(model_type: ModelType) -> String {

View file

@ -8,6 +8,7 @@ use std::path::Path;
/// without being tightly coupled to the Session entity.
pub trait Request {
/// Get the history of previous requests
#[allow(dead_code)]
fn history(&self) -> &[SessionRequest];
/// Get the current request prompt
@ -40,4 +41,4 @@ pub trait Request {
/// Get the session's TODO list (plan)
fn get_session_plan(&self) -> Option<crate::domain::todo::TodoList>;
}
}

View file

@ -4,6 +4,7 @@ use crate::repository::SessionRequestRow;
/// Domain entity representing a user request within a session.
/// Each request contains the user's prompt and the resulting summary.
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct SessionRequest {
prompt: String,
result_summary: Option<String>,
@ -32,4 +33,4 @@ impl SessionRequest {
mode: row.mode,
}
}
}
}

View file

@ -1,6 +1,7 @@
use super::Error;
use crate::domain::prompting::session_naming_prompt;
use crate::domain::workflow::Chain;
use crate::domain::workflow::ChainStep;
use crate::domain::{Project, Session, SessionRequest};
use crate::infrastructure::db::DbPool;
use crate::infrastructure::inference::InferenceEngine;
@ -83,38 +84,37 @@ impl StartupService {
let prompt_preview: String = first_prompt.chars().take(100).collect();
let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview);
let chain = Chain::new();
let session_name =
match self
.engine
.generate("", &naming_prompt, 15, &[], &chain, &[], None)
{
Ok(raw) => {
log::debug!("Raw session name response: {:?}", raw.summary);
// Clean up the response
let cleaned = raw
.summary
.lines()
.next()
.unwrap_or("")
.trim()
.trim_matches('"')
.trim_matches('*')
.replace("<|im_end|>", "")
.trim()
.to_string();
let mut chain = Chain::new();
chain
.steps
.push(ChainStep::user_message(naming_prompt.parse().unwrap(), Vec::new()));
let session_name = match self.engine.generate(&[], &chain, &[], None) {
Ok(raw) => {
log::debug!("Raw session name response: {:?}", raw.summary);
// Clean up the response
let cleaned = raw
.summary
.lines()
.next()
.unwrap_or("")
.trim()
.trim_matches('"')
.trim_matches('*')
.replace("<|im_end|>", "")
.trim()
.to_string();
if cleaned.is_empty() || cleaned.len() > 50 {
Self::fallback_session_name(first_prompt)
} else {
cleaned
}
}
Err(e) => {
log::warn!("Session naming failed: {}", e);
if cleaned.is_empty() || cleaned.len() > 50 {
Self::fallback_session_name(first_prompt)
} else {
cleaned
}
};
}
Err(e) => {
log::warn!("Session naming failed: {}", e);
Self::fallback_session_name(first_prompt)
}
};
// Create session in database
let conn = self
@ -225,4 +225,4 @@ mod tests {
assert_eq!(StartupService::fallback_session_name(""), "New Session");
}
}
}

View file

@ -15,6 +15,7 @@ pub struct DiscoverObjects {
struct DiscoverObjectsInput {
raw: String,
full_path_to_file: String,
call_id: String,
}
#[derive(Debug, Deserialize)]
@ -69,6 +70,7 @@ impl DiscoverObjects {
Ok(DiscoverObjectsInput {
raw: raw.to_string(),
full_path_to_file: parsed.full_path_to_file,
call_id: String::new(),
})
}
@ -86,12 +88,13 @@ impl Tool for DiscoverObjects {
"discover_objects"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed);
match parsed {
Ok(parsed) => {
Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed);
None
}
@ -103,17 +106,18 @@ impl Tool for DiscoverObjects {
let input = match self.load_input() {
Ok(input) => input,
Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
}
};
match Self::parse_file(&input.full_path_to_file) {
Ok((lang, objects)) => ToolResult::ok(
self.name().to_string(),
input.raw,
input.raw.clone(),
Self::format_output(lang, &objects),
input.call_id,
),
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()),
Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
}
}

View file

@ -19,6 +19,7 @@ struct FindFilesInput {
query: String,
root: Option<String>,
max_results: usize,
call_id: String,
}
#[derive(Debug, Deserialize)]
@ -132,6 +133,7 @@ impl FindFiles {
query: parsed.query,
root: parsed.root,
max_results: parsed.max_results.unwrap_or(20),
call_id: String::new(),
})
}
@ -149,12 +151,13 @@ impl Tool for FindFiles {
"find_files"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed);
match parsed {
Ok(parsed) => {
Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed);
None
}
@ -166,7 +169,7 @@ impl Tool for FindFiles {
let input = match self.load_input() {
Ok(input) => input,
Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
}
};
@ -174,15 +177,16 @@ impl Tool for FindFiles {
let search_root =
match Self::resolve_search_root(input.root.as_deref(), request.project_root()) {
Ok(root) => root,
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e),
Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
};
// Check if search root exists
if !search_root.exists() {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Search root does not exist: {}", search_root.display()),
input.call_id,
);
}
@ -204,7 +208,7 @@ impl Tool for FindFiles {
output
};
ToolResult::ok(self.name().to_string(), input.raw, output)
ToolResult::ok(self.name().to_string(), input.raw, output, input.call_id)
}
fn parameters(&self) -> serde_json::Value {

View file

@ -45,6 +45,7 @@ pub struct FileChange {
pub struct ToolResult {
tool_name: String,
call_id: String,
pub(crate) input: String,
is_successful: bool,
output: String,
@ -53,7 +54,7 @@ pub struct ToolResult {
}
impl ToolResult {
pub fn ok(tool_name: String, input: String, output: String) -> Self {
pub fn ok(tool_name: String, input: String, output: String, call_id: String) -> Self {
Self {
tool_name,
input,
@ -61,10 +62,11 @@ impl ToolResult {
output,
error_message: String::new(),
file_changes: None,
call_id
}
}
pub fn error(tool_name: String, input: String, message: String) -> Self {
pub fn error(tool_name: String, input: String, message: String, call_id: String) -> Self {
Self {
tool_name,
input,
@ -72,6 +74,7 @@ impl ToolResult {
output: String::new(),
error_message: message.into(),
file_changes: None,
call_id
}
}
@ -128,11 +131,15 @@ impl ToolResult {
pub fn file_changes(&self) -> Option<&[FileChange]> {
self.file_changes.as_deref()
}
pub(crate) fn call_id(&self) -> String {
self.call_id.clone()
}
}
pub trait Tool: Send + Sync {
fn name(&self) -> &'static str;
fn parse_input(&self, input: String) -> Option<Error>;
fn parse_input(&self, input: String, call_id: String) -> Option<Error>;
fn work(&self, request: &dyn Request) -> ToolResult;
fn parameters(&self) -> Value;
fn desc(&self) -> String;

View file

@ -18,6 +18,7 @@ pub struct PatchFiles {
struct PatchFilesInput {
raw: String,
patch: String,
call_id: String,
}
#[derive(Debug, Deserialize)]
@ -41,6 +42,7 @@ impl PatchFiles {
Ok(PatchFilesInput {
raw: raw.to_string(),
patch: parsed.patch,
call_id: String::new(),
})
}
@ -58,12 +60,13 @@ impl Tool for PatchFiles {
"patch_files"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed);
match parsed {
Ok(parsed) => {
Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed);
None
}
@ -75,7 +78,7 @@ impl Tool for PatchFiles {
let input = match self.load_input() {
Ok(input) => input,
Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
}
};
@ -84,8 +87,9 @@ impl Tool for PatchFiles {
Err(e) => {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Failed to parse patch: {}", e),
input.call_id.clone(),
)
}
};
@ -115,8 +119,9 @@ impl Tool for PatchFiles {
{
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Invalid path outside project root: {}", path),
input.call_id.clone(),
);
}
let fs_path = project_root.join(rel_path);
@ -129,8 +134,9 @@ impl Tool for PatchFiles {
Err(e) => {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Failed to read file '{}': {}", fs_path.display(), e),
input.call_id.clone(),
)
}
}
@ -142,8 +148,9 @@ impl Tool for PatchFiles {
Err(e) => {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Failed to apply patch: {}", e),
input.call_id.clone(),
)
}
};
@ -155,28 +162,31 @@ impl Tool for PatchFiles {
if let Err(e) = fs::create_dir_all(parent) {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!(
"Failed to create parent directories for '{}': {}",
fs_path.display(),
e
),
input.call_id.clone(),
);
}
}
if let Err(e) = fs::write(&fs_path, content) {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Failed to write file '{}': {}", fs_path.display(), e),
input.call_id.clone(),
);
}
} else if fs_path.exists() {
if let Err(e) = fs::remove_file(&fs_path) {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Failed to delete file '{}': {}", fs_path.display(), e),
input.call_id.clone(),
);
}
}
@ -224,6 +234,7 @@ impl Tool for PatchFiles {
self.name().to_string(),
input.raw,
"Patch applied successfully".to_string(),
input.call_id,
);
if !file_changes.is_empty() {
@ -531,7 +542,7 @@ mod tests {
"{{\"patch\":\"{}\"}}",
patch.replace('\n', "\\n").replace('"', "\\\"")
);
assert!(tool.parse_input(input).is_none());
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
let result = tool.work(&request);
assert!(
@ -571,7 +582,7 @@ mod tests {
"{{\"patch\":\"{}\"}}",
patch.replace('\n', "\\n").replace('"', "\\\"")
);
assert!(tool.parse_input(input).is_none());
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
let result = tool.work(&request);
assert!(
@ -592,7 +603,7 @@ mod tests {
let tool = PatchFiles::new();
let input = "{\"patch\":\"not a patch\"}".to_string();
assert!(tool.parse_input(input).is_none());
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
let result = tool.work(&request);
assert!(
@ -601,4 +612,4 @@ mod tests {
result.output_string()
);
}
}
}

View file

@ -18,6 +18,7 @@ struct ReadObjectsInput {
raw: String,
full_path_to_file: String,
queries: Vec<ObjectQuery>,
call_id: String,
}
#[derive(Debug, Deserialize)]
@ -138,6 +139,7 @@ impl ReadObjects {
raw: raw.to_string(),
full_path_to_file: trimmed_path.to_string(),
queries,
call_id: String::new(),
})
}
@ -182,12 +184,13 @@ impl Tool for ReadObjects {
"read_objects"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed);
match parsed {
Ok(parsed) => {
Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed);
None
}
@ -199,17 +202,18 @@ impl Tool for ReadObjects {
let input = match self.load_input() {
Ok(input) => input,
Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
}
};
match Self::read_objects(&input.full_path_to_file, &input.queries) {
Ok((lang, results)) => ToolResult::ok(
self.name().to_string(),
input.raw,
input.raw.clone(),
Self::format_output(lang, results),
input.call_id,
),
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()),
Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
}
}
@ -224,7 +228,7 @@ impl Tool for ReadObjects {
},
"query": {
"type": "string",
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\")",
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\"), grepping is not working here, use exact names",
"minLength": 1
}
},
@ -235,8 +239,8 @@ impl Tool for ReadObjects {
fn desc(&self) -> String {
format!(
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties to use for `{}`, use the `discover_objects` tool first.",
self.name(), self.name()
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties, use the `discover_objects` tool first.",
self.name(),
)
}
@ -293,7 +297,7 @@ mod tests {
fn test_parse_input_valid() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_none(), "Expected no error, got: {:?}", result);
// Verify the parsed input
@ -308,7 +312,7 @@ mod tests {
fn test_parse_input_empty_query() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": ""}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for empty query");
if let Some(Error::Parse(msg)) = result {
@ -326,7 +330,7 @@ mod tests {
fn test_parse_input_whitespace_query() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": " "}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for whitespace query");
if let Some(Error::Parse(msg)) = result {
@ -344,7 +348,7 @@ mod tests {
fn test_parse_input_comma_only_query() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": ",,,"}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for comma-only query");
if let Some(Error::Parse(msg)) = result {
@ -362,7 +366,7 @@ mod tests {
fn test_parse_input_missing_path() {
let tool = ReadObjects::new();
let input = r#"{"path": "", "query": "main"}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for empty path");
if let Some(Error::Parse(msg)) = result {
@ -407,7 +411,7 @@ mod tests {
fn test_parse_input_malformed_json() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(result.is_some(), "Expected error for malformed JSON");
}
@ -415,7 +419,7 @@ mod tests {
fn test_parse_input_path_with_whitespace() {
let tool = ReadObjects::new();
let input = r#"{"path": " src/main.rs ", "query": "main"}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(
result.is_none(),
"Expected no error for path with whitespace"
@ -432,7 +436,7 @@ mod tests {
fn test_parse_input_query_with_extra_whitespace() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(
result.is_none(),
"Expected no error for query with extra whitespace"
@ -448,7 +452,7 @@ mod tests {
fn test_parse_input_single_object() {
let tool = ReadObjects::new();
let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#;
let result = tool.parse_input(input.to_string());
let result = tool.parse_input(input.to_string(), "call-id".to_string());
assert!(
result.is_none(),
"Expected no error for single object query"
@ -481,4 +485,4 @@ mod tests {
};
assert!(query.matches(&obj_partial), "Should match partial name");
}
}
}

View file

@ -25,6 +25,7 @@ struct ShellExecInputParsed {
raw: String,
command: String,
working_dir: Option<String>,
call_id: String,
}
impl Tool for ShellExec {
@ -32,7 +33,7 @@ impl Tool for ShellExec {
"shell_exec"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim();
let parsed = serde_json::from_str::<ShellExecInput>(trimmed)
.map_err(|e| Error::Parse(e.to_string()));
@ -46,6 +47,7 @@ impl Tool for ShellExec {
raw: trimmed.to_string(),
command: parsed.command,
working_dir: parsed.working_dir,
call_id,
});
None
}
@ -61,6 +63,7 @@ impl Tool for ShellExec {
self.name().to_string(),
String::new(),
"input not parsed".to_string(),
String::new(),
)
}
};
@ -72,15 +75,17 @@ impl Tool for ShellExec {
if !path.exists() {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Working directory does not exist: {}", dir),
input.call_id.clone(),
);
}
if !crate::utils::paths::is_within_root(path, request.project_root()) {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
"Working directory is outside project root".to_string(),
input.call_id.clone(),
);
}
path.to_path_buf()
@ -99,8 +104,9 @@ impl Tool for ShellExec {
Err(e) => {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Failed to execute command: {}", e),
input.call_id.clone(),
)
}
};
@ -129,10 +135,10 @@ impl Tool for ShellExec {
if !stderr.is_empty() {
result.push_str(&format!("[stderr]: {}", stderr));
}
return ToolResult::error(self.name().to_string(), input.raw, result);
return ToolResult::error(self.name().to_string(), input.raw.clone(), result, input.call_id.clone());
};
ToolResult::ok(self.name().to_string(), input.raw, result)
ToolResult::ok(self.name().to_string(), input.raw, result, input.call_id)
}
fn parameters(&self) -> serde_json::Value {
@ -155,8 +161,9 @@ impl Tool for ShellExec {
fn desc(&self) -> String {
format!(
r#"Use `{}` tool to execute shell commands.
Please DO NOT use it to read the full content of a file, this is not efficient, use `read_objects` tool for this."#,
"Use `{}` tool to execute shell commands.
DO NOT use it to read code, this is not efficient, use `read_objects` tool for this. \n\
DO NOT use it to change files, use `patch_files` tool for this.",
self.name()
)
}
@ -339,7 +346,7 @@ mod tests {
let tool = ShellExec::new();
assert!(tool
.parse_input(r#"{"command":"echo hello"}"#.to_string())
.parse_input(r#"{"command":"echo hello"}"#.to_string(), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -357,7 +364,7 @@ mod tests {
let tool = ShellExec::new();
assert!(tool
.parse_input(r#"{"command":"pwd"}"#.to_string())
.parse_input(r#"{"command":"pwd"}"#.to_string(), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -384,7 +391,7 @@ mod tests {
.parse_input(format!(
r#"{{"command":"pwd","working_dir":"{}"}}"#,
subdir.display()
))
), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -405,7 +412,7 @@ mod tests {
let tool = ShellExec::new();
assert!(tool
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string())
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string(), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -424,7 +431,7 @@ mod tests {
let tool = ShellExec::new();
assert!(tool
.parse_input(r#"{"command":"exit 1"}"#.to_string())
.parse_input(r#"{"command":"exit 1"}"#.to_string(), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -443,7 +450,7 @@ mod tests {
let tool = ShellExec::new();
assert!(tool
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string())
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string(), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -457,7 +464,7 @@ mod tests {
let request = TestRequest::new(temp.path());
let tool = ShellExec::new();
let err = tool.parse_input(r#"{"command":""}"#.to_string());
let err = tool.parse_input(r#"{"command":""}"#.to_string(), "call-id".to_string());
assert!(err.is_some());
let result = tool.work(&request);
assert!(result.output_string().contains("Error"));
@ -474,7 +481,7 @@ mod tests {
.parse_input(format!(
r#"{{"command":"echo 'test content' > {}"}}"#,
file_path.display()
))
), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -496,7 +503,7 @@ mod tests {
let tool = ShellExec::new();
assert!(tool
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string())
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string(), "call-id".to_string())
.is_none());
let result = tool.work(&request);
@ -506,4 +513,4 @@ mod tests {
result.output_string()
);
}
}
}

View file

@ -16,6 +16,7 @@ struct StructureInput {
raw: String,
path: String,
max_depth: usize,
call_id: String,
}
#[derive(Debug, Deserialize)]
@ -127,6 +128,7 @@ impl Structure {
raw: raw.to_string(),
path: parsed.path.unwrap_or_else(|| ".".to_string()),
max_depth: parsed.max_depth.unwrap_or(3),
call_id: String::new(),
})
}
@ -144,12 +146,13 @@ impl Tool for Structure {
"structure"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let trimmed = input.trim();
let parsed = Self::parse_input_json(trimmed);
match parsed {
Ok(parsed) => {
Ok(mut parsed) => {
parsed.call_id = call_id;
*self.input.lock().unwrap() = Some(parsed);
None
}
@ -161,34 +164,36 @@ impl Tool for Structure {
let input = match self.load_input() {
Ok(input) => input,
Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
}
};
let target_path = match Self::resolve_path(&input.path, request.project_root()) {
Ok(p) => p,
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e),
Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
};
if !target_path.exists() {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Path does not exist: {}", target_path.display()),
input.call_id.clone(),
);
}
if !target_path.is_dir() {
return ToolResult::error(
self.name().to_string(),
input.raw,
input.raw.clone(),
format!("Path is not a directory: {}", target_path.display()),
input.call_id,
);
}
let tree = Self::build_tree(&target_path, input.max_depth);
ToolResult::ok(self.name().to_string(), input.raw, tree)
ToolResult::ok(self.name().to_string(), input.raw, tree, input.call_id)
}
fn parameters(&self) -> serde_json::Value {

View file

@ -18,6 +18,7 @@ pub struct UpdateTodoList {
#[derive(Debug, Clone)]
struct UpdateTodoListInput {
raw: String,
call_id: String,
}
impl UpdateTodoList {
@ -36,12 +37,12 @@ impl Tool for UpdateTodoList {
"update_todo_list"
}
fn parse_input(&self, input: String) -> Option<Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
let parsed: Result<TodoList, _> = serde_json::from_str(&input);
match parsed {
Ok(_) => {
let mut lock = self.input.lock().unwrap();
*lock = Some(UpdateTodoListInput { raw: input });
*lock = Some(UpdateTodoListInput { raw: input, call_id });
None
}
Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))),
@ -57,6 +58,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(),
String::new(),
"No input provided".to_string(),
String::new(),
)
}
};
@ -70,6 +72,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(),
input.raw.clone(),
format!("Failed to get database connection: {}", e),
input.call_id.clone(),
)
}
};
@ -83,6 +86,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(),
input.raw.clone(),
format!("Failed to get/create TODO list: {}", e),
input.call_id.clone(),
)
}
};
@ -92,6 +96,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(),
input.raw.clone(),
format!("Failed to update TODO list: {}", e),
input.call_id.clone(),
);
}
@ -106,6 +111,7 @@ impl Tool for UpdateTodoList {
self.name().to_string(),
input.raw,
"TODO list updated successfully.".to_string(),
input.call_id,
)
}

View file

@ -16,6 +16,7 @@ struct WebSearchInput {
raw: String,
query: String,
max_results: u32,
call_id: String,
}
#[derive(Debug, Deserialize)]
@ -94,7 +95,7 @@ impl Tool for WebSearch {
"web_search"
}
fn parse_input(&self, input: String) -> Option<crate::domain::tools::Error> {
fn parse_input(&self, input: String, call_id: String) -> Option<crate::domain::tools::Error> {
let trimmed = input.trim();
match serde_json::from_str::<WebSearchInputJson>(trimmed) {
Ok(parsed) => {
@ -103,6 +104,7 @@ impl Tool for WebSearch {
raw: input,
query: parsed.query,
max_results,
call_id,
};
*self.input.lock().unwrap() = Some(parsed_input);
None
@ -118,7 +120,7 @@ impl Tool for WebSearch {
let input = match self.load_input() {
Ok(input) => input,
Err(e) => {
return ToolResult::error(self.name().to_string(), String::new(), e);
return ToolResult::error(self.name().to_string(), String::new(), e, String::new());
}
};
@ -130,6 +132,7 @@ impl Tool for WebSearch {
self.name().to_string(),
input.raw.clone(),
"User settings not available".to_string(),
input.call_id.clone(),
);
}
};
@ -139,6 +142,7 @@ impl Tool for WebSearch {
self.name().to_string(),
input.raw.clone(),
"Web search is not enabled in settings".to_string(),
input.call_id.clone(),
);
}
@ -149,6 +153,7 @@ impl Tool for WebSearch {
self.name().to_string(),
input.raw.clone(),
"Brave API key not configured".to_string(),
input.call_id.clone(),
);
}
};
@ -159,12 +164,13 @@ impl Tool for WebSearch {
let output = WebSearchOutput { results };
match serde_json::to_string(&output) {
Ok(json_output) => {
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output)
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output, input.call_id)
}
Err(e) => ToolResult::error(
self.name().to_string(),
input.raw.clone(),
format!("Failed to serialize output: {}", e),
input.call_id.clone(),
),
}
}
@ -172,6 +178,7 @@ impl Tool for WebSearch {
self.name().to_string(),
input.raw.clone(),
format!("Search failed: {}", e),
input.call_id,
),
}
}
@ -593,8 +600,10 @@ mod tests {
user_settings,
};
let web_search = WebSearch::new();
let _ =
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string());
let _ = web_search.parse_input(
r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
"call-id".to_string(),
);
let _allowed = checker.check(&web_search, &request, Some(1)).unwrap();
@ -629,8 +638,10 @@ mod tests {
user_settings,
};
let web_search = WebSearch::new();
let _ =
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string());
let _ = web_search.parse_input(
r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
"call-id".to_string(),
);
let allowed = checker.check(&web_search, &request, Some(1)).unwrap();
@ -640,4 +651,4 @@ mod tests {
assert_eq!(created[0].tool_name, "web_search");
assert_eq!(created[0].resource_pattern, None);
}
}
}

View file

@ -17,6 +17,7 @@ pub struct Chain {
pub fail_reason: String,
#[serde(default)]
pub final_message: Option<String>,
pub system_prompt: String,
}
impl Chain {
@ -28,6 +29,7 @@ impl Chain {
is_failed: false,
fail_reason: String::new(),
final_message: None,
system_prompt: String::new(),
}
}
@ -120,6 +122,10 @@ impl Chain {
.collect::<Vec<_>>()
.join("\n")
}
pub fn set_system_prompt(&mut self, system_prompt: String) {
self.system_prompt = system_prompt;
}
#[allow(dead_code)]
pub fn total_payload_len_chars(&self) -> usize {

View file

@ -39,6 +39,8 @@ pub struct ChainStep {
#[serde(default)]
pub tool_name: Option<String>,
#[serde(default)]
pub call_id: Option<String>,
#[serde(default)]
pub tool_output: Option<String>,
#[serde(default)]
pub is_successful: Option<bool>,
@ -60,6 +62,7 @@ impl ChainStep {
let mut tool_output = None;
let mut is_successful = None;
let mut file_changes = None;
let mut call_id = None;
if let Some(tr) = tool_result {
summary = tr.summary();
context_payload = tr.output_string();
@ -68,6 +71,7 @@ impl ChainStep {
tool_output = Some(tr.output_raw().to_string());
is_successful = Some(tr.is_successful());
file_changes = tr.file_changes().map(|fc| fc.to_vec());
call_id = Some(tr.call_id());
}
Self {
@ -76,6 +80,7 @@ impl ChainStep {
context_payload,
input_payload,
tool_name,
call_id,
tool_output,
is_successful,
file_changes,
@ -91,6 +96,7 @@ impl ChainStep {
context_payload: raw_output.clone(),
input_payload: String::new(),
tool_name: None,
call_id: None,
tool_output: Some(raw_output),
is_successful: Some(true),
file_changes: None,
@ -118,6 +124,7 @@ impl ChainStep {
context_payload: prompt.clone(),
input_payload: prompt,
tool_name: None,
call_id: None,
tool_output: None,
is_successful: Some(true),
file_changes: None,

View file

@ -41,6 +41,7 @@ impl ToolRunner {
tool_call.name.clone(),
tool_call.arguments.clone(),
error_msg,
tool_call.call_id.clone(),
);
}
};
@ -73,11 +74,13 @@ impl ToolRunner {
tool.name().to_string(),
String::new(),
"Permission denied".to_string(),
String::new(),
),
Err(err) => ToolResult::error(
tool.name().to_string(),
String::new(),
format!("Permission check error: {}", err),
String::new(),
),
};

View file

@ -91,7 +91,7 @@ pub trait Toolset {
.map(|t| t.as_ref())
.ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?;
if let Some(err) = tool.parse_input(tool_call.arguments.clone()) {
if let Some(err) = tool.parse_input(tool_call.arguments.clone(), tool_call.call_id.clone()) {
return Err(err);
}

View file

@ -5,6 +5,7 @@ use super::{
use crate::domain::bt::GeneralTree;
use crate::domain::prompting;
use crate::domain::session::Request;
use crate::domain::todo::TodoListStatus;
use crate::infrastructure::db::DbPool;
use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase};
use crate::infrastructure::inference::InferenceEngine;
@ -16,7 +17,6 @@ use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use crossbeam_channel::Sender;
use crate::domain::AgentModeType;
use crate::domain::todo::TodoListStatus;
/// Main workflow orchestrator that runs LLM-driven coding tasks
/// Implements an eternal agent loop that:
@ -94,7 +94,6 @@ impl Workflow {
self.chain.set_todo_list(request.get_session_plan());
let result = self._run(request, cancel, None, None, mode);
self.chain.steps.push(ChainStep::user_message(request.current_request().to_string(), request.images().to_vec()));
self._end_tracing();
result
}
@ -145,6 +144,7 @@ impl Workflow {
context_payload: String::new(),
input_payload: step_user_prompt.to_string(),
tool_name: None,
call_id: None,
tool_output: None,
is_successful: Some(true),
file_changes: None,
@ -171,6 +171,13 @@ impl Workflow {
mode: AgentModeType,
) -> Result<(), Error> {
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
let user_prompt = user_prompt_override
.as_deref()
.unwrap_or(&base_user_prompt)
.to_string();
self.chain.steps.push(ChainStep::user_message(user_prompt.clone(), request.images().to_vec()));
// Get max_tool_calls from override (for BT mode) or user settings
let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| {
request.user_settings()
@ -198,11 +205,7 @@ impl Workflow {
// Get base system prompt and inject remaining count
let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls);
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
let user_prompt = user_prompt_override
.as_deref()
.unwrap_or(&base_user_prompt)
.to_string();
self.chain.set_system_prompt(system_prompt.clone());
// Switch to finishing toolset if approaching limit
if !in_finishing_mode && tool_call_count >= finishing_threshold {
@ -220,13 +223,9 @@ impl Workflow {
}
self._emit_inference_progress();
self._trace_llm_start(user_prompt.clone());
// Ask LLM to choose next tool
let llm_output = match self.engine.generate(
&system_prompt,
&user_prompt,
1024,
&self.toolset.tool_refs(),
&self.chain,
request.images(),
@ -238,8 +237,6 @@ impl Workflow {
}
};
self._trace_llm_end(llm_output.raw_output.clone());
// Always capture the assistant's response in the chain
if !llm_output.raw_output.is_empty() {
self.chain.steps.push(ChainStep::assistant_response(
@ -347,13 +344,6 @@ impl Workflow {
self.chain = Chain::new();
}
fn _trace_llm_start(&mut self, prompt: String) {
if let Some(tracer) = &mut self.tracer {
tracer.start_span("LLM generation", SpanKind::Generation);
tracer.add_input("LLM generation", prompt);
}
}
fn _emit_inference_progress(&self) {
let options = [
"Thinking.. well kinda..",
@ -383,18 +373,10 @@ impl Workflow {
});
}
fn _trace_llm_end(&mut self, output: String) {
if let Some(tracer) = &mut self.tracer {
tracer.add_output("LLM generation", output);
tracer.end_span("LLM generation");
}
}
/// Detects if the active TODO item has changed
/// Returns true if the first non-completed item's title is different
fn _did_todo_item_change(&self, previous_todo: &Option<crate::domain::todo::TodoList>) -> bool {
use crate::domain::todo::TodoListStatus;
// Get first non-completed item from previous TODO list
let prev_active = previous_todo.as_ref().and_then(|list| {
list.items.iter()

View file

@ -64,8 +64,6 @@ impl OpenAIClient {
pub fn call_responses_api(
&self,
system_prompt: &str,
user_prompt: &str,
tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain,
images: &[String],
@ -74,8 +72,6 @@ impl OpenAIClient {
let max_attempts = 3;
for attempt in 1..=max_attempts {
match self.call_responses_api_inner(
system_prompt,
user_prompt,
tools,
chain,
images,
@ -121,25 +117,31 @@ impl OpenAIClient {
fn call_responses_api_inner(
&self,
system_prompt: &str,
user_prompt: &str,
tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain,
images: &[String],
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, Box<dyn Error + Send + Sync>> {
let url = "https://api.openai.com/v1/responses";
let request_body = build_request_dto(
&self.model,
system_prompt,
user_prompt,
images,
tools,
chain,
tracer,
tracer.as_deref_mut(),
);
// Start span with model name and add request as JSON
if let Some(tracer) = &mut tracer {
tracer.start_span(&self.model, openai_agents_tracing::SpanKind::Generation);
// Convert request_body to JSON Value and set as input
if let Ok(request_json) = serde_json::to_value(&request_body) {
tracer.set_input_json(&self.model, request_json);
}
}
let response = self
.client
.post(url)
@ -151,18 +153,35 @@ impl OpenAIClient {
let body = response.text()?;
if !status.is_success() {
if let Some(t) = tracer {
t.end_span(&self.model);
}
return Err(Box::new(OpenAIClientError::Api { status, body }));
}
let dto = match serde_json::from_str::<ResponseDTO>(&body) {
Ok(v) => v,
Err(e) => return Err(Box::new(OpenAIClientError::Deserialize { source: e, body })),
Err(e) => {
if let Some(t) = tracer {
t.end_span(&self.model);
}
return Err(Box::new(OpenAIClientError::Deserialize { source: e, body }));
}
};
// Add response as JSON and end span
if let Some(tracer) = &mut tracer {
if let Ok(response_json) = serde_json::to_value(&dto) {
tracer.set_output_json(&self.model, response_json);
}
tracer.end_span(&self.model);
}
let result = build_llm_result(dto, tools);
if result.summary.is_empty() && result.tool_call.is_none() {
return Err(Box::new(OpenAIClientError::NoText { body }));
}
Ok(result)
}
}

View file

@ -3,13 +3,13 @@ use crate::domain::{Chain, ModelType};
use crate::domain::prompting::format_todo_list_message;
use serde::Serialize;
use serde_json::Value;
use crate::domain::workflow::step::StepType::{AssistantResponse, ToolCall, UserMessage};
use crate::domain::workflow::step::StepType;
#[derive(Debug, Serialize)]
pub struct RequestDTO {
model: String,
instructions: String,
input: Vec<InputDto>,
input: Vec<InputMessageDto>,
tools: Vec<ToolDto>,
tool_choice: String,
parallel_tool_calls: bool,
@ -18,12 +18,37 @@ pub struct RequestDTO {
}
#[derive(Debug, Serialize)]
pub(super) struct InputDto {
pub(super) struct MessageDto {
content: Vec<InputContent>,
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(rename = "type")]
kind: String,
status: String,
}
#[derive(Debug, Serialize)]
pub(super) struct FunctionOutputDto {
output: String,
#[serde(rename = "type")]
kind: String,
call_id: String,
}
#[derive(Debug, Serialize)]
pub(super) struct FunctionCallDto {
arguments: String,
name: String,
#[serde(rename = "type")]
kind: String,
call_id: String,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum InputMessageDto {
Message(MessageDto),
FunctionOutput(FunctionOutputDto),
FunctionCall(FunctionCallDto),
}
#[derive(Debug, Serialize)]
@ -63,10 +88,25 @@ impl InputContent {
}
}
pub fn function(text: String) -> Self {
Self::Text {
kind: "function".to_string(),
text,
}
impl FunctionOutputDto {
pub fn new(output: String, call_id: String) -> Self {
Self {
kind: "function_call_output".to_string(),
output,
call_id,
}
}
}
impl FunctionCallDto {
pub fn new(name: String, arguments: String, call_id: String) -> Self {
Self {
name,
kind: "function_call".to_string(),
arguments,
call_id,
}
}
}
@ -83,28 +123,23 @@ pub(super) struct ToolDto {
const ROLE_USER: &str = "user";
const ROLE_SYSTEM: &str = "system";
const ROLE_ASSISTANT: &str = "assistant";
const INPUT_STATUS_IN_PROGRESS: &str = "in_progress";
const INPUT_STATUS_COMPLETED: &str = "completed";
const INPUT_STATUS_FAILED: &str = "failed";
impl RequestDTO {
pub(crate) fn new(
model: String,
system_prompt: String,
user_prompt: String,
tools: &[&dyn Tool],
chain: &crate::domain::workflow::Chain,
chain: &Chain,
) -> Self {
// User request is now part of the chain, no need to add separately
let input = InputDto::build(user_prompt, chain);
let input = MessageDto::build(chain);
Self {
model,
instructions: system_prompt,
instructions: chain.system_prompt.clone(),
input,
tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(),
tool_choice: "auto".to_string(),
parallel_tool_calls: true,
parallel_tool_calls: false,
store: false,
stream: false,
}
@ -123,50 +158,46 @@ impl ToolDto {
}
}
impl InputDto {
fn build(user_prompt: String, chain: &Chain) -> Vec<Self> {
impl MessageDto {
fn build(chain: &Chain) -> Vec<InputMessageDto> {
let steps = chain.get_steps_with_history();
let mut result: Vec<Self> = steps
.iter()
.enumerate()
.map(|(idx, step)| {
// Determine status
let is_user_message = step.step_type == UserMessage.as_str();
let mut result: Vec<InputMessageDto> = Vec::new();
let status = if is_user_message || step.is_successful.unwrap_or(false) {
INPUT_STATUS_COMPLETED
} else {
INPUT_STATUS_FAILED
};
let mut role = ROLE_ASSISTANT.to_string();
for step in steps.iter() {
let is_user_message = step.step_type == StepType::UserMessage.as_str();
// Build content items
let content_items = if is_user_message {
role = ROLE_USER.to_string();
// For user messages, include text and images
if is_user_message {
// User message: text + optional images
let mut items = vec![InputContent::text(step.input_payload.clone())];
// Add image content items if present
if let Some(ref images) = step.images {
for image_url in images {
items.push(InputContent::image(image_url.clone()));
}
}
items
} else {
vec![InputContent::output_text(step.get_output(ModelType::OpenAI))]
};
result.push(InputMessageDto::Message(Self {
content: items,
role: Some(ROLE_USER.to_string()),
kind: "message".to_string(),
}));
} else if step.step_type == StepType::ToolCall.as_str() {
Self {
content: content_items,
role,
kind: "message".to_string(),
status: status.to_string(),
let tool_name = step.tool_name.clone().unwrap();
let call_id = step.call_id.clone().unwrap();
// Tool call output is a separate DTO type
result.push(InputMessageDto::FunctionCall(FunctionCallDto::new(tool_name, step.input_payload.clone(), call_id.clone())));
result.push(InputMessageDto::FunctionOutput(FunctionOutputDto::new(step.get_output(ModelType::OpenAI), call_id)));
} else {
// Assistant message
result.push(InputMessageDto::Message(Self {
content: vec![InputContent::output_text(step.get_output(ModelType::OpenAI))],
role: Some(ROLE_ASSISTANT.to_string()),
kind: "message".to_string(),
}));
}
})
.collect();
}
// Add the plan as system message at the beginning if it exists and is not completed
if let Some(ref todo_list) = chain.todo_list {
@ -178,22 +209,15 @@ impl InputDto {
let todo_input = Self {
content: vec![InputContent::text(todo_message)],
role: ROLE_SYSTEM.to_string(),
role: Some(ROLE_SYSTEM.to_string()),
kind: "message".to_string(),
status: INPUT_STATUS_COMPLETED.to_string(),
};
result.push(todo_input);
// Put it first
result.insert(0, InputMessageDto::Message(todo_input));
}
}
// and adding the current user message at the end
result.push(Self {
content: vec![InputContent::text(user_prompt)],
role: ROLE_USER.to_string(),
kind: "message".to_string(),
status: INPUT_STATUS_IN_PROGRESS.to_string(),
});
result
}
}
}

View file

@ -1,20 +1,21 @@
use serde::Deserialize;
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct ResponseDTO {
output: Vec<OpenAIOutputItem>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
struct OpenAIOutputItem {
#[serde(rename = "type")]
kind: String,
content: Option<Vec<OpenAIContentItem>>,
name: Option<String>,
arguments: Option<String>,
call_id: Option<String>,
}
#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
struct OpenAIContentItem {
#[serde(rename = "type")]
kind: String,
@ -41,11 +42,12 @@ impl ResponseDTO {
}
}
} else if item.kind == "function_call" {
if let (Some(name), Some(arguments)) = (item.name.as_ref(), item.arguments.as_ref())
if let (Some(name), Some(arguments), Some(call_id)) = (item.name.as_ref(), item.arguments.as_ref(), item.call_id.as_ref())
{
tool_call = Some(FunctionCall {
name: name.to_string(),
arguments: arguments.to_string(),
call_id: call_id.to_string(),
});
}
}
@ -59,4 +61,5 @@ impl ResponseDTO {
pub(crate) struct FunctionCall {
pub name: String,
pub arguments: String,
pub call_id: String,
}

View file

@ -3,8 +3,6 @@ use crate::infrastructure::inference::{LLMInferenceResult, ToolCall};
pub fn build_request_dto(
model: &str,
system_prompt: &str,
user_prompt: &str,
images: &[String],
tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain,
@ -25,8 +23,6 @@ pub fn build_request_dto(
let dto = RequestDTO::new(
model.to_string(),
system_prompt.to_string(),
user_prompt.to_string(),
tools,
chain,
);
@ -54,9 +50,10 @@ pub fn build_llm_result(
let tool_call = if let Some(call) = tool_call_dto {
// Create ToolCall for the workflow to use
Some(ToolCall {
Some(ToolCall{
name: call.name.clone(),
arguments: call.arguments.clone(),
call_id: call.call_id.clone(),
})
} else {
None

View file

@ -51,4 +51,4 @@ pub fn render(frame: &mut Frame, area: Rect, state: &UiState, theme: Theme) {
let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line);
frame.set_cursor(x, y);
}
}
}

View file

@ -159,6 +159,7 @@ fn handle_agent_event(
});
state.request_status = None;
state.request_progress = None;
state.last_progress_update = None; // Reset timer for new request
state.file_changes = None;
// Display the user's message
@ -175,7 +176,22 @@ fn handle_agent_event(
summary,
} => {
if matches!(phase, StepPhase::Before) {
state.request_progress = Some(summary.clone());
// Check if enough time has passed since the last progress update
let can_update = state
.last_progress_update
.map(|last| {
std::time::Instant::now()
.duration_since(last)
.as_millis()
>= crate::infrastructure::cli::state::MIN_PROGRESS_DISPLAY_MS
})
.unwrap_or(true); // Update immediately if no previous update
if can_update {
state.request_progress = Some(summary.clone());
state.last_progress_update = Some(std::time::Instant::now());
}
// Otherwise skip this update - too soon since last one
} else {
state.request_progress = None;
}
@ -207,6 +223,7 @@ fn handle_agent_event(
});
}
state.request_progress = None;
state.last_progress_update = None; // Reset timer when request finishes
// Update the user message color based on request status
for entry in state.progress.iter_mut() {

View file

@ -33,8 +33,8 @@ impl AttachedImage {
}
}
pub const INPUT_MIN_HEIGHT: usize = 3;
pub const INPUT_MAX_HEIGHT: usize = 6;
pub const INPUT_MIN_HEIGHT: usize = 1;
pub const INPUT_MAX_HEIGHT: usize = 5;
pub const PROGRESS_HISTORY_LIMIT: usize = 200;
pub const MAIN_BODY_SCROLL_STEP: usize = 3;
@ -51,6 +51,7 @@ pub enum ProgressKind {
}
pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2);
pub const MIN_PROGRESS_DISPLAY_MS: u128 = 1000; // 1 second minimum display time
#[derive(Debug, Clone)]
#[allow(dead_code)]
@ -260,6 +261,7 @@ pub struct UiState {
pub request_in_flight: Option<RequestIndicator>,
pub request_status: Option<RequestStatusDisplay>,
pub request_progress: Option<String>,
pub last_progress_update: Option<Instant>,
pub file_changes: Option<FileChangesDisplay>,
pub agent_mode: AgentModeType,
pub todo_list: Option<TodoListDisplay>,
@ -300,6 +302,7 @@ impl UiState {
request_in_flight: None,
request_status: None,
request_progress: None,
last_progress_update: None,
file_changes: None,
agent_mode: AgentModeType::Build, // Default to build mode
todo_list: None,
@ -353,4 +356,4 @@ impl UiState {
Some(crate::domain::ModelType::OpenAI)
)
}
}
}

View file

@ -31,4 +31,4 @@ impl Theme {
}
pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1);
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);

View file

@ -6,7 +6,7 @@ use crate::infrastructure::cli::helpers::{
centered_rect, cursor_position, list_state, panel_block,
};
use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState};
use crate::infrastructure::cli::theme::{Theme, PANEL_PADDING};
use crate::infrastructure::cli::theme::{Theme, INPUT_PADDING, PANEL_PADDING};
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
use ratatui::style::{Modifier, Style};
use ratatui::text::{Line, Span, Text};
@ -28,7 +28,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
input_lines,
crate::infrastructure::cli::state::INPUT_MAX_HEIGHT,
);
let mut input_box_height = (input_lines + 2) as u16;
let mut input_box_height = (input_lines + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16;
let indicator_height = 1u16;
let attachment_indicator_height = if state.attached_images.is_empty() {
@ -45,7 +45,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
let overflow = fixed_height - size.height;
let reduced = input_box_height.saturating_sub(overflow);
input_box_height =
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + 2) as u16);
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16);
}
let constraints = if state.attached_images.is_empty() {
@ -557,4 +557,4 @@ pub fn openai_available_filtered(state: &UiState, filter: &str) -> Vec<String> {
.filter(|name| name.to_lowercase().contains(&query))
.cloned()
.collect()
}
}

View file

@ -89,17 +89,39 @@ impl LocalEngine {
}
impl InferenceEngine for LocalEngine {
// so far very fucked up implementation, even user_prompt is not properly passed
// need to be refactored using proper request builder same way as for openai inference
fn generate(
&self,
system_prompt: &str,
user_prompt: &str,
max_tokens: u32,
_tools: &[&dyn crate::domain::tools::Tool],
_chain: &crate::domain::workflow::Chain,
chain: &crate::domain::workflow::Chain,
_images: &[String],
_tracer: Option<&mut openai_agents_tracing::TracingFacade>,
mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, InfaError> {
let prompt = format!("{}\n\n{}", system_prompt, user_prompt);
let model_name = "local";
// Start span with model name and add request as JSON
if let Some(tracer) = &mut tracer {
tracer.start_span(model_name, openai_agents_tracing::SpanKind::Generation);
// Set request as structured JSON input
let request_json = serde_json::json!({
"system_prompt": chain.system_prompt.clone(),
"user_prompt": chain.get_steps_with_history()[0].context_payload.clone(),
"max_tokens": 1000,
});
tracer.set_input_json(model_name, request_json);
// Set model configuration
let mut model_config = std::collections::HashMap::new();
model_config.insert("temperature".to_string(), serde_json::json!(self.params.temperature));
model_config.insert("top_p".to_string(), serde_json::json!(self.params.top_p));
model_config.insert("ctx_size".to_string(), serde_json::json!(self.params.ctx_size));
model_config.insert("threads".to_string(), serde_json::json!(self.params.threads));
tracer.set_model_config(model_name, model_config);
}
let prompt = format!("{}\n\n{}", chain.system_prompt.clone(), chain.get_steps_with_history()[0].context_payload.clone());
let to_error = |msg: String| -> InfaError {
std::io::Error::new(std::io::ErrorKind::Other, msg).into()
};
@ -145,7 +167,7 @@ impl InferenceEngine for LocalEngine {
let mut output = String::new();
let mut n_cur = tokens.len();
for _ in 0..max_tokens {
for _ in 0..2000 {
let new_token = sampler.sample(&ctx, -1);
if self.model.is_eog_token(new_token) {
@ -169,6 +191,15 @@ impl InferenceEngine for LocalEngine {
n_cur += 1;
}
// Add output as JSON and end span
if let Some(tracer) = tracer {
let response_json = serde_json::json!({
"text": &output,
});
tracer.set_output_json(model_name, response_json);
tracer.end_span(model_name);
}
Ok(LLMInferenceResult {
summary: output.trim().to_string(),
raw_output: output,

View file

@ -9,6 +9,7 @@ pub mod openai;
pub struct ToolCall {
pub name: String,
pub arguments: String,
pub call_id: String,
}
pub struct LLMInferenceResult {
@ -22,9 +23,6 @@ pub trait InferenceEngine: Send + Sync {
/// Generate text without streaming output
fn generate(
&self,
system_prompt: &str,
user_prompt: &str,
max_tokens: u32,
tools: &[&dyn Tool],
chain: &Chain,
images: &[String],

View file

@ -66,15 +66,12 @@ impl OpenAIEngine {
impl InferenceEngine for OpenAIEngine {
fn generate(
&self,
system_prompt: &str,
user_prompt: &str,
_max_tokens: u32,
tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain,
images: &[String],
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, InfaError> {
self.generate_with_responses_api(system_prompt, user_prompt, tools, chain, images, tracer)
self.generate_with_responses_api(tools, chain, images, tracer)
}
fn get_type(&self) -> ModelType {
ModelType::OpenAI
@ -85,15 +82,13 @@ impl OpenAIEngine {
/// Generate using the Responses API (for newer models like codex, o-series)
fn generate_with_responses_api(
&self,
system_prompt: &str,
user_prompt: &str,
tools: &[&dyn crate::domain::tools::Tool],
chain: &crate::domain::workflow::Chain,
images: &[String],
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
) -> Result<LLMInferenceResult, InfaError> {
self.responses_client
.call_responses_api(system_prompt, user_prompt, tools, chain, images, tracer)
.call_responses_api(tools, chain, images, tracer)
.map_err(|e| e)
}
}