mirror of
https://github.com/taggart-comet/quill-code.git
synced 2026-04-28 03:19:33 +00:00
fix
This commit is contained in:
parent
bd21ae6296
commit
3fd019bd7a
35 changed files with 633 additions and 343 deletions
|
|
@ -82,6 +82,38 @@ impl TracingFacade {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_model_config(&mut self, name: impl AsRef<str>, config: HashMap<String, serde_json::Value>) {
|
||||||
|
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||||
|
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||||
|
data.model_config = Some(config);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_usage(&mut self, name: impl AsRef<str>, input_tokens: u32, output_tokens: u32) {
|
||||||
|
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||||
|
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||||
|
data.usage = Some(crate::types::UsageData::new(input_tokens, output_tokens));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_input_json(&mut self, name: impl AsRef<str>, input: serde_json::Value) {
|
||||||
|
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||||
|
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||||
|
data.input = Some(vec![input]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_output_json(&mut self, name: impl AsRef<str>, output: serde_json::Value) {
|
||||||
|
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||||
|
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||||
|
data.output = Some(vec![output]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn end(&mut self) {
|
pub async fn end(&mut self) {
|
||||||
for (_, mut span) in self.open_spans.drain() {
|
for (_, mut span) in self.open_spans.drain() {
|
||||||
span.mark_ended();
|
span.mark_ended();
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ pub use session::{Session, SessionRequest};
|
||||||
pub use startup::StartupService;
|
pub use startup::StartupService;
|
||||||
pub use user_settings::UserSettings;
|
pub use user_settings::UserSettings;
|
||||||
pub use workflow::{CancellationToken, Chain};
|
pub use workflow::{CancellationToken, Chain};
|
||||||
|
#[allow(unused_imports)]
|
||||||
pub use todo::{TodoList, TodoItem};
|
pub use todo::{TodoList, TodoItem};
|
||||||
|
|
||||||
/// Model type enum matching the inference engine types
|
/// Model type enum matching the inference engine types
|
||||||
|
|
@ -41,4 +42,4 @@ impl ModelType {
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -401,12 +401,17 @@ mod tests {
|
||||||
"read_only"
|
"read_only"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, _input: String) -> Option<Error> {
|
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn work(&self, _request: &dyn Request) -> ToolResult {
|
fn work(&self, _request: &dyn Request) -> ToolResult {
|
||||||
ToolResult::ok("read_only".to_string(), String::new(), String::new())
|
ToolResult::ok(
|
||||||
|
"read_only".to_string(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters(&self) -> serde_json::Value {
|
fn parameters(&self) -> serde_json::Value {
|
||||||
|
|
@ -439,12 +444,17 @@ mod tests {
|
||||||
"write_tool"
|
"write_tool"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, _input: String) -> Option<Error> {
|
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn work(&self, _request: &dyn Request) -> ToolResult {
|
fn work(&self, _request: &dyn Request) -> ToolResult {
|
||||||
ToolResult::ok("write_tool".to_string(), String::new(), String::new())
|
ToolResult::ok(
|
||||||
|
"write_tool".to_string(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters(&self) -> serde_json::Value {
|
fn parameters(&self) -> serde_json::Value {
|
||||||
|
|
@ -474,12 +484,17 @@ mod tests {
|
||||||
"command_tool"
|
"command_tool"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, _input: String) -> Option<Error> {
|
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn work(&self, _request: &dyn Request) -> ToolResult {
|
fn work(&self, _request: &dyn Request) -> ToolResult {
|
||||||
ToolResult::ok("command_tool".to_string(), String::new(), String::new())
|
ToolResult::ok(
|
||||||
|
"command_tool".to_string(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
String::new(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters(&self) -> serde_json::Value {
|
fn parameters(&self) -> serde_json::Value {
|
||||||
|
|
@ -955,4 +970,4 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(decision, PermissionDecision::AlwaysAllow);
|
assert_eq!(decision, PermissionDecision::AlwaysAllow);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
use super::types::{Permission, PermissionDecision, PermissionScope};
|
use super::types::{Permission, PermissionDecision, PermissionScope};
|
||||||
use crate::infrastructure::db::DbPool;
|
use crate::infrastructure::db::DbPool;
|
||||||
use rusqlite::params;
|
use rusqlite::params;
|
||||||
use std::path::PathBuf;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
|
@ -121,16 +120,19 @@ impl PermissionStore for SqlitePermissionStore {
|
||||||
) -> Result<Option<Permission>, StoreError> {
|
) -> Result<Option<Permission>, StoreError> {
|
||||||
// Build query dynamically to handle NULL values properly
|
// Build query dynamically to handle NULL values properly
|
||||||
// In SQL, NULL != '', so we need to use IS NULL for empty strings
|
// In SQL, NULL != '', so we need to use IS NULL for empty strings
|
||||||
|
let mut param_num = 3;
|
||||||
let command_clause = if command_pattern.is_empty() {
|
let command_clause = if command_pattern.is_empty() {
|
||||||
"command_pattern IS NULL"
|
"command_pattern IS NULL".to_string()
|
||||||
} else {
|
} else {
|
||||||
"command_pattern = ?3"
|
let clause = format!("command_pattern = ?{}", param_num);
|
||||||
|
param_num += 1;
|
||||||
|
clause
|
||||||
};
|
};
|
||||||
|
|
||||||
let resource_clause = if resource_pattern.is_empty() {
|
let resource_clause = if resource_pattern.is_empty() {
|
||||||
"resource_pattern IS NULL"
|
"resource_pattern IS NULL".to_string()
|
||||||
} else {
|
} else {
|
||||||
"resource_pattern = ?4"
|
format!("resource_pattern = ?{}", param_num)
|
||||||
};
|
};
|
||||||
|
|
||||||
let query = format!(
|
let query = format!(
|
||||||
|
|
@ -151,32 +153,28 @@ impl PermissionStore for SqlitePermissionStore {
|
||||||
)))
|
)))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Build params based on whether patterns are empty
|
// Build params list dynamically to match the query
|
||||||
let result = if command_pattern.is_empty() && resource_pattern.is_empty() {
|
let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = vec![
|
||||||
conn.query_row(
|
Box::new(project_id),
|
||||||
&query,
|
Box::new(tool.to_string()),
|
||||||
params![project_id, tool],
|
];
|
||||||
|row| self.row_to_permission(row),
|
|
||||||
)
|
if !command_pattern.is_empty() {
|
||||||
} else if command_pattern.is_empty() {
|
param_values.push(Box::new(command_pattern.to_string()));
|
||||||
conn.query_row(
|
}
|
||||||
&query,
|
|
||||||
params![project_id, tool, resource_pattern],
|
if !resource_pattern.is_empty() {
|
||||||
|row| self.row_to_permission(row),
|
param_values.push(Box::new(resource_pattern.to_string()));
|
||||||
)
|
}
|
||||||
} else if resource_pattern.is_empty() {
|
|
||||||
conn.query_row(
|
let params_refs: Vec<&dyn rusqlite::ToSql> = param_values
|
||||||
&query,
|
.iter()
|
||||||
params![project_id, tool, command_pattern],
|
.map(|p| p.as_ref())
|
||||||
|row| self.row_to_permission(row),
|
.collect();
|
||||||
)
|
|
||||||
} else {
|
let result = conn.query_row(&query, params_refs.as_slice(), |row| {
|
||||||
conn.query_row(
|
self.row_to_permission(row)
|
||||||
&query,
|
});
|
||||||
params![project_id, tool, command_pattern, resource_pattern],
|
|
||||||
|row| self.row_to_permission(row),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(permission) => Ok(Some(permission)),
|
Ok(permission) => Ok(Some(permission)),
|
||||||
|
|
@ -186,48 +184,140 @@ impl PermissionStore for SqlitePermissionStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SqlitePermissionStore {
|
#[cfg(test)]
|
||||||
fn find_matching_command_permission(
|
mod tests {
|
||||||
&self,
|
use super::*;
|
||||||
rows: rusqlite::MappedRows<
|
use r2d2_sqlite::SqliteConnectionManager;
|
||||||
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
|
|
||||||
>,
|
|
||||||
tool: &str,
|
|
||||||
command: &str,
|
|
||||||
) -> Result<Option<Permission>, StoreError> {
|
|
||||||
for row_result in rows {
|
|
||||||
match row_result {
|
|
||||||
Ok(permission) => {
|
|
||||||
if permission.matches(tool, Some(command), None::<&PathBuf>) {
|
|
||||||
return Ok(Some(permission));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => return Err(StoreError::Database(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(None)
|
fn setup_test_db() -> DbPool {
|
||||||
|
let manager = SqliteConnectionManager::memory();
|
||||||
|
let pool = r2d2::Pool::new(manager).unwrap();
|
||||||
|
let conn = pool.get().unwrap();
|
||||||
|
|
||||||
|
// Create permissions table
|
||||||
|
conn.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS permissions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
tool_name TEXT NOT NULL,
|
||||||
|
command_pattern TEXT,
|
||||||
|
resource_pattern TEXT,
|
||||||
|
decision TEXT NOT NULL,
|
||||||
|
scope TEXT NOT NULL,
|
||||||
|
project_id INTEGER,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
)",
|
||||||
|
[],
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
pool
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_matching_path_permission(
|
#[test]
|
||||||
&self,
|
fn test_find_permission_with_null_patterns() {
|
||||||
rows: rusqlite::MappedRows<
|
let pool = setup_test_db();
|
||||||
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
|
let store = SqlitePermissionStore::new(pool.clone());
|
||||||
>,
|
|
||||||
tool: &str,
|
|
||||||
path: &PathBuf,
|
|
||||||
) -> Result<Option<Permission>, StoreError> {
|
|
||||||
for row_result in rows {
|
|
||||||
match row_result {
|
|
||||||
Ok(permission) => {
|
|
||||||
if permission.matches(tool, None, Some(path)) {
|
|
||||||
return Ok(Some(permission));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => return Err(StoreError::Database(e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(None)
|
// Create permission with NULL command and resource patterns
|
||||||
|
let permission = Permission::new(
|
||||||
|
"test_tool".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
PermissionDecision::AlwaysAllow,
|
||||||
|
PermissionScope::Project,
|
||||||
|
Some(1),
|
||||||
|
);
|
||||||
|
store.create_permission(permission).unwrap();
|
||||||
|
|
||||||
|
// Should find the permission when searching with empty strings
|
||||||
|
let result = store.find_permission("test_tool", 1, "", "").unwrap();
|
||||||
|
assert!(result.is_some());
|
||||||
|
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_permission_with_command_only() {
|
||||||
|
let pool = setup_test_db();
|
||||||
|
let store = SqlitePermissionStore::new(pool.clone());
|
||||||
|
|
||||||
|
// Create permission with command but NULL resource
|
||||||
|
let permission = Permission::new(
|
||||||
|
"test_tool".to_string(),
|
||||||
|
Some("echo test".to_string()),
|
||||||
|
None,
|
||||||
|
PermissionDecision::AlwaysDeny,
|
||||||
|
PermissionScope::Project,
|
||||||
|
Some(1),
|
||||||
|
);
|
||||||
|
store.create_permission(permission).unwrap();
|
||||||
|
|
||||||
|
// Should find the permission when searching with command and empty resource
|
||||||
|
let result = store.find_permission("test_tool", 1, "echo test", "").unwrap();
|
||||||
|
assert!(result.is_some());
|
||||||
|
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_permission_with_resource_only() {
|
||||||
|
let pool = setup_test_db();
|
||||||
|
let store = SqlitePermissionStore::new(pool.clone());
|
||||||
|
|
||||||
|
// Create permission with resource but NULL command
|
||||||
|
let permission = Permission::new(
|
||||||
|
"test_tool".to_string(),
|
||||||
|
None,
|
||||||
|
Some("/path/to/file".to_string()),
|
||||||
|
PermissionDecision::AlwaysAllow,
|
||||||
|
PermissionScope::Project,
|
||||||
|
Some(1),
|
||||||
|
);
|
||||||
|
store.create_permission(permission).unwrap();
|
||||||
|
|
||||||
|
// Should find the permission when searching with empty command and resource
|
||||||
|
let result = store.find_permission("test_tool", 1, "", "/path/to/file").unwrap();
|
||||||
|
assert!(result.is_some());
|
||||||
|
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_permission_with_both_patterns() {
|
||||||
|
let pool = setup_test_db();
|
||||||
|
let store = SqlitePermissionStore::new(pool.clone());
|
||||||
|
|
||||||
|
// Create permission with both patterns
|
||||||
|
let permission = Permission::new(
|
||||||
|
"test_tool".to_string(),
|
||||||
|
Some("rm -rf".to_string()),
|
||||||
|
Some("/etc/passwd".to_string()),
|
||||||
|
PermissionDecision::AlwaysDeny,
|
||||||
|
PermissionScope::Project,
|
||||||
|
Some(1),
|
||||||
|
);
|
||||||
|
store.create_permission(permission).unwrap();
|
||||||
|
|
||||||
|
// Should find the permission when searching with both patterns
|
||||||
|
let result = store.find_permission("test_tool", 1, "rm -rf", "/etc/passwd").unwrap();
|
||||||
|
assert!(result.is_some());
|
||||||
|
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_permission_no_match() {
|
||||||
|
let pool = setup_test_db();
|
||||||
|
let store = SqlitePermissionStore::new(pool.clone());
|
||||||
|
|
||||||
|
// Create permission with command
|
||||||
|
let permission = Permission::new(
|
||||||
|
"test_tool".to_string(),
|
||||||
|
Some("echo test".to_string()),
|
||||||
|
None,
|
||||||
|
PermissionDecision::AlwaysAllow,
|
||||||
|
PermissionScope::Project,
|
||||||
|
Some(1),
|
||||||
|
);
|
||||||
|
store.create_permission(permission).unwrap();
|
||||||
|
|
||||||
|
// Should NOT find when searching with different command
|
||||||
|
let result = store.find_permission("test_tool", 1, "rm -rf", "").unwrap();
|
||||||
|
assert!(result.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -29,6 +29,7 @@ pub struct Permission {
|
||||||
pub created_at: DateTime<Utc>,
|
pub created_at: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
impl Permission {
|
impl Permission {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
|
|
@ -173,4 +174,4 @@ impl Default for PermissionConfig {
|
||||||
require_confirmation: true,
|
require_confirmation: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -32,33 +32,38 @@ pub fn get_bt_tree_step_prompt(
|
||||||
pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String {
|
pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String {
|
||||||
let (os_name, shell_name) = get_runtime_environment();
|
let (os_name, shell_name) = get_runtime_environment();
|
||||||
|
|
||||||
let mut system_prompt = "".to_string();
|
|
||||||
if agent_mode == AgentModeType::Plan {
|
if agent_mode == AgentModeType::Plan {
|
||||||
system_prompt = _system_prompt_for_plan(model_type);
|
let mut system_prompt = _system_prompt_for_plan(model_type);
|
||||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
if remaining_calls < 3 {
|
||||||
|
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||||
|
}
|
||||||
return system_prompt;
|
return system_prompt;
|
||||||
}
|
}
|
||||||
if agent_mode == AgentModeType::BuildFromPlan {
|
if agent_mode == AgentModeType::BuildFromPlan {
|
||||||
system_prompt = _system_prompt_for_build_from_plan(model_type);
|
let mut system_prompt = _system_prompt_for_build_from_plan(model_type);
|
||||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
if remaining_calls < 3 {
|
||||||
|
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||||
|
}
|
||||||
return system_prompt;
|
return system_prompt;
|
||||||
}
|
}
|
||||||
if model_type == ModelType::OpenAI {
|
let mut system_prompt = if model_type == ModelType::OpenAI {
|
||||||
system_prompt = format!(
|
format!(
|
||||||
"You are Drastis, a coding agent. \n\
|
"You are Drastis, a coding agent. \n\
|
||||||
Use the available tools to gather context and make changes. \
|
Use the available tools to gather context and make changes. \
|
||||||
When using tools, pass JSON arguments that match their parameters. \n\
|
When using tools, pass JSON arguments that match their parameters. \n\
|
||||||
Runtime: os={}, shell={}.",
|
Runtime: os={}, shell={}.",
|
||||||
os_name, shell_name
|
os_name, shell_name
|
||||||
);
|
)
|
||||||
} else {
|
} else {
|
||||||
system_prompt = format!(
|
format!(
|
||||||
"You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.",
|
"You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.",
|
||||||
os_name, shell_name
|
os_name, shell_name
|
||||||
);
|
)
|
||||||
|
};
|
||||||
|
if remaining_calls < 3 {
|
||||||
|
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||||
}
|
}
|
||||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
system_prompt
|
||||||
return system_prompt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _system_prompt_for_plan(model_type: ModelType) -> String {
|
fn _system_prompt_for_plan(model_type: ModelType) -> String {
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ use std::path::Path;
|
||||||
/// without being tightly coupled to the Session entity.
|
/// without being tightly coupled to the Session entity.
|
||||||
pub trait Request {
|
pub trait Request {
|
||||||
/// Get the history of previous requests
|
/// Get the history of previous requests
|
||||||
|
#[allow(dead_code)]
|
||||||
fn history(&self) -> &[SessionRequest];
|
fn history(&self) -> &[SessionRequest];
|
||||||
|
|
||||||
/// Get the current request prompt
|
/// Get the current request prompt
|
||||||
|
|
@ -40,4 +41,4 @@ pub trait Request {
|
||||||
|
|
||||||
/// Get the session's TODO list (plan)
|
/// Get the session's TODO list (plan)
|
||||||
fn get_session_plan(&self) -> Option<crate::domain::todo::TodoList>;
|
fn get_session_plan(&self) -> Option<crate::domain::todo::TodoList>;
|
||||||
}
|
}
|
||||||
|
|
@ -4,6 +4,7 @@ use crate::repository::SessionRequestRow;
|
||||||
/// Domain entity representing a user request within a session.
|
/// Domain entity representing a user request within a session.
|
||||||
/// Each request contains the user's prompt and the resulting summary.
|
/// Each request contains the user's prompt and the resulting summary.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(dead_code)]
|
||||||
pub struct SessionRequest {
|
pub struct SessionRequest {
|
||||||
prompt: String,
|
prompt: String,
|
||||||
result_summary: Option<String>,
|
result_summary: Option<String>,
|
||||||
|
|
@ -32,4 +33,4 @@ impl SessionRequest {
|
||||||
mode: row.mode,
|
mode: row.mode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
use super::Error;
|
use super::Error;
|
||||||
use crate::domain::prompting::session_naming_prompt;
|
use crate::domain::prompting::session_naming_prompt;
|
||||||
use crate::domain::workflow::Chain;
|
use crate::domain::workflow::Chain;
|
||||||
|
use crate::domain::workflow::ChainStep;
|
||||||
use crate::domain::{Project, Session, SessionRequest};
|
use crate::domain::{Project, Session, SessionRequest};
|
||||||
use crate::infrastructure::db::DbPool;
|
use crate::infrastructure::db::DbPool;
|
||||||
use crate::infrastructure::inference::InferenceEngine;
|
use crate::infrastructure::inference::InferenceEngine;
|
||||||
|
|
@ -83,38 +84,37 @@ impl StartupService {
|
||||||
let prompt_preview: String = first_prompt.chars().take(100).collect();
|
let prompt_preview: String = first_prompt.chars().take(100).collect();
|
||||||
let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview);
|
let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview);
|
||||||
|
|
||||||
let chain = Chain::new();
|
let mut chain = Chain::new();
|
||||||
let session_name =
|
chain
|
||||||
match self
|
.steps
|
||||||
.engine
|
.push(ChainStep::user_message(naming_prompt.parse().unwrap(), Vec::new()));
|
||||||
.generate("", &naming_prompt, 15, &[], &chain, &[], None)
|
let session_name = match self.engine.generate(&[], &chain, &[], None) {
|
||||||
{
|
Ok(raw) => {
|
||||||
Ok(raw) => {
|
log::debug!("Raw session name response: {:?}", raw.summary);
|
||||||
log::debug!("Raw session name response: {:?}", raw.summary);
|
// Clean up the response
|
||||||
// Clean up the response
|
let cleaned = raw
|
||||||
let cleaned = raw
|
.summary
|
||||||
.summary
|
.lines()
|
||||||
.lines()
|
.next()
|
||||||
.next()
|
.unwrap_or("")
|
||||||
.unwrap_or("")
|
.trim()
|
||||||
.trim()
|
.trim_matches('"')
|
||||||
.trim_matches('"')
|
.trim_matches('*')
|
||||||
.trim_matches('*')
|
.replace("<|im_end|>", "")
|
||||||
.replace("<|im_end|>", "")
|
.trim()
|
||||||
.trim()
|
.to_string();
|
||||||
.to_string();
|
|
||||||
|
|
||||||
if cleaned.is_empty() || cleaned.len() > 50 {
|
if cleaned.is_empty() || cleaned.len() > 50 {
|
||||||
Self::fallback_session_name(first_prompt)
|
|
||||||
} else {
|
|
||||||
cleaned
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
log::warn!("Session naming failed: {}", e);
|
|
||||||
Self::fallback_session_name(first_prompt)
|
Self::fallback_session_name(first_prompt)
|
||||||
|
} else {
|
||||||
|
cleaned
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Session naming failed: {}", e);
|
||||||
|
Self::fallback_session_name(first_prompt)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Create session in database
|
// Create session in database
|
||||||
let conn = self
|
let conn = self
|
||||||
|
|
@ -225,4 +225,4 @@ mod tests {
|
||||||
|
|
||||||
assert_eq!(StartupService::fallback_session_name(""), "New Session");
|
assert_eq!(StartupService::fallback_session_name(""), "New Session");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -15,6 +15,7 @@ pub struct DiscoverObjects {
|
||||||
struct DiscoverObjectsInput {
|
struct DiscoverObjectsInput {
|
||||||
raw: String,
|
raw: String,
|
||||||
full_path_to_file: String,
|
full_path_to_file: String,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -69,6 +70,7 @@ impl DiscoverObjects {
|
||||||
Ok(DiscoverObjectsInput {
|
Ok(DiscoverObjectsInput {
|
||||||
raw: raw.to_string(),
|
raw: raw.to_string(),
|
||||||
full_path_to_file: parsed.full_path_to_file,
|
full_path_to_file: parsed.full_path_to_file,
|
||||||
|
call_id: String::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,12 +88,13 @@ impl Tool for DiscoverObjects {
|
||||||
"discover_objects"
|
"discover_objects"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
let parsed = Self::parse_input_json(trimmed);
|
let parsed = Self::parse_input_json(trimmed);
|
||||||
|
|
||||||
match parsed {
|
match parsed {
|
||||||
Ok(parsed) => {
|
Ok(mut parsed) => {
|
||||||
|
parsed.call_id = call_id;
|
||||||
*self.input.lock().unwrap() = Some(parsed);
|
*self.input.lock().unwrap() = Some(parsed);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -103,17 +106,18 @@ impl Tool for DiscoverObjects {
|
||||||
let input = match self.load_input() {
|
let input = match self.load_input() {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match Self::parse_file(&input.full_path_to_file) {
|
match Self::parse_file(&input.full_path_to_file) {
|
||||||
Ok((lang, objects)) => ToolResult::ok(
|
Ok((lang, objects)) => ToolResult::ok(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
Self::format_output(lang, &objects),
|
Self::format_output(lang, &objects),
|
||||||
|
input.call_id,
|
||||||
),
|
),
|
||||||
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()),
|
Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ struct FindFilesInput {
|
||||||
query: String,
|
query: String,
|
||||||
root: Option<String>,
|
root: Option<String>,
|
||||||
max_results: usize,
|
max_results: usize,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -132,6 +133,7 @@ impl FindFiles {
|
||||||
query: parsed.query,
|
query: parsed.query,
|
||||||
root: parsed.root,
|
root: parsed.root,
|
||||||
max_results: parsed.max_results.unwrap_or(20),
|
max_results: parsed.max_results.unwrap_or(20),
|
||||||
|
call_id: String::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -149,12 +151,13 @@ impl Tool for FindFiles {
|
||||||
"find_files"
|
"find_files"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
let parsed = Self::parse_input_json(trimmed);
|
let parsed = Self::parse_input_json(trimmed);
|
||||||
|
|
||||||
match parsed {
|
match parsed {
|
||||||
Ok(parsed) => {
|
Ok(mut parsed) => {
|
||||||
|
parsed.call_id = call_id;
|
||||||
*self.input.lock().unwrap() = Some(parsed);
|
*self.input.lock().unwrap() = Some(parsed);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -166,7 +169,7 @@ impl Tool for FindFiles {
|
||||||
let input = match self.load_input() {
|
let input = match self.load_input() {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -174,15 +177,16 @@ impl Tool for FindFiles {
|
||||||
let search_root =
|
let search_root =
|
||||||
match Self::resolve_search_root(input.root.as_deref(), request.project_root()) {
|
match Self::resolve_search_root(input.root.as_deref(), request.project_root()) {
|
||||||
Ok(root) => root,
|
Ok(root) => root,
|
||||||
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e),
|
Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if search root exists
|
// Check if search root exists
|
||||||
if !search_root.exists() {
|
if !search_root.exists() {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Search root does not exist: {}", search_root.display()),
|
format!("Search root does not exist: {}", search_root.display()),
|
||||||
|
input.call_id,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -204,7 +208,7 @@ impl Tool for FindFiles {
|
||||||
output
|
output
|
||||||
};
|
};
|
||||||
|
|
||||||
ToolResult::ok(self.name().to_string(), input.raw, output)
|
ToolResult::ok(self.name().to_string(), input.raw, output, input.call_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters(&self) -> serde_json::Value {
|
fn parameters(&self) -> serde_json::Value {
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,7 @@ pub struct FileChange {
|
||||||
|
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
|
call_id: String,
|
||||||
pub(crate) input: String,
|
pub(crate) input: String,
|
||||||
is_successful: bool,
|
is_successful: bool,
|
||||||
output: String,
|
output: String,
|
||||||
|
|
@ -53,7 +54,7 @@ pub struct ToolResult {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolResult {
|
impl ToolResult {
|
||||||
pub fn ok(tool_name: String, input: String, output: String) -> Self {
|
pub fn ok(tool_name: String, input: String, output: String, call_id: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tool_name,
|
tool_name,
|
||||||
input,
|
input,
|
||||||
|
|
@ -61,10 +62,11 @@ impl ToolResult {
|
||||||
output,
|
output,
|
||||||
error_message: String::new(),
|
error_message: String::new(),
|
||||||
file_changes: None,
|
file_changes: None,
|
||||||
|
call_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn error(tool_name: String, input: String, message: String) -> Self {
|
pub fn error(tool_name: String, input: String, message: String, call_id: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tool_name,
|
tool_name,
|
||||||
input,
|
input,
|
||||||
|
|
@ -72,6 +74,7 @@ impl ToolResult {
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error_message: message.into(),
|
error_message: message.into(),
|
||||||
file_changes: None,
|
file_changes: None,
|
||||||
|
call_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,11 +131,15 @@ impl ToolResult {
|
||||||
pub fn file_changes(&self) -> Option<&[FileChange]> {
|
pub fn file_changes(&self) -> Option<&[FileChange]> {
|
||||||
self.file_changes.as_deref()
|
self.file_changes.as_deref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn call_id(&self) -> String {
|
||||||
|
self.call_id.clone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Tool: Send + Sync {
|
pub trait Tool: Send + Sync {
|
||||||
fn name(&self) -> &'static str;
|
fn name(&self) -> &'static str;
|
||||||
fn parse_input(&self, input: String) -> Option<Error>;
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error>;
|
||||||
fn work(&self, request: &dyn Request) -> ToolResult;
|
fn work(&self, request: &dyn Request) -> ToolResult;
|
||||||
fn parameters(&self) -> Value;
|
fn parameters(&self) -> Value;
|
||||||
fn desc(&self) -> String;
|
fn desc(&self) -> String;
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ pub struct PatchFiles {
|
||||||
struct PatchFilesInput {
|
struct PatchFilesInput {
|
||||||
raw: String,
|
raw: String,
|
||||||
patch: String,
|
patch: String,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -41,6 +42,7 @@ impl PatchFiles {
|
||||||
Ok(PatchFilesInput {
|
Ok(PatchFilesInput {
|
||||||
raw: raw.to_string(),
|
raw: raw.to_string(),
|
||||||
patch: parsed.patch,
|
patch: parsed.patch,
|
||||||
|
call_id: String::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,12 +60,13 @@ impl Tool for PatchFiles {
|
||||||
"patch_files"
|
"patch_files"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
let parsed = Self::parse_input_json(trimmed);
|
let parsed = Self::parse_input_json(trimmed);
|
||||||
|
|
||||||
match parsed {
|
match parsed {
|
||||||
Ok(parsed) => {
|
Ok(mut parsed) => {
|
||||||
|
parsed.call_id = call_id;
|
||||||
*self.input.lock().unwrap() = Some(parsed);
|
*self.input.lock().unwrap() = Some(parsed);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -75,7 +78,7 @@ impl Tool for PatchFiles {
|
||||||
let input = match self.load_input() {
|
let input = match self.load_input() {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -84,8 +87,9 @@ impl Tool for PatchFiles {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Failed to parse patch: {}", e),
|
format!("Failed to parse patch: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -115,8 +119,9 @@ impl Tool for PatchFiles {
|
||||||
{
|
{
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Invalid path outside project root: {}", path),
|
format!("Invalid path outside project root: {}", path),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let fs_path = project_root.join(rel_path);
|
let fs_path = project_root.join(rel_path);
|
||||||
|
|
@ -129,8 +134,9 @@ impl Tool for PatchFiles {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Failed to read file '{}': {}", fs_path.display(), e),
|
format!("Failed to read file '{}': {}", fs_path.display(), e),
|
||||||
|
input.call_id.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -142,8 +148,9 @@ impl Tool for PatchFiles {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Failed to apply patch: {}", e),
|
format!("Failed to apply patch: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -155,28 +162,31 @@ impl Tool for PatchFiles {
|
||||||
if let Err(e) = fs::create_dir_all(parent) {
|
if let Err(e) = fs::create_dir_all(parent) {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!(
|
format!(
|
||||||
"Failed to create parent directories for '{}': {}",
|
"Failed to create parent directories for '{}': {}",
|
||||||
fs_path.display(),
|
fs_path.display(),
|
||||||
e
|
e
|
||||||
),
|
),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Err(e) = fs::write(&fs_path, content) {
|
if let Err(e) = fs::write(&fs_path, content) {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Failed to write file '{}': {}", fs_path.display(), e),
|
format!("Failed to write file '{}': {}", fs_path.display(), e),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else if fs_path.exists() {
|
} else if fs_path.exists() {
|
||||||
if let Err(e) = fs::remove_file(&fs_path) {
|
if let Err(e) = fs::remove_file(&fs_path) {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Failed to delete file '{}': {}", fs_path.display(), e),
|
format!("Failed to delete file '{}': {}", fs_path.display(), e),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -224,6 +234,7 @@ impl Tool for PatchFiles {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw,
|
||||||
"Patch applied successfully".to_string(),
|
"Patch applied successfully".to_string(),
|
||||||
|
input.call_id,
|
||||||
);
|
);
|
||||||
|
|
||||||
if !file_changes.is_empty() {
|
if !file_changes.is_empty() {
|
||||||
|
|
@ -531,7 +542,7 @@ mod tests {
|
||||||
"{{\"patch\":\"{}\"}}",
|
"{{\"patch\":\"{}\"}}",
|
||||||
patch.replace('\n', "\\n").replace('"', "\\\"")
|
patch.replace('\n', "\\n").replace('"', "\\\"")
|
||||||
);
|
);
|
||||||
assert!(tool.parse_input(input).is_none());
|
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -571,7 +582,7 @@ mod tests {
|
||||||
"{{\"patch\":\"{}\"}}",
|
"{{\"patch\":\"{}\"}}",
|
||||||
patch.replace('\n', "\\n").replace('"', "\\\"")
|
patch.replace('\n', "\\n").replace('"', "\\\"")
|
||||||
);
|
);
|
||||||
assert!(tool.parse_input(input).is_none());
|
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -592,7 +603,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = PatchFiles::new();
|
let tool = PatchFiles::new();
|
||||||
let input = "{\"patch\":\"not a patch\"}".to_string();
|
let input = "{\"patch\":\"not a patch\"}".to_string();
|
||||||
assert!(tool.parse_input(input).is_none());
|
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -601,4 +612,4 @@ mod tests {
|
||||||
result.output_string()
|
result.output_string()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -18,6 +18,7 @@ struct ReadObjectsInput {
|
||||||
raw: String,
|
raw: String,
|
||||||
full_path_to_file: String,
|
full_path_to_file: String,
|
||||||
queries: Vec<ObjectQuery>,
|
queries: Vec<ObjectQuery>,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -138,6 +139,7 @@ impl ReadObjects {
|
||||||
raw: raw.to_string(),
|
raw: raw.to_string(),
|
||||||
full_path_to_file: trimmed_path.to_string(),
|
full_path_to_file: trimmed_path.to_string(),
|
||||||
queries,
|
queries,
|
||||||
|
call_id: String::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -182,12 +184,13 @@ impl Tool for ReadObjects {
|
||||||
"read_objects"
|
"read_objects"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
let parsed = Self::parse_input_json(trimmed);
|
let parsed = Self::parse_input_json(trimmed);
|
||||||
|
|
||||||
match parsed {
|
match parsed {
|
||||||
Ok(parsed) => {
|
Ok(mut parsed) => {
|
||||||
|
parsed.call_id = call_id;
|
||||||
*self.input.lock().unwrap() = Some(parsed);
|
*self.input.lock().unwrap() = Some(parsed);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -199,17 +202,18 @@ impl Tool for ReadObjects {
|
||||||
let input = match self.load_input() {
|
let input = match self.load_input() {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match Self::read_objects(&input.full_path_to_file, &input.queries) {
|
match Self::read_objects(&input.full_path_to_file, &input.queries) {
|
||||||
Ok((lang, results)) => ToolResult::ok(
|
Ok((lang, results)) => ToolResult::ok(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
Self::format_output(lang, results),
|
Self::format_output(lang, results),
|
||||||
|
input.call_id,
|
||||||
),
|
),
|
||||||
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()),
|
Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -224,7 +228,7 @@ impl Tool for ReadObjects {
|
||||||
},
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\")",
|
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\"), grepping is not working here, use exact names",
|
||||||
"minLength": 1
|
"minLength": 1
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
@ -235,8 +239,8 @@ impl Tool for ReadObjects {
|
||||||
|
|
||||||
fn desc(&self) -> String {
|
fn desc(&self) -> String {
|
||||||
format!(
|
format!(
|
||||||
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties to use for `{}`, use the `discover_objects` tool first.",
|
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties, use the `discover_objects` tool first.",
|
||||||
self.name(), self.name()
|
self.name(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -293,7 +297,7 @@ mod tests {
|
||||||
fn test_parse_input_valid() {
|
fn test_parse_input_valid() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#;
|
let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(result.is_none(), "Expected no error, got: {:?}", result);
|
assert!(result.is_none(), "Expected no error, got: {:?}", result);
|
||||||
|
|
||||||
// Verify the parsed input
|
// Verify the parsed input
|
||||||
|
|
@ -308,7 +312,7 @@ mod tests {
|
||||||
fn test_parse_input_empty_query() {
|
fn test_parse_input_empty_query() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/main.rs", "query": ""}"#;
|
let input = r#"{"path": "src/main.rs", "query": ""}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(result.is_some(), "Expected error for empty query");
|
assert!(result.is_some(), "Expected error for empty query");
|
||||||
|
|
||||||
if let Some(Error::Parse(msg)) = result {
|
if let Some(Error::Parse(msg)) = result {
|
||||||
|
|
@ -326,7 +330,7 @@ mod tests {
|
||||||
fn test_parse_input_whitespace_query() {
|
fn test_parse_input_whitespace_query() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/main.rs", "query": " "}"#;
|
let input = r#"{"path": "src/main.rs", "query": " "}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(result.is_some(), "Expected error for whitespace query");
|
assert!(result.is_some(), "Expected error for whitespace query");
|
||||||
|
|
||||||
if let Some(Error::Parse(msg)) = result {
|
if let Some(Error::Parse(msg)) = result {
|
||||||
|
|
@ -344,7 +348,7 @@ mod tests {
|
||||||
fn test_parse_input_comma_only_query() {
|
fn test_parse_input_comma_only_query() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/main.rs", "query": ",,,"}"#;
|
let input = r#"{"path": "src/main.rs", "query": ",,,"}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(result.is_some(), "Expected error for comma-only query");
|
assert!(result.is_some(), "Expected error for comma-only query");
|
||||||
|
|
||||||
if let Some(Error::Parse(msg)) = result {
|
if let Some(Error::Parse(msg)) = result {
|
||||||
|
|
@ -362,7 +366,7 @@ mod tests {
|
||||||
fn test_parse_input_missing_path() {
|
fn test_parse_input_missing_path() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "", "query": "main"}"#;
|
let input = r#"{"path": "", "query": "main"}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(result.is_some(), "Expected error for empty path");
|
assert!(result.is_some(), "Expected error for empty path");
|
||||||
|
|
||||||
if let Some(Error::Parse(msg)) = result {
|
if let Some(Error::Parse(msg)) = result {
|
||||||
|
|
@ -407,7 +411,7 @@ mod tests {
|
||||||
fn test_parse_input_malformed_json() {
|
fn test_parse_input_malformed_json() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace
|
let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(result.is_some(), "Expected error for malformed JSON");
|
assert!(result.is_some(), "Expected error for malformed JSON");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -415,7 +419,7 @@ mod tests {
|
||||||
fn test_parse_input_path_with_whitespace() {
|
fn test_parse_input_path_with_whitespace() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": " src/main.rs ", "query": "main"}"#;
|
let input = r#"{"path": " src/main.rs ", "query": "main"}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(
|
assert!(
|
||||||
result.is_none(),
|
result.is_none(),
|
||||||
"Expected no error for path with whitespace"
|
"Expected no error for path with whitespace"
|
||||||
|
|
@ -432,7 +436,7 @@ mod tests {
|
||||||
fn test_parse_input_query_with_extra_whitespace() {
|
fn test_parse_input_query_with_extra_whitespace() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#;
|
let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(
|
assert!(
|
||||||
result.is_none(),
|
result.is_none(),
|
||||||
"Expected no error for query with extra whitespace"
|
"Expected no error for query with extra whitespace"
|
||||||
|
|
@ -448,7 +452,7 @@ mod tests {
|
||||||
fn test_parse_input_single_object() {
|
fn test_parse_input_single_object() {
|
||||||
let tool = ReadObjects::new();
|
let tool = ReadObjects::new();
|
||||||
let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#;
|
let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#;
|
||||||
let result = tool.parse_input(input.to_string());
|
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||||
assert!(
|
assert!(
|
||||||
result.is_none(),
|
result.is_none(),
|
||||||
"Expected no error for single object query"
|
"Expected no error for single object query"
|
||||||
|
|
@ -481,4 +485,4 @@ mod tests {
|
||||||
};
|
};
|
||||||
assert!(query.matches(&obj_partial), "Should match partial name");
|
assert!(query.matches(&obj_partial), "Should match partial name");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -25,6 +25,7 @@ struct ShellExecInputParsed {
|
||||||
raw: String,
|
raw: String,
|
||||||
command: String,
|
command: String,
|
||||||
working_dir: Option<String>,
|
working_dir: Option<String>,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Tool for ShellExec {
|
impl Tool for ShellExec {
|
||||||
|
|
@ -32,7 +33,7 @@ impl Tool for ShellExec {
|
||||||
"shell_exec"
|
"shell_exec"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
let parsed = serde_json::from_str::<ShellExecInput>(trimmed)
|
let parsed = serde_json::from_str::<ShellExecInput>(trimmed)
|
||||||
.map_err(|e| Error::Parse(e.to_string()));
|
.map_err(|e| Error::Parse(e.to_string()));
|
||||||
|
|
@ -46,6 +47,7 @@ impl Tool for ShellExec {
|
||||||
raw: trimmed.to_string(),
|
raw: trimmed.to_string(),
|
||||||
command: parsed.command,
|
command: parsed.command,
|
||||||
working_dir: parsed.working_dir,
|
working_dir: parsed.working_dir,
|
||||||
|
call_id,
|
||||||
});
|
});
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -61,6 +63,7 @@ impl Tool for ShellExec {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
String::new(),
|
String::new(),
|
||||||
"input not parsed".to_string(),
|
"input not parsed".to_string(),
|
||||||
|
String::new(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -72,15 +75,17 @@ impl Tool for ShellExec {
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Working directory does not exist: {}", dir),
|
format!("Working directory does not exist: {}", dir),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if !crate::utils::paths::is_within_root(path, request.project_root()) {
|
if !crate::utils::paths::is_within_root(path, request.project_root()) {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
"Working directory is outside project root".to_string(),
|
"Working directory is outside project root".to_string(),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
path.to_path_buf()
|
path.to_path_buf()
|
||||||
|
|
@ -99,8 +104,9 @@ impl Tool for ShellExec {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Failed to execute command: {}", e),
|
format!("Failed to execute command: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -129,10 +135,10 @@ impl Tool for ShellExec {
|
||||||
if !stderr.is_empty() {
|
if !stderr.is_empty() {
|
||||||
result.push_str(&format!("[stderr]: {}", stderr));
|
result.push_str(&format!("[stderr]: {}", stderr));
|
||||||
}
|
}
|
||||||
return ToolResult::error(self.name().to_string(), input.raw, result);
|
return ToolResult::error(self.name().to_string(), input.raw.clone(), result, input.call_id.clone());
|
||||||
};
|
};
|
||||||
|
|
||||||
ToolResult::ok(self.name().to_string(), input.raw, result)
|
ToolResult::ok(self.name().to_string(), input.raw, result, input.call_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters(&self) -> serde_json::Value {
|
fn parameters(&self) -> serde_json::Value {
|
||||||
|
|
@ -155,8 +161,9 @@ impl Tool for ShellExec {
|
||||||
|
|
||||||
fn desc(&self) -> String {
|
fn desc(&self) -> String {
|
||||||
format!(
|
format!(
|
||||||
r#"Use `{}` tool to execute shell commands.
|
"Use `{}` tool to execute shell commands.
|
||||||
Please DO NOT use it to read the full content of a file, this is not efficient, use `read_objects` tool for this."#,
|
DO NOT use it to read code, this is not efficient, use `read_objects` tool for this. \n\
|
||||||
|
DO NOT use it to change files, use `patch_files` tool for this.",
|
||||||
self.name()
|
self.name()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -339,7 +346,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
assert!(tool
|
assert!(tool
|
||||||
.parse_input(r#"{"command":"echo hello"}"#.to_string())
|
.parse_input(r#"{"command":"echo hello"}"#.to_string(), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -357,7 +364,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
assert!(tool
|
assert!(tool
|
||||||
.parse_input(r#"{"command":"pwd"}"#.to_string())
|
.parse_input(r#"{"command":"pwd"}"#.to_string(), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -384,7 +391,7 @@ mod tests {
|
||||||
.parse_input(format!(
|
.parse_input(format!(
|
||||||
r#"{{"command":"pwd","working_dir":"{}"}}"#,
|
r#"{{"command":"pwd","working_dir":"{}"}}"#,
|
||||||
subdir.display()
|
subdir.display()
|
||||||
))
|
), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -405,7 +412,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
assert!(tool
|
assert!(tool
|
||||||
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string())
|
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string(), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -424,7 +431,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
assert!(tool
|
assert!(tool
|
||||||
.parse_input(r#"{"command":"exit 1"}"#.to_string())
|
.parse_input(r#"{"command":"exit 1"}"#.to_string(), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -443,7 +450,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
assert!(tool
|
assert!(tool
|
||||||
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string())
|
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string(), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -457,7 +464,7 @@ mod tests {
|
||||||
let request = TestRequest::new(temp.path());
|
let request = TestRequest::new(temp.path());
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
let err = tool.parse_input(r#"{"command":""}"#.to_string());
|
let err = tool.parse_input(r#"{"command":""}"#.to_string(), "call-id".to_string());
|
||||||
assert!(err.is_some());
|
assert!(err.is_some());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
assert!(result.output_string().contains("Error"));
|
assert!(result.output_string().contains("Error"));
|
||||||
|
|
@ -474,7 +481,7 @@ mod tests {
|
||||||
.parse_input(format!(
|
.parse_input(format!(
|
||||||
r#"{{"command":"echo 'test content' > {}"}}"#,
|
r#"{{"command":"echo 'test content' > {}"}}"#,
|
||||||
file_path.display()
|
file_path.display()
|
||||||
))
|
), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -496,7 +503,7 @@ mod tests {
|
||||||
|
|
||||||
let tool = ShellExec::new();
|
let tool = ShellExec::new();
|
||||||
assert!(tool
|
assert!(tool
|
||||||
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string())
|
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string(), "call-id".to_string())
|
||||||
.is_none());
|
.is_none());
|
||||||
let result = tool.work(&request);
|
let result = tool.work(&request);
|
||||||
|
|
||||||
|
|
@ -506,4 +513,4 @@ mod tests {
|
||||||
result.output_string()
|
result.output_string()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -16,6 +16,7 @@ struct StructureInput {
|
||||||
raw: String,
|
raw: String,
|
||||||
path: String,
|
path: String,
|
||||||
max_depth: usize,
|
max_depth: usize,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -127,6 +128,7 @@ impl Structure {
|
||||||
raw: raw.to_string(),
|
raw: raw.to_string(),
|
||||||
path: parsed.path.unwrap_or_else(|| ".".to_string()),
|
path: parsed.path.unwrap_or_else(|| ".".to_string()),
|
||||||
max_depth: parsed.max_depth.unwrap_or(3),
|
max_depth: parsed.max_depth.unwrap_or(3),
|
||||||
|
call_id: String::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,12 +146,13 @@ impl Tool for Structure {
|
||||||
"structure"
|
"structure"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
let parsed = Self::parse_input_json(trimmed);
|
let parsed = Self::parse_input_json(trimmed);
|
||||||
|
|
||||||
match parsed {
|
match parsed {
|
||||||
Ok(parsed) => {
|
Ok(mut parsed) => {
|
||||||
|
parsed.call_id = call_id;
|
||||||
*self.input.lock().unwrap() = Some(parsed);
|
*self.input.lock().unwrap() = Some(parsed);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -161,34 +164,36 @@ impl Tool for Structure {
|
||||||
let input = match self.load_input() {
|
let input = match self.load_input() {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let target_path = match Self::resolve_path(&input.path, request.project_root()) {
|
let target_path = match Self::resolve_path(&input.path, request.project_root()) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e),
|
Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
if !target_path.exists() {
|
if !target_path.exists() {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Path does not exist: {}", target_path.display()),
|
format!("Path does not exist: {}", target_path.display()),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if !target_path.is_dir() {
|
if !target_path.is_dir() {
|
||||||
return ToolResult::error(
|
return ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw.clone(),
|
||||||
format!("Path is not a directory: {}", target_path.display()),
|
format!("Path is not a directory: {}", target_path.display()),
|
||||||
|
input.call_id,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let tree = Self::build_tree(&target_path, input.max_depth);
|
let tree = Self::build_tree(&target_path, input.max_depth);
|
||||||
|
|
||||||
ToolResult::ok(self.name().to_string(), input.raw, tree)
|
ToolResult::ok(self.name().to_string(), input.raw, tree, input.call_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parameters(&self) -> serde_json::Value {
|
fn parameters(&self) -> serde_json::Value {
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ pub struct UpdateTodoList {
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct UpdateTodoListInput {
|
struct UpdateTodoListInput {
|
||||||
raw: String,
|
raw: String,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UpdateTodoList {
|
impl UpdateTodoList {
|
||||||
|
|
@ -36,12 +37,12 @@ impl Tool for UpdateTodoList {
|
||||||
"update_todo_list"
|
"update_todo_list"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||||
let parsed: Result<TodoList, _> = serde_json::from_str(&input);
|
let parsed: Result<TodoList, _> = serde_json::from_str(&input);
|
||||||
match parsed {
|
match parsed {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
let mut lock = self.input.lock().unwrap();
|
let mut lock = self.input.lock().unwrap();
|
||||||
*lock = Some(UpdateTodoListInput { raw: input });
|
*lock = Some(UpdateTodoListInput { raw: input, call_id });
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))),
|
Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))),
|
||||||
|
|
@ -57,6 +58,7 @@ impl Tool for UpdateTodoList {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
String::new(),
|
String::new(),
|
||||||
"No input provided".to_string(),
|
"No input provided".to_string(),
|
||||||
|
String::new(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -70,6 +72,7 @@ impl Tool for UpdateTodoList {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
format!("Failed to get database connection: {}", e),
|
format!("Failed to get database connection: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -83,6 +86,7 @@ impl Tool for UpdateTodoList {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
format!("Failed to get/create TODO list: {}", e),
|
format!("Failed to get/create TODO list: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -92,6 +96,7 @@ impl Tool for UpdateTodoList {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
format!("Failed to update TODO list: {}", e),
|
format!("Failed to update TODO list: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -106,6 +111,7 @@ impl Tool for UpdateTodoList {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw,
|
input.raw,
|
||||||
"TODO list updated successfully.".to_string(),
|
"TODO list updated successfully.".to_string(),
|
||||||
|
input.call_id,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ struct WebSearchInput {
|
||||||
raw: String,
|
raw: String,
|
||||||
query: String,
|
query: String,
|
||||||
max_results: u32,
|
max_results: u32,
|
||||||
|
call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
|
|
@ -94,7 +95,7 @@ impl Tool for WebSearch {
|
||||||
"web_search"
|
"web_search"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_input(&self, input: String) -> Option<crate::domain::tools::Error> {
|
fn parse_input(&self, input: String, call_id: String) -> Option<crate::domain::tools::Error> {
|
||||||
let trimmed = input.trim();
|
let trimmed = input.trim();
|
||||||
match serde_json::from_str::<WebSearchInputJson>(trimmed) {
|
match serde_json::from_str::<WebSearchInputJson>(trimmed) {
|
||||||
Ok(parsed) => {
|
Ok(parsed) => {
|
||||||
|
|
@ -103,6 +104,7 @@ impl Tool for WebSearch {
|
||||||
raw: input,
|
raw: input,
|
||||||
query: parsed.query,
|
query: parsed.query,
|
||||||
max_results,
|
max_results,
|
||||||
|
call_id,
|
||||||
};
|
};
|
||||||
*self.input.lock().unwrap() = Some(parsed_input);
|
*self.input.lock().unwrap() = Some(parsed_input);
|
||||||
None
|
None
|
||||||
|
|
@ -118,7 +120,7 @@ impl Tool for WebSearch {
|
||||||
let input = match self.load_input() {
|
let input = match self.load_input() {
|
||||||
Ok(input) => input,
|
Ok(input) => input,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return ToolResult::error(self.name().to_string(), String::new(), e);
|
return ToolResult::error(self.name().to_string(), String::new(), e, String::new());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -130,6 +132,7 @@ impl Tool for WebSearch {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
"User settings not available".to_string(),
|
"User settings not available".to_string(),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -139,6 +142,7 @@ impl Tool for WebSearch {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
"Web search is not enabled in settings".to_string(),
|
"Web search is not enabled in settings".to_string(),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -149,6 +153,7 @@ impl Tool for WebSearch {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
"Brave API key not configured".to_string(),
|
"Brave API key not configured".to_string(),
|
||||||
|
input.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -159,12 +164,13 @@ impl Tool for WebSearch {
|
||||||
let output = WebSearchOutput { results };
|
let output = WebSearchOutput { results };
|
||||||
match serde_json::to_string(&output) {
|
match serde_json::to_string(&output) {
|
||||||
Ok(json_output) => {
|
Ok(json_output) => {
|
||||||
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output)
|
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output, input.call_id)
|
||||||
}
|
}
|
||||||
Err(e) => ToolResult::error(
|
Err(e) => ToolResult::error(
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
format!("Failed to serialize output: {}", e),
|
format!("Failed to serialize output: {}", e),
|
||||||
|
input.call_id.clone(),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -172,6 +178,7 @@ impl Tool for WebSearch {
|
||||||
self.name().to_string(),
|
self.name().to_string(),
|
||||||
input.raw.clone(),
|
input.raw.clone(),
|
||||||
format!("Search failed: {}", e),
|
format!("Search failed: {}", e),
|
||||||
|
input.call_id,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -593,8 +600,10 @@ mod tests {
|
||||||
user_settings,
|
user_settings,
|
||||||
};
|
};
|
||||||
let web_search = WebSearch::new();
|
let web_search = WebSearch::new();
|
||||||
let _ =
|
let _ = web_search.parse_input(
|
||||||
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string());
|
r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
|
||||||
|
"call-id".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
let _allowed = checker.check(&web_search, &request, Some(1)).unwrap();
|
let _allowed = checker.check(&web_search, &request, Some(1)).unwrap();
|
||||||
|
|
||||||
|
|
@ -629,8 +638,10 @@ mod tests {
|
||||||
user_settings,
|
user_settings,
|
||||||
};
|
};
|
||||||
let web_search = WebSearch::new();
|
let web_search = WebSearch::new();
|
||||||
let _ =
|
let _ = web_search.parse_input(
|
||||||
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string());
|
r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
|
||||||
|
"call-id".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
let allowed = checker.check(&web_search, &request, Some(1)).unwrap();
|
let allowed = checker.check(&web_search, &request, Some(1)).unwrap();
|
||||||
|
|
||||||
|
|
@ -640,4 +651,4 @@ mod tests {
|
||||||
assert_eq!(created[0].tool_name, "web_search");
|
assert_eq!(created[0].tool_name, "web_search");
|
||||||
assert_eq!(created[0].resource_pattern, None);
|
assert_eq!(created[0].resource_pattern, None);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -17,6 +17,7 @@ pub struct Chain {
|
||||||
pub fail_reason: String,
|
pub fail_reason: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub final_message: Option<String>,
|
pub final_message: Option<String>,
|
||||||
|
pub system_prompt: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Chain {
|
impl Chain {
|
||||||
|
|
@ -28,6 +29,7 @@ impl Chain {
|
||||||
is_failed: false,
|
is_failed: false,
|
||||||
fail_reason: String::new(),
|
fail_reason: String::new(),
|
||||||
final_message: None,
|
final_message: None,
|
||||||
|
system_prompt: String::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -120,6 +122,10 @@ impl Chain {
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join("\n")
|
.join("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_system_prompt(&mut self, system_prompt: String) {
|
||||||
|
self.system_prompt = system_prompt;
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn total_payload_len_chars(&self) -> usize {
|
pub fn total_payload_len_chars(&self) -> usize {
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@ pub struct ChainStep {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub tool_name: Option<String>,
|
pub tool_name: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
pub call_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
pub tool_output: Option<String>,
|
pub tool_output: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub is_successful: Option<bool>,
|
pub is_successful: Option<bool>,
|
||||||
|
|
@ -60,6 +62,7 @@ impl ChainStep {
|
||||||
let mut tool_output = None;
|
let mut tool_output = None;
|
||||||
let mut is_successful = None;
|
let mut is_successful = None;
|
||||||
let mut file_changes = None;
|
let mut file_changes = None;
|
||||||
|
let mut call_id = None;
|
||||||
if let Some(tr) = tool_result {
|
if let Some(tr) = tool_result {
|
||||||
summary = tr.summary();
|
summary = tr.summary();
|
||||||
context_payload = tr.output_string();
|
context_payload = tr.output_string();
|
||||||
|
|
@ -68,6 +71,7 @@ impl ChainStep {
|
||||||
tool_output = Some(tr.output_raw().to_string());
|
tool_output = Some(tr.output_raw().to_string());
|
||||||
is_successful = Some(tr.is_successful());
|
is_successful = Some(tr.is_successful());
|
||||||
file_changes = tr.file_changes().map(|fc| fc.to_vec());
|
file_changes = tr.file_changes().map(|fc| fc.to_vec());
|
||||||
|
call_id = Some(tr.call_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -76,6 +80,7 @@ impl ChainStep {
|
||||||
context_payload,
|
context_payload,
|
||||||
input_payload,
|
input_payload,
|
||||||
tool_name,
|
tool_name,
|
||||||
|
call_id,
|
||||||
tool_output,
|
tool_output,
|
||||||
is_successful,
|
is_successful,
|
||||||
file_changes,
|
file_changes,
|
||||||
|
|
@ -91,6 +96,7 @@ impl ChainStep {
|
||||||
context_payload: raw_output.clone(),
|
context_payload: raw_output.clone(),
|
||||||
input_payload: String::new(),
|
input_payload: String::new(),
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
call_id: None,
|
||||||
tool_output: Some(raw_output),
|
tool_output: Some(raw_output),
|
||||||
is_successful: Some(true),
|
is_successful: Some(true),
|
||||||
file_changes: None,
|
file_changes: None,
|
||||||
|
|
@ -118,6 +124,7 @@ impl ChainStep {
|
||||||
context_payload: prompt.clone(),
|
context_payload: prompt.clone(),
|
||||||
input_payload: prompt,
|
input_payload: prompt,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
call_id: None,
|
||||||
tool_output: None,
|
tool_output: None,
|
||||||
is_successful: Some(true),
|
is_successful: Some(true),
|
||||||
file_changes: None,
|
file_changes: None,
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ impl ToolRunner {
|
||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
tool_call.arguments.clone(),
|
tool_call.arguments.clone(),
|
||||||
error_msg,
|
error_msg,
|
||||||
|
tool_call.call_id.clone(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -73,11 +74,13 @@ impl ToolRunner {
|
||||||
tool.name().to_string(),
|
tool.name().to_string(),
|
||||||
String::new(),
|
String::new(),
|
||||||
"Permission denied".to_string(),
|
"Permission denied".to_string(),
|
||||||
|
String::new(),
|
||||||
),
|
),
|
||||||
Err(err) => ToolResult::error(
|
Err(err) => ToolResult::error(
|
||||||
tool.name().to_string(),
|
tool.name().to_string(),
|
||||||
String::new(),
|
String::new(),
|
||||||
format!("Permission check error: {}", err),
|
format!("Permission check error: {}", err),
|
||||||
|
String::new(),
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ pub trait Toolset {
|
||||||
.map(|t| t.as_ref())
|
.map(|t| t.as_ref())
|
||||||
.ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?;
|
.ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?;
|
||||||
|
|
||||||
if let Some(err) = tool.parse_input(tool_call.arguments.clone()) {
|
if let Some(err) = tool.parse_input(tool_call.arguments.clone(), tool_call.call_id.clone()) {
|
||||||
return Err(err);
|
return Err(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ use super::{
|
||||||
use crate::domain::bt::GeneralTree;
|
use crate::domain::bt::GeneralTree;
|
||||||
use crate::domain::prompting;
|
use crate::domain::prompting;
|
||||||
use crate::domain::session::Request;
|
use crate::domain::session::Request;
|
||||||
|
use crate::domain::todo::TodoListStatus;
|
||||||
use crate::infrastructure::db::DbPool;
|
use crate::infrastructure::db::DbPool;
|
||||||
use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase};
|
use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase};
|
||||||
use crate::infrastructure::inference::InferenceEngine;
|
use crate::infrastructure::inference::InferenceEngine;
|
||||||
|
|
@ -16,7 +17,6 @@ use std::sync::Arc;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use crossbeam_channel::Sender;
|
use crossbeam_channel::Sender;
|
||||||
use crate::domain::AgentModeType;
|
use crate::domain::AgentModeType;
|
||||||
use crate::domain::todo::TodoListStatus;
|
|
||||||
|
|
||||||
/// Main workflow orchestrator that runs LLM-driven coding tasks
|
/// Main workflow orchestrator that runs LLM-driven coding tasks
|
||||||
/// Implements an eternal agent loop that:
|
/// Implements an eternal agent loop that:
|
||||||
|
|
@ -94,7 +94,6 @@ impl Workflow {
|
||||||
self.chain.set_todo_list(request.get_session_plan());
|
self.chain.set_todo_list(request.get_session_plan());
|
||||||
|
|
||||||
let result = self._run(request, cancel, None, None, mode);
|
let result = self._run(request, cancel, None, None, mode);
|
||||||
self.chain.steps.push(ChainStep::user_message(request.current_request().to_string(), request.images().to_vec()));
|
|
||||||
self._end_tracing();
|
self._end_tracing();
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
@ -145,6 +144,7 @@ impl Workflow {
|
||||||
context_payload: String::new(),
|
context_payload: String::new(),
|
||||||
input_payload: step_user_prompt.to_string(),
|
input_payload: step_user_prompt.to_string(),
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
call_id: None,
|
||||||
tool_output: None,
|
tool_output: None,
|
||||||
is_successful: Some(true),
|
is_successful: Some(true),
|
||||||
file_changes: None,
|
file_changes: None,
|
||||||
|
|
@ -171,6 +171,13 @@ impl Workflow {
|
||||||
mode: AgentModeType,
|
mode: AgentModeType,
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
|
|
||||||
|
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
|
||||||
|
let user_prompt = user_prompt_override
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or(&base_user_prompt)
|
||||||
|
.to_string();
|
||||||
|
self.chain.steps.push(ChainStep::user_message(user_prompt.clone(), request.images().to_vec()));
|
||||||
|
|
||||||
// Get max_tool_calls from override (for BT mode) or user settings
|
// Get max_tool_calls from override (for BT mode) or user settings
|
||||||
let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| {
|
let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| {
|
||||||
request.user_settings()
|
request.user_settings()
|
||||||
|
|
@ -198,11 +205,7 @@ impl Workflow {
|
||||||
|
|
||||||
// Get base system prompt and inject remaining count
|
// Get base system prompt and inject remaining count
|
||||||
let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls);
|
let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls);
|
||||||
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
|
self.chain.set_system_prompt(system_prompt.clone());
|
||||||
let user_prompt = user_prompt_override
|
|
||||||
.as_deref()
|
|
||||||
.unwrap_or(&base_user_prompt)
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
// Switch to finishing toolset if approaching limit
|
// Switch to finishing toolset if approaching limit
|
||||||
if !in_finishing_mode && tool_call_count >= finishing_threshold {
|
if !in_finishing_mode && tool_call_count >= finishing_threshold {
|
||||||
|
|
@ -220,13 +223,9 @@ impl Workflow {
|
||||||
}
|
}
|
||||||
|
|
||||||
self._emit_inference_progress();
|
self._emit_inference_progress();
|
||||||
self._trace_llm_start(user_prompt.clone());
|
|
||||||
|
|
||||||
// Ask LLM to choose next tool
|
// Ask LLM to choose next tool
|
||||||
let llm_output = match self.engine.generate(
|
let llm_output = match self.engine.generate(
|
||||||
&system_prompt,
|
|
||||||
&user_prompt,
|
|
||||||
1024,
|
|
||||||
&self.toolset.tool_refs(),
|
&self.toolset.tool_refs(),
|
||||||
&self.chain,
|
&self.chain,
|
||||||
request.images(),
|
request.images(),
|
||||||
|
|
@ -238,8 +237,6 @@ impl Workflow {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self._trace_llm_end(llm_output.raw_output.clone());
|
|
||||||
|
|
||||||
// Always capture the assistant's response in the chain
|
// Always capture the assistant's response in the chain
|
||||||
if !llm_output.raw_output.is_empty() {
|
if !llm_output.raw_output.is_empty() {
|
||||||
self.chain.steps.push(ChainStep::assistant_response(
|
self.chain.steps.push(ChainStep::assistant_response(
|
||||||
|
|
@ -347,13 +344,6 @@ impl Workflow {
|
||||||
self.chain = Chain::new();
|
self.chain = Chain::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _trace_llm_start(&mut self, prompt: String) {
|
|
||||||
if let Some(tracer) = &mut self.tracer {
|
|
||||||
tracer.start_span("LLM generation", SpanKind::Generation);
|
|
||||||
tracer.add_input("LLM generation", prompt);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _emit_inference_progress(&self) {
|
fn _emit_inference_progress(&self) {
|
||||||
let options = [
|
let options = [
|
||||||
"Thinking.. well kinda..",
|
"Thinking.. well kinda..",
|
||||||
|
|
@ -383,18 +373,10 @@ impl Workflow {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _trace_llm_end(&mut self, output: String) {
|
|
||||||
if let Some(tracer) = &mut self.tracer {
|
|
||||||
tracer.add_output("LLM generation", output);
|
|
||||||
tracer.end_span("LLM generation");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Detects if the active TODO item has changed
|
/// Detects if the active TODO item has changed
|
||||||
/// Returns true if the first non-completed item's title is different
|
/// Returns true if the first non-completed item's title is different
|
||||||
fn _did_todo_item_change(&self, previous_todo: &Option<crate::domain::todo::TodoList>) -> bool {
|
fn _did_todo_item_change(&self, previous_todo: &Option<crate::domain::todo::TodoList>) -> bool {
|
||||||
use crate::domain::todo::TodoListStatus;
|
|
||||||
|
|
||||||
// Get first non-completed item from previous TODO list
|
// Get first non-completed item from previous TODO list
|
||||||
let prev_active = previous_todo.as_ref().and_then(|list| {
|
let prev_active = previous_todo.as_ref().and_then(|list| {
|
||||||
list.items.iter()
|
list.items.iter()
|
||||||
|
|
|
||||||
|
|
@ -64,8 +64,6 @@ impl OpenAIClient {
|
||||||
|
|
||||||
pub fn call_responses_api(
|
pub fn call_responses_api(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
tools: &[&dyn crate::domain::tools::Tool],
|
tools: &[&dyn crate::domain::tools::Tool],
|
||||||
chain: &crate::domain::workflow::Chain,
|
chain: &crate::domain::workflow::Chain,
|
||||||
images: &[String],
|
images: &[String],
|
||||||
|
|
@ -74,8 +72,6 @@ impl OpenAIClient {
|
||||||
let max_attempts = 3;
|
let max_attempts = 3;
|
||||||
for attempt in 1..=max_attempts {
|
for attempt in 1..=max_attempts {
|
||||||
match self.call_responses_api_inner(
|
match self.call_responses_api_inner(
|
||||||
system_prompt,
|
|
||||||
user_prompt,
|
|
||||||
tools,
|
tools,
|
||||||
chain,
|
chain,
|
||||||
images,
|
images,
|
||||||
|
|
@ -121,25 +117,31 @@ impl OpenAIClient {
|
||||||
|
|
||||||
fn call_responses_api_inner(
|
fn call_responses_api_inner(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
tools: &[&dyn crate::domain::tools::Tool],
|
tools: &[&dyn crate::domain::tools::Tool],
|
||||||
chain: &crate::domain::workflow::Chain,
|
chain: &crate::domain::workflow::Chain,
|
||||||
images: &[String],
|
images: &[String],
|
||||||
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||||
) -> Result<LLMInferenceResult, Box<dyn Error + Send + Sync>> {
|
) -> Result<LLMInferenceResult, Box<dyn Error + Send + Sync>> {
|
||||||
let url = "https://api.openai.com/v1/responses";
|
let url = "https://api.openai.com/v1/responses";
|
||||||
|
|
||||||
let request_body = build_request_dto(
|
let request_body = build_request_dto(
|
||||||
&self.model,
|
&self.model,
|
||||||
system_prompt,
|
|
||||||
user_prompt,
|
|
||||||
images,
|
images,
|
||||||
tools,
|
tools,
|
||||||
chain,
|
chain,
|
||||||
tracer,
|
tracer.as_deref_mut(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Start span with model name and add request as JSON
|
||||||
|
if let Some(tracer) = &mut tracer {
|
||||||
|
tracer.start_span(&self.model, openai_agents_tracing::SpanKind::Generation);
|
||||||
|
|
||||||
|
// Convert request_body to JSON Value and set as input
|
||||||
|
if let Ok(request_json) = serde_json::to_value(&request_body) {
|
||||||
|
tracer.set_input_json(&self.model, request_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post(url)
|
.post(url)
|
||||||
|
|
@ -151,18 +153,35 @@ impl OpenAIClient {
|
||||||
let body = response.text()?;
|
let body = response.text()?;
|
||||||
|
|
||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
|
if let Some(t) = tracer {
|
||||||
|
t.end_span(&self.model);
|
||||||
|
}
|
||||||
return Err(Box::new(OpenAIClientError::Api { status, body }));
|
return Err(Box::new(OpenAIClientError::Api { status, body }));
|
||||||
}
|
}
|
||||||
|
|
||||||
let dto = match serde_json::from_str::<ResponseDTO>(&body) {
|
let dto = match serde_json::from_str::<ResponseDTO>(&body) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(e) => return Err(Box::new(OpenAIClientError::Deserialize { source: e, body })),
|
Err(e) => {
|
||||||
|
if let Some(t) = tracer {
|
||||||
|
t.end_span(&self.model);
|
||||||
|
}
|
||||||
|
return Err(Box::new(OpenAIClientError::Deserialize { source: e, body }));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Add response as JSON and end span
|
||||||
|
if let Some(tracer) = &mut tracer {
|
||||||
|
if let Ok(response_json) = serde_json::to_value(&dto) {
|
||||||
|
tracer.set_output_json(&self.model, response_json);
|
||||||
|
}
|
||||||
|
tracer.end_span(&self.model);
|
||||||
|
}
|
||||||
|
|
||||||
let result = build_llm_result(dto, tools);
|
let result = build_llm_result(dto, tools);
|
||||||
if result.summary.is_empty() && result.tool_call.is_none() {
|
if result.summary.is_empty() && result.tool_call.is_none() {
|
||||||
return Err(Box::new(OpenAIClientError::NoText { body }));
|
return Err(Box::new(OpenAIClientError::NoText { body }));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,13 @@ use crate::domain::{Chain, ModelType};
|
||||||
use crate::domain::prompting::format_todo_list_message;
|
use crate::domain::prompting::format_todo_list_message;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use crate::domain::workflow::step::StepType::{AssistantResponse, ToolCall, UserMessage};
|
use crate::domain::workflow::step::StepType;
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct RequestDTO {
|
pub struct RequestDTO {
|
||||||
model: String,
|
model: String,
|
||||||
instructions: String,
|
instructions: String,
|
||||||
input: Vec<InputDto>,
|
input: Vec<InputMessageDto>,
|
||||||
tools: Vec<ToolDto>,
|
tools: Vec<ToolDto>,
|
||||||
tool_choice: String,
|
tool_choice: String,
|
||||||
parallel_tool_calls: bool,
|
parallel_tool_calls: bool,
|
||||||
|
|
@ -18,12 +18,37 @@ pub struct RequestDTO {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub(super) struct InputDto {
|
pub(super) struct MessageDto {
|
||||||
content: Vec<InputContent>,
|
content: Vec<InputContent>,
|
||||||
role: String,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
role: Option<String>,
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
kind: String,
|
kind: String,
|
||||||
status: String,
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub(super) struct FunctionOutputDto {
|
||||||
|
output: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
call_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub(super) struct FunctionCallDto {
|
||||||
|
arguments: String,
|
||||||
|
name: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
call_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum InputMessageDto {
|
||||||
|
Message(MessageDto),
|
||||||
|
FunctionOutput(FunctionOutputDto),
|
||||||
|
FunctionCall(FunctionCallDto),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
|
@ -63,10 +88,25 @@ impl InputContent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn function(text: String) -> Self {
|
}
|
||||||
Self::Text {
|
|
||||||
kind: "function".to_string(),
|
impl FunctionOutputDto {
|
||||||
text,
|
pub fn new(output: String, call_id: String) -> Self {
|
||||||
|
Self {
|
||||||
|
kind: "function_call_output".to_string(),
|
||||||
|
output,
|
||||||
|
call_id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FunctionCallDto {
|
||||||
|
pub fn new(name: String, arguments: String, call_id: String) -> Self {
|
||||||
|
Self {
|
||||||
|
name,
|
||||||
|
kind: "function_call".to_string(),
|
||||||
|
arguments,
|
||||||
|
call_id,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -83,28 +123,23 @@ pub(super) struct ToolDto {
|
||||||
const ROLE_USER: &str = "user";
|
const ROLE_USER: &str = "user";
|
||||||
const ROLE_SYSTEM: &str = "system";
|
const ROLE_SYSTEM: &str = "system";
|
||||||
const ROLE_ASSISTANT: &str = "assistant";
|
const ROLE_ASSISTANT: &str = "assistant";
|
||||||
const INPUT_STATUS_IN_PROGRESS: &str = "in_progress";
|
|
||||||
const INPUT_STATUS_COMPLETED: &str = "completed";
|
|
||||||
const INPUT_STATUS_FAILED: &str = "failed";
|
|
||||||
|
|
||||||
impl RequestDTO {
|
impl RequestDTO {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
model: String,
|
model: String,
|
||||||
system_prompt: String,
|
|
||||||
user_prompt: String,
|
|
||||||
tools: &[&dyn Tool],
|
tools: &[&dyn Tool],
|
||||||
chain: &crate::domain::workflow::Chain,
|
chain: &Chain,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// User request is now part of the chain, no need to add separately
|
// User request is now part of the chain, no need to add separately
|
||||||
let input = InputDto::build(user_prompt, chain);
|
let input = MessageDto::build(chain);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
instructions: system_prompt,
|
instructions: chain.system_prompt.clone(),
|
||||||
input,
|
input,
|
||||||
tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(),
|
tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(),
|
||||||
tool_choice: "auto".to_string(),
|
tool_choice: "auto".to_string(),
|
||||||
parallel_tool_calls: true,
|
parallel_tool_calls: false,
|
||||||
store: false,
|
store: false,
|
||||||
stream: false,
|
stream: false,
|
||||||
}
|
}
|
||||||
|
|
@ -123,50 +158,46 @@ impl ToolDto {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InputDto {
|
impl MessageDto {
|
||||||
fn build(user_prompt: String, chain: &Chain) -> Vec<Self> {
|
fn build(chain: &Chain) -> Vec<InputMessageDto> {
|
||||||
let steps = chain.get_steps_with_history();
|
let steps = chain.get_steps_with_history();
|
||||||
|
|
||||||
let mut result: Vec<Self> = steps
|
let mut result: Vec<InputMessageDto> = Vec::new();
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(idx, step)| {
|
|
||||||
// Determine status
|
|
||||||
let is_user_message = step.step_type == UserMessage.as_str();
|
|
||||||
|
|
||||||
let status = if is_user_message || step.is_successful.unwrap_or(false) {
|
for step in steps.iter() {
|
||||||
INPUT_STATUS_COMPLETED
|
let is_user_message = step.step_type == StepType::UserMessage.as_str();
|
||||||
} else {
|
|
||||||
INPUT_STATUS_FAILED
|
|
||||||
};
|
|
||||||
let mut role = ROLE_ASSISTANT.to_string();
|
|
||||||
|
|
||||||
// Build content items
|
if is_user_message {
|
||||||
let content_items = if is_user_message {
|
// User message: text + optional images
|
||||||
role = ROLE_USER.to_string();
|
|
||||||
// For user messages, include text and images
|
|
||||||
let mut items = vec![InputContent::text(step.input_payload.clone())];
|
let mut items = vec![InputContent::text(step.input_payload.clone())];
|
||||||
|
|
||||||
// Add image content items if present
|
|
||||||
if let Some(ref images) = step.images {
|
if let Some(ref images) = step.images {
|
||||||
for image_url in images {
|
for image_url in images {
|
||||||
items.push(InputContent::image(image_url.clone()));
|
items.push(InputContent::image(image_url.clone()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
items
|
result.push(InputMessageDto::Message(Self {
|
||||||
} else {
|
content: items,
|
||||||
vec![InputContent::output_text(step.get_output(ModelType::OpenAI))]
|
role: Some(ROLE_USER.to_string()),
|
||||||
};
|
kind: "message".to_string(),
|
||||||
|
}));
|
||||||
|
} else if step.step_type == StepType::ToolCall.as_str() {
|
||||||
|
|
||||||
Self {
|
let tool_name = step.tool_name.clone().unwrap();
|
||||||
content: content_items,
|
let call_id = step.call_id.clone().unwrap();
|
||||||
role,
|
// Tool call output is a separate DTO type
|
||||||
kind: "message".to_string(),
|
result.push(InputMessageDto::FunctionCall(FunctionCallDto::new(tool_name, step.input_payload.clone(), call_id.clone())));
|
||||||
status: status.to_string(),
|
result.push(InputMessageDto::FunctionOutput(FunctionOutputDto::new(step.get_output(ModelType::OpenAI), call_id)));
|
||||||
|
} else {
|
||||||
|
// Assistant message
|
||||||
|
result.push(InputMessageDto::Message(Self {
|
||||||
|
content: vec![InputContent::output_text(step.get_output(ModelType::OpenAI))],
|
||||||
|
role: Some(ROLE_ASSISTANT.to_string()),
|
||||||
|
kind: "message".to_string(),
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Add the plan as system message at the beginning if it exists and is not completed
|
// Add the plan as system message at the beginning if it exists and is not completed
|
||||||
if let Some(ref todo_list) = chain.todo_list {
|
if let Some(ref todo_list) = chain.todo_list {
|
||||||
|
|
@ -178,22 +209,15 @@ impl InputDto {
|
||||||
|
|
||||||
let todo_input = Self {
|
let todo_input = Self {
|
||||||
content: vec![InputContent::text(todo_message)],
|
content: vec![InputContent::text(todo_message)],
|
||||||
role: ROLE_SYSTEM.to_string(),
|
role: Some(ROLE_SYSTEM.to_string()),
|
||||||
kind: "message".to_string(),
|
kind: "message".to_string(),
|
||||||
status: INPUT_STATUS_COMPLETED.to_string(),
|
|
||||||
};
|
};
|
||||||
result.push(todo_input);
|
|
||||||
|
// Put it first
|
||||||
|
result.insert(0, InputMessageDto::Message(todo_input));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// and adding the current user message at the end
|
|
||||||
result.push(Self {
|
|
||||||
content: vec![InputContent::text(user_prompt)],
|
|
||||||
role: ROLE_USER.to_string(),
|
|
||||||
kind: "message".to_string(),
|
|
||||||
status: INPUT_STATUS_IN_PROGRESS.to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1,20 +1,21 @@
|
||||||
use serde::Deserialize;
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
pub struct ResponseDTO {
|
pub struct ResponseDTO {
|
||||||
output: Vec<OpenAIOutputItem>,
|
output: Vec<OpenAIOutputItem>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
struct OpenAIOutputItem {
|
struct OpenAIOutputItem {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
kind: String,
|
kind: String,
|
||||||
content: Option<Vec<OpenAIContentItem>>,
|
content: Option<Vec<OpenAIContentItem>>,
|
||||||
name: Option<String>,
|
name: Option<String>,
|
||||||
arguments: Option<String>,
|
arguments: Option<String>,
|
||||||
|
call_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
struct OpenAIContentItem {
|
struct OpenAIContentItem {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
kind: String,
|
kind: String,
|
||||||
|
|
@ -41,11 +42,12 @@ impl ResponseDTO {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if item.kind == "function_call" {
|
} else if item.kind == "function_call" {
|
||||||
if let (Some(name), Some(arguments)) = (item.name.as_ref(), item.arguments.as_ref())
|
if let (Some(name), Some(arguments), Some(call_id)) = (item.name.as_ref(), item.arguments.as_ref(), item.call_id.as_ref())
|
||||||
{
|
{
|
||||||
tool_call = Some(FunctionCall {
|
tool_call = Some(FunctionCall {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
arguments: arguments.to_string(),
|
arguments: arguments.to_string(),
|
||||||
|
call_id: call_id.to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -59,4 +61,5 @@ impl ResponseDTO {
|
||||||
pub(crate) struct FunctionCall {
|
pub(crate) struct FunctionCall {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub arguments: String,
|
pub arguments: String,
|
||||||
|
pub call_id: String,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,6 @@ use crate::infrastructure::inference::{LLMInferenceResult, ToolCall};
|
||||||
|
|
||||||
pub fn build_request_dto(
|
pub fn build_request_dto(
|
||||||
model: &str,
|
model: &str,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
images: &[String],
|
images: &[String],
|
||||||
tools: &[&dyn crate::domain::tools::Tool],
|
tools: &[&dyn crate::domain::tools::Tool],
|
||||||
chain: &crate::domain::workflow::Chain,
|
chain: &crate::domain::workflow::Chain,
|
||||||
|
|
@ -25,8 +23,6 @@ pub fn build_request_dto(
|
||||||
|
|
||||||
let dto = RequestDTO::new(
|
let dto = RequestDTO::new(
|
||||||
model.to_string(),
|
model.to_string(),
|
||||||
system_prompt.to_string(),
|
|
||||||
user_prompt.to_string(),
|
|
||||||
tools,
|
tools,
|
||||||
chain,
|
chain,
|
||||||
);
|
);
|
||||||
|
|
@ -54,9 +50,10 @@ pub fn build_llm_result(
|
||||||
|
|
||||||
let tool_call = if let Some(call) = tool_call_dto {
|
let tool_call = if let Some(call) = tool_call_dto {
|
||||||
// Create ToolCall for the workflow to use
|
// Create ToolCall for the workflow to use
|
||||||
Some(ToolCall {
|
Some(ToolCall{
|
||||||
name: call.name.clone(),
|
name: call.name.clone(),
|
||||||
arguments: call.arguments.clone(),
|
arguments: call.arguments.clone(),
|
||||||
|
call_id: call.call_id.clone(),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
|
||||||
|
|
@ -51,4 +51,4 @@ pub fn render(frame: &mut Frame, area: Rect, state: &UiState, theme: Theme) {
|
||||||
let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line);
|
let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line);
|
||||||
frame.set_cursor(x, y);
|
frame.set_cursor(x, y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -159,6 +159,7 @@ fn handle_agent_event(
|
||||||
});
|
});
|
||||||
state.request_status = None;
|
state.request_status = None;
|
||||||
state.request_progress = None;
|
state.request_progress = None;
|
||||||
|
state.last_progress_update = None; // Reset timer for new request
|
||||||
state.file_changes = None;
|
state.file_changes = None;
|
||||||
|
|
||||||
// Display the user's message
|
// Display the user's message
|
||||||
|
|
@ -175,7 +176,22 @@ fn handle_agent_event(
|
||||||
summary,
|
summary,
|
||||||
} => {
|
} => {
|
||||||
if matches!(phase, StepPhase::Before) {
|
if matches!(phase, StepPhase::Before) {
|
||||||
state.request_progress = Some(summary.clone());
|
// Check if enough time has passed since the last progress update
|
||||||
|
let can_update = state
|
||||||
|
.last_progress_update
|
||||||
|
.map(|last| {
|
||||||
|
std::time::Instant::now()
|
||||||
|
.duration_since(last)
|
||||||
|
.as_millis()
|
||||||
|
>= crate::infrastructure::cli::state::MIN_PROGRESS_DISPLAY_MS
|
||||||
|
})
|
||||||
|
.unwrap_or(true); // Update immediately if no previous update
|
||||||
|
|
||||||
|
if can_update {
|
||||||
|
state.request_progress = Some(summary.clone());
|
||||||
|
state.last_progress_update = Some(std::time::Instant::now());
|
||||||
|
}
|
||||||
|
// Otherwise skip this update - too soon since last one
|
||||||
} else {
|
} else {
|
||||||
state.request_progress = None;
|
state.request_progress = None;
|
||||||
}
|
}
|
||||||
|
|
@ -207,6 +223,7 @@ fn handle_agent_event(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
state.request_progress = None;
|
state.request_progress = None;
|
||||||
|
state.last_progress_update = None; // Reset timer when request finishes
|
||||||
|
|
||||||
// Update the user message color based on request status
|
// Update the user message color based on request status
|
||||||
for entry in state.progress.iter_mut() {
|
for entry in state.progress.iter_mut() {
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,8 @@ impl AttachedImage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const INPUT_MIN_HEIGHT: usize = 3;
|
pub const INPUT_MIN_HEIGHT: usize = 1;
|
||||||
pub const INPUT_MAX_HEIGHT: usize = 6;
|
pub const INPUT_MAX_HEIGHT: usize = 5;
|
||||||
pub const PROGRESS_HISTORY_LIMIT: usize = 200;
|
pub const PROGRESS_HISTORY_LIMIT: usize = 200;
|
||||||
pub const MAIN_BODY_SCROLL_STEP: usize = 3;
|
pub const MAIN_BODY_SCROLL_STEP: usize = 3;
|
||||||
|
|
||||||
|
|
@ -51,6 +51,7 @@ pub enum ProgressKind {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2);
|
pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2);
|
||||||
|
pub const MIN_PROGRESS_DISPLAY_MS: u128 = 1000; // 1 second minimum display time
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
|
@ -260,6 +261,7 @@ pub struct UiState {
|
||||||
pub request_in_flight: Option<RequestIndicator>,
|
pub request_in_flight: Option<RequestIndicator>,
|
||||||
pub request_status: Option<RequestStatusDisplay>,
|
pub request_status: Option<RequestStatusDisplay>,
|
||||||
pub request_progress: Option<String>,
|
pub request_progress: Option<String>,
|
||||||
|
pub last_progress_update: Option<Instant>,
|
||||||
pub file_changes: Option<FileChangesDisplay>,
|
pub file_changes: Option<FileChangesDisplay>,
|
||||||
pub agent_mode: AgentModeType,
|
pub agent_mode: AgentModeType,
|
||||||
pub todo_list: Option<TodoListDisplay>,
|
pub todo_list: Option<TodoListDisplay>,
|
||||||
|
|
@ -300,6 +302,7 @@ impl UiState {
|
||||||
request_in_flight: None,
|
request_in_flight: None,
|
||||||
request_status: None,
|
request_status: None,
|
||||||
request_progress: None,
|
request_progress: None,
|
||||||
|
last_progress_update: None,
|
||||||
file_changes: None,
|
file_changes: None,
|
||||||
agent_mode: AgentModeType::Build, // Default to build mode
|
agent_mode: AgentModeType::Build, // Default to build mode
|
||||||
todo_list: None,
|
todo_list: None,
|
||||||
|
|
@ -353,4 +356,4 @@ impl UiState {
|
||||||
Some(crate::domain::ModelType::OpenAI)
|
Some(crate::domain::ModelType::OpenAI)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -31,4 +31,4 @@ impl Theme {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1);
|
pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1);
|
||||||
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);
|
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);
|
||||||
|
|
@ -6,7 +6,7 @@ use crate::infrastructure::cli::helpers::{
|
||||||
centered_rect, cursor_position, list_state, panel_block,
|
centered_rect, cursor_position, list_state, panel_block,
|
||||||
};
|
};
|
||||||
use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState};
|
use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState};
|
||||||
use crate::infrastructure::cli::theme::{Theme, PANEL_PADDING};
|
use crate::infrastructure::cli::theme::{Theme, INPUT_PADDING, PANEL_PADDING};
|
||||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
|
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
|
||||||
use ratatui::style::{Modifier, Style};
|
use ratatui::style::{Modifier, Style};
|
||||||
use ratatui::text::{Line, Span, Text};
|
use ratatui::text::{Line, Span, Text};
|
||||||
|
|
@ -28,7 +28,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
|
||||||
input_lines,
|
input_lines,
|
||||||
crate::infrastructure::cli::state::INPUT_MAX_HEIGHT,
|
crate::infrastructure::cli::state::INPUT_MAX_HEIGHT,
|
||||||
);
|
);
|
||||||
let mut input_box_height = (input_lines + 2) as u16;
|
let mut input_box_height = (input_lines + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16;
|
||||||
|
|
||||||
let indicator_height = 1u16;
|
let indicator_height = 1u16;
|
||||||
let attachment_indicator_height = if state.attached_images.is_empty() {
|
let attachment_indicator_height = if state.attached_images.is_empty() {
|
||||||
|
|
@ -45,7 +45,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
|
||||||
let overflow = fixed_height - size.height;
|
let overflow = fixed_height - size.height;
|
||||||
let reduced = input_box_height.saturating_sub(overflow);
|
let reduced = input_box_height.saturating_sub(overflow);
|
||||||
input_box_height =
|
input_box_height =
|
||||||
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + 2) as u16);
|
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16);
|
||||||
}
|
}
|
||||||
|
|
||||||
let constraints = if state.attached_images.is_empty() {
|
let constraints = if state.attached_images.is_empty() {
|
||||||
|
|
@ -557,4 +557,4 @@ pub fn openai_available_filtered(state: &UiState, filter: &str) -> Vec<String> {
|
||||||
.filter(|name| name.to_lowercase().contains(&query))
|
.filter(|name| name.to_lowercase().contains(&query))
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
@ -89,17 +89,39 @@ impl LocalEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InferenceEngine for LocalEngine {
|
impl InferenceEngine for LocalEngine {
|
||||||
|
// so far very fucked up implementation, even user_prompt is not properly passed
|
||||||
|
// need to be refactored using proper request builder same way as for openai inference
|
||||||
fn generate(
|
fn generate(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
max_tokens: u32,
|
|
||||||
_tools: &[&dyn crate::domain::tools::Tool],
|
_tools: &[&dyn crate::domain::tools::Tool],
|
||||||
_chain: &crate::domain::workflow::Chain,
|
chain: &crate::domain::workflow::Chain,
|
||||||
_images: &[String],
|
_images: &[String],
|
||||||
_tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||||
) -> Result<LLMInferenceResult, InfaError> {
|
) -> Result<LLMInferenceResult, InfaError> {
|
||||||
let prompt = format!("{}\n\n{}", system_prompt, user_prompt);
|
let model_name = "local";
|
||||||
|
|
||||||
|
// Start span with model name and add request as JSON
|
||||||
|
if let Some(tracer) = &mut tracer {
|
||||||
|
tracer.start_span(model_name, openai_agents_tracing::SpanKind::Generation);
|
||||||
|
|
||||||
|
// Set request as structured JSON input
|
||||||
|
let request_json = serde_json::json!({
|
||||||
|
"system_prompt": chain.system_prompt.clone(),
|
||||||
|
"user_prompt": chain.get_steps_with_history()[0].context_payload.clone(),
|
||||||
|
"max_tokens": 1000,
|
||||||
|
});
|
||||||
|
tracer.set_input_json(model_name, request_json);
|
||||||
|
|
||||||
|
// Set model configuration
|
||||||
|
let mut model_config = std::collections::HashMap::new();
|
||||||
|
model_config.insert("temperature".to_string(), serde_json::json!(self.params.temperature));
|
||||||
|
model_config.insert("top_p".to_string(), serde_json::json!(self.params.top_p));
|
||||||
|
model_config.insert("ctx_size".to_string(), serde_json::json!(self.params.ctx_size));
|
||||||
|
model_config.insert("threads".to_string(), serde_json::json!(self.params.threads));
|
||||||
|
tracer.set_model_config(model_name, model_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
let prompt = format!("{}\n\n{}", chain.system_prompt.clone(), chain.get_steps_with_history()[0].context_payload.clone());
|
||||||
let to_error = |msg: String| -> InfaError {
|
let to_error = |msg: String| -> InfaError {
|
||||||
std::io::Error::new(std::io::ErrorKind::Other, msg).into()
|
std::io::Error::new(std::io::ErrorKind::Other, msg).into()
|
||||||
};
|
};
|
||||||
|
|
@ -145,7 +167,7 @@ impl InferenceEngine for LocalEngine {
|
||||||
let mut output = String::new();
|
let mut output = String::new();
|
||||||
let mut n_cur = tokens.len();
|
let mut n_cur = tokens.len();
|
||||||
|
|
||||||
for _ in 0..max_tokens {
|
for _ in 0..2000 {
|
||||||
let new_token = sampler.sample(&ctx, -1);
|
let new_token = sampler.sample(&ctx, -1);
|
||||||
|
|
||||||
if self.model.is_eog_token(new_token) {
|
if self.model.is_eog_token(new_token) {
|
||||||
|
|
@ -169,6 +191,15 @@ impl InferenceEngine for LocalEngine {
|
||||||
n_cur += 1;
|
n_cur += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add output as JSON and end span
|
||||||
|
if let Some(tracer) = tracer {
|
||||||
|
let response_json = serde_json::json!({
|
||||||
|
"text": &output,
|
||||||
|
});
|
||||||
|
tracer.set_output_json(model_name, response_json);
|
||||||
|
tracer.end_span(model_name);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(LLMInferenceResult {
|
Ok(LLMInferenceResult {
|
||||||
summary: output.trim().to_string(),
|
summary: output.trim().to_string(),
|
||||||
raw_output: output,
|
raw_output: output,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ pub mod openai;
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub arguments: String,
|
pub arguments: String,
|
||||||
|
pub call_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LLMInferenceResult {
|
pub struct LLMInferenceResult {
|
||||||
|
|
@ -22,9 +23,6 @@ pub trait InferenceEngine: Send + Sync {
|
||||||
/// Generate text without streaming output
|
/// Generate text without streaming output
|
||||||
fn generate(
|
fn generate(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
max_tokens: u32,
|
|
||||||
tools: &[&dyn Tool],
|
tools: &[&dyn Tool],
|
||||||
chain: &Chain,
|
chain: &Chain,
|
||||||
images: &[String],
|
images: &[String],
|
||||||
|
|
|
||||||
|
|
@ -66,15 +66,12 @@ impl OpenAIEngine {
|
||||||
impl InferenceEngine for OpenAIEngine {
|
impl InferenceEngine for OpenAIEngine {
|
||||||
fn generate(
|
fn generate(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
_max_tokens: u32,
|
|
||||||
tools: &[&dyn crate::domain::tools::Tool],
|
tools: &[&dyn crate::domain::tools::Tool],
|
||||||
chain: &crate::domain::workflow::Chain,
|
chain: &crate::domain::workflow::Chain,
|
||||||
images: &[String],
|
images: &[String],
|
||||||
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||||
) -> Result<LLMInferenceResult, InfaError> {
|
) -> Result<LLMInferenceResult, InfaError> {
|
||||||
self.generate_with_responses_api(system_prompt, user_prompt, tools, chain, images, tracer)
|
self.generate_with_responses_api(tools, chain, images, tracer)
|
||||||
}
|
}
|
||||||
fn get_type(&self) -> ModelType {
|
fn get_type(&self) -> ModelType {
|
||||||
ModelType::OpenAI
|
ModelType::OpenAI
|
||||||
|
|
@ -85,15 +82,13 @@ impl OpenAIEngine {
|
||||||
/// Generate using the Responses API (for newer models like codex, o-series)
|
/// Generate using the Responses API (for newer models like codex, o-series)
|
||||||
fn generate_with_responses_api(
|
fn generate_with_responses_api(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: &str,
|
|
||||||
user_prompt: &str,
|
|
||||||
tools: &[&dyn crate::domain::tools::Tool],
|
tools: &[&dyn crate::domain::tools::Tool],
|
||||||
chain: &crate::domain::workflow::Chain,
|
chain: &crate::domain::workflow::Chain,
|
||||||
images: &[String],
|
images: &[String],
|
||||||
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||||
) -> Result<LLMInferenceResult, InfaError> {
|
) -> Result<LLMInferenceResult, InfaError> {
|
||||||
self.responses_client
|
self.responses_client
|
||||||
.call_responses_api(system_prompt, user_prompt, tools, chain, images, tracer)
|
.call_responses_api(tools, chain, images, tracer)
|
||||||
.map_err(|e| e)
|
.map_err(|e| e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue