diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index cac687d077e..f972c6f72ab 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -15,10 +15,13 @@ use futures::channel::mpsc; use futures::future::Shared; use futures::io::BufReader; use futures::{AsyncBufReadExt as _, Future, FutureExt as _, StreamExt as _}; -use project::agent_server_store::{AgentServerCommand, AgentServerStore}; +use project::agent_server_store::{ + AgentServerCommand, AgentServerStore, AllAgentServersSettings, CustomAgentServerSettings, +}; use project::{AgentId, Project}; use remote::remote_client::Interactive; use serde::Deserialize; +use settings::SettingsStore; use std::path::PathBuf; use std::process::{ExitStatus, Stdio}; use std::rc::Rc; @@ -32,7 +35,7 @@ use util::path_list::PathList; use util::process::Child; use anyhow::{Context as _, Result}; -use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Task, WeakEntity}; +use gpui::{App, AppContext as _, AsyncApp, Entity, SharedString, Subscription, Task, WeakEntity}; use acp_thread::{AcpThread, AuthRequired, LoadError, TerminalProviderEvent}; use terminal::TerminalBuilder; @@ -421,18 +424,101 @@ pub struct AcpConnection { auth_methods: Vec, agent_server_store: WeakEntity, agent_capabilities: acp::AgentCapabilities, - default_mode: Option, - default_model: Option, - default_config_options: HashMap, + defaults: AcpConnectionDefaults, child: Option, session_list: Option>, debug_log: AcpDebugLog, + _settings_subscription: Subscription, _io_task: Task<()>, _dispatch_task: Task<()>, _wait_task: Task>, _stderr_task: Task>, } +#[derive(Clone, Default)] +struct AcpConnectionDefaults { + mode: Rc>>, + model: Rc>>, + config_options: Rc>>, +} + +impl AcpConnectionDefaults { + fn new( + mode: Option, + model: Option, + config_options: HashMap, + ) -> Self { + Self { + mode: Rc::new(RefCell::new(mode)), + model: Rc::new(RefCell::new(model)), + config_options: Rc::new(RefCell::new(config_options)), + } + } + + fn mode(&self) -> Option { + self.mode.borrow().clone() + } + + fn model(&self) -> Option { + self.model.borrow().clone() + } + + fn config_option(&self, config_id: &str) -> Option { + self.config_options.borrow().get(config_id).cloned() + } + + fn set( + &self, + mode: Option, + model: Option, + config_options: HashMap, + ) { + *self.mode.borrow_mut() = mode; + *self.model.borrow_mut() = model; + *self.config_options.borrow_mut() = config_options; + } + + fn refresh_from_settings(&self, agent_id: &AgentId, cx: &App) { + let Some(settings_store) = cx.try_global::() else { + self.set(None, None, HashMap::default()); + return; + }; + let settings = settings_store.get::(None); + let Some(agent_settings) = settings.get(agent_id.as_ref()) else { + self.set(None, None, HashMap::default()); + return; + }; + + let default_config_options = match agent_settings { + CustomAgentServerSettings::Custom { + default_config_options, + .. + } + | CustomAgentServerSettings::Registry { + default_config_options, + .. + } => default_config_options.clone(), + }; + self.set( + agent_settings.default_mode().map(acp::SessionModeId::new), + agent_settings.default_model().map(acp::ModelId::new), + default_config_options, + ); + } + + fn observe_settings(&self, agent_id: AgentId, cx: &mut App) -> Subscription { + if cx.try_global::().is_none() { + return Subscription::new(|| {}); + } + + self.refresh_from_settings(&agent_id, cx); + let defaults = self.clone(); + cx.observe_global::(move |cx| { + defaults.refresh_from_settings(&agent_id, cx); + }) + } +} + struct PendingAcpSession { task: Shared, Arc>>>, ref_count: usize, @@ -996,6 +1082,14 @@ impl AcpConnection { } else { response.auth_methods }; + let defaults = + AcpConnectionDefaults::new(default_mode, default_model, default_config_options); + let settings_subscription = cx.update({ + let agent_id = agent_id.clone(); + let defaults = defaults.clone(); + move |cx| defaults.observe_settings(agent_id, cx) + }); + Ok(Self { id: agent_id, auth_methods, @@ -1006,11 +1100,10 @@ impl AcpConnection { sessions, pending_sessions: Rc::new(RefCell::new(HashMap::default())), agent_capabilities: response.agent_capabilities, - default_mode, - default_model, - default_config_options, + defaults, session_list, debug_log, + _settings_subscription: settings_subscription, _io_task: io_task, _dispatch_task: dispatch_task, _wait_task: wait_task, @@ -1031,10 +1124,14 @@ impl AcpConnection { agent_server_store: WeakEntity, io_task: Task<()>, dispatch_task: Task<()>, - _cx: &mut App, + cx: &mut App, ) -> Self { + let agent_id = AgentId::new("test"); + let defaults = AcpConnectionDefaults::default(); + let settings_subscription = defaults.observe_settings(agent_id.clone(), cx); + Self { - id: AgentId::new("test"), + id: agent_id, telemetry_id: "test".into(), agent_version: None, connection, @@ -1043,12 +1140,11 @@ impl AcpConnection { auth_methods: vec![], agent_server_store, agent_capabilities, - default_mode: None, - default_model: None, - default_config_options: HashMap::default(), + defaults, child: None, session_list: None, debug_log: AcpDebugLog::default(), + _settings_subscription: settings_subscription, _io_task: io_task, _dispatch_task: dispatch_task, _wait_task: Task::ready(Ok(())), @@ -1215,7 +1311,7 @@ impl AcpConnection { config_opts_ref .iter() .filter_map(|config_option| { - let default_value = self.default_config_options.get(&*config_option.id.0)?; + let default_value = self.defaults.config_option(config_option.id.0.as_ref())?; let is_valid = match &config_option.kind { acp::SessionConfigKind::Select(select) => match &select.options { @@ -1241,11 +1337,7 @@ impl AcpConnection { } _ => None, }; - Some(( - config_option.id.clone(), - default_value.clone(), - initial_value, - )) + Some((config_option.id.clone(), default_value, initial_value)) } else { log::warn!( "`{}` is not a valid value for config option `{}` in {}", @@ -1488,7 +1580,8 @@ impl AgentConnection for AcpConnection { let (modes, models, config_options) = config_state(response.modes, response.models, response.config_options); - if let Some(default_mode) = self.default_mode.clone() { + let default_mode = self.defaults.mode(); + if let Some(default_mode) = default_mode { if let Some(modes) = modes.as_ref() { let mut modes_ref = modes.borrow_mut(); let has_mode = modes_ref @@ -1537,7 +1630,8 @@ impl AgentConnection for AcpConnection { } } - if let Some(default_model) = self.default_model.clone() { + let default_model = self.defaults.model(); + if let Some(default_model) = default_model { if let Some(models) = models.as_ref() { let mut models_ref = models.borrow_mut(); let has_model = models_ref @@ -2501,6 +2595,7 @@ mod tests { use super::*; use gpui::UpdateGlobal as _; + use settings::Settings as _; #[test] fn terminal_auth_task_builds_spawn_from_prebuilt_command() { @@ -2970,6 +3065,68 @@ mod tests { .expect("failed to receive ACP connection") } + #[gpui::test] + async fn settings_changes_refresh_active_connection_defaults(cx: &mut gpui::TestAppContext) { + cx.update(|cx| { + let store = settings::SettingsStore::test(cx); + cx.set_global(store); + }); + + let fs = fs::FakeFs::new(cx.executor()); + fs.insert_tree("/", serde_json::json!({ "a": {} })).await; + let project = project::Project::test(fs, [std::path::Path::new("/a")], cx).await; + let harness = test_support::connect_fake_acp_connection(project, cx).await; + + cx.update(|cx| { + AllAgentServersSettings::override_global( + AllAgentServersSettings(HashMap::from_iter([( + "test".to_string(), + settings::CustomAgentServerSettings::Custom { + path: PathBuf::from("test-agent"), + args: Vec::new(), + env: HashMap::default(), + default_mode: Some("manual".to_string()), + default_model: Some("claude-sonnet-4".to_string()), + favorite_models: Vec::new(), + default_config_options: HashMap::from_iter([( + "mode".to_string(), + "manual".to_string(), + )]), + favorite_config_option_values: HashMap::default(), + } + .into(), + )])), + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!( + harness.connection.defaults.mode(), + Some(acp::SessionModeId::new("manual")) + ); + assert_eq!( + harness.connection.defaults.model(), + Some(acp::ModelId::new("claude-sonnet-4")) + ); + assert_eq!( + harness.connection.defaults.config_option("mode").as_deref(), + Some("manual") + ); + + cx.update(|cx| { + AllAgentServersSettings::override_global( + AllAgentServersSettings(HashMap::default()), + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!(harness.connection.defaults.mode(), None); + assert_eq!(harness.connection.defaults.model(), None); + assert_eq!(harness.connection.defaults.config_option("mode"), None); + } + #[gpui::test] async fn session_list_delete_sends_session_delete_when_supported( cx: &mut gpui::TestAppContext, diff --git a/crates/agent_ui/src/config_options.rs b/crates/agent_ui/src/config_options.rs index c1f9a09c22f..b980e2e0c97 100644 --- a/crates/agent_ui/src/config_options.rs +++ b/crates/agent_ui/src/config_options.rs @@ -19,7 +19,7 @@ use ui::{ }; use util::ResultExt as _; -use crate::ui::{HoldForDefault, documentation_aside_side}; +use crate::ui::documentation_aside_side; const PICKER_THRESHOLD: usize = 5; @@ -101,6 +101,13 @@ impl ConfigOptionsView { return false; }; + self.agent_server.set_default_config_option( + config_id.0.as_ref(), + Some(next_value.0.as_ref()), + self.fs.clone(), + cx, + ); + let task = self .config_options .set_config_option(config_id, next_value, cx); @@ -412,7 +419,7 @@ struct ConfigOptionPickerDelegate { filtered_entries: Vec, all_options: Vec, selected_index: usize, - selected_description: Option<(usize, SharedString, bool)>, + selected_description: Option<(usize, SharedString)>, favorites: HashSet, _settings_subscription: Subscription, } @@ -544,28 +551,16 @@ impl PickerDelegate for ConfigOptionPickerDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { if let Some(ConfigOptionPickerEntry::Option(option)) = self.filtered_entries.get(self.selected_index) { - if window.modifiers().secondary() { - let default_value = self - .agent_server - .default_config_option(self.config_id.0.as_ref(), cx); - let is_default = default_value.as_deref() == Some(&*option.value.0); - - self.agent_server.set_default_config_option( - self.config_id.0.as_ref(), - if is_default { - None - } else { - Some(option.value.0.as_ref()) - }, - self.fs.clone(), - cx, - ); - } - + self.agent_server.set_default_config_option( + self.config_id.0.as_ref(), + Some(option.value.0.as_ref()), + self.fs.clone(), + cx, + ); let task = self.config_options.set_config_option( self.config_id.clone(), option.value.clone(), @@ -614,11 +609,6 @@ impl PickerDelegate for ConfigOptionPickerDelegate { let current_value = self.current_value(); let is_selected = current_value.as_ref() == Some(&option.value); - let default_value = self - .agent_server - .default_config_option(self.config_id.0.as_ref(), cx); - let is_default = default_value.as_deref() == Some(&*option.value.0); - let is_favorite = self.favorites.contains(&option.value); let option_name = option.name.clone(); @@ -631,9 +621,8 @@ impl PickerDelegate for ConfigOptionPickerDelegate { let desc: SharedString = desc.into(); this.on_hover(cx.listener(move |menu, hovered, _, cx| { if *hovered { - menu.delegate.selected_description = - Some((ix, desc.clone(), is_default)); - } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) + menu.delegate.selected_description = Some((ix, desc.clone())); + } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) { menu.delegate.selected_description = None; } @@ -688,29 +677,20 @@ impl PickerDelegate for ConfigOptionPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - self.selected_description - .as_ref() - .map(|(_, description, is_default)| { - let description = description.clone(); - let is_default = *is_default; + self.selected_description.as_ref().map(|(_, description)| { + let description = description.clone(); - let side = documentation_aside_side(cx); + let side = documentation_aside_side(cx); - ui::DocumentationAside::new( - side, - Rc::new(move |_| { - v_flex() - .gap_1() - .child(Label::new(description.clone())) - .child(HoldForDefault::new(is_default)) - .into_any_element() - }), - ) - }) + ui::DocumentationAside::new( + side, + Rc::new(move |_| Label::new(description.clone()).into_any_element()), + ) + }) } fn documentation_aside_index(&self) -> Option { - self.selected_description.as_ref().map(|(ix, _, _)| *ix) + self.selected_description.as_ref().map(|(ix, _)| *ix) } } @@ -878,3 +858,143 @@ fn count_config_options(option: &acp::SessionConfigOption) -> usize { _ => 0, } } + +#[cfg(test)] +mod tests { + use super::*; + use acp_thread::AgentConnection; + use fs::FakeFs; + use gpui::TestAppContext; + use parking_lot::Mutex; + use project::{AgentId, Project}; + use std::{any::Any, cell::RefCell}; + + #[gpui::test] + fn cycling_config_option_saves_selected_value_as_default(cx: &mut TestAppContext) { + let agent_server = Rc::new(TestAgentServer::default()); + let config_options = Rc::new(TestSessionConfigOptions::new(vec![ + acp::SessionConfigOption::select( + "mode", + "Mode", + "auto", + vec![ + acp::SessionConfigSelectOption::new("auto", "Auto"), + acp::SessionConfigSelectOption::new("manual", "Manual"), + ], + ) + .category(acp::SessionConfigOptionCategory::Mode), + ])); + let fs: Arc = FakeFs::new(cx.executor()); + + cx.update(|cx| { + let config_options: Rc = config_options.clone(); + let agent_server: Rc = agent_server.clone(); + let fs = fs.clone(); + let view = cx.new(|_| ConfigOptionsView { + config_option_ids: ConfigOptionsView::config_option_ids(&config_options), + config_options, + selectors: Vec::new(), + agent_server, + fs, + _refresh_task: Task::ready(()), + }); + + assert!(view.update(cx, |view, cx| { + view.cycle_category_option(acp::SessionConfigOptionCategory::Mode, false, cx) + })); + }); + + assert_eq!( + agent_server.saved_defaults.lock().as_slice(), + &[("mode".to_string(), Some("manual".to_string()))] + ); + assert_eq!( + config_options.set_values.borrow().as_slice(), + &[("mode".to_string(), "manual".to_string())] + ); + } + + #[derive(Default)] + struct TestAgentServer { + saved_defaults: Arc)>>>, + } + + impl AgentServer for TestAgentServer { + fn logo(&self) -> IconName { + IconName::ZedAssistant + } + + fn agent_id(&self) -> AgentId { + AgentId::new("test-agent") + } + + fn connect( + &self, + _delegate: agent_servers::AgentServerDelegate, + _project: Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Err(anyhow::anyhow!("test agent server cannot connect"))) + } + + fn into_any(self: Rc) -> Rc { + self + } + + fn set_default_config_option( + &self, + config_id: &str, + value_id: Option<&str>, + _fs: Arc, + _cx: &mut App, + ) { + self.saved_defaults.lock().push(( + config_id.to_string(), + value_id.map(|value| value.to_string()), + )); + } + } + + struct TestSessionConfigOptions { + options: RefCell>, + set_values: RefCell>, + } + + impl TestSessionConfigOptions { + fn new(options: Vec) -> Self { + Self { + options: RefCell::new(options), + set_values: RefCell::new(Vec::new()), + } + } + } + + impl AgentSessionConfigOptions for TestSessionConfigOptions { + fn config_options(&self) -> Vec { + self.options.borrow().clone() + } + + fn set_config_option( + &self, + config_id: acp::SessionConfigId, + value: acp::SessionConfigValueId, + _cx: &mut App, + ) -> Task>> { + self.set_values + .borrow_mut() + .push((config_id.0.to_string(), value.0.to_string())); + + let options = { + let mut options = self.options.borrow_mut(); + if let Some(option) = options.iter_mut().find(|option| option.id == config_id) + && let acp::SessionConfigKind::Select(select) = &mut option.kind + { + select.current_value = value; + } + options.clone() + }; + + Task::ready(Ok(options)) + } + } +} diff --git a/crates/agent_ui/src/mode_selector.rs b/crates/agent_ui/src/mode_selector.rs index 9e4464517c2..cea60af7aa7 100644 --- a/crates/agent_ui/src/mode_selector.rs +++ b/crates/agent_ui/src/mode_selector.rs @@ -11,10 +11,7 @@ use ui::{ prelude::*, }; -use crate::{ - CycleModeSelector, ToggleProfileSelector, - ui::{HoldForDefault, documentation_aside_side}, -}; +use crate::{CycleModeSelector, ToggleProfileSelector, ui::documentation_aside_side}; pub struct ModeSelector { connection: Rc, @@ -45,6 +42,10 @@ impl ModeSelector { pub fn cycle_mode(&mut self, _window: &mut Window, cx: &mut Context) { let all_modes = self.connection.all_modes(); + if all_modes.is_empty() { + return; + } + let current_mode = self.connection.current_mode(); let current_index = all_modes @@ -52,8 +53,9 @@ impl ModeSelector { .position(|mode| mode.id.0 == current_mode.0) .unwrap_or(0); - let next_index = (current_index + 1) % all_modes.len(); - self.set_mode(all_modes[next_index].id.clone(), cx); + if let Some(next_mode) = all_modes.get((current_index + 1) % all_modes.len()) { + self.set_mode(next_mode.id.clone(), cx); + } } pub fn mode(&self) -> acp::SessionModeId { @@ -61,6 +63,9 @@ impl ModeSelector { } pub fn set_mode(&mut self, mode: acp::SessionModeId, cx: &mut Context) { + self.agent_server + .set_default_mode(Some(mode.clone()), self.fs.clone(), cx); + let task = self.connection.set_mode(mode, cx); self.setting_mode = true; cx.notify(); @@ -88,13 +93,11 @@ impl ModeSelector { ContextMenu::build(window, cx, move |mut menu, _window, cx| { let all_modes = self.connection.all_modes(); let current_mode = self.connection.current_mode(); - let default_mode = self.agent_server.default_mode(cx); let side = documentation_aside_side(cx); for mode in all_modes { let is_selected = &mode.id == ¤t_mode; - let is_default = Some(&mode.id) == default_mode.as_ref(); let entry = ContextMenuEntry::new(mode.name.clone()) .toggleable(IconPosition::End, is_selected); @@ -102,13 +105,7 @@ impl ModeSelector { entry.documentation_aside(side, { let description = description.clone(); - move |_| { - v_flex() - .gap_1() - .child(Label::new(description.clone())) - .child(HoldForDefault::new(is_default)) - .into_any_element() - } + move |_| Label::new(description.clone()).into_any_element() }) } else { entry @@ -117,21 +114,9 @@ impl ModeSelector { menu.push_item(entry.handler({ let mode_id = mode.id.clone(); let weak_self = weak_self.clone(); - move |window, cx| { + move |_window, cx| { weak_self .update(cx, |this, cx| { - if window.modifiers().secondary() { - this.agent_server.set_default_mode( - if is_default { - None - } else { - Some(mode_id.clone()) - }, - this.fs.clone(), - cx, - ); - } - this.set_mode(mode_id.clone(), cx); }) .ok(); @@ -209,3 +194,110 @@ impl Render for ModeSelector { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use acp_thread::AgentConnection; + use fs::FakeFs; + use gpui::{App, Task, TestAppContext}; + use parking_lot::Mutex; + use project::{AgentId, Project}; + use std::{any::Any, cell::RefCell}; + + #[gpui::test] + fn setting_mode_saves_selected_mode_as_default(cx: &mut TestAppContext) { + let agent_server = Rc::new(TestAgentServer::default()); + let session_modes = Rc::new(TestSessionModes::new()); + let fs: Arc = FakeFs::new(cx.executor()); + + cx.update(|cx| { + let session_modes: Rc = session_modes.clone(); + let agent_server: Rc = agent_server.clone(); + let selector = cx.new(|_| ModeSelector::new(session_modes, agent_server, fs)); + + selector.update(cx, |selector, cx| { + selector.set_mode(acp::SessionModeId::new("manual"), cx); + }); + }); + + assert_eq!( + agent_server.saved_defaults.lock().as_slice(), + &[Some(acp::SessionModeId::new("manual"))] + ); + assert_eq!( + session_modes.set_modes.borrow().as_slice(), + &[acp::SessionModeId::new("manual")] + ); + } + + #[derive(Default)] + struct TestAgentServer { + saved_defaults: Arc>>>, + } + + impl AgentServer for TestAgentServer { + fn logo(&self) -> IconName { + IconName::ZedAssistant + } + + fn agent_id(&self) -> AgentId { + AgentId::new("test-agent") + } + + fn connect( + &self, + _delegate: agent_servers::AgentServerDelegate, + _project: Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Err(anyhow::anyhow!("test agent server cannot connect"))) + } + + fn into_any(self: Rc) -> Rc { + self + } + + fn set_default_mode( + &self, + mode_id: Option, + _fs: Arc, + _cx: &mut App, + ) { + self.saved_defaults.lock().push(mode_id); + } + } + + struct TestSessionModes { + current_mode: RefCell, + set_modes: RefCell>, + } + + impl TestSessionModes { + fn new() -> Self { + Self { + current_mode: RefCell::new(acp::SessionModeId::new("auto")), + set_modes: RefCell::new(Vec::new()), + } + } + } + + impl AgentSessionModes for TestSessionModes { + fn current_mode(&self) -> acp::SessionModeId { + self.current_mode.borrow().clone() + } + + fn all_modes(&self) -> Vec { + vec![ + acp::SessionMode::new("auto", "Auto"), + acp::SessionMode::new("manual", "Manual"), + ] + } + + fn set_mode(&self, mode: acp::SessionModeId, _cx: &mut App) -> Task> { + *self.current_mode.borrow_mut() = mode.clone(); + self.set_modes.borrow_mut().push(mode); + Task::ready(Ok(())) + } + } +} diff --git a/crates/agent_ui/src/model_selector.rs b/crates/agent_ui/src/model_selector.rs index 47171979496..a04a82793ac 100644 --- a/crates/agent_ui/src/model_selector.rs +++ b/crates/agent_ui/src/model_selector.rs @@ -22,8 +22,7 @@ use util::ResultExt; use zed_actions::agent::OpenSettings; use crate::ui::{ - HoldForDefault, ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem, - documentation_aside_side, + ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem, documentation_aside_side, }; pub type ModelSelector = Picker; @@ -55,7 +54,7 @@ pub struct ModelPickerDelegate { filtered_entries: Vec, models: Option, selected_index: usize, - selected_description: Option<(usize, SharedString, bool)>, + selected_description: Option<(usize, SharedString)>, selected_model: Option, favorites: HashSet, _refresh_models_task: Task<()>, @@ -182,6 +181,9 @@ impl ModelPickerDelegate { let next_model = favorite_models[next_index].clone(); + self.agent_server + .set_default_model(Some(next_model.id.clone()), self.fs.clone(), cx); + self.selector .select_model(next_model.id.clone(), cx) .detach_and_log_err(cx); @@ -277,20 +279,8 @@ impl PickerDelegate for ModelPickerDelegate { if let Some(ModelPickerEntry::Model(model_info, _)) = self.filtered_entries.get(self.selected_index) { - if window.modifiers().secondary() { - let default_model = self.agent_server.default_model(cx); - let is_default = default_model.as_ref() == Some(&model_info.id); - - self.agent_server.set_default_model( - if is_default { - None - } else { - Some(model_info.id.clone()) - }, - self.fs.clone(), - cx, - ); - } + self.agent_server + .set_default_model(Some(model_info.id.clone()), self.fs.clone(), cx); self.selector .select_model(model_info.id.clone(), cx) @@ -322,8 +312,6 @@ impl PickerDelegate for ModelPickerDelegate { } ModelPickerEntry::Model(model_info, is_favorite) => { let is_selected = Some(model_info) == self.selected_model.as_ref(); - let default_model = self.agent_server.default_model(cx); - let is_default = default_model.as_ref() == Some(&model_info.id); let is_favorite = *is_favorite; let handle_action_click = { @@ -350,8 +338,8 @@ impl PickerDelegate for ModelPickerDelegate { this.on_hover(cx.listener(move |menu, hovered, _, cx| { if *hovered { menu.delegate.selected_description = - Some((ix, description.clone(), is_default)); - } else if matches!(menu.delegate.selected_description, Some((id, _, _)) if id == ix) { + Some((ix, description.clone())); + } else if matches!(menu.delegate.selected_description, Some((id, _)) if id == ix) { menu.delegate.selected_description = None; } cx.notify(); @@ -382,29 +370,20 @@ impl PickerDelegate for ModelPickerDelegate { _window: &mut Window, cx: &mut Context>, ) -> Option { - self.selected_description - .as_ref() - .map(|(_, description, is_default)| { - let description = description.clone(); - let is_default = *is_default; + self.selected_description.as_ref().map(|(_, description)| { + let description = description.clone(); - let side = documentation_aside_side(cx); + let side = documentation_aside_side(cx); - DocumentationAside::new( - side, - Rc::new(move |_| { - v_flex() - .gap_1() - .child(Label::new(description.clone())) - .child(HoldForDefault::new(is_default)) - .into_any_element() - }), - ) - }) + DocumentationAside::new( + side, + Rc::new(move |_| Label::new(description.clone()).into_any_element()), + ) + }) } fn documentation_aside_index(&self) -> Option { - self.selected_description.as_ref().map(|(ix, _, _)| *ix) + self.selected_description.as_ref().map(|(ix, _)| *ix) } fn render_footer( @@ -530,7 +509,12 @@ async fn fuzzy_search( #[cfg(test)] mod tests { - use gpui::TestAppContext; + use acp_thread::AgentConnection; + use fs::FakeFs; + use gpui::{App, Entity, TestAppContext, VisualTestContext}; + use parking_lot::Mutex; + use project::{AgentId, Project}; + use std::{any::Any, cell::RefCell}; use super::*; @@ -608,6 +592,138 @@ mod tests { .collect() } + #[gpui::test] + fn confirming_model_saves_selected_model_as_default(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = settings::SettingsStore::test(cx); + cx.set_global(settings_store); + theme_settings::init(theme::LoadThemes::JustBase, cx); + editor::init(cx); + }); + + let agent_server = Rc::new(TestAgentServer::default()); + let model_selector = Rc::new(TestModelSelector::new()); + let fs: Arc = FakeFs::new(cx.executor()); + + let window_handle = cx.add_window({ + let agent_server = agent_server.clone(); + let model_selector = model_selector.clone(); + move |window, cx| { + let selector: Rc = model_selector.clone(); + let agent_server: Rc = agent_server.clone(); + acp_model_selector(selector, agent_server, fs, cx.focus_handle(), window, cx) + } + }); + cx.run_until_parked(); + + let mut cx = VisualTestContext::from_window(window_handle.into(), cx); + window_handle + .update(&mut cx, |picker, window, cx| { + picker.delegate.set_selected_index(1, window, cx); + picker.delegate.confirm(false, window, cx); + }) + .unwrap(); + + assert_eq!( + agent_server.saved_defaults.lock().as_slice(), + &[Some(acp::ModelId::new("manual"))] + ); + assert_eq!( + model_selector.selected_models.borrow().as_slice(), + &[acp::ModelId::new("manual")] + ); + } + + #[derive(Default)] + struct TestAgentServer { + saved_defaults: Arc>>>, + } + + impl AgentServer for TestAgentServer { + fn logo(&self) -> IconName { + IconName::ZedAssistant + } + + fn agent_id(&self) -> AgentId { + AgentId::new("test-agent") + } + + fn connect( + &self, + _delegate: agent_servers::AgentServerDelegate, + _project: Entity, + _cx: &mut App, + ) -> Task>> { + Task::ready(Err(anyhow::anyhow!("test agent server cannot connect"))) + } + + fn into_any(self: Rc) -> Rc { + self + } + + fn set_default_model( + &self, + model_id: Option, + _fs: Arc, + _cx: &mut App, + ) { + self.saved_defaults.lock().push(model_id); + } + } + + struct TestModelSelector { + models: Vec, + selected_model: RefCell, + selected_models: RefCell>, + } + + impl TestModelSelector { + fn new() -> Self { + let models = vec![ + AgentModelInfo { + id: acp::ModelId::new("auto"), + name: "Auto".into(), + description: None, + icon: None, + is_latest: false, + cost: None, + }, + AgentModelInfo { + id: acp::ModelId::new("manual"), + name: "Manual".into(), + description: None, + icon: None, + is_latest: false, + cost: None, + }, + ]; + + Self { + selected_model: RefCell::new(models[0].clone()), + models, + selected_models: RefCell::new(Vec::new()), + } + } + } + + impl AgentModelSelector for TestModelSelector { + fn list_models(&self, _cx: &mut App) -> Task> { + Task::ready(Ok(AgentModelList::Flat(self.models.clone()))) + } + + fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task> { + self.selected_models.borrow_mut().push(model_id.clone()); + if let Some(model) = self.models.iter().find(|model| model.id == model_id) { + *self.selected_model.borrow_mut() = model.clone(); + } + Task::ready(Ok(())) + } + + fn selected_model(&self, _cx: &mut App) -> Task> { + Task::ready(Ok(self.selected_model.borrow().clone())) + } + } + fn get_entry_labels(entries: &[ModelPickerEntry]) -> Vec<&str> { entries .iter() diff --git a/crates/agent_ui/src/ui.rs b/crates/agent_ui/src/ui.rs index d69cfe86a63..4d355de28e4 100644 --- a/crates/agent_ui/src/ui.rs +++ b/crates/agent_ui/src/ui.rs @@ -1,13 +1,11 @@ mod agent_notification; mod end_trial_upsell; -mod hold_for_default; mod mention_crease; mod model_selector_components; mod undo_reject_toast; pub use agent_notification::*; pub use end_trial_upsell::*; -pub use hold_for_default::*; pub use mention_crease::*; pub use model_selector_components::*; pub use undo_reject_toast::*; diff --git a/crates/agent_ui/src/ui/hold_for_default.rs b/crates/agent_ui/src/ui/hold_for_default.rs deleted file mode 100644 index 972f61f0057..00000000000 --- a/crates/agent_ui/src/ui/hold_for_default.rs +++ /dev/null @@ -1,52 +0,0 @@ -use gpui::{App, IntoElement, Modifiers, RenderOnce, Window}; -use ui::{prelude::*, render_modifiers}; - -#[derive(IntoElement)] -pub struct HoldForDefault { - is_default: bool, - more_content: bool, -} - -impl HoldForDefault { - pub fn new(is_default: bool) -> Self { - Self { - is_default, - more_content: true, - } - } - - #[allow(dead_code)] - pub fn more_content(mut self, more_content: bool) -> Self { - self.more_content = more_content; - self - } -} - -impl RenderOnce for HoldForDefault { - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - h_flex() - .when(self.more_content, |this| { - this.pt_1() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - }) - .gap_0p5() - .text_sm() - .text_color(Color::Muted.color(cx)) - .child("Hold") - .child(h_flex().flex_shrink_0().children(render_modifiers( - &Modifiers::secondary_key(), - PlatformStyle::platform(), - None, - Some(TextSize::Default.rems(cx).into()), - false, - ))) - .child(div().map(|this| { - if self.is_default { - this.child("to unset as default") - } else { - this.child("to set as default") - } - })) - } -}