mirror of
https://github.com/taggart-comet/quill-code.git
synced 2026-04-28 03:19:33 +00:00
fix
This commit is contained in:
parent
bd21ae6296
commit
3fd019bd7a
35 changed files with 633 additions and 343 deletions
|
|
@ -82,6 +82,38 @@ impl TracingFacade {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_model_config(&mut self, name: impl AsRef<str>, config: HashMap<String, serde_json::Value>) {
|
||||
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||
data.model_config = Some(config);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_usage(&mut self, name: impl AsRef<str>, input_tokens: u32, output_tokens: u32) {
|
||||
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||
data.usage = Some(crate::types::UsageData::new(input_tokens, output_tokens));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_input_json(&mut self, name: impl AsRef<str>, input: serde_json::Value) {
|
||||
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||
data.input = Some(vec![input]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_output_json(&mut self, name: impl AsRef<str>, output: serde_json::Value) {
|
||||
if let Some(span) = self.open_spans.get_mut(name.as_ref()) {
|
||||
if let SpanData::Generation(ref mut data) = span.span_data {
|
||||
data.output = Some(vec![output]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn end(&mut self) {
|
||||
for (_, mut span) in self.open_spans.drain() {
|
||||
span.mark_ended();
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ pub use session::{Session, SessionRequest};
|
|||
pub use startup::StartupService;
|
||||
pub use user_settings::UserSettings;
|
||||
pub use workflow::{CancellationToken, Chain};
|
||||
#[allow(unused_imports)]
|
||||
pub use todo::{TodoList, TodoItem};
|
||||
|
||||
/// Model type enum matching the inference engine types
|
||||
|
|
@ -41,4 +42,4 @@ impl ModelType {
|
|||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -401,12 +401,17 @@ mod tests {
|
|||
"read_only"
|
||||
}
|
||||
|
||||
fn parse_input(&self, _input: String) -> Option<Error> {
|
||||
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
|
||||
None
|
||||
}
|
||||
|
||||
fn work(&self, _request: &dyn Request) -> ToolResult {
|
||||
ToolResult::ok("read_only".to_string(), String::new(), String::new())
|
||||
ToolResult::ok(
|
||||
"read_only".to_string(),
|
||||
String::new(),
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
}
|
||||
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
|
|
@ -439,12 +444,17 @@ mod tests {
|
|||
"write_tool"
|
||||
}
|
||||
|
||||
fn parse_input(&self, _input: String) -> Option<Error> {
|
||||
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
|
||||
None
|
||||
}
|
||||
|
||||
fn work(&self, _request: &dyn Request) -> ToolResult {
|
||||
ToolResult::ok("write_tool".to_string(), String::new(), String::new())
|
||||
ToolResult::ok(
|
||||
"write_tool".to_string(),
|
||||
String::new(),
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
}
|
||||
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
|
|
@ -474,12 +484,17 @@ mod tests {
|
|||
"command_tool"
|
||||
}
|
||||
|
||||
fn parse_input(&self, _input: String) -> Option<Error> {
|
||||
fn parse_input(&self, _input: String, _call_id: String) -> Option<Error> {
|
||||
None
|
||||
}
|
||||
|
||||
fn work(&self, _request: &dyn Request) -> ToolResult {
|
||||
ToolResult::ok("command_tool".to_string(), String::new(), String::new())
|
||||
ToolResult::ok(
|
||||
"command_tool".to_string(),
|
||||
String::new(),
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
}
|
||||
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
|
|
@ -955,4 +970,4 @@ mod tests {
|
|||
|
||||
assert_eq!(decision, PermissionDecision::AlwaysAllow);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
use super::types::{Permission, PermissionDecision, PermissionScope};
|
||||
use crate::infrastructure::db::DbPool;
|
||||
use rusqlite::params;
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -121,16 +120,19 @@ impl PermissionStore for SqlitePermissionStore {
|
|||
) -> Result<Option<Permission>, StoreError> {
|
||||
// Build query dynamically to handle NULL values properly
|
||||
// In SQL, NULL != '', so we need to use IS NULL for empty strings
|
||||
let mut param_num = 3;
|
||||
let command_clause = if command_pattern.is_empty() {
|
||||
"command_pattern IS NULL"
|
||||
"command_pattern IS NULL".to_string()
|
||||
} else {
|
||||
"command_pattern = ?3"
|
||||
let clause = format!("command_pattern = ?{}", param_num);
|
||||
param_num += 1;
|
||||
clause
|
||||
};
|
||||
|
||||
let resource_clause = if resource_pattern.is_empty() {
|
||||
"resource_pattern IS NULL"
|
||||
"resource_pattern IS NULL".to_string()
|
||||
} else {
|
||||
"resource_pattern = ?4"
|
||||
format!("resource_pattern = ?{}", param_num)
|
||||
};
|
||||
|
||||
let query = format!(
|
||||
|
|
@ -151,32 +153,28 @@ impl PermissionStore for SqlitePermissionStore {
|
|||
)))
|
||||
})?;
|
||||
|
||||
// Build params based on whether patterns are empty
|
||||
let result = if command_pattern.is_empty() && resource_pattern.is_empty() {
|
||||
conn.query_row(
|
||||
&query,
|
||||
params![project_id, tool],
|
||||
|row| self.row_to_permission(row),
|
||||
)
|
||||
} else if command_pattern.is_empty() {
|
||||
conn.query_row(
|
||||
&query,
|
||||
params![project_id, tool, resource_pattern],
|
||||
|row| self.row_to_permission(row),
|
||||
)
|
||||
} else if resource_pattern.is_empty() {
|
||||
conn.query_row(
|
||||
&query,
|
||||
params![project_id, tool, command_pattern],
|
||||
|row| self.row_to_permission(row),
|
||||
)
|
||||
} else {
|
||||
conn.query_row(
|
||||
&query,
|
||||
params![project_id, tool, command_pattern, resource_pattern],
|
||||
|row| self.row_to_permission(row),
|
||||
)
|
||||
};
|
||||
// Build params list dynamically to match the query
|
||||
let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = vec![
|
||||
Box::new(project_id),
|
||||
Box::new(tool.to_string()),
|
||||
];
|
||||
|
||||
if !command_pattern.is_empty() {
|
||||
param_values.push(Box::new(command_pattern.to_string()));
|
||||
}
|
||||
|
||||
if !resource_pattern.is_empty() {
|
||||
param_values.push(Box::new(resource_pattern.to_string()));
|
||||
}
|
||||
|
||||
let params_refs: Vec<&dyn rusqlite::ToSql> = param_values
|
||||
.iter()
|
||||
.map(|p| p.as_ref())
|
||||
.collect();
|
||||
|
||||
let result = conn.query_row(&query, params_refs.as_slice(), |row| {
|
||||
self.row_to_permission(row)
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok(permission) => Ok(Some(permission)),
|
||||
|
|
@ -186,48 +184,140 @@ impl PermissionStore for SqlitePermissionStore {
|
|||
}
|
||||
}
|
||||
|
||||
impl SqlitePermissionStore {
|
||||
fn find_matching_command_permission(
|
||||
&self,
|
||||
rows: rusqlite::MappedRows<
|
||||
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
|
||||
>,
|
||||
tool: &str,
|
||||
command: &str,
|
||||
) -> Result<Option<Permission>, StoreError> {
|
||||
for row_result in rows {
|
||||
match row_result {
|
||||
Ok(permission) => {
|
||||
if permission.matches(tool, Some(command), None::<&PathBuf>) {
|
||||
return Ok(Some(permission));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(StoreError::Database(e)),
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use r2d2_sqlite::SqliteConnectionManager;
|
||||
|
||||
Ok(None)
|
||||
fn setup_test_db() -> DbPool {
|
||||
let manager = SqliteConnectionManager::memory();
|
||||
let pool = r2d2::Pool::new(manager).unwrap();
|
||||
let conn = pool.get().unwrap();
|
||||
|
||||
// Create permissions table
|
||||
conn.execute(
|
||||
"CREATE TABLE IF NOT EXISTS permissions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
tool_name TEXT NOT NULL,
|
||||
command_pattern TEXT,
|
||||
resource_pattern TEXT,
|
||||
decision TEXT NOT NULL,
|
||||
scope TEXT NOT NULL,
|
||||
project_id INTEGER,
|
||||
created_at TEXT NOT NULL
|
||||
)",
|
||||
[],
|
||||
).unwrap();
|
||||
|
||||
pool
|
||||
}
|
||||
|
||||
fn find_matching_path_permission(
|
||||
&self,
|
||||
rows: rusqlite::MappedRows<
|
||||
impl FnMut(&rusqlite::Row) -> Result<Permission, rusqlite::Error>,
|
||||
>,
|
||||
tool: &str,
|
||||
path: &PathBuf,
|
||||
) -> Result<Option<Permission>, StoreError> {
|
||||
for row_result in rows {
|
||||
match row_result {
|
||||
Ok(permission) => {
|
||||
if permission.matches(tool, None, Some(path)) {
|
||||
return Ok(Some(permission));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(StoreError::Database(e)),
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_find_permission_with_null_patterns() {
|
||||
let pool = setup_test_db();
|
||||
let store = SqlitePermissionStore::new(pool.clone());
|
||||
|
||||
Ok(None)
|
||||
// Create permission with NULL command and resource patterns
|
||||
let permission = Permission::new(
|
||||
"test_tool".to_string(),
|
||||
None,
|
||||
None,
|
||||
PermissionDecision::AlwaysAllow,
|
||||
PermissionScope::Project,
|
||||
Some(1),
|
||||
);
|
||||
store.create_permission(permission).unwrap();
|
||||
|
||||
// Should find the permission when searching with empty strings
|
||||
let result = store.find_permission("test_tool", 1, "", "").unwrap();
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_permission_with_command_only() {
|
||||
let pool = setup_test_db();
|
||||
let store = SqlitePermissionStore::new(pool.clone());
|
||||
|
||||
// Create permission with command but NULL resource
|
||||
let permission = Permission::new(
|
||||
"test_tool".to_string(),
|
||||
Some("echo test".to_string()),
|
||||
None,
|
||||
PermissionDecision::AlwaysDeny,
|
||||
PermissionScope::Project,
|
||||
Some(1),
|
||||
);
|
||||
store.create_permission(permission).unwrap();
|
||||
|
||||
// Should find the permission when searching with command and empty resource
|
||||
let result = store.find_permission("test_tool", 1, "echo test", "").unwrap();
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_permission_with_resource_only() {
|
||||
let pool = setup_test_db();
|
||||
let store = SqlitePermissionStore::new(pool.clone());
|
||||
|
||||
// Create permission with resource but NULL command
|
||||
let permission = Permission::new(
|
||||
"test_tool".to_string(),
|
||||
None,
|
||||
Some("/path/to/file".to_string()),
|
||||
PermissionDecision::AlwaysAllow,
|
||||
PermissionScope::Project,
|
||||
Some(1),
|
||||
);
|
||||
store.create_permission(permission).unwrap();
|
||||
|
||||
// Should find the permission when searching with empty command and resource
|
||||
let result = store.find_permission("test_tool", 1, "", "/path/to/file").unwrap();
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysAllow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_permission_with_both_patterns() {
|
||||
let pool = setup_test_db();
|
||||
let store = SqlitePermissionStore::new(pool.clone());
|
||||
|
||||
// Create permission with both patterns
|
||||
let permission = Permission::new(
|
||||
"test_tool".to_string(),
|
||||
Some("rm -rf".to_string()),
|
||||
Some("/etc/passwd".to_string()),
|
||||
PermissionDecision::AlwaysDeny,
|
||||
PermissionScope::Project,
|
||||
Some(1),
|
||||
);
|
||||
store.create_permission(permission).unwrap();
|
||||
|
||||
// Should find the permission when searching with both patterns
|
||||
let result = store.find_permission("test_tool", 1, "rm -rf", "/etc/passwd").unwrap();
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().decision, PermissionDecision::AlwaysDeny);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_permission_no_match() {
|
||||
let pool = setup_test_db();
|
||||
let store = SqlitePermissionStore::new(pool.clone());
|
||||
|
||||
// Create permission with command
|
||||
let permission = Permission::new(
|
||||
"test_tool".to_string(),
|
||||
Some("echo test".to_string()),
|
||||
None,
|
||||
PermissionDecision::AlwaysAllow,
|
||||
PermissionScope::Project,
|
||||
Some(1),
|
||||
);
|
||||
store.create_permission(permission).unwrap();
|
||||
|
||||
// Should NOT find when searching with different command
|
||||
let result = store.find_permission("test_tool", 1, "rm -rf", "").unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@ pub struct Permission {
|
|||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl Permission {
|
||||
pub fn new(
|
||||
tool_name: String,
|
||||
|
|
@ -173,4 +174,4 @@ impl Default for PermissionConfig {
|
|||
require_confirmation: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -32,33 +32,38 @@ pub fn get_bt_tree_step_prompt(
|
|||
pub fn get_system_prompt(model_type: ModelType, agent_mode: AgentModeType, remaining_calls: usize) -> String {
|
||||
let (os_name, shell_name) = get_runtime_environment();
|
||||
|
||||
let mut system_prompt = "".to_string();
|
||||
if agent_mode == AgentModeType::Plan {
|
||||
system_prompt = _system_prompt_for_plan(model_type);
|
||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||
let mut system_prompt = _system_prompt_for_plan(model_type);
|
||||
if remaining_calls < 3 {
|
||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||
}
|
||||
return system_prompt;
|
||||
}
|
||||
if agent_mode == AgentModeType::BuildFromPlan {
|
||||
system_prompt = _system_prompt_for_build_from_plan(model_type);
|
||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||
let mut system_prompt = _system_prompt_for_build_from_plan(model_type);
|
||||
if remaining_calls < 3 {
|
||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||
}
|
||||
return system_prompt;
|
||||
}
|
||||
if model_type == ModelType::OpenAI {
|
||||
system_prompt = format!(
|
||||
let mut system_prompt = if model_type == ModelType::OpenAI {
|
||||
format!(
|
||||
"You are Drastis, a coding agent. \n\
|
||||
Use the available tools to gather context and make changes. \
|
||||
When using tools, pass JSON arguments that match their parameters. \n\
|
||||
Runtime: os={}, shell={}.",
|
||||
os_name, shell_name
|
||||
);
|
||||
)
|
||||
} else {
|
||||
system_prompt = format!(
|
||||
format!(
|
||||
"You are Drastis, a coding agent. Use available tools to gather context and make changes. Be concise and accurate. Runtime: os={}, shell={}.",
|
||||
os_name, shell_name
|
||||
);
|
||||
)
|
||||
};
|
||||
if remaining_calls < 3 {
|
||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||
}
|
||||
system_prompt.push_str(&format!("\n\nYou have {} tool calls left to process this request.", remaining_calls));
|
||||
return system_prompt;
|
||||
system_prompt
|
||||
}
|
||||
|
||||
fn _system_prompt_for_plan(model_type: ModelType) -> String {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use std::path::Path;
|
|||
/// without being tightly coupled to the Session entity.
|
||||
pub trait Request {
|
||||
/// Get the history of previous requests
|
||||
#[allow(dead_code)]
|
||||
fn history(&self) -> &[SessionRequest];
|
||||
|
||||
/// Get the current request prompt
|
||||
|
|
@ -40,4 +41,4 @@ pub trait Request {
|
|||
|
||||
/// Get the session's TODO list (plan)
|
||||
fn get_session_plan(&self) -> Option<crate::domain::todo::TodoList>;
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ use crate::repository::SessionRequestRow;
|
|||
/// Domain entity representing a user request within a session.
|
||||
/// Each request contains the user's prompt and the resulting summary.
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct SessionRequest {
|
||||
prompt: String,
|
||||
result_summary: Option<String>,
|
||||
|
|
@ -32,4 +33,4 @@ impl SessionRequest {
|
|||
mode: row.mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
use super::Error;
|
||||
use crate::domain::prompting::session_naming_prompt;
|
||||
use crate::domain::workflow::Chain;
|
||||
use crate::domain::workflow::ChainStep;
|
||||
use crate::domain::{Project, Session, SessionRequest};
|
||||
use crate::infrastructure::db::DbPool;
|
||||
use crate::infrastructure::inference::InferenceEngine;
|
||||
|
|
@ -83,38 +84,37 @@ impl StartupService {
|
|||
let prompt_preview: String = first_prompt.chars().take(100).collect();
|
||||
let naming_prompt = session_naming_prompt(self.engine.get_type(), &prompt_preview);
|
||||
|
||||
let chain = Chain::new();
|
||||
let session_name =
|
||||
match self
|
||||
.engine
|
||||
.generate("", &naming_prompt, 15, &[], &chain, &[], None)
|
||||
{
|
||||
Ok(raw) => {
|
||||
log::debug!("Raw session name response: {:?}", raw.summary);
|
||||
// Clean up the response
|
||||
let cleaned = raw
|
||||
.summary
|
||||
.lines()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.trim_matches('"')
|
||||
.trim_matches('*')
|
||||
.replace("<|im_end|>", "")
|
||||
.trim()
|
||||
.to_string();
|
||||
let mut chain = Chain::new();
|
||||
chain
|
||||
.steps
|
||||
.push(ChainStep::user_message(naming_prompt.parse().unwrap(), Vec::new()));
|
||||
let session_name = match self.engine.generate(&[], &chain, &[], None) {
|
||||
Ok(raw) => {
|
||||
log::debug!("Raw session name response: {:?}", raw.summary);
|
||||
// Clean up the response
|
||||
let cleaned = raw
|
||||
.summary
|
||||
.lines()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.trim_matches('"')
|
||||
.trim_matches('*')
|
||||
.replace("<|im_end|>", "")
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
if cleaned.is_empty() || cleaned.len() > 50 {
|
||||
Self::fallback_session_name(first_prompt)
|
||||
} else {
|
||||
cleaned
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Session naming failed: {}", e);
|
||||
if cleaned.is_empty() || cleaned.len() > 50 {
|
||||
Self::fallback_session_name(first_prompt)
|
||||
} else {
|
||||
cleaned
|
||||
}
|
||||
};
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Session naming failed: {}", e);
|
||||
Self::fallback_session_name(first_prompt)
|
||||
}
|
||||
};
|
||||
|
||||
// Create session in database
|
||||
let conn = self
|
||||
|
|
@ -225,4 +225,4 @@ mod tests {
|
|||
|
||||
assert_eq!(StartupService::fallback_session_name(""), "New Session");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -15,6 +15,7 @@ pub struct DiscoverObjects {
|
|||
struct DiscoverObjectsInput {
|
||||
raw: String,
|
||||
full_path_to_file: String,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -69,6 +70,7 @@ impl DiscoverObjects {
|
|||
Ok(DiscoverObjectsInput {
|
||||
raw: raw.to_string(),
|
||||
full_path_to_file: parsed.full_path_to_file,
|
||||
call_id: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -86,12 +88,13 @@ impl Tool for DiscoverObjects {
|
|||
"discover_objects"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let trimmed = input.trim();
|
||||
let parsed = Self::parse_input_json(trimmed);
|
||||
|
||||
match parsed {
|
||||
Ok(parsed) => {
|
||||
Ok(mut parsed) => {
|
||||
parsed.call_id = call_id;
|
||||
*self.input.lock().unwrap() = Some(parsed);
|
||||
None
|
||||
}
|
||||
|
|
@ -103,17 +106,18 @@ impl Tool for DiscoverObjects {
|
|||
let input = match self.load_input() {
|
||||
Ok(input) => input,
|
||||
Err(e) => {
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||
}
|
||||
};
|
||||
|
||||
match Self::parse_file(&input.full_path_to_file) {
|
||||
Ok((lang, objects)) => ToolResult::ok(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
Self::format_output(lang, &objects),
|
||||
input.call_id,
|
||||
),
|
||||
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()),
|
||||
Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ struct FindFilesInput {
|
|||
query: String,
|
||||
root: Option<String>,
|
||||
max_results: usize,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -132,6 +133,7 @@ impl FindFiles {
|
|||
query: parsed.query,
|
||||
root: parsed.root,
|
||||
max_results: parsed.max_results.unwrap_or(20),
|
||||
call_id: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -149,12 +151,13 @@ impl Tool for FindFiles {
|
|||
"find_files"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let trimmed = input.trim();
|
||||
let parsed = Self::parse_input_json(trimmed);
|
||||
|
||||
match parsed {
|
||||
Ok(parsed) => {
|
||||
Ok(mut parsed) => {
|
||||
parsed.call_id = call_id;
|
||||
*self.input.lock().unwrap() = Some(parsed);
|
||||
None
|
||||
}
|
||||
|
|
@ -166,7 +169,7 @@ impl Tool for FindFiles {
|
|||
let input = match self.load_input() {
|
||||
Ok(input) => input,
|
||||
Err(e) => {
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -174,15 +177,16 @@ impl Tool for FindFiles {
|
|||
let search_root =
|
||||
match Self::resolve_search_root(input.root.as_deref(), request.project_root()) {
|
||||
Ok(root) => root,
|
||||
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e),
|
||||
Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
|
||||
};
|
||||
|
||||
// Check if search root exists
|
||||
if !search_root.exists() {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Search root does not exist: {}", search_root.display()),
|
||||
input.call_id,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -204,7 +208,7 @@ impl Tool for FindFiles {
|
|||
output
|
||||
};
|
||||
|
||||
ToolResult::ok(self.name().to_string(), input.raw, output)
|
||||
ToolResult::ok(self.name().to_string(), input.raw, output, input.call_id)
|
||||
}
|
||||
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ pub struct FileChange {
|
|||
|
||||
pub struct ToolResult {
|
||||
tool_name: String,
|
||||
call_id: String,
|
||||
pub(crate) input: String,
|
||||
is_successful: bool,
|
||||
output: String,
|
||||
|
|
@ -53,7 +54,7 @@ pub struct ToolResult {
|
|||
}
|
||||
|
||||
impl ToolResult {
|
||||
pub fn ok(tool_name: String, input: String, output: String) -> Self {
|
||||
pub fn ok(tool_name: String, input: String, output: String, call_id: String) -> Self {
|
||||
Self {
|
||||
tool_name,
|
||||
input,
|
||||
|
|
@ -61,10 +62,11 @@ impl ToolResult {
|
|||
output,
|
||||
error_message: String::new(),
|
||||
file_changes: None,
|
||||
call_id
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(tool_name: String, input: String, message: String) -> Self {
|
||||
pub fn error(tool_name: String, input: String, message: String, call_id: String) -> Self {
|
||||
Self {
|
||||
tool_name,
|
||||
input,
|
||||
|
|
@ -72,6 +74,7 @@ impl ToolResult {
|
|||
output: String::new(),
|
||||
error_message: message.into(),
|
||||
file_changes: None,
|
||||
call_id
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -128,11 +131,15 @@ impl ToolResult {
|
|||
pub fn file_changes(&self) -> Option<&[FileChange]> {
|
||||
self.file_changes.as_deref()
|
||||
}
|
||||
|
||||
pub(crate) fn call_id(&self) -> String {
|
||||
self.call_id.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Tool: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
fn parse_input(&self, input: String) -> Option<Error>;
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error>;
|
||||
fn work(&self, request: &dyn Request) -> ToolResult;
|
||||
fn parameters(&self) -> Value;
|
||||
fn desc(&self) -> String;
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ pub struct PatchFiles {
|
|||
struct PatchFilesInput {
|
||||
raw: String,
|
||||
patch: String,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -41,6 +42,7 @@ impl PatchFiles {
|
|||
Ok(PatchFilesInput {
|
||||
raw: raw.to_string(),
|
||||
patch: parsed.patch,
|
||||
call_id: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -58,12 +60,13 @@ impl Tool for PatchFiles {
|
|||
"patch_files"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let trimmed = input.trim();
|
||||
let parsed = Self::parse_input_json(trimmed);
|
||||
|
||||
match parsed {
|
||||
Ok(parsed) => {
|
||||
Ok(mut parsed) => {
|
||||
parsed.call_id = call_id;
|
||||
*self.input.lock().unwrap() = Some(parsed);
|
||||
None
|
||||
}
|
||||
|
|
@ -75,7 +78,7 @@ impl Tool for PatchFiles {
|
|||
let input = match self.load_input() {
|
||||
Ok(input) => input,
|
||||
Err(e) => {
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -84,8 +87,9 @@ impl Tool for PatchFiles {
|
|||
Err(e) => {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Failed to parse patch: {}", e),
|
||||
input.call_id.clone(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -115,8 +119,9 @@ impl Tool for PatchFiles {
|
|||
{
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Invalid path outside project root: {}", path),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
let fs_path = project_root.join(rel_path);
|
||||
|
|
@ -129,8 +134,9 @@ impl Tool for PatchFiles {
|
|||
Err(e) => {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Failed to read file '{}': {}", fs_path.display(), e),
|
||||
input.call_id.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -142,8 +148,9 @@ impl Tool for PatchFiles {
|
|||
Err(e) => {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Failed to apply patch: {}", e),
|
||||
input.call_id.clone(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -155,28 +162,31 @@ impl Tool for PatchFiles {
|
|||
if let Err(e) = fs::create_dir_all(parent) {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!(
|
||||
"Failed to create parent directories for '{}': {}",
|
||||
fs_path.display(),
|
||||
e
|
||||
),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Err(e) = fs::write(&fs_path, content) {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Failed to write file '{}': {}", fs_path.display(), e),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
} else if fs_path.exists() {
|
||||
if let Err(e) = fs::remove_file(&fs_path) {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Failed to delete file '{}': {}", fs_path.display(), e),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -224,6 +234,7 @@ impl Tool for PatchFiles {
|
|||
self.name().to_string(),
|
||||
input.raw,
|
||||
"Patch applied successfully".to_string(),
|
||||
input.call_id,
|
||||
);
|
||||
|
||||
if !file_changes.is_empty() {
|
||||
|
|
@ -531,7 +542,7 @@ mod tests {
|
|||
"{{\"patch\":\"{}\"}}",
|
||||
patch.replace('\n', "\\n").replace('"', "\\\"")
|
||||
);
|
||||
assert!(tool.parse_input(input).is_none());
|
||||
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
assert!(
|
||||
|
|
@ -571,7 +582,7 @@ mod tests {
|
|||
"{{\"patch\":\"{}\"}}",
|
||||
patch.replace('\n', "\\n").replace('"', "\\\"")
|
||||
);
|
||||
assert!(tool.parse_input(input).is_none());
|
||||
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
assert!(
|
||||
|
|
@ -592,7 +603,7 @@ mod tests {
|
|||
|
||||
let tool = PatchFiles::new();
|
||||
let input = "{\"patch\":\"not a patch\"}".to_string();
|
||||
assert!(tool.parse_input(input).is_none());
|
||||
assert!(tool.parse_input(input, "call-id".to_string()).is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
assert!(
|
||||
|
|
@ -601,4 +612,4 @@ mod tests {
|
|||
result.output_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -18,6 +18,7 @@ struct ReadObjectsInput {
|
|||
raw: String,
|
||||
full_path_to_file: String,
|
||||
queries: Vec<ObjectQuery>,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -138,6 +139,7 @@ impl ReadObjects {
|
|||
raw: raw.to_string(),
|
||||
full_path_to_file: trimmed_path.to_string(),
|
||||
queries,
|
||||
call_id: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -182,12 +184,13 @@ impl Tool for ReadObjects {
|
|||
"read_objects"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let trimmed = input.trim();
|
||||
let parsed = Self::parse_input_json(trimmed);
|
||||
|
||||
match parsed {
|
||||
Ok(parsed) => {
|
||||
Ok(mut parsed) => {
|
||||
parsed.call_id = call_id;
|
||||
*self.input.lock().unwrap() = Some(parsed);
|
||||
None
|
||||
}
|
||||
|
|
@ -199,17 +202,18 @@ impl Tool for ReadObjects {
|
|||
let input = match self.load_input() {
|
||||
Ok(input) => input,
|
||||
Err(e) => {
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||
}
|
||||
};
|
||||
|
||||
match Self::read_objects(&input.full_path_to_file, &input.queries) {
|
||||
Ok((lang, results)) => ToolResult::ok(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
Self::format_output(lang, results),
|
||||
input.call_id,
|
||||
),
|
||||
Err(e) => ToolResult::error(self.name().to_string(), input.raw, e.to_string()),
|
||||
Err(e) => ToolResult::error(self.name().to_string(), input.raw.clone(), e.to_string(), input.call_id),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -224,7 +228,7 @@ impl Tool for ReadObjects {
|
|||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\")",
|
||||
"description": "comma- or space-separated object names (e.g., \"main\", \"Config\", or \"main, Config, Parser\"), grepping is not working here, use exact names",
|
||||
"minLength": 1
|
||||
}
|
||||
},
|
||||
|
|
@ -235,8 +239,8 @@ impl Tool for ReadObjects {
|
|||
|
||||
fn desc(&self) -> String {
|
||||
format!(
|
||||
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties to use for `{}`, use the `discover_objects` tool first.",
|
||||
self.name(), self.name()
|
||||
"Use the `{}` tool to read source code of specific objects from a file. To determine correct properties, use the `discover_objects` tool first.",
|
||||
self.name(),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
@ -293,7 +297,7 @@ mod tests {
|
|||
fn test_parse_input_valid() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/main.rs", "query": "main, Config"}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(result.is_none(), "Expected no error, got: {:?}", result);
|
||||
|
||||
// Verify the parsed input
|
||||
|
|
@ -308,7 +312,7 @@ mod tests {
|
|||
fn test_parse_input_empty_query() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/main.rs", "query": ""}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(result.is_some(), "Expected error for empty query");
|
||||
|
||||
if let Some(Error::Parse(msg)) = result {
|
||||
|
|
@ -326,7 +330,7 @@ mod tests {
|
|||
fn test_parse_input_whitespace_query() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/main.rs", "query": " "}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(result.is_some(), "Expected error for whitespace query");
|
||||
|
||||
if let Some(Error::Parse(msg)) = result {
|
||||
|
|
@ -344,7 +348,7 @@ mod tests {
|
|||
fn test_parse_input_comma_only_query() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/main.rs", "query": ",,,"}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(result.is_some(), "Expected error for comma-only query");
|
||||
|
||||
if let Some(Error::Parse(msg)) = result {
|
||||
|
|
@ -362,7 +366,7 @@ mod tests {
|
|||
fn test_parse_input_missing_path() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "", "query": "main"}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(result.is_some(), "Expected error for empty path");
|
||||
|
||||
if let Some(Error::Parse(msg)) = result {
|
||||
|
|
@ -407,7 +411,7 @@ mod tests {
|
|||
fn test_parse_input_malformed_json() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/main.rs", "query": "main"#; // Missing closing brace
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(result.is_some(), "Expected error for malformed JSON");
|
||||
}
|
||||
|
||||
|
|
@ -415,7 +419,7 @@ mod tests {
|
|||
fn test_parse_input_path_with_whitespace() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": " src/main.rs ", "query": "main"}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(
|
||||
result.is_none(),
|
||||
"Expected no error for path with whitespace"
|
||||
|
|
@ -432,7 +436,7 @@ mod tests {
|
|||
fn test_parse_input_query_with_extra_whitespace() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/main.rs", "query": " main , Config "}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(
|
||||
result.is_none(),
|
||||
"Expected no error for query with extra whitespace"
|
||||
|
|
@ -448,7 +452,7 @@ mod tests {
|
|||
fn test_parse_input_single_object() {
|
||||
let tool = ReadObjects::new();
|
||||
let input = r#"{"path": "src/lib.rs", "query": "MyStruct"}"#;
|
||||
let result = tool.parse_input(input.to_string());
|
||||
let result = tool.parse_input(input.to_string(), "call-id".to_string());
|
||||
assert!(
|
||||
result.is_none(),
|
||||
"Expected no error for single object query"
|
||||
|
|
@ -481,4 +485,4 @@ mod tests {
|
|||
};
|
||||
assert!(query.matches(&obj_partial), "Should match partial name");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -25,6 +25,7 @@ struct ShellExecInputParsed {
|
|||
raw: String,
|
||||
command: String,
|
||||
working_dir: Option<String>,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
impl Tool for ShellExec {
|
||||
|
|
@ -32,7 +33,7 @@ impl Tool for ShellExec {
|
|||
"shell_exec"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let trimmed = input.trim();
|
||||
let parsed = serde_json::from_str::<ShellExecInput>(trimmed)
|
||||
.map_err(|e| Error::Parse(e.to_string()));
|
||||
|
|
@ -46,6 +47,7 @@ impl Tool for ShellExec {
|
|||
raw: trimmed.to_string(),
|
||||
command: parsed.command,
|
||||
working_dir: parsed.working_dir,
|
||||
call_id,
|
||||
});
|
||||
None
|
||||
}
|
||||
|
|
@ -61,6 +63,7 @@ impl Tool for ShellExec {
|
|||
self.name().to_string(),
|
||||
String::new(),
|
||||
"input not parsed".to_string(),
|
||||
String::new(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -72,15 +75,17 @@ impl Tool for ShellExec {
|
|||
if !path.exists() {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Working directory does not exist: {}", dir),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
if !crate::utils::paths::is_within_root(path, request.project_root()) {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
"Working directory is outside project root".to_string(),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
path.to_path_buf()
|
||||
|
|
@ -99,8 +104,9 @@ impl Tool for ShellExec {
|
|||
Err(e) => {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Failed to execute command: {}", e),
|
||||
input.call_id.clone(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -129,10 +135,10 @@ impl Tool for ShellExec {
|
|||
if !stderr.is_empty() {
|
||||
result.push_str(&format!("[stderr]: {}", stderr));
|
||||
}
|
||||
return ToolResult::error(self.name().to_string(), input.raw, result);
|
||||
return ToolResult::error(self.name().to_string(), input.raw.clone(), result, input.call_id.clone());
|
||||
};
|
||||
|
||||
ToolResult::ok(self.name().to_string(), input.raw, result)
|
||||
ToolResult::ok(self.name().to_string(), input.raw, result, input.call_id)
|
||||
}
|
||||
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
|
|
@ -155,8 +161,9 @@ impl Tool for ShellExec {
|
|||
|
||||
fn desc(&self) -> String {
|
||||
format!(
|
||||
r#"Use `{}` tool to execute shell commands.
|
||||
Please DO NOT use it to read the full content of a file, this is not efficient, use `read_objects` tool for this."#,
|
||||
"Use `{}` tool to execute shell commands.
|
||||
DO NOT use it to read code, this is not efficient, use `read_objects` tool for this. \n\
|
||||
DO NOT use it to change files, use `patch_files` tool for this.",
|
||||
self.name()
|
||||
)
|
||||
}
|
||||
|
|
@ -339,7 +346,7 @@ mod tests {
|
|||
|
||||
let tool = ShellExec::new();
|
||||
assert!(tool
|
||||
.parse_input(r#"{"command":"echo hello"}"#.to_string())
|
||||
.parse_input(r#"{"command":"echo hello"}"#.to_string(), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -357,7 +364,7 @@ mod tests {
|
|||
|
||||
let tool = ShellExec::new();
|
||||
assert!(tool
|
||||
.parse_input(r#"{"command":"pwd"}"#.to_string())
|
||||
.parse_input(r#"{"command":"pwd"}"#.to_string(), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -384,7 +391,7 @@ mod tests {
|
|||
.parse_input(format!(
|
||||
r#"{{"command":"pwd","working_dir":"{}"}}"#,
|
||||
subdir.display()
|
||||
))
|
||||
), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -405,7 +412,7 @@ mod tests {
|
|||
|
||||
let tool = ShellExec::new();
|
||||
assert!(tool
|
||||
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string())
|
||||
.parse_input(r#"{"command":"pwd","working_dir":"/tmp"}"#.to_string(), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -424,7 +431,7 @@ mod tests {
|
|||
|
||||
let tool = ShellExec::new();
|
||||
assert!(tool
|
||||
.parse_input(r#"{"command":"exit 1"}"#.to_string())
|
||||
.parse_input(r#"{"command":"exit 1"}"#.to_string(), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -443,7 +450,7 @@ mod tests {
|
|||
|
||||
let tool = ShellExec::new();
|
||||
assert!(tool
|
||||
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string())
|
||||
.parse_input(r#"{"command":"echo error >&2 && exit 1"}"#.to_string(), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -457,7 +464,7 @@ mod tests {
|
|||
let request = TestRequest::new(temp.path());
|
||||
|
||||
let tool = ShellExec::new();
|
||||
let err = tool.parse_input(r#"{"command":""}"#.to_string());
|
||||
let err = tool.parse_input(r#"{"command":""}"#.to_string(), "call-id".to_string());
|
||||
assert!(err.is_some());
|
||||
let result = tool.work(&request);
|
||||
assert!(result.output_string().contains("Error"));
|
||||
|
|
@ -474,7 +481,7 @@ mod tests {
|
|||
.parse_input(format!(
|
||||
r#"{{"command":"echo 'test content' > {}"}}"#,
|
||||
file_path.display()
|
||||
))
|
||||
), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -496,7 +503,7 @@ mod tests {
|
|||
|
||||
let tool = ShellExec::new();
|
||||
assert!(tool
|
||||
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string())
|
||||
.parse_input(r#"{"command":"echo 'hello world' | tr 'a-z' 'A-Z'"}"#.to_string(), "call-id".to_string())
|
||||
.is_none());
|
||||
let result = tool.work(&request);
|
||||
|
||||
|
|
@ -506,4 +513,4 @@ mod tests {
|
|||
result.output_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ struct StructureInput {
|
|||
raw: String,
|
||||
path: String,
|
||||
max_depth: usize,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -127,6 +128,7 @@ impl Structure {
|
|||
raw: raw.to_string(),
|
||||
path: parsed.path.unwrap_or_else(|| ".".to_string()),
|
||||
max_depth: parsed.max_depth.unwrap_or(3),
|
||||
call_id: String::new(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -144,12 +146,13 @@ impl Tool for Structure {
|
|||
"structure"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let trimmed = input.trim();
|
||||
let parsed = Self::parse_input_json(trimmed);
|
||||
|
||||
match parsed {
|
||||
Ok(parsed) => {
|
||||
Ok(mut parsed) => {
|
||||
parsed.call_id = call_id;
|
||||
*self.input.lock().unwrap() = Some(parsed);
|
||||
None
|
||||
}
|
||||
|
|
@ -161,34 +164,36 @@ impl Tool for Structure {
|
|||
let input = match self.load_input() {
|
||||
Ok(input) => input,
|
||||
Err(e) => {
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string())
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e.to_string(), String::new())
|
||||
}
|
||||
};
|
||||
|
||||
let target_path = match Self::resolve_path(&input.path, request.project_root()) {
|
||||
Ok(p) => p,
|
||||
Err(e) => return ToolResult::error(self.name().to_string(), input.raw, e),
|
||||
Err(e) => return ToolResult::error(self.name().to_string(), input.raw.clone(), e, input.call_id.clone()),
|
||||
};
|
||||
|
||||
if !target_path.exists() {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Path does not exist: {}", target_path.display()),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
if !target_path.is_dir() {
|
||||
return ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw,
|
||||
input.raw.clone(),
|
||||
format!("Path is not a directory: {}", target_path.display()),
|
||||
input.call_id,
|
||||
);
|
||||
}
|
||||
|
||||
let tree = Self::build_tree(&target_path, input.max_depth);
|
||||
|
||||
ToolResult::ok(self.name().to_string(), input.raw, tree)
|
||||
ToolResult::ok(self.name().to_string(), input.raw, tree, input.call_id)
|
||||
}
|
||||
|
||||
fn parameters(&self) -> serde_json::Value {
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ pub struct UpdateTodoList {
|
|||
#[derive(Debug, Clone)]
|
||||
struct UpdateTodoListInput {
|
||||
raw: String,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
impl UpdateTodoList {
|
||||
|
|
@ -36,12 +37,12 @@ impl Tool for UpdateTodoList {
|
|||
"update_todo_list"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<Error> {
|
||||
let parsed: Result<TodoList, _> = serde_json::from_str(&input);
|
||||
match parsed {
|
||||
Ok(_) => {
|
||||
let mut lock = self.input.lock().unwrap();
|
||||
*lock = Some(UpdateTodoListInput { raw: input });
|
||||
*lock = Some(UpdateTodoListInput { raw: input, call_id });
|
||||
None
|
||||
}
|
||||
Err(e) => Some(Error::Parse(format!("Failed to parse input: {}", e))),
|
||||
|
|
@ -57,6 +58,7 @@ impl Tool for UpdateTodoList {
|
|||
self.name().to_string(),
|
||||
String::new(),
|
||||
"No input provided".to_string(),
|
||||
String::new(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -70,6 +72,7 @@ impl Tool for UpdateTodoList {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
format!("Failed to get database connection: {}", e),
|
||||
input.call_id.clone(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -83,6 +86,7 @@ impl Tool for UpdateTodoList {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
format!("Failed to get/create TODO list: {}", e),
|
||||
input.call_id.clone(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
|
@ -92,6 +96,7 @@ impl Tool for UpdateTodoList {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
format!("Failed to update TODO list: {}", e),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -106,6 +111,7 @@ impl Tool for UpdateTodoList {
|
|||
self.name().to_string(),
|
||||
input.raw,
|
||||
"TODO list updated successfully.".to_string(),
|
||||
input.call_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ struct WebSearchInput {
|
|||
raw: String,
|
||||
query: String,
|
||||
max_results: u32,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -94,7 +95,7 @@ impl Tool for WebSearch {
|
|||
"web_search"
|
||||
}
|
||||
|
||||
fn parse_input(&self, input: String) -> Option<crate::domain::tools::Error> {
|
||||
fn parse_input(&self, input: String, call_id: String) -> Option<crate::domain::tools::Error> {
|
||||
let trimmed = input.trim();
|
||||
match serde_json::from_str::<WebSearchInputJson>(trimmed) {
|
||||
Ok(parsed) => {
|
||||
|
|
@ -103,6 +104,7 @@ impl Tool for WebSearch {
|
|||
raw: input,
|
||||
query: parsed.query,
|
||||
max_results,
|
||||
call_id,
|
||||
};
|
||||
*self.input.lock().unwrap() = Some(parsed_input);
|
||||
None
|
||||
|
|
@ -118,7 +120,7 @@ impl Tool for WebSearch {
|
|||
let input = match self.load_input() {
|
||||
Ok(input) => input,
|
||||
Err(e) => {
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e);
|
||||
return ToolResult::error(self.name().to_string(), String::new(), e, String::new());
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -130,6 +132,7 @@ impl Tool for WebSearch {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
"User settings not available".to_string(),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
|
@ -139,6 +142,7 @@ impl Tool for WebSearch {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
"Web search is not enabled in settings".to_string(),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -149,6 +153,7 @@ impl Tool for WebSearch {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
"Brave API key not configured".to_string(),
|
||||
input.call_id.clone(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
|
@ -159,12 +164,13 @@ impl Tool for WebSearch {
|
|||
let output = WebSearchOutput { results };
|
||||
match serde_json::to_string(&output) {
|
||||
Ok(json_output) => {
|
||||
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output)
|
||||
ToolResult::ok(self.name().to_string(), input.raw.clone(), json_output, input.call_id)
|
||||
}
|
||||
Err(e) => ToolResult::error(
|
||||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
format!("Failed to serialize output: {}", e),
|
||||
input.call_id.clone(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
|
@ -172,6 +178,7 @@ impl Tool for WebSearch {
|
|||
self.name().to_string(),
|
||||
input.raw.clone(),
|
||||
format!("Search failed: {}", e),
|
||||
input.call_id,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
|
@ -593,8 +600,10 @@ mod tests {
|
|||
user_settings,
|
||||
};
|
||||
let web_search = WebSearch::new();
|
||||
let _ =
|
||||
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string());
|
||||
let _ = web_search.parse_input(
|
||||
r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
|
||||
"call-id".to_string(),
|
||||
);
|
||||
|
||||
let _allowed = checker.check(&web_search, &request, Some(1)).unwrap();
|
||||
|
||||
|
|
@ -629,8 +638,10 @@ mod tests {
|
|||
user_settings,
|
||||
};
|
||||
let web_search = WebSearch::new();
|
||||
let _ =
|
||||
web_search.parse_input(r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string());
|
||||
let _ = web_search.parse_input(
|
||||
r#"{"query":"site:docs.rs serde","max_results":5}"#.to_string(),
|
||||
"call-id".to_string(),
|
||||
);
|
||||
|
||||
let allowed = checker.check(&web_search, &request, Some(1)).unwrap();
|
||||
|
||||
|
|
@ -640,4 +651,4 @@ mod tests {
|
|||
assert_eq!(created[0].tool_name, "web_search");
|
||||
assert_eq!(created[0].resource_pattern, None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ pub struct Chain {
|
|||
pub fail_reason: String,
|
||||
#[serde(default)]
|
||||
pub final_message: Option<String>,
|
||||
pub system_prompt: String,
|
||||
}
|
||||
|
||||
impl Chain {
|
||||
|
|
@ -28,6 +29,7 @@ impl Chain {
|
|||
is_failed: false,
|
||||
fail_reason: String::new(),
|
||||
final_message: None,
|
||||
system_prompt: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -120,6 +122,10 @@ impl Chain {
|
|||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub fn set_system_prompt(&mut self, system_prompt: String) {
|
||||
self.system_prompt = system_prompt;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn total_payload_len_chars(&self) -> usize {
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ pub struct ChainStep {
|
|||
#[serde(default)]
|
||||
pub tool_name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub call_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub tool_output: Option<String>,
|
||||
#[serde(default)]
|
||||
pub is_successful: Option<bool>,
|
||||
|
|
@ -60,6 +62,7 @@ impl ChainStep {
|
|||
let mut tool_output = None;
|
||||
let mut is_successful = None;
|
||||
let mut file_changes = None;
|
||||
let mut call_id = None;
|
||||
if let Some(tr) = tool_result {
|
||||
summary = tr.summary();
|
||||
context_payload = tr.output_string();
|
||||
|
|
@ -68,6 +71,7 @@ impl ChainStep {
|
|||
tool_output = Some(tr.output_raw().to_string());
|
||||
is_successful = Some(tr.is_successful());
|
||||
file_changes = tr.file_changes().map(|fc| fc.to_vec());
|
||||
call_id = Some(tr.call_id());
|
||||
}
|
||||
|
||||
Self {
|
||||
|
|
@ -76,6 +80,7 @@ impl ChainStep {
|
|||
context_payload,
|
||||
input_payload,
|
||||
tool_name,
|
||||
call_id,
|
||||
tool_output,
|
||||
is_successful,
|
||||
file_changes,
|
||||
|
|
@ -91,6 +96,7 @@ impl ChainStep {
|
|||
context_payload: raw_output.clone(),
|
||||
input_payload: String::new(),
|
||||
tool_name: None,
|
||||
call_id: None,
|
||||
tool_output: Some(raw_output),
|
||||
is_successful: Some(true),
|
||||
file_changes: None,
|
||||
|
|
@ -118,6 +124,7 @@ impl ChainStep {
|
|||
context_payload: prompt.clone(),
|
||||
input_payload: prompt,
|
||||
tool_name: None,
|
||||
call_id: None,
|
||||
tool_output: None,
|
||||
is_successful: Some(true),
|
||||
file_changes: None,
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ impl ToolRunner {
|
|||
tool_call.name.clone(),
|
||||
tool_call.arguments.clone(),
|
||||
error_msg,
|
||||
tool_call.call_id.clone(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
|
@ -73,11 +74,13 @@ impl ToolRunner {
|
|||
tool.name().to_string(),
|
||||
String::new(),
|
||||
"Permission denied".to_string(),
|
||||
String::new(),
|
||||
),
|
||||
Err(err) => ToolResult::error(
|
||||
tool.name().to_string(),
|
||||
String::new(),
|
||||
format!("Permission check error: {}", err),
|
||||
String::new(),
|
||||
),
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ pub trait Toolset {
|
|||
.map(|t| t.as_ref())
|
||||
.ok_or_else(|| Error::Parse(format!("Tool not found: {}", tool_call.name)))?;
|
||||
|
||||
if let Some(err) = tool.parse_input(tool_call.arguments.clone()) {
|
||||
if let Some(err) = tool.parse_input(tool_call.arguments.clone(), tool_call.call_id.clone()) {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use super::{
|
|||
use crate::domain::bt::GeneralTree;
|
||||
use crate::domain::prompting;
|
||||
use crate::domain::session::Request;
|
||||
use crate::domain::todo::TodoListStatus;
|
||||
use crate::infrastructure::db::DbPool;
|
||||
use crate::infrastructure::event_bus::{AgentToUiEvent, StepPhase};
|
||||
use crate::infrastructure::inference::InferenceEngine;
|
||||
|
|
@ -16,7 +17,6 @@ use std::sync::Arc;
|
|||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use crossbeam_channel::Sender;
|
||||
use crate::domain::AgentModeType;
|
||||
use crate::domain::todo::TodoListStatus;
|
||||
|
||||
/// Main workflow orchestrator that runs LLM-driven coding tasks
|
||||
/// Implements an eternal agent loop that:
|
||||
|
|
@ -94,7 +94,6 @@ impl Workflow {
|
|||
self.chain.set_todo_list(request.get_session_plan());
|
||||
|
||||
let result = self._run(request, cancel, None, None, mode);
|
||||
self.chain.steps.push(ChainStep::user_message(request.current_request().to_string(), request.images().to_vec()));
|
||||
self._end_tracing();
|
||||
result
|
||||
}
|
||||
|
|
@ -145,6 +144,7 @@ impl Workflow {
|
|||
context_payload: String::new(),
|
||||
input_payload: step_user_prompt.to_string(),
|
||||
tool_name: None,
|
||||
call_id: None,
|
||||
tool_output: None,
|
||||
is_successful: Some(true),
|
||||
file_changes: None,
|
||||
|
|
@ -171,6 +171,13 @@ impl Workflow {
|
|||
mode: AgentModeType,
|
||||
) -> Result<(), Error> {
|
||||
|
||||
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
|
||||
let user_prompt = user_prompt_override
|
||||
.as_deref()
|
||||
.unwrap_or(&base_user_prompt)
|
||||
.to_string();
|
||||
self.chain.steps.push(ChainStep::user_message(user_prompt.clone(), request.images().to_vec()));
|
||||
|
||||
// Get max_tool_calls from override (for BT mode) or user settings
|
||||
let max_tool_calls = max_tool_calls_override.unwrap_or_else(|| {
|
||||
request.user_settings()
|
||||
|
|
@ -198,11 +205,7 @@ impl Workflow {
|
|||
|
||||
// Get base system prompt and inject remaining count
|
||||
let system_prompt = prompting::get_system_prompt(self.engine.get_type(), request.mode(), remaining_calls);
|
||||
let base_user_prompt = prompting::get_user_prompt(self.engine.get_type(), request);
|
||||
let user_prompt = user_prompt_override
|
||||
.as_deref()
|
||||
.unwrap_or(&base_user_prompt)
|
||||
.to_string();
|
||||
self.chain.set_system_prompt(system_prompt.clone());
|
||||
|
||||
// Switch to finishing toolset if approaching limit
|
||||
if !in_finishing_mode && tool_call_count >= finishing_threshold {
|
||||
|
|
@ -220,13 +223,9 @@ impl Workflow {
|
|||
}
|
||||
|
||||
self._emit_inference_progress();
|
||||
self._trace_llm_start(user_prompt.clone());
|
||||
|
||||
// Ask LLM to choose next tool
|
||||
let llm_output = match self.engine.generate(
|
||||
&system_prompt,
|
||||
&user_prompt,
|
||||
1024,
|
||||
&self.toolset.tool_refs(),
|
||||
&self.chain,
|
||||
request.images(),
|
||||
|
|
@ -238,8 +237,6 @@ impl Workflow {
|
|||
}
|
||||
};
|
||||
|
||||
self._trace_llm_end(llm_output.raw_output.clone());
|
||||
|
||||
// Always capture the assistant's response in the chain
|
||||
if !llm_output.raw_output.is_empty() {
|
||||
self.chain.steps.push(ChainStep::assistant_response(
|
||||
|
|
@ -347,13 +344,6 @@ impl Workflow {
|
|||
self.chain = Chain::new();
|
||||
}
|
||||
|
||||
fn _trace_llm_start(&mut self, prompt: String) {
|
||||
if let Some(tracer) = &mut self.tracer {
|
||||
tracer.start_span("LLM generation", SpanKind::Generation);
|
||||
tracer.add_input("LLM generation", prompt);
|
||||
}
|
||||
}
|
||||
|
||||
fn _emit_inference_progress(&self) {
|
||||
let options = [
|
||||
"Thinking.. well kinda..",
|
||||
|
|
@ -383,18 +373,10 @@ impl Workflow {
|
|||
});
|
||||
}
|
||||
|
||||
fn _trace_llm_end(&mut self, output: String) {
|
||||
if let Some(tracer) = &mut self.tracer {
|
||||
tracer.add_output("LLM generation", output);
|
||||
tracer.end_span("LLM generation");
|
||||
}
|
||||
}
|
||||
|
||||
/// Detects if the active TODO item has changed
|
||||
/// Returns true if the first non-completed item's title is different
|
||||
fn _did_todo_item_change(&self, previous_todo: &Option<crate::domain::todo::TodoList>) -> bool {
|
||||
use crate::domain::todo::TodoListStatus;
|
||||
|
||||
|
||||
// Get first non-completed item from previous TODO list
|
||||
let prev_active = previous_todo.as_ref().and_then(|list| {
|
||||
list.items.iter()
|
||||
|
|
|
|||
|
|
@ -64,8 +64,6 @@ impl OpenAIClient {
|
|||
|
||||
pub fn call_responses_api(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
tools: &[&dyn crate::domain::tools::Tool],
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
images: &[String],
|
||||
|
|
@ -74,8 +72,6 @@ impl OpenAIClient {
|
|||
let max_attempts = 3;
|
||||
for attempt in 1..=max_attempts {
|
||||
match self.call_responses_api_inner(
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
tools,
|
||||
chain,
|
||||
images,
|
||||
|
|
@ -121,25 +117,31 @@ impl OpenAIClient {
|
|||
|
||||
fn call_responses_api_inner(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
tools: &[&dyn crate::domain::tools::Tool],
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
images: &[String],
|
||||
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||
mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||
) -> Result<LLMInferenceResult, Box<dyn Error + Send + Sync>> {
|
||||
let url = "https://api.openai.com/v1/responses";
|
||||
|
||||
let request_body = build_request_dto(
|
||||
&self.model,
|
||||
system_prompt,
|
||||
user_prompt,
|
||||
images,
|
||||
tools,
|
||||
chain,
|
||||
tracer,
|
||||
tracer.as_deref_mut(),
|
||||
);
|
||||
|
||||
// Start span with model name and add request as JSON
|
||||
if let Some(tracer) = &mut tracer {
|
||||
tracer.start_span(&self.model, openai_agents_tracing::SpanKind::Generation);
|
||||
|
||||
// Convert request_body to JSON Value and set as input
|
||||
if let Ok(request_json) = serde_json::to_value(&request_body) {
|
||||
tracer.set_input_json(&self.model, request_json);
|
||||
}
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(url)
|
||||
|
|
@ -151,18 +153,35 @@ impl OpenAIClient {
|
|||
let body = response.text()?;
|
||||
|
||||
if !status.is_success() {
|
||||
if let Some(t) = tracer {
|
||||
t.end_span(&self.model);
|
||||
}
|
||||
return Err(Box::new(OpenAIClientError::Api { status, body }));
|
||||
}
|
||||
|
||||
let dto = match serde_json::from_str::<ResponseDTO>(&body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => return Err(Box::new(OpenAIClientError::Deserialize { source: e, body })),
|
||||
Err(e) => {
|
||||
if let Some(t) = tracer {
|
||||
t.end_span(&self.model);
|
||||
}
|
||||
return Err(Box::new(OpenAIClientError::Deserialize { source: e, body }));
|
||||
}
|
||||
};
|
||||
|
||||
// Add response as JSON and end span
|
||||
if let Some(tracer) = &mut tracer {
|
||||
if let Ok(response_json) = serde_json::to_value(&dto) {
|
||||
tracer.set_output_json(&self.model, response_json);
|
||||
}
|
||||
tracer.end_span(&self.model);
|
||||
}
|
||||
|
||||
let result = build_llm_result(dto, tools);
|
||||
if result.summary.is_empty() && result.tool_call.is_none() {
|
||||
return Err(Box::new(OpenAIClientError::NoText { body }));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ use crate::domain::{Chain, ModelType};
|
|||
use crate::domain::prompting::format_todo_list_message;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use crate::domain::workflow::step::StepType::{AssistantResponse, ToolCall, UserMessage};
|
||||
use crate::domain::workflow::step::StepType;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RequestDTO {
|
||||
model: String,
|
||||
instructions: String,
|
||||
input: Vec<InputDto>,
|
||||
input: Vec<InputMessageDto>,
|
||||
tools: Vec<ToolDto>,
|
||||
tool_choice: String,
|
||||
parallel_tool_calls: bool,
|
||||
|
|
@ -18,12 +18,37 @@ pub struct RequestDTO {
|
|||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct InputDto {
|
||||
pub(super) struct MessageDto {
|
||||
content: Vec<InputContent>,
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
role: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct FunctionOutputDto {
|
||||
output: String,
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(super) struct FunctionCallDto {
|
||||
arguments: String,
|
||||
name: String,
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
call_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputMessageDto {
|
||||
Message(MessageDto),
|
||||
FunctionOutput(FunctionOutputDto),
|
||||
FunctionCall(FunctionCallDto),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
|
@ -63,10 +88,25 @@ impl InputContent {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn function(text: String) -> Self {
|
||||
Self::Text {
|
||||
kind: "function".to_string(),
|
||||
text,
|
||||
}
|
||||
|
||||
impl FunctionOutputDto {
|
||||
pub fn new(output: String, call_id: String) -> Self {
|
||||
Self {
|
||||
kind: "function_call_output".to_string(),
|
||||
output,
|
||||
call_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FunctionCallDto {
|
||||
pub fn new(name: String, arguments: String, call_id: String) -> Self {
|
||||
Self {
|
||||
name,
|
||||
kind: "function_call".to_string(),
|
||||
arguments,
|
||||
call_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -83,28 +123,23 @@ pub(super) struct ToolDto {
|
|||
const ROLE_USER: &str = "user";
|
||||
const ROLE_SYSTEM: &str = "system";
|
||||
const ROLE_ASSISTANT: &str = "assistant";
|
||||
const INPUT_STATUS_IN_PROGRESS: &str = "in_progress";
|
||||
const INPUT_STATUS_COMPLETED: &str = "completed";
|
||||
const INPUT_STATUS_FAILED: &str = "failed";
|
||||
|
||||
impl RequestDTO {
|
||||
pub(crate) fn new(
|
||||
model: String,
|
||||
system_prompt: String,
|
||||
user_prompt: String,
|
||||
tools: &[&dyn Tool],
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
chain: &Chain,
|
||||
) -> Self {
|
||||
// User request is now part of the chain, no need to add separately
|
||||
let input = InputDto::build(user_prompt, chain);
|
||||
let input = MessageDto::build(chain);
|
||||
|
||||
Self {
|
||||
model,
|
||||
instructions: system_prompt,
|
||||
instructions: chain.system_prompt.clone(),
|
||||
input,
|
||||
tools: tools.iter().map(|tool| ToolDto::from_tool(*tool)).collect(),
|
||||
tool_choice: "auto".to_string(),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: false,
|
||||
store: false,
|
||||
stream: false,
|
||||
}
|
||||
|
|
@ -123,50 +158,46 @@ impl ToolDto {
|
|||
}
|
||||
}
|
||||
|
||||
impl InputDto {
|
||||
fn build(user_prompt: String, chain: &Chain) -> Vec<Self> {
|
||||
impl MessageDto {
|
||||
fn build(chain: &Chain) -> Vec<InputMessageDto> {
|
||||
let steps = chain.get_steps_with_history();
|
||||
|
||||
let mut result: Vec<Self> = steps
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, step)| {
|
||||
// Determine status
|
||||
let is_user_message = step.step_type == UserMessage.as_str();
|
||||
let mut result: Vec<InputMessageDto> = Vec::new();
|
||||
|
||||
let status = if is_user_message || step.is_successful.unwrap_or(false) {
|
||||
INPUT_STATUS_COMPLETED
|
||||
} else {
|
||||
INPUT_STATUS_FAILED
|
||||
};
|
||||
let mut role = ROLE_ASSISTANT.to_string();
|
||||
for step in steps.iter() {
|
||||
let is_user_message = step.step_type == StepType::UserMessage.as_str();
|
||||
|
||||
// Build content items
|
||||
let content_items = if is_user_message {
|
||||
role = ROLE_USER.to_string();
|
||||
// For user messages, include text and images
|
||||
if is_user_message {
|
||||
// User message: text + optional images
|
||||
let mut items = vec![InputContent::text(step.input_payload.clone())];
|
||||
|
||||
// Add image content items if present
|
||||
if let Some(ref images) = step.images {
|
||||
for image_url in images {
|
||||
items.push(InputContent::image(image_url.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
items
|
||||
} else {
|
||||
vec![InputContent::output_text(step.get_output(ModelType::OpenAI))]
|
||||
};
|
||||
result.push(InputMessageDto::Message(Self {
|
||||
content: items,
|
||||
role: Some(ROLE_USER.to_string()),
|
||||
kind: "message".to_string(),
|
||||
}));
|
||||
} else if step.step_type == StepType::ToolCall.as_str() {
|
||||
|
||||
Self {
|
||||
content: content_items,
|
||||
role,
|
||||
kind: "message".to_string(),
|
||||
status: status.to_string(),
|
||||
let tool_name = step.tool_name.clone().unwrap();
|
||||
let call_id = step.call_id.clone().unwrap();
|
||||
// Tool call output is a separate DTO type
|
||||
result.push(InputMessageDto::FunctionCall(FunctionCallDto::new(tool_name, step.input_payload.clone(), call_id.clone())));
|
||||
result.push(InputMessageDto::FunctionOutput(FunctionOutputDto::new(step.get_output(ModelType::OpenAI), call_id)));
|
||||
} else {
|
||||
// Assistant message
|
||||
result.push(InputMessageDto::Message(Self {
|
||||
content: vec![InputContent::output_text(step.get_output(ModelType::OpenAI))],
|
||||
role: Some(ROLE_ASSISTANT.to_string()),
|
||||
kind: "message".to_string(),
|
||||
}));
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Add the plan as system message at the beginning if it exists and is not completed
|
||||
if let Some(ref todo_list) = chain.todo_list {
|
||||
|
|
@ -178,22 +209,15 @@ impl InputDto {
|
|||
|
||||
let todo_input = Self {
|
||||
content: vec![InputContent::text(todo_message)],
|
||||
role: ROLE_SYSTEM.to_string(),
|
||||
role: Some(ROLE_SYSTEM.to_string()),
|
||||
kind: "message".to_string(),
|
||||
status: INPUT_STATUS_COMPLETED.to_string(),
|
||||
};
|
||||
result.push(todo_input);
|
||||
|
||||
// Put it first
|
||||
result.insert(0, InputMessageDto::Message(todo_input));
|
||||
}
|
||||
}
|
||||
|
||||
// and adding the current user message at the end
|
||||
result.push(Self {
|
||||
content: vec![InputContent::text(user_prompt)],
|
||||
role: ROLE_USER.to_string(),
|
||||
kind: "message".to_string(),
|
||||
status: INPUT_STATUS_IN_PROGRESS.to_string(),
|
||||
});
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,20 +1,21 @@
|
|||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ResponseDTO {
|
||||
output: Vec<OpenAIOutputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OpenAIOutputItem {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
content: Option<Vec<OpenAIContentItem>>,
|
||||
name: Option<String>,
|
||||
arguments: Option<String>,
|
||||
call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct OpenAIContentItem {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
|
|
@ -41,11 +42,12 @@ impl ResponseDTO {
|
|||
}
|
||||
}
|
||||
} else if item.kind == "function_call" {
|
||||
if let (Some(name), Some(arguments)) = (item.name.as_ref(), item.arguments.as_ref())
|
||||
if let (Some(name), Some(arguments), Some(call_id)) = (item.name.as_ref(), item.arguments.as_ref(), item.call_id.as_ref())
|
||||
{
|
||||
tool_call = Some(FunctionCall {
|
||||
name: name.to_string(),
|
||||
arguments: arguments.to_string(),
|
||||
call_id: call_id.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -59,4 +61,5 @@ impl ResponseDTO {
|
|||
pub(crate) struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub call_id: String,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@ use crate::infrastructure::inference::{LLMInferenceResult, ToolCall};
|
|||
|
||||
pub fn build_request_dto(
|
||||
model: &str,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
images: &[String],
|
||||
tools: &[&dyn crate::domain::tools::Tool],
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
|
|
@ -25,8 +23,6 @@ pub fn build_request_dto(
|
|||
|
||||
let dto = RequestDTO::new(
|
||||
model.to_string(),
|
||||
system_prompt.to_string(),
|
||||
user_prompt.to_string(),
|
||||
tools,
|
||||
chain,
|
||||
);
|
||||
|
|
@ -54,9 +50,10 @@ pub fn build_llm_result(
|
|||
|
||||
let tool_call = if let Some(call) = tool_call_dto {
|
||||
// Create ToolCall for the workflow to use
|
||||
Some(ToolCall {
|
||||
Some(ToolCall{
|
||||
name: call.name.clone(),
|
||||
arguments: call.arguments.clone(),
|
||||
call_id: call.call_id.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
|
|
@ -51,4 +51,4 @@ pub fn render(frame: &mut Frame, area: Rect, state: &UiState, theme: Theme) {
|
|||
let y = area.y.saturating_add(INPUT_PADDING.top + adjusted_line);
|
||||
frame.set_cursor(x, y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -159,6 +159,7 @@ fn handle_agent_event(
|
|||
});
|
||||
state.request_status = None;
|
||||
state.request_progress = None;
|
||||
state.last_progress_update = None; // Reset timer for new request
|
||||
state.file_changes = None;
|
||||
|
||||
// Display the user's message
|
||||
|
|
@ -175,7 +176,22 @@ fn handle_agent_event(
|
|||
summary,
|
||||
} => {
|
||||
if matches!(phase, StepPhase::Before) {
|
||||
state.request_progress = Some(summary.clone());
|
||||
// Check if enough time has passed since the last progress update
|
||||
let can_update = state
|
||||
.last_progress_update
|
||||
.map(|last| {
|
||||
std::time::Instant::now()
|
||||
.duration_since(last)
|
||||
.as_millis()
|
||||
>= crate::infrastructure::cli::state::MIN_PROGRESS_DISPLAY_MS
|
||||
})
|
||||
.unwrap_or(true); // Update immediately if no previous update
|
||||
|
||||
if can_update {
|
||||
state.request_progress = Some(summary.clone());
|
||||
state.last_progress_update = Some(std::time::Instant::now());
|
||||
}
|
||||
// Otherwise skip this update - too soon since last one
|
||||
} else {
|
||||
state.request_progress = None;
|
||||
}
|
||||
|
|
@ -207,6 +223,7 @@ fn handle_agent_event(
|
|||
});
|
||||
}
|
||||
state.request_progress = None;
|
||||
state.last_progress_update = None; // Reset timer when request finishes
|
||||
|
||||
// Update the user message color based on request status
|
||||
for entry in state.progress.iter_mut() {
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ impl AttachedImage {
|
|||
}
|
||||
}
|
||||
|
||||
pub const INPUT_MIN_HEIGHT: usize = 3;
|
||||
pub const INPUT_MAX_HEIGHT: usize = 6;
|
||||
pub const INPUT_MIN_HEIGHT: usize = 1;
|
||||
pub const INPUT_MAX_HEIGHT: usize = 5;
|
||||
pub const PROGRESS_HISTORY_LIMIT: usize = 200;
|
||||
pub const MAIN_BODY_SCROLL_STEP: usize = 3;
|
||||
|
||||
|
|
@ -51,6 +51,7 @@ pub enum ProgressKind {
|
|||
}
|
||||
|
||||
pub const REQUEST_STATUS_DISPLAY_DURATION: Duration = Duration::from_secs(2);
|
||||
pub const MIN_PROGRESS_DISPLAY_MS: u128 = 1000; // 1 second minimum display time
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
|
|
@ -260,6 +261,7 @@ pub struct UiState {
|
|||
pub request_in_flight: Option<RequestIndicator>,
|
||||
pub request_status: Option<RequestStatusDisplay>,
|
||||
pub request_progress: Option<String>,
|
||||
pub last_progress_update: Option<Instant>,
|
||||
pub file_changes: Option<FileChangesDisplay>,
|
||||
pub agent_mode: AgentModeType,
|
||||
pub todo_list: Option<TodoListDisplay>,
|
||||
|
|
@ -300,6 +302,7 @@ impl UiState {
|
|||
request_in_flight: None,
|
||||
request_status: None,
|
||||
request_progress: None,
|
||||
last_progress_update: None,
|
||||
file_changes: None,
|
||||
agent_mode: AgentModeType::Build, // Default to build mode
|
||||
todo_list: None,
|
||||
|
|
@ -353,4 +356,4 @@ impl UiState {
|
|||
Some(crate::domain::ModelType::OpenAI)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -31,4 +31,4 @@ impl Theme {
|
|||
}
|
||||
|
||||
pub const PANEL_PADDING: Padding = Padding::new(2, 2, 1, 1);
|
||||
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);
|
||||
pub const INPUT_PADDING: Padding = Padding::new(1, 1, 1, 1);
|
||||
|
|
@ -6,7 +6,7 @@ use crate::infrastructure::cli::helpers::{
|
|||
centered_rect, cursor_position, list_state, panel_block,
|
||||
};
|
||||
use crate::infrastructure::cli::state::{LoadStatus, PopupState, UiMode, UiState};
|
||||
use crate::infrastructure::cli::theme::{Theme, PANEL_PADDING};
|
||||
use crate::infrastructure::cli::theme::{Theme, INPUT_PADDING, PANEL_PADDING};
|
||||
use ratatui::layout::{Alignment, Constraint, Direction, Layout, Rect};
|
||||
use ratatui::style::{Modifier, Style};
|
||||
use ratatui::text::{Line, Span, Text};
|
||||
|
|
@ -28,7 +28,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
|
|||
input_lines,
|
||||
crate::infrastructure::cli::state::INPUT_MAX_HEIGHT,
|
||||
);
|
||||
let mut input_box_height = (input_lines + 2) as u16;
|
||||
let mut input_box_height = (input_lines + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16;
|
||||
|
||||
let indicator_height = 1u16;
|
||||
let attachment_indicator_height = if state.attached_images.is_empty() {
|
||||
|
|
@ -45,7 +45,7 @@ pub fn render(frame: &mut Frame, state: &UiState) {
|
|||
let overflow = fixed_height - size.height;
|
||||
let reduced = input_box_height.saturating_sub(overflow);
|
||||
input_box_height =
|
||||
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + 2) as u16);
|
||||
reduced.max((crate::infrastructure::cli::state::INPUT_MIN_HEIGHT + INPUT_PADDING.top as usize + INPUT_PADDING.bottom as usize) as u16);
|
||||
}
|
||||
|
||||
let constraints = if state.attached_images.is_empty() {
|
||||
|
|
@ -557,4 +557,4 @@ pub fn openai_available_filtered(state: &UiState, filter: &str) -> Vec<String> {
|
|||
.filter(|name| name.to_lowercase().contains(&query))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
|
@ -89,17 +89,39 @@ impl LocalEngine {
|
|||
}
|
||||
|
||||
impl InferenceEngine for LocalEngine {
|
||||
// so far very fucked up implementation, even user_prompt is not properly passed
|
||||
// need to be refactored using proper request builder same way as for openai inference
|
||||
fn generate(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
max_tokens: u32,
|
||||
_tools: &[&dyn crate::domain::tools::Tool],
|
||||
_chain: &crate::domain::workflow::Chain,
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
_images: &[String],
|
||||
_tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||
mut tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||
) -> Result<LLMInferenceResult, InfaError> {
|
||||
let prompt = format!("{}\n\n{}", system_prompt, user_prompt);
|
||||
let model_name = "local";
|
||||
|
||||
// Start span with model name and add request as JSON
|
||||
if let Some(tracer) = &mut tracer {
|
||||
tracer.start_span(model_name, openai_agents_tracing::SpanKind::Generation);
|
||||
|
||||
// Set request as structured JSON input
|
||||
let request_json = serde_json::json!({
|
||||
"system_prompt": chain.system_prompt.clone(),
|
||||
"user_prompt": chain.get_steps_with_history()[0].context_payload.clone(),
|
||||
"max_tokens": 1000,
|
||||
});
|
||||
tracer.set_input_json(model_name, request_json);
|
||||
|
||||
// Set model configuration
|
||||
let mut model_config = std::collections::HashMap::new();
|
||||
model_config.insert("temperature".to_string(), serde_json::json!(self.params.temperature));
|
||||
model_config.insert("top_p".to_string(), serde_json::json!(self.params.top_p));
|
||||
model_config.insert("ctx_size".to_string(), serde_json::json!(self.params.ctx_size));
|
||||
model_config.insert("threads".to_string(), serde_json::json!(self.params.threads));
|
||||
tracer.set_model_config(model_name, model_config);
|
||||
}
|
||||
|
||||
let prompt = format!("{}\n\n{}", chain.system_prompt.clone(), chain.get_steps_with_history()[0].context_payload.clone());
|
||||
let to_error = |msg: String| -> InfaError {
|
||||
std::io::Error::new(std::io::ErrorKind::Other, msg).into()
|
||||
};
|
||||
|
|
@ -145,7 +167,7 @@ impl InferenceEngine for LocalEngine {
|
|||
let mut output = String::new();
|
||||
let mut n_cur = tokens.len();
|
||||
|
||||
for _ in 0..max_tokens {
|
||||
for _ in 0..2000 {
|
||||
let new_token = sampler.sample(&ctx, -1);
|
||||
|
||||
if self.model.is_eog_token(new_token) {
|
||||
|
|
@ -169,6 +191,15 @@ impl InferenceEngine for LocalEngine {
|
|||
n_cur += 1;
|
||||
}
|
||||
|
||||
// Add output as JSON and end span
|
||||
if let Some(tracer) = tracer {
|
||||
let response_json = serde_json::json!({
|
||||
"text": &output,
|
||||
});
|
||||
tracer.set_output_json(model_name, response_json);
|
||||
tracer.end_span(model_name);
|
||||
}
|
||||
|
||||
Ok(LLMInferenceResult {
|
||||
summary: output.trim().to_string(),
|
||||
raw_output: output,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ pub mod openai;
|
|||
pub struct ToolCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
pub call_id: String,
|
||||
}
|
||||
|
||||
pub struct LLMInferenceResult {
|
||||
|
|
@ -22,9 +23,6 @@ pub trait InferenceEngine: Send + Sync {
|
|||
/// Generate text without streaming output
|
||||
fn generate(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
max_tokens: u32,
|
||||
tools: &[&dyn Tool],
|
||||
chain: &Chain,
|
||||
images: &[String],
|
||||
|
|
|
|||
|
|
@ -66,15 +66,12 @@ impl OpenAIEngine {
|
|||
impl InferenceEngine for OpenAIEngine {
|
||||
fn generate(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
_max_tokens: u32,
|
||||
tools: &[&dyn crate::domain::tools::Tool],
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
images: &[String],
|
||||
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||
) -> Result<LLMInferenceResult, InfaError> {
|
||||
self.generate_with_responses_api(system_prompt, user_prompt, tools, chain, images, tracer)
|
||||
self.generate_with_responses_api(tools, chain, images, tracer)
|
||||
}
|
||||
fn get_type(&self) -> ModelType {
|
||||
ModelType::OpenAI
|
||||
|
|
@ -85,15 +82,13 @@ impl OpenAIEngine {
|
|||
/// Generate using the Responses API (for newer models like codex, o-series)
|
||||
fn generate_with_responses_api(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
tools: &[&dyn crate::domain::tools::Tool],
|
||||
chain: &crate::domain::workflow::Chain,
|
||||
images: &[String],
|
||||
tracer: Option<&mut openai_agents_tracing::TracingFacade>,
|
||||
) -> Result<LLMInferenceResult, InfaError> {
|
||||
self.responses_client
|
||||
.call_responses_api(system_prompt, user_prompt, tools, chain, images, tracer)
|
||||
.call_responses_api(tools, chain, images, tracer)
|
||||
.map_err(|e| e)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue