overhaul provider inventory and agent/model selection (#8652)
Some checks are pending
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Cargo Deny / deny (push) Waiting to run
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Goose 2 CI / Lint & Format (push) Waiting to run
Goose 2 CI / Unit Tests (push) Waiting to run
Goose 2 CI / Desktop Build & E2E (push) Waiting to run
Goose 2 CI / Rust Lint (push) Waiting to run
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run

Signed-off-by: Bradley Axen <baxen@squareup.com>
This commit is contained in:
Bradley Axen 2026-04-20 15:00:17 -07:00 committed by GitHub
parent 3d582943fd
commit 8eda6fdabc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
70 changed files with 5321 additions and 2123 deletions

View file

@ -51,9 +51,14 @@
"responseType": "GetProviderDetailsResponse"
},
{
"method": "_goose/providers/models",
"requestType": "GetProviderModelsRequest",
"responseType": "GetProviderModelsResponse"
"method": "_goose/providers/inventory",
"requestType": "GetProviderInventoryRequest",
"responseType": "GetProviderInventoryResponse"
},
{
"method": "_goose/providers/inventory/refresh",
"requestType": "RefreshProviderInventoryRequest",
"responseType": "RefreshProviderInventoryResponse"
},
{
"method": "_goose/config/read",

View file

@ -362,36 +362,224 @@
"contextLimit"
]
},
"GetProviderModelsRequest": {
"GetProviderInventoryRequest": {
"type": "object",
"properties": {
"providerName": {
"type": "string"
}
},
"required": [
"providerName"
],
"description": "Fetch the full list of models available for a specific provider.",
"x-side": "agent",
"x-method": "_goose/providers/models"
},
"GetProviderModelsResponse": {
"type": "object",
"properties": {
"models": {
"providerIds": {
"type": "array",
"items": {
"type": "string"
},
"description": "Only return entries for these providers. Empty means all.",
"default": []
}
},
"description": "Read per-provider inventory. Always returns immediately from stored state.",
"x-side": "agent",
"x-method": "_goose/providers/inventory"
},
"GetProviderInventoryResponse": {
"type": "object",
"properties": {
"entries": {
"type": "array",
"items": {
"$ref": "#/$defs/ProviderInventoryEntryDto"
}
}
},
"required": [
"models"
"entries"
],
"description": "Provider models response.",
"description": "Provider inventory response.",
"x-side": "agent",
"x-method": "_goose/providers/models"
"x-method": "_goose/providers/inventory"
},
"ProviderInventoryEntryDto": {
"type": "object",
"properties": {
"providerId": {
"type": "string",
"description": "Provider identifier."
},
"providerName": {
"type": "string",
"description": "Human-readable provider name."
},
"configured": {
"type": "boolean",
"description": "Whether Goose has enough configuration to use this provider."
},
"supportsRefresh": {
"type": "boolean",
"description": "Whether this provider supports background inventory refresh."
},
"refreshing": {
"type": "boolean",
"description": "Whether a refresh is currently in flight."
},
"models": {
"type": "array",
"items": {
"$ref": "#/$defs/ProviderInventoryModelDto"
},
"description": "The list of available models."
},
"lastUpdatedAt": {
"type": [
"string",
"null"
],
"description": "When this entry was last successfully refreshed (ISO 8601)."
},
"lastRefreshAttemptAt": {
"type": [
"string",
"null"
],
"description": "When a refresh was most recently attempted (ISO 8601)."
},
"lastRefreshError": {
"type": [
"string",
"null"
],
"description": "The last refresh failure message, if any."
},
"stale": {
"type": "boolean",
"description": "Whether we believe this data may be outdated."
},
"modelSelectionHint": {
"type": [
"string",
"null"
],
"description": "Guidance message shown when this provider manages its own model selection externally."
}
},
"required": [
"providerId",
"providerName",
"configured",
"supportsRefresh",
"refreshing",
"models",
"stale"
],
"description": "Provider inventory entry."
},
"ProviderInventoryModelDto": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Model identifier as the provider knows it."
},
"name": {
"type": "string",
"description": "Human-readable display name."
},
"family": {
"type": [
"string",
"null"
],
"description": "Model family for grouping in UI."
},
"contextLimit": {
"type": [
"integer",
"null"
],
"format": "uint",
"minimum": 0,
"description": "Context window size in tokens."
},
"reasoning": {
"type": [
"boolean",
"null"
],
"description": "Whether the model supports reasoning/extended thinking."
},
"recommended": {
"type": "boolean",
"description": "Whether this model should appear in the compact recommended picker.",
"default": false
}
},
"required": [
"id",
"name"
],
"description": "A single model in provider inventory."
},
"RefreshProviderInventoryRequest": {
"type": "object",
"properties": {
"providerIds": {
"type": "array",
"items": {
"type": "string"
},
"description": "Which providers to refresh. Empty means all known providers.",
"default": []
}
},
"description": "Trigger a background refresh of provider inventories.",
"x-side": "agent",
"x-method": "_goose/providers/inventory/refresh"
},
"RefreshProviderInventoryResponse": {
"type": "object",
"properties": {
"started": {
"type": "array",
"items": {
"type": "string"
},
"description": "Which providers will be refreshed."
},
"skipped": {
"type": "array",
"items": {
"$ref": "#/$defs/RefreshProviderInventorySkipDto"
},
"description": "Which providers were skipped and why.",
"default": []
}
},
"required": [
"started"
],
"description": "Refresh acknowledgement.",
"x-side": "agent",
"x-method": "_goose/providers/inventory/refresh"
},
"RefreshProviderInventorySkipDto": {
"type": "object",
"properties": {
"providerId": {
"type": "string"
},
"reason": {
"$ref": "#/$defs/RefreshProviderInventorySkipReasonDto"
}
},
"required": [
"providerId",
"reason"
]
},
"RefreshProviderInventorySkipReasonDto": {
"type": "string",
"enum": [
"unknown_provider",
"not_configured",
"does_not_support_refresh",
"already_refreshing"
]
},
"ReadConfigRequest": {
"type": "object",
@ -1035,11 +1223,20 @@
{
"allOf": [
{
"$ref": "#/$defs/GetProviderModelsRequest"
"$ref": "#/$defs/GetProviderInventoryRequest"
}
],
"description": "Params for _goose/providers/models",
"title": "GetProviderModelsRequest"
"description": "Params for _goose/providers/inventory",
"title": "GetProviderInventoryRequest"
},
{
"allOf": [
{
"$ref": "#/$defs/RefreshProviderInventoryRequest"
}
],
"description": "Params for _goose/providers/inventory/refresh",
"title": "RefreshProviderInventoryRequest"
},
{
"allOf": [
@ -1292,10 +1489,18 @@
{
"allOf": [
{
"$ref": "#/$defs/GetProviderModelsResponse"
"$ref": "#/$defs/GetProviderInventoryResponse"
}
],
"title": "GetProviderModelsResponse"
"title": "GetProviderInventoryResponse"
},
{
"allOf": [
{
"$ref": "#/$defs/RefreshProviderInventoryResponse"
}
],
"title": "RefreshProviderInventoryResponse"
},
{
"allOf": [

View file

@ -27,6 +27,9 @@ use goose::mcp_utils::ToolResult;
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::{Permission, PermissionConfirmation};
use goose::providers::base::Provider;
use goose::providers::inventory::{
ProviderInventoryEntry, ProviderInventoryService, RefreshSkipReason,
};
use goose::session::session_manager::SessionType;
use goose::session::{EnabledExtensionsState, Session, SessionManager};
use goose_acp_macros::custom_methods;
@ -125,6 +128,8 @@ struct AgentSetupRequest {
/// Pre-resolved provider name + model config (from config, no network).
/// When present the spawn skips re-deriving these from config.
resolved_provider: Option<(String, goose::model::ModelConfig)>,
/// Pre-instantiated provider reused from synchronous session initialization.
prebuilt_provider: Option<Arc<dyn Provider>>,
}
pub struct GooseAcpAgent {
@ -139,6 +144,7 @@ pub struct GooseAcpAgent {
permission_manager: Arc<PermissionManager>,
goose_mode: GooseMode,
disable_session_naming: bool,
provider_inventory: ProviderInventoryService,
}
/// Shorten a session/thread id for perf log correlation.
@ -415,19 +421,50 @@ fn builtin_to_extension_config(name: &str) -> ExtensionConfig {
}
}
async fn build_model_state(provider: &dyn Provider) -> Result<SessionModelState, sacp::Error> {
let models = provider
.fetch_recommended_models()
.await
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
let current_model = &provider.get_model_config().model_name;
Ok(SessionModelState::new(
ModelId::new(current_model.as_str()),
models
.iter()
.map(|name| ModelInfo::new(ModelId::new(&**name), &**name))
fn inventory_entry_to_dto(entry: ProviderInventoryEntry) -> ProviderInventoryEntryDto {
let stale = ProviderInventoryService::is_stale(&entry);
ProviderInventoryEntryDto {
provider_id: entry.provider_id,
provider_name: entry.provider_name,
configured: entry.configured,
supports_refresh: entry.supports_refresh,
refreshing: entry.refreshing,
models: entry
.models
.into_iter()
.map(|m| ProviderInventoryModelDto {
id: m.id,
name: m.name,
family: m.family,
context_limit: m.context_limit,
reasoning: m.reasoning,
recommended: m.recommended,
})
.collect(),
))
last_updated_at: entry.last_updated_at.map(|t| t.to_rfc3339()),
last_refresh_attempt_at: entry.last_refresh_attempt_at.map(|t| t.to_rfc3339()),
last_refresh_error: entry.last_refresh_error,
stale,
model_selection_hint: entry.model_selection_hint,
}
}
fn build_model_state(current_model: &str, inventory: &ProviderInventoryEntry) -> SessionModelState {
let mut available_models = inventory
.models
.iter()
.map(|model| ModelInfo::new(ModelId::new(model.id.as_str()), model.name.as_str()))
.collect::<Vec<_>>();
if !available_models
.iter()
.any(|model| model.model_id.0.as_ref() == current_model)
{
available_models.insert(
0,
ModelInfo::new(ModelId::new(current_model), current_model),
);
}
SessionModelState::new(ModelId::new(current_model), available_models)
}
async fn list_provider_entries(current_provider: Option<&str>) -> Vec<ProviderListEntry> {
@ -546,31 +583,25 @@ fn build_mode_state(current_mode: GooseMode) -> Result<SessionModeState, sacp::E
))
}
/// Build model state and config options eagerly from the canonical registry.
///
/// TODO: This trades speed for correctness — the canonical registry may not perfectly
/// match what the provider API returns (new models not yet in the registry, deprecated
/// models still listed, or locally-installed models for providers like Ollama). Consider
/// whether to reconcile with a live API call in the background.
async fn build_eager_config(
resolved: &Result<(String, goose::model::ModelConfig), String>,
fn should_refresh_inventory_for_session_init(entry: &ProviderInventoryEntry) -> bool {
entry.configured
&& entry.supports_refresh
&& (entry.last_updated_at.is_none() || ProviderInventoryService::is_stale(entry))
}
async fn build_eager_config_from_inventory(
provider_name: &str,
current_model: &str,
inventory: &ProviderInventoryEntry,
mode_state: &SessionModeState,
goose_session: &Session,
) -> (Option<SessionModelState>, Option<Vec<SessionConfigOption>>) {
let Ok((ref provider_name, ref mc)) = resolved else {
return (None, None);
};
let recommended = goose::providers::canonical::recommended_models_from_registry(provider_name);
let available: Vec<ModelInfo> = recommended
.iter()
.map(|name| ModelInfo::new(ModelId::new(&**name), &**name))
.collect();
let ms = SessionModelState::new(ModelId::new(mc.model_name.as_str()), available);
) -> (SessionModelState, Vec<SessionConfigOption>) {
let ms = build_model_state(current_model, inventory);
let provider_selection = session_provider_selection(goose_session);
let provider_options = build_provider_options(Some(provider_name.as_str())).await;
let provider_options = build_provider_options(Some(provider_name)).await;
let config_options =
build_config_options(mode_state, &ms, provider_selection, provider_options);
(Some(ms), Some(config_options))
(ms, config_options)
}
fn build_config_options(
@ -651,6 +682,7 @@ impl GooseAcpAgent {
session_manager.storage().clone(),
));
let permission_manager = Arc::new(PermissionManager::new(config_dir.clone()));
let provider_inventory = ProviderInventoryService::new(session_manager.storage().clone());
Ok(Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
@ -664,6 +696,7 @@ impl GooseAcpAgent {
permission_manager,
goose_mode,
disable_session_naming,
provider_inventory,
})
}
@ -680,6 +713,125 @@ impl GooseAcpAgent {
(self.provider_factory)(provider_name.to_string(), model_config, extensions).await
}
async fn prepare_session_init_config(
&self,
resolved: &Result<(String, goose::model::ModelConfig), String>,
mode_state: &SessionModeState,
goose_session: &Session,
) -> (
Option<SessionModelState>,
Option<Vec<SessionConfigOption>>,
Option<Arc<dyn Provider>>,
) {
let Ok((provider_name, model_config)) = resolved else {
return (None, None, None);
};
let Some(mut inventory) = self
.provider_inventory
.entry_for_provider(provider_name)
.await
.ok()
.flatten()
else {
return (None, None, None);
};
let mut prebuilt_provider = None;
if should_refresh_inventory_for_session_init(&inventory) {
match self.load_config() {
Ok(config) => {
let ext_state = EnabledExtensionsState::extensions_or_default(
Some(&goose_session.extension_data),
&config,
);
match self
.create_provider(provider_name, model_config.clone(), ext_state)
.await
{
Ok(provider) => {
let provider_id = provider_name.clone();
prebuilt_provider = Some(provider.clone());
match self
.provider_inventory
.plan_refresh(std::slice::from_ref(&provider_id))
.await
{
Ok(plan) if plan.started.iter().any(|id| id == &provider_id) => {
match provider.fetch_recommended_models().await {
Ok(models) => {
if let Err(error) = self
.provider_inventory
.store_refreshed_models(&provider_id, &models)
.await
{
warn!(
provider = %provider_id,
error = %error,
"failed to store refreshed provider inventory during session init"
);
}
}
Err(error) => {
if let Err(store_error) = self
.provider_inventory
.store_refresh_error(
&provider_id,
error.to_string(),
)
.await
{
warn!(
provider = %provider_id,
error = %store_error,
"failed to store provider inventory refresh error during session init"
);
}
}
}
}
Ok(_) => {}
Err(error) => warn!(
provider = %provider_id,
error = %error,
"failed to plan provider inventory refresh during session init"
),
}
if let Ok(Some(refreshed_inventory)) = self
.provider_inventory
.entry_for_provider(provider_name)
.await
{
inventory = refreshed_inventory;
}
}
Err(error) => warn!(
provider = %provider_name,
error = %error,
"failed to initialize provider during synchronous inventory refresh"
),
}
}
Err(error) => warn!(
provider = %provider_name,
error = %error,
"failed to load config during synchronous inventory refresh"
),
}
}
let (model_state, config_options) = build_eager_config_from_inventory(
provider_name,
model_config.model_name.as_str(),
&inventory,
mode_state,
goose_session,
)
.await;
(Some(model_state), Some(config_options), prebuilt_provider)
}
fn spawn_agent_setup(
&self,
cx: &ConnectionTo<Client>,
@ -691,6 +843,7 @@ impl GooseAcpAgent {
goose_session,
mcp_servers,
resolved_provider,
prebuilt_provider,
} = req;
let goose_mode = goose_session.goose_mode;
@ -845,9 +998,12 @@ impl GooseAcpAgent {
Some(&goose_session.extension_data),
&config,
);
let provider = provider_factory(provider_name.to_string(), model_config, ext_state)
.await
.map_err(|e| e.to_string())?;
let provider = match prebuilt_provider {
Some(provider) => provider,
None => provider_factory(provider_name.to_string(), model_config, ext_state)
.await
.map_err(|e| e.to_string())?,
};
agent
.update_provider(provider.clone(), &goose_session.id)
.await
@ -1416,9 +1572,10 @@ impl GooseAcpAgent {
.as_ref()
.ok()
.map(|(_, mc)| build_usage_update(&goose_session, mc.context_limit()));
let (model_state, config_options) =
build_eager_config(&resolved, &mode_state, &goose_session).await;
let session_id = SessionId::new(thread_id.clone());
let (model_state, config_options, prebuilt_provider) = self
.prepare_session_init_config(&resolved, &mode_state, &goose_session)
.await;
self.spawn_agent_setup(
cx,
@ -1428,6 +1585,7 @@ impl GooseAcpAgent {
goose_session,
mcp_servers: args.mcp_servers,
resolved_provider: resolved.as_ref().ok().cloned(),
prebuilt_provider,
},
);
@ -1798,8 +1956,9 @@ impl GooseAcpAgent {
.as_ref()
.map(|mc| build_usage_update(&goose_session, mc.context_limit()))
});
let (model_state, config_options) =
build_eager_config(&resolved, &mode_state, &goose_session).await;
let (model_state, config_options, prebuilt_provider) = self
.prepare_session_init_config(&resolved, &mode_state, &goose_session)
.await;
self.spawn_agent_setup(
cx,
@ -1809,6 +1968,7 @@ impl GooseAcpAgent {
goose_session,
mcp_servers: args.mcp_servers,
resolved_provider: None,
prebuilt_provider,
},
);
@ -2116,10 +2276,21 @@ impl GooseAcpAgent {
let provider = agent.provider().await.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to get provider: {}", e))
})?;
let provider_name = provider.get_name().to_string();
let current_model = provider.get_model_config().model_name.clone();
let goose_mode = agent.goose_mode().await;
let model_state = build_model_state(&*provider).await?;
let inventory = self
.provider_inventory
.entry_for_provider(&provider_name)
.await
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
let Some(inventory) = inventory else {
return Err(sacp::Error::internal_error()
.data(format!("Unknown provider inventory: {}", provider_name)));
};
let model_state = build_model_state(current_model.as_str(), &inventory);
let mode_state = build_mode_state(goose_mode)?;
let provider_options = build_provider_options(Some(provider.get_name())).await;
let provider_options = build_provider_options(Some(&provider_name)).await;
let config_options = build_config_options(
&mode_state,
&model_state,
@ -2399,8 +2570,9 @@ impl GooseAcpAgent {
let mode_state = build_mode_state(self.goose_mode)?;
let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await;
let (model_state, config_options) =
build_eager_config(&resolved, &mode_state, &goose_session).await;
let (model_state, config_options, prebuilt_provider) = self
.prepare_session_init_config(&resolved, &mode_state, &goose_session)
.await;
self.spawn_agent_setup(
cx,
@ -2410,6 +2582,7 @@ impl GooseAcpAgent {
goose_session,
mcp_servers: args.mcp_servers,
resolved_provider: resolved.ok(),
prebuilt_provider,
},
);
@ -2677,58 +2850,82 @@ impl GooseAcpAgent {
Ok(GetProviderDetailsResponse { providers: entries })
}
#[custom_method(GetProviderModelsRequest)]
async fn on_get_provider_models(
#[custom_method(GetProviderInventoryRequest)]
async fn on_get_provider_inventory(
&self,
req: GetProviderModelsRequest,
) -> Result<GetProviderModelsResponse, sacp::Error> {
let config = self.load_config().ok();
let all = goose::providers::providers().await;
req: GetProviderInventoryRequest,
) -> Result<GetProviderInventoryResponse, sacp::Error> {
let entries = self
.provider_inventory
.entries(&req.provider_ids)
.await
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
Ok(GetProviderInventoryResponse {
entries: entries.into_iter().map(inventory_entry_to_dto).collect(),
})
}
let Some((metadata, _provider_type)) =
all.into_iter().find(|(m, _)| m.name == req.provider_name)
else {
return Err(sacp::Error::invalid_params()
.data(format!("Unknown provider: {}", req.provider_name)));
};
let is_configured = config
.as_ref()
.map(|c| {
metadata.config_keys.iter().all(|k| {
if !k.required {
return true;
}
if k.secret {
c.get_secret::<String>(&k.name).is_ok()
} else {
c.get_param::<String>(&k.name).is_ok()
}
})
})
.unwrap_or(false);
if !is_configured {
return Err(sacp::Error::invalid_params().data(format!(
"Provider '{}' is not configured",
req.provider_name
)));
#[custom_method(RefreshProviderInventoryRequest)]
async fn on_refresh_provider_inventory(
&self,
req: RefreshProviderInventoryRequest,
) -> Result<RefreshProviderInventoryResponse, sacp::Error> {
let refresh_plan = self
.provider_inventory
.plan_refresh(&req.provider_ids)
.await;
let refresh_plan =
refresh_plan.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
for provider_id in &refresh_plan.started {
let provider_inventory = self.provider_inventory.clone();
let provider_factory = Arc::clone(&self.provider_factory);
let provider_id = provider_id.clone();
tokio::spawn(async move {
let result = async {
let metadata = goose::providers::get_from_registry(&provider_id).await?;
let model_config =
goose::model::ModelConfig::new(&metadata.metadata().default_model)?
.with_canonical_limits(&provider_id);
let provider =
provider_factory(provider_id.clone(), model_config, Vec::new()).await?;
let models = provider.fetch_recommended_models().await?;
provider_inventory
.store_refreshed_models(&provider_id, &models)
.await
}
.await;
if let Err(error) = result {
let _ = provider_inventory
.store_refresh_error(&provider_id, error.to_string())
.await;
warn!(provider = %provider_id, error = %error, "provider inventory refresh failed");
}
});
}
let model_config = goose::model::ModelConfig::new(&metadata.default_model)
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?
.with_canonical_limits(&req.provider_name);
let provider = (self.provider_factory)(req.provider_name.clone(), model_config, Vec::new())
.await
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
let models = provider
.fetch_recommended_models()
.await
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
Ok(GetProviderModelsResponse { models })
Ok(RefreshProviderInventoryResponse {
started: refresh_plan.started,
skipped: refresh_plan
.skipped
.into_iter()
.map(|entry| RefreshProviderInventorySkipDto {
provider_id: entry.provider_id,
reason: match entry.reason {
RefreshSkipReason::UnknownProvider => {
RefreshProviderInventorySkipReasonDto::UnknownProvider
}
RefreshSkipReason::NotConfigured => {
RefreshProviderInventorySkipReasonDto::NotConfigured
}
RefreshSkipReason::DoesNotSupportRefresh => {
RefreshProviderInventorySkipReasonDto::DoesNotSupportRefresh
}
RefreshSkipReason::AlreadyRefreshing => {
RefreshProviderInventorySkipReasonDto::AlreadyRefreshing
}
},
})
.collect(),
})
}
#[custom_method(ReadConfigRequest)]
@ -3450,11 +3647,77 @@ impl HandleDispatchFrom<Client> for GooseAcpHandler {
return Ok(());
}
}
// Respond immediately using the current provider inventory snapshot.
let t_tail = std::time::Instant::now();
let (notification, config_options) = agent.build_config_update(&session_id).await?;
cx.send_notification(notification)?;
responder.respond(SetSessionConfigOptionResponse::new(config_options))?;
debug!(target: "perf", sid = %sid, ms = t_tail.elapsed().as_millis() as u64, "perf: set_config_option notification_and_respond");
debug!(target: "perf", sid = %sid, ms = t_tail.elapsed().as_millis() as u64, "perf: set_config_option inventory_respond");
let maybe_refresh = if config_id == "provider" {
let provider_id = value_id.0.to_string();
agent
.provider_inventory
.plan_refresh(std::slice::from_ref(&provider_id))
.await
.ok()
.filter(|plan| plan.started.iter().any(|id| id == &provider_id))
} else {
None
};
if maybe_refresh.is_some() {
let agent_bg = agent.clone();
let cx_bg = cx.clone();
let session_id_bg = session_id.clone();
let sid_bg = sid.clone();
tokio::spawn(async move {
let t_bg = std::time::Instant::now();
let refreshed = async {
let session_agent =
agent_bg.get_session_agent(&session_id_bg.0, None).await?;
let provider = session_agent
.provider()
.await
.map_err(|e| anyhow::anyhow!(e.to_string()))?;
let provider_name = provider.get_name().to_string();
let models = provider
.fetch_recommended_models()
.await
.map_err(|e| anyhow::anyhow!(e.to_string()))?;
agent_bg
.provider_inventory
.store_refreshed_models(&provider_name, &models)
.await?;
agent_bg
.build_config_update(&session_id_bg)
.await
.map_err(|e| anyhow::anyhow!(e.to_string()))
}
.await;
match refreshed {
Ok((fresh_notification, _)) => {
let _ = cx_bg.send_notification(fresh_notification);
debug!(target: "perf", sid = %sid_bg, ms = t_bg.elapsed().as_millis() as u64, "perf: set_config_option background_refresh done");
}
Err(e) => {
if let Ok(session_agent) =
agent_bg.get_session_agent(&session_id_bg.0, None).await
{
if let Ok(provider) = session_agent.provider().await {
let provider_name = provider.get_name().to_string();
let _ = agent_bg
.provider_inventory
.store_refresh_error(&provider_name, e.to_string())
.await;
}
}
debug!(target: "perf", sid = %sid_bg, error = %e, ms = t_bg.elapsed().as_millis() as u64, "perf: set_config_option background_refresh failed");
}
}
});
}
debug!(target: "perf", sid = %sid, ms = t_handler.elapsed().as_millis() as u64, config_id = %config_id, "perf: set_config_option done");
Ok(())
}
@ -3597,7 +3860,6 @@ pub async fn run(builtins: Vec<String>) -> Result<()> {
mod tests {
use super::*;
use goose::conversation::message::{ToolRequest, ToolResponse};
use goose::providers::errors::ProviderError;
use rmcp::model::{CallToolRequestParams, Content as RmcpContent};
use sacp::schema::{
EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio,
@ -3787,61 +4049,48 @@ print(\"hello, world\")
assert_eq!(outcome_to_confirmation(&input), expected);
}
struct MockModelProvider {
models: Result<Vec<String>, ProviderError>,
}
#[async_trait::async_trait]
impl Provider for MockModelProvider {
fn get_name(&self) -> &str {
"mock"
}
async fn stream(
&self,
_model_config: &goose::model::ModelConfig,
_session_id: &str,
_system: &str,
_messages: &[goose::conversation::message::Message],
_tools: &[rmcp::model::Tool],
) -> Result<goose::providers::base::MessageStream, ProviderError> {
unimplemented!()
}
fn get_model_config(&self) -> goose::model::ModelConfig {
goose::model::ModelConfig::new_or_fail("unused")
}
async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> {
self.models.clone()
}
}
#[test_case(
Ok(vec!["model-a".into(), "model-b".into()])
=> Ok(SessionModelState::new(
vec!["model-a".into(), "model-b".into()]
=> SessionModelState::new(
ModelId::new("unused"),
vec![ModelInfo::new(ModelId::new("model-a"), "model-a"),
vec![ModelInfo::new(ModelId::new("unused"), "unused"),
ModelInfo::new(ModelId::new("model-a"), "model-a"),
ModelInfo::new(ModelId::new("model-b"), "model-b")],
))
)
; "returns current and available models"
)]
#[test_case(
Ok(vec![])
=> Ok(SessionModelState::new(ModelId::new("unused"), vec![]))
vec![]
=> SessionModelState::new(
ModelId::new("unused"),
vec![ModelInfo::new(ModelId::new("unused"), "unused")],
)
; "empty model list"
)]
#[test_case(
Err(ProviderError::ExecutionError("fail".into()))
=> Err(sacp::Error::internal_error().data("Execution error: fail".to_string()))
; "fetch error propagates"
)]
#[tokio::test]
async fn test_build_model_state(
models: Result<Vec<String>, ProviderError>,
) -> Result<SessionModelState, sacp::Error> {
let provider = MockModelProvider { models };
build_model_state(&provider).await
fn test_build_model_state(models: Vec<String>) -> SessionModelState {
let inventory = ProviderInventoryEntry {
provider_id: "mock".to_string(),
provider_name: "Mock".to_string(),
configured: true,
supports_refresh: true,
refreshing: false,
models: models
.into_iter()
.map(|id| goose::providers::inventory::InventoryModel {
name: id.clone(),
id,
family: None,
context_limit: None,
reasoning: None,
recommended: false,
})
.collect(),
last_updated_at: None,
last_refresh_attempt_at: None,
last_refresh_error: None,
model_selection_hint: None,
};
build_model_state("unused", &inventory)
}
fn json_object(pairs: Vec<(&str, serde_json::Value)>) -> rmcp::model::JsonObject {

View file

@ -257,20 +257,6 @@ pub struct GetProviderDetailsResponse {
pub providers: Vec<ProviderDetailEntry>,
}
/// Fetch the full list of models available for a specific provider.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)]
#[request(method = "_goose/providers/models", response = GetProviderModelsResponse)]
#[serde(rename_all = "camelCase")]
pub struct GetProviderModelsRequest {
pub provider_name: String,
}
/// Provider models response.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
pub struct GetProviderModelsResponse {
pub models: Vec<String>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ProviderDetailEntry {
@ -370,6 +356,120 @@ pub struct DictationConfigResponse {
pub providers: HashMap<String, DictationProviderStatusEntry>,
}
/// Read per-provider inventory. Always returns immediately from stored state.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)]
#[request(
method = "_goose/providers/inventory",
response = GetProviderInventoryResponse
)]
#[serde(rename_all = "camelCase")]
pub struct GetProviderInventoryRequest {
/// Only return entries for these providers. Empty means all.
#[serde(default)]
pub provider_ids: Vec<String>,
}
/// Provider inventory response.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
pub struct GetProviderInventoryResponse {
pub entries: Vec<ProviderInventoryEntryDto>,
}
/// Trigger a background refresh of provider inventories.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)]
#[request(
method = "_goose/providers/inventory/refresh",
response = RefreshProviderInventoryResponse
)]
#[serde(rename_all = "camelCase")]
pub struct RefreshProviderInventoryRequest {
/// Which providers to refresh. Empty means all known providers.
#[serde(default)]
pub provider_ids: Vec<String>,
}
/// Refresh acknowledgement.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
#[serde(rename_all = "camelCase")]
pub struct RefreshProviderInventoryResponse {
/// Which providers will be refreshed.
pub started: Vec<String>,
/// Which providers were skipped and why.
#[serde(default)]
pub skipped: Vec<RefreshProviderInventorySkipDto>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct RefreshProviderInventorySkipDto {
pub provider_id: String,
pub reason: RefreshProviderInventorySkipReasonDto,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum RefreshProviderInventorySkipReasonDto {
#[default]
UnknownProvider,
NotConfigured,
DoesNotSupportRefresh,
AlreadyRefreshing,
}
/// A single model in provider inventory.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ProviderInventoryModelDto {
/// Model identifier as the provider knows it.
pub id: String,
/// Human-readable display name.
pub name: String,
/// Model family for grouping in UI.
#[serde(skip_serializing_if = "Option::is_none")]
pub family: Option<String>,
/// Context window size in tokens.
#[serde(skip_serializing_if = "Option::is_none")]
pub context_limit: Option<usize>,
/// Whether the model supports reasoning/extended thinking.
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<bool>,
/// Whether this model should appear in the compact recommended picker.
#[serde(default)]
pub recommended: bool,
}
/// Provider inventory entry.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ProviderInventoryEntryDto {
/// Provider identifier.
pub provider_id: String,
/// Human-readable provider name.
pub provider_name: String,
/// Whether Goose has enough configuration to use this provider.
pub configured: bool,
/// Whether this provider supports background inventory refresh.
pub supports_refresh: bool,
/// Whether a refresh is currently in flight.
pub refreshing: bool,
/// The list of available models.
pub models: Vec<ProviderInventoryModelDto>,
/// When this entry was last successfully refreshed (ISO 8601).
#[serde(skip_serializing_if = "Option::is_none")]
pub last_updated_at: Option<String>,
/// When a refresh was most recently attempted (ISO 8601).
#[serde(skip_serializing_if = "Option::is_none")]
pub last_refresh_attempt_at: Option<String>,
/// The last refresh failure message, if any.
#[serde(skip_serializing_if = "Option::is_none")]
pub last_refresh_error: Option<String>,
/// Whether we believe this data may be outdated.
pub stale: bool,
/// Guidance message shown when this provider manages its own model selection externally.
#[serde(skip_serializing_if = "Option::is_none")]
pub model_selection_hint: Option<String>,
}
/// Empty success response for operations that return no data.
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
pub struct EmptyResponse {}

View file

@ -416,6 +416,11 @@ mod tests {
use std::sync::Arc;
let tmp_dir = tempfile::tempdir().unwrap();
let temp_root = tmp_dir.path().display().to_string();
let _guard = env_lock::lock_env([
("HOME", Some(temp_root.as_str())),
("GOOSE_PATH_ROOT", Some(temp_root.as_str())),
]);
let session_manager = Arc::new(SessionManager::new(tmp_dir.path().to_path_buf()));
let session = session_manager
.create_session(

View file

@ -2,6 +2,7 @@ use crate::config::paths::Paths;
use crate::config::Config;
use crate::providers::anthropic::AnthropicProvider;
use crate::providers::base::{ModelInfo, ProviderType};
use crate::providers::inventory::declarative_inventory_identity;
use crate::providers::ollama::OllamaProvider;
use crate::providers::openai::OpenAiProvider;
use anyhow::Result;
@ -460,38 +461,59 @@ pub fn register_declarative_provider(
match config.engine {
ProviderEngine::OpenAI => {
let captured = config.clone();
registry.register_with_name::<OpenAiProvider, _>(
let identity_config = config.clone();
registry.register_with_name::<OpenAiProvider, _, _>(
&config,
provider_type,
config.dynamic_models.unwrap_or(false),
move |model| {
let mut cfg = captured.clone();
resolve_config(&mut cfg)?;
OpenAiProvider::from_custom_config(model, cfg)
},
move || {
let mut cfg = identity_config.clone();
resolve_config(&mut cfg)?;
declarative_inventory_identity(&cfg)
},
);
}
ProviderEngine::Ollama => {
let captured = config.clone();
registry.register_with_name::<OllamaProvider, _>(
let identity_config = config.clone();
registry.register_with_name::<OllamaProvider, _, _>(
&config,
provider_type,
config.dynamic_models.unwrap_or(false),
move |model| {
let mut cfg = captured.clone();
resolve_config(&mut cfg)?;
OllamaProvider::from_custom_config(model, cfg)
},
move || {
let mut cfg = identity_config.clone();
resolve_config(&mut cfg)?;
declarative_inventory_identity(&cfg)
},
);
}
ProviderEngine::Anthropic => {
let captured = config.clone();
registry.register_with_name::<AnthropicProvider, _>(
let identity_config = config.clone();
registry.register_with_name::<AnthropicProvider, _, _>(
&config,
provider_type,
config.dynamic_models.unwrap_or(false),
move |model| {
let mut cfg = captured.clone();
resolve_config(&mut cfg)?;
AnthropicProvider::from_custom_config(model, cfg)
},
move || {
let mut cfg = identity_config.clone();
resolve_config(&mut cfg)?;
declarative_inventory_identity(&cfg)
},
);
}
}

View file

@ -0,0 +1,18 @@
use crate::config::search_path::SearchPaths;
use crate::providers::inventory::InventoryIdentityInput;
use anyhow::Result;
use std::path::PathBuf;
pub fn acp_adapter_installed(command: &str) -> bool {
resolve_acp_command(command).is_ok()
}
pub fn acp_inventory_identity(provider_id: &str, command: &str) -> Result<InventoryIdentityInput> {
let resolved_command = resolve_acp_command(command)?;
Ok(InventoryIdentityInput::new(provider_id, provider_id)
.with_public("command", resolved_command.display().to_string()))
}
fn resolve_acp_command(command: &str) -> Result<PathBuf> {
SearchPaths::builder().with_npm().resolve(command)
}

View file

@ -9,7 +9,9 @@ use crate::acp::{
use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;
const AMP_ACP_PROVIDER_NAME: &str = "amp-acp";
const AMP_ACP_DOC_URL: &str = "https://ampcode.com";
@ -37,6 +39,7 @@ impl ProviderDef for AmpAcpProvider {
"Set in your goose config file (`~/.config/goose/config.yaml` on macOS/Linux):\n GOOSE_PROVIDER: amp-acp\n GOOSE_MODEL: current",
"Restart goose for changes to take effect",
])
.with_model_selection_hint("Use the Amp CLI to configure models")
}
fn from_env(
@ -49,10 +52,12 @@ impl ProviderDef for AmpAcpProvider {
let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto);
let mode_mapping = HashMap::from([
(GooseMode::Auto, "auto".to_string()),
(GooseMode::Approve, "approve".to_string()),
(GooseMode::SmartApprove, "smart-approve".to_string()),
(GooseMode::Chat, "chat".to_string()),
// "bypass" skips confirmations, closest to autonomous mode.
(GooseMode::Auto, "bypass".to_string()),
// "default" prompts before risky actions.
(GooseMode::Approve, "default".to_string()),
(GooseMode::SmartApprove, "default".to_string()),
(GooseMode::Chat, "default".to_string()),
]);
let provider_config = AcpProviderConfig {
@ -71,4 +76,16 @@ impl ProviderDef for AmpAcpProvider {
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
})
}
fn supports_inventory_refresh() -> bool {
false
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
acp_inventory_identity(AMP_ACP_PROVIDER_NAME, AMP_ACP_BINARY)
}
fn inventory_configured() -> bool {
acp_adapter_installed(AMP_ACP_BINARY)
}
}

View file

@ -14,6 +14,7 @@ use super::errors::ProviderError;
use super::formats::anthropic::{
create_request, response_to_streaming_message, thinking_type, ThinkingType,
};
use super::inventory::{config_secret_value, serialize_string_map, InventoryIdentityInput};
use super::openai_compatible::handle_status_openai_compat;
use super::openai_compatible::map_http_error_to_provider_error;
use super::retry::ProviderRetry;
@ -235,6 +236,33 @@ impl ProviderDef for AnthropicProvider {
) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
}
fn supports_inventory_refresh() -> bool {
true
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
let config = crate::config::Config::global();
let mut identity =
InventoryIdentityInput::new(ANTHROPIC_PROVIDER_NAME, ANTHROPIC_PROVIDER_NAME)
.with_public(
"host",
config
.get_param::<String>("ANTHROPIC_HOST")
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
);
if let Some(api_key) = config_secret_value(config, "ANTHROPIC_API_KEY") {
identity = identity.with_secret("api_key", api_key);
}
if let Ok(headers) = config
.get_secret::<std::collections::HashMap<String, String>>("ANTHROPIC_CUSTOM_HEADERS")
{
identity = identity.with_secret("headers", serialize_string_map(&headers)?);
}
Ok(identity)
}
}
#[async_trait]

View file

@ -6,9 +6,10 @@ use serde::{Deserialize, Serialize};
use super::canonical::{map_to_canonical_model, CanonicalModelRegistry};
use super::errors::ProviderError;
use super::inventory::{default_inventory_identity, InventoryIdentityInput};
use super::retry::RetryConfig;
use crate::config::base::ConfigValue;
use crate::config::{ExtensionConfig, GooseMode};
use crate::config::{Config, ExtensionConfig, GooseMode};
use crate::conversation::message::{Message, MessageContent};
use crate::conversation::Conversation;
use crate::model::ModelConfig;
@ -179,6 +180,9 @@ pub struct ProviderMetadata {
/// step-by-step instructions for set up providers eg: api key
#[serde(default)]
pub setup_steps: Vec<String>,
/// Hint shown in the model picker when this provider manages its own model selection.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_selection_hint: Option<String>,
}
impl ProviderMetadata {
@ -212,6 +216,7 @@ impl ProviderMetadata {
model_doc_link: model_doc_link.to_string(),
config_keys,
setup_steps: vec![],
model_selection_hint: None,
}
}
@ -233,6 +238,7 @@ impl ProviderMetadata {
model_doc_link: model_doc_link.to_string(),
config_keys,
setup_steps: vec![],
model_selection_hint: None,
}
}
@ -246,6 +252,7 @@ impl ProviderMetadata {
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
model_selection_hint: None,
}
}
@ -253,6 +260,11 @@ impl ProviderMetadata {
self.setup_steps = steps.into_iter().map(|s| s.to_string()).collect();
self
}
pub fn with_model_selection_hint(mut self, hint: &str) -> Self {
self.model_selection_hint = Some(hint.to_string());
self
}
}
/// Configuration key metadata for provider setup
@ -492,6 +504,34 @@ pub trait ProviderDef: Send + Sync {
) -> BoxFuture<'static, Result<Self::Provider>>
where
Self: Sized;
fn supports_inventory_refresh() -> bool
where
Self: Sized,
{
false
}
fn inventory_identity() -> Result<InventoryIdentityInput>
where
Self: Sized,
{
let metadata = Self::metadata();
Ok(default_inventory_identity(
&metadata.name,
&metadata.name,
&metadata.config_keys,
Config::global(),
))
}
fn inventory_configured() -> bool
where
Self: Sized,
{
let metadata = Self::metadata();
super::inventory::default_inventory_configured(&metadata.config_keys, Config::global())
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -588,7 +628,7 @@ pub trait Provider: Send + Sync {
false
}
/// Fetch models filtered by canonical registry and usability
/// Fetch inventory models filtered by canonical registry and usability.
async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> {
let all_models = self.fetch_supported_models().await?;
@ -637,15 +677,15 @@ pub trait Provider: Send + Sync {
(None, None) => a.0.cmp(&b.0),
});
let recommended_models: Vec<String> = models_with_dates
let inventory_models: Vec<String> = models_with_dates
.into_iter()
.map(|(name, _)| name)
.collect();
if recommended_models.is_empty() {
if inventory_models.is_empty() {
Ok(all_models)
} else {
Ok(recommended_models)
Ok(inventory_models)
}
}

View file

@ -9,7 +9,9 @@ use crate::acp::{
use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;
const CLAUDE_ACP_PROVIDER_NAME: &str = "claude-acp";
const CLAUDE_ACP_DOC_URL: &str = "https://github.com/zed-industries/claude-agent-acp";
@ -78,4 +80,16 @@ impl ProviderDef for ClaudeAcpProvider {
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
})
}
fn supports_inventory_refresh() -> bool {
true
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
acp_inventory_identity(CLAUDE_ACP_PROVIDER_NAME, CLAUDE_ACP_BINARY)
}
fn inventory_configured() -> bool {
acp_adapter_installed(CLAUDE_ACP_BINARY)
}
}

View file

@ -9,7 +9,9 @@ use crate::acp::{
use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;
const CODEX_ACP_PROVIDER_NAME: &str = "codex-acp";
const CODEX_ACP_DOC_URL: &str = "https://github.com/zed-industries/codex-acp";
@ -98,6 +100,18 @@ impl ProviderDef for CodexAcpProvider {
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
})
}
fn supports_inventory_refresh() -> bool {
true
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
acp_inventory_identity(CODEX_ACP_PROVIDER_NAME, CODEX_ACP_PROVIDER_NAME)
}
fn inventory_configured() -> bool {
acp_adapter_installed(CODEX_ACP_PROVIDER_NAME)
}
}
// Codex sandbox scope determines what needs approval: operations within the

View file

@ -9,7 +9,9 @@ use crate::acp::{
use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;
const COPILOT_ACP_PROVIDER_NAME: &str = "copilot-acp";
const COPILOT_ACP_DOC_URL: &str = "https://github.com/github/copilot-cli";
@ -84,4 +86,16 @@ impl ProviderDef for CopilotAcpProvider {
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
})
}
fn supports_inventory_refresh() -> bool {
true
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
acp_inventory_identity(COPILOT_ACP_PROVIDER_NAME, COPILOT_ACP_BINARY)
}
fn inventory_configured() -> bool {
acp_adapter_installed(COPILOT_ACP_BINARY)
}
}

View file

@ -340,6 +340,10 @@ impl ProviderDef for DatabricksProvider {
) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
}
fn supports_inventory_refresh() -> bool {
true
}
}
#[async_trait]

View file

@ -1100,12 +1100,16 @@ mod tests {
fn test_create_request_enabled_thinking_with_budget() -> anyhow::Result<()> {
let _guard = env_lock::lock_env([
("CLAUDE_THINKING_TYPE", None::<&str>),
("CLAUDE_THINKING_ENABLED", Some("1")),
("CLAUDE_THINKING_ENABLED", None::<&str>),
("CLAUDE_THINKING_BUDGET", Some("10000")),
]);
let mut model_config = ModelConfig::new_or_fail("databricks-claude-3-7-sonnet");
model_config.max_tokens = Some(4096);
model_config = model_config.with_request_params(Some(std::collections::HashMap::from([(
"thinking_type".to_string(),
json!("enabled"),
)])));
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;

View file

@ -149,6 +149,10 @@ pub async fn get_from_registry(name: &str) -> Result<ProviderEntry> {
.cloned()
}
pub async fn inventory_identity(name: &str) -> Result<super::inventory::InventoryIdentityInput> {
get_from_registry(name).await?.inventory_identity()
}
pub async fn create(
name: &str,
model: ModelConfig,

View file

@ -0,0 +1,970 @@
use super::base::{ConfigKey, ModelInfo};
use super::canonical::{map_provider_name, map_to_canonical_model, CanonicalModelRegistry};
use crate::config::declarative_providers::{DeclarativeProviderConfig, ProviderEngine};
use crate::config::Config;
use crate::session::session_manager::SessionStorage;
use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use sqlx::{Pool, Row, Sqlite, Transaction};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
const STALE_AFTER_HOURS: i64 = 24;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProviderInventoryEntry {
pub provider_id: String,
pub provider_name: String,
pub configured: bool,
pub supports_refresh: bool,
pub refreshing: bool,
pub models: Vec<InventoryModel>,
pub last_updated_at: Option<DateTime<Utc>>,
pub last_refresh_attempt_at: Option<DateTime<Utc>>,
pub last_refresh_error: Option<String>,
pub model_selection_hint: Option<String>,
}
/// Families whose latest model should be surfaced in the compact picker.
/// Each entry is matched against the `family` field of enriched models.
const RECOMMENDED_FAMILIES: &[&str] = &[
"claude-opus",
"claude-sonnet",
"gpt",
"gpt-mini",
"glm",
"gemini-pro",
"gemini-flash",
"gemma",
];
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InventoryModel {
pub id: String,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub family: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<bool>,
/// Whether this model should appear in the compact recommended picker.
pub recommended: bool,
}
#[derive(Debug, Clone)]
pub struct InventoryIdentity {
pub provider_id: String,
pub provider_family: String,
pub inventory_key: String,
}
#[derive(Debug, Clone, Default)]
pub struct InventoryIdentityInput {
pub provider_id: String,
pub provider_family: String,
pub public_inputs: BTreeMap<String, String>,
pub secret_inputs: BTreeMap<String, String>,
}
impl InventoryIdentityInput {
pub fn new(
provider_id: impl Into<String>,
provider_family: impl Into<String>,
) -> InventoryIdentityInput {
InventoryIdentityInput {
provider_id: provider_id.into(),
provider_family: provider_family.into(),
public_inputs: BTreeMap::new(),
secret_inputs: BTreeMap::new(),
}
}
pub fn with_public(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> InventoryIdentityInput {
self.public_inputs.insert(key.into(), value.into());
self
}
pub fn with_secret(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> InventoryIdentityInput {
self.secret_inputs.insert(key.into(), value.into());
self
}
pub fn into_identity(self) -> Result<InventoryIdentity> {
let InventoryIdentityInput {
provider_id,
provider_family,
public_inputs,
secret_inputs,
} = self;
let payload = serde_json::json!({
"provider_family": provider_family,
"public_inputs": public_inputs,
"secret_inputs": secret_inputs,
});
let digest = Sha256::digest(serde_json::to_vec(&payload)?);
Ok(InventoryIdentity {
provider_id,
provider_family,
inventory_key: format!("{digest:x}"),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RefreshSkipReason {
UnknownProvider,
NotConfigured,
DoesNotSupportRefresh,
AlreadyRefreshing,
}
#[derive(Debug, Clone)]
pub struct RefreshSkip {
pub provider_id: String,
pub reason: RefreshSkipReason,
}
#[derive(Debug, Clone, Default)]
pub struct RefreshPlan {
pub started: Vec<String>,
pub skipped: Vec<RefreshSkip>,
}
#[derive(Clone)]
pub struct ProviderInventoryService {
storage: Arc<SessionStorage>,
refreshing_keys: Arc<RwLock<HashSet<String>>>,
}
#[derive(Debug, Clone)]
struct InventorySnapshot {
models: Vec<InventoryModel>,
last_updated_at: Option<DateTime<Utc>>,
last_refresh_attempt_at: Option<DateTime<Utc>>,
last_refresh_error: Option<String>,
}
#[derive(Debug, Clone)]
struct ProviderDescriptor {
provider_id: String,
provider_name: String,
identity: InventoryIdentity,
configured: bool,
supports_refresh: bool,
static_models: Vec<ModelInfo>,
model_selection_hint: Option<String>,
}
impl ProviderInventoryService {
pub fn new(storage: Arc<SessionStorage>) -> ProviderInventoryService {
ProviderInventoryService {
storage,
refreshing_keys: Arc::new(RwLock::new(HashSet::new())),
}
}
pub async fn entry_for_provider(
&self,
provider_id: &str,
) -> Result<Option<ProviderInventoryEntry>> {
let Some(descriptor) = self.describe_provider(provider_id).await? else {
return Ok(None);
};
let snapshot = self.read_snapshot(&descriptor.identity).await?;
let refreshing = self
.refreshing_keys
.read()
.await
.contains(&descriptor.identity.inventory_key);
let models = inventory_models_from_snapshot(
snapshot.as_ref(),
&descriptor.identity.provider_family,
&descriptor.static_models,
);
Ok(Some(ProviderInventoryEntry {
provider_id: descriptor.provider_id,
provider_name: descriptor.provider_name,
configured: descriptor.configured,
supports_refresh: descriptor.supports_refresh,
refreshing,
models,
last_updated_at: snapshot
.as_ref()
.and_then(|snapshot| snapshot.last_updated_at),
last_refresh_attempt_at: snapshot
.as_ref()
.and_then(|snapshot| snapshot.last_refresh_attempt_at),
last_refresh_error: snapshot.and_then(|snapshot| snapshot.last_refresh_error),
model_selection_hint: descriptor.model_selection_hint,
}))
}
pub async fn entries(&self, provider_ids: &[String]) -> Result<Vec<ProviderInventoryEntry>> {
let ids = self.resolve_provider_ids(provider_ids).await;
let mut entries = Vec::with_capacity(ids.len());
for provider_id in ids {
if let Some(entry) = self.entry_for_provider(&provider_id).await? {
entries.push(entry);
}
}
Ok(entries)
}
pub async fn plan_refresh(&self, provider_ids: &[String]) -> Result<RefreshPlan> {
let ids = self.resolve_provider_ids(provider_ids).await;
let mut plan = RefreshPlan::default();
for provider_id in ids {
let Some(descriptor) = self.describe_provider(&provider_id).await? else {
plan.skipped.push(RefreshSkip {
provider_id,
reason: RefreshSkipReason::UnknownProvider,
});
continue;
};
if !descriptor.supports_refresh {
plan.skipped.push(RefreshSkip {
provider_id: descriptor.provider_id,
reason: RefreshSkipReason::DoesNotSupportRefresh,
});
continue;
}
if !descriptor.configured {
plan.skipped.push(RefreshSkip {
provider_id: descriptor.provider_id,
reason: RefreshSkipReason::NotConfigured,
});
continue;
}
let mut refreshing_keys = self.refreshing_keys.write().await;
if refreshing_keys.contains(&descriptor.identity.inventory_key) {
plan.skipped.push(RefreshSkip {
provider_id: descriptor.provider_id,
reason: RefreshSkipReason::AlreadyRefreshing,
});
continue;
}
refreshing_keys.insert(descriptor.identity.inventory_key.clone());
drop(refreshing_keys);
self.mark_refresh_started(&descriptor.identity).await?;
plan.started.push(descriptor.provider_id);
}
Ok(plan)
}
pub async fn store_refreshed_models(
&self,
provider_id: &str,
model_ids: &[String],
) -> Result<()> {
let descriptor = self.require_provider(provider_id).await?;
let models =
enrich_model_ids_with_canonical(&descriptor.identity.provider_family, model_ids);
let now = Utc::now();
let pool = self.storage.pool().await?;
let mut tx = pool.begin().await?;
sqlx::query(
r#"
INSERT INTO provider_inventory_entries (
inventory_key,
provider_id,
provider_family,
last_updated_at,
last_refresh_attempt_at,
last_refresh_error,
updated_at
) VALUES (?, ?, ?, ?, ?, NULL, CURRENT_TIMESTAMP)
ON CONFLICT(inventory_key) DO UPDATE SET
provider_id = excluded.provider_id,
provider_family = excluded.provider_family,
last_updated_at = excluded.last_updated_at,
last_refresh_attempt_at = excluded.last_refresh_attempt_at,
last_refresh_error = NULL,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&descriptor.identity.inventory_key)
.bind(&descriptor.identity.provider_id)
.bind(&descriptor.identity.provider_family)
.bind(now.to_rfc3339())
.bind(now.to_rfc3339())
.execute(&mut *tx)
.await?;
sqlx::query("DELETE FROM provider_inventory_models WHERE inventory_key = ?")
.bind(&descriptor.identity.inventory_key)
.execute(&mut *tx)
.await?;
for (ordinal, model) in models.iter().enumerate() {
sqlx::query(
r#"
INSERT INTO provider_inventory_models (
inventory_key,
ordinal,
model_id,
name,
family,
context_limit,
reasoning,
recommended
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&descriptor.identity.inventory_key)
.bind(i64::try_from(ordinal)?)
.bind(&model.id)
.bind(&model.name)
.bind(&model.family)
.bind(model.context_limit.map(i64::try_from).transpose()?)
.bind(model.reasoning)
.bind(model.recommended)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
self.refreshing_keys
.write()
.await
.remove(&descriptor.identity.inventory_key);
Ok(())
}
pub async fn store_refresh_error(
&self,
provider_id: &str,
error: impl Into<String>,
) -> Result<()> {
let descriptor = self.require_provider(provider_id).await?;
let error = error.into();
let existing = self.read_snapshot(&descriptor.identity).await?;
sqlx::query(
r#"
INSERT INTO provider_inventory_entries (
inventory_key,
provider_id,
provider_family,
last_updated_at,
last_refresh_attempt_at,
last_refresh_error,
updated_at
) VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(inventory_key) DO UPDATE SET
provider_id = excluded.provider_id,
provider_family = excluded.provider_family,
last_updated_at = excluded.last_updated_at,
last_refresh_attempt_at = excluded.last_refresh_attempt_at,
last_refresh_error = excluded.last_refresh_error,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&descriptor.identity.inventory_key)
.bind(&descriptor.identity.provider_id)
.bind(&descriptor.identity.provider_family)
.bind(existing.and_then(|snapshot| snapshot.last_updated_at.map(|time| time.to_rfc3339())))
.bind(Utc::now().to_rfc3339())
.bind(error)
.execute(self.storage.pool().await?)
.await?;
self.refreshing_keys
.write()
.await
.remove(&descriptor.identity.inventory_key);
Ok(())
}
pub fn is_stale(entry: &ProviderInventoryEntry) -> bool {
let Some(last_updated_at) = entry.last_updated_at else {
return false;
};
entry.supports_refresh && Utc::now() - last_updated_at > Duration::hours(STALE_AFTER_HOURS)
}
async fn describe_provider(&self, provider_id: &str) -> Result<Option<ProviderDescriptor>> {
let entry = match crate::providers::get_from_registry(provider_id).await {
Ok(entry) => entry,
Err(_) => return Ok(None),
};
let metadata = entry.metadata().clone();
let identity = crate::providers::inventory_identity(provider_id)
.await
.unwrap_or_else(|_| fallback_inventory_identity(provider_id))
.into_identity()?;
Ok(Some(ProviderDescriptor {
provider_id: metadata.name.clone(),
provider_name: metadata.display_name.clone(),
identity,
configured: entry.inventory_configured(),
supports_refresh: entry.supports_inventory_refresh(),
static_models: metadata.known_models,
model_selection_hint: metadata.model_selection_hint,
}))
}
async fn require_provider(&self, provider_id: &str) -> Result<ProviderDescriptor> {
self.describe_provider(provider_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", provider_id))
}
async fn mark_refresh_started(&self, identity: &InventoryIdentity) -> Result<()> {
let existing = self.read_snapshot(identity).await?;
sqlx::query(
r#"
INSERT INTO provider_inventory_entries (
inventory_key,
provider_id,
provider_family,
last_updated_at,
last_refresh_attempt_at,
last_refresh_error,
updated_at
) VALUES (?, ?, ?, ?, ?, NULL, CURRENT_TIMESTAMP)
ON CONFLICT(inventory_key) DO UPDATE SET
provider_id = excluded.provider_id,
provider_family = excluded.provider_family,
last_updated_at = excluded.last_updated_at,
last_refresh_attempt_at = excluded.last_refresh_attempt_at,
last_refresh_error = NULL,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&identity.inventory_key)
.bind(&identity.provider_id)
.bind(&identity.provider_family)
.bind(existing.and_then(|snapshot| snapshot.last_updated_at.map(|time| time.to_rfc3339())))
.bind(Utc::now().to_rfc3339())
.execute(self.storage.pool().await?)
.await?;
Ok(())
}
async fn read_snapshot(
&self,
identity: &InventoryIdentity,
) -> Result<Option<InventorySnapshot>> {
let pool = self.storage.pool().await?;
let entry = sqlx::query(
r#"
SELECT last_updated_at, last_refresh_attempt_at, last_refresh_error
FROM provider_inventory_entries
WHERE inventory_key = ?
"#,
)
.bind(&identity.inventory_key)
.fetch_optional(pool)
.await?;
let Some(entry) = entry else {
return Ok(None);
};
let last_updated_at = parse_optional_datetime(entry.try_get("last_updated_at")?)?;
let last_refresh_attempt_at =
parse_optional_datetime(entry.try_get("last_refresh_attempt_at")?)?;
let last_refresh_error = entry.try_get("last_refresh_error")?;
let rows = sqlx::query(
r#"
SELECT model_id, name, family, context_limit, reasoning, recommended
FROM provider_inventory_models
WHERE inventory_key = ?
ORDER BY ordinal
"#,
)
.bind(&identity.inventory_key)
.fetch_all(pool)
.await?;
let models = rows
.into_iter()
.map(|row| {
Ok(InventoryModel {
id: row.try_get("model_id")?,
name: row.try_get("name")?,
family: row.try_get("family")?,
context_limit: row
.try_get::<Option<i64>, _>("context_limit")?
.map(usize::try_from)
.transpose()?,
reasoning: row.try_get("reasoning")?,
recommended: row
.try_get::<Option<bool>, _>("recommended")?
.unwrap_or(false),
})
})
.collect::<Result<Vec<_>, anyhow::Error>>()?;
Ok(Some(InventorySnapshot {
models,
last_updated_at,
last_refresh_attempt_at,
last_refresh_error,
}))
}
async fn resolve_provider_ids(&self, provider_ids: &[String]) -> Vec<String> {
let mut ids = if provider_ids.is_empty() {
crate::providers::providers()
.await
.into_iter()
.map(|(metadata, _)| metadata.name)
.collect::<Vec<_>>()
} else {
provider_ids.to_vec()
};
ids.sort();
ids.dedup();
ids
}
}
pub fn default_inventory_identity(
provider_id: &str,
provider_family: &str,
config_keys: &[ConfigKey],
config: &Config,
) -> InventoryIdentityInput {
let mut identity = InventoryIdentityInput::new(provider_id, provider_family);
for key in config_keys {
if key.secret {
if let Some(value) = config_secret_value(config, &key.name) {
identity.secret_inputs.insert(key.name.clone(), value);
}
} else if let Some(value) = config_param_value(config, &key.name) {
identity.public_inputs.insert(key.name.clone(), value);
}
}
identity
}
pub fn default_inventory_configured(config_keys: &[ConfigKey], config: &Config) -> bool {
config_keys.iter().all(|key| {
if !key.required {
return true;
}
if key.default.is_some() {
return true;
}
if key.secret {
config.get_secret::<serde_json::Value>(&key.name).is_ok()
} else {
config.get_param::<serde_json::Value>(&key.name).is_ok()
}
})
}
pub fn declarative_inventory_identity(
config: &DeclarativeProviderConfig,
) -> Result<InventoryIdentityInput> {
let global = Config::global();
let mut identity = InventoryIdentityInput::new(
config.name.clone(),
config
.catalog_provider_id
.clone()
.unwrap_or_else(|| match config.engine {
ProviderEngine::OpenAI => "openai".to_string(),
ProviderEngine::Anthropic => "anthropic".to_string(),
ProviderEngine::Ollama => "ollama".to_string(),
}),
);
identity
.public_inputs
.insert("base_url".to_string(), config.base_url.clone());
if let Some(base_path) = &config.base_path {
identity
.public_inputs
.insert("base_path".to_string(), base_path.clone());
}
if let Some(catalog_provider_id) = &config.catalog_provider_id {
identity.public_inputs.insert(
"catalog_provider_id".to_string(),
catalog_provider_id.clone(),
);
}
if let Some(dynamic_models) = config.dynamic_models {
identity
.public_inputs
.insert("dynamic_models".to_string(), dynamic_models.to_string());
}
identity.public_inputs.insert(
"skip_canonical_filtering".to_string(),
config.skip_canonical_filtering.to_string(),
);
if !config.models.is_empty() {
identity.public_inputs.insert(
"models".to_string(),
serde_json::to_string(
&config
.models
.iter()
.map(|model| &model.name)
.collect::<Vec<_>>(),
)?,
);
}
if let Some(headers) = &config.headers {
identity
.public_inputs
.insert("headers".to_string(), serialize_string_map(headers)?);
}
if config.requires_auth && !config.api_key_env.is_empty() {
if let Some(value) = config_secret_value(global, &config.api_key_env) {
identity
.secret_inputs
.insert(config.api_key_env.clone(), value);
}
}
Ok(identity)
}
pub fn config_param_value(config: &Config, key: &str) -> Option<String> {
config
.get_param::<serde_json::Value>(key)
.ok()
.and_then(|value| normalize_json_value(&value))
}
pub fn config_secret_value(config: &Config, key: &str) -> Option<String> {
config
.get_secret::<serde_json::Value>(key)
.ok()
.and_then(|value| normalize_json_value(&value))
}
pub fn serialize_string_map(map: &HashMap<String, String>) -> Result<String> {
let ordered = map
.iter()
.map(|(key, value)| (key.clone(), value.clone()))
.collect::<BTreeMap<_, _>>();
Ok(serde_json::to_string(&ordered)?)
}
fn parse_optional_datetime(value: Option<String>) -> Result<Option<DateTime<Utc>>> {
value
.map(|value| value.parse::<DateTime<Utc>>())
.transpose()
.map_err(Into::into)
}
fn normalize_json_value(value: &serde_json::Value) -> Option<String> {
match value {
serde_json::Value::Null => None,
serde_json::Value::String(value) if value.is_empty() => None,
serde_json::Value::String(value) => Some(value.clone()),
other => serde_json::to_string(other).ok(),
}
}
fn fallback_inventory_identity(provider_id: &str) -> InventoryIdentityInput {
InventoryIdentityInput::new(
provider_id.to_string(),
map_provider_name(provider_id).to_string(),
)
}
fn enrich_model_ids_with_canonical(
provider_family: &str,
model_ids: &[String],
) -> Vec<InventoryModel> {
let mut models: Vec<InventoryModel> = Vec::new();
let mut seen_names: HashSet<String> = HashSet::new();
for id in model_ids {
let model = enriched_model(provider_family, id, None);
if !seen_names.insert(model.name.clone()) {
continue;
}
models.push(model);
}
// For databricks, prefer goose- prefixed model_ids when there are duplicates.
// Re-scan: if a later model_id with "goose-" prefix maps to the same display name,
// swap it in.
if provider_family == "databricks" {
let mut name_to_idx: HashMap<String, usize> = HashMap::new();
for (idx, model) in models.iter().enumerate() {
name_to_idx.insert(model.name.clone(), idx);
}
for id in model_ids {
if !id.starts_with("goose-") {
continue;
}
let candidate = enriched_model(provider_family, id, None);
if let Some(&idx) = name_to_idx.get(&candidate.name) {
if !models[idx].id.starts_with("goose-") {
models[idx].id = candidate.id;
}
}
}
}
// Mark the latest model per recommended family.
let mut seen_recommended_families: HashSet<String> = HashSet::new();
for model in &mut models {
if let Some(family) = &model.family {
if RECOMMENDED_FAMILIES.contains(&family.as_str())
&& seen_recommended_families.insert(family.clone())
{
model.recommended = true;
}
}
}
models
}
fn configured_models_to_inventory(
provider_family: &str,
models: &[ModelInfo],
) -> Vec<InventoryModel> {
let mut result: Vec<InventoryModel> = Vec::new();
let mut seen_names: HashSet<String> = HashSet::new();
for model in models {
let enriched = enriched_model(provider_family, &model.name, Some(model.context_limit));
if seen_names.insert(enriched.name.clone()) {
result.push(enriched);
}
}
let mut seen_recommended_families: HashSet<String> = HashSet::new();
for model in &mut result {
if let Some(family) = &model.family {
if RECOMMENDED_FAMILIES.contains(&family.as_str())
&& seen_recommended_families.insert(family.clone())
{
model.recommended = true;
}
}
}
result
}
fn inventory_models_from_snapshot(
snapshot: Option<&InventorySnapshot>,
provider_family: &str,
configured_models: &[ModelInfo],
) -> Vec<InventoryModel> {
match snapshot {
Some(snapshot) if !snapshot.models.is_empty() || snapshot.last_updated_at.is_some() => {
snapshot.models.clone()
}
_ => configured_models_to_inventory(provider_family, configured_models),
}
}
fn enriched_model(
provider_family: &str,
model_id: &str,
fallback_context_limit: Option<usize>,
) -> InventoryModel {
let registry = CanonicalModelRegistry::bundled().ok();
let canonical = registry.as_ref().and_then(|registry| {
let canonical_id = map_to_canonical_model(provider_family, model_id, registry)?;
let (provider, model) = canonical_id.split_once('/')?;
registry.get(provider, model).cloned()
});
InventoryModel {
id: model_id.to_string(),
name: canonical
.as_ref()
.map(|model| model.name.clone())
.unwrap_or_else(|| model_id.to_string()),
family: canonical.as_ref().and_then(|model| model.family.clone()),
context_limit: canonical
.as_ref()
.map(|model| model.limit.context)
.or(fallback_context_limit),
reasoning: canonical.as_ref().and_then(|model| model.reasoning),
recommended: false,
}
}
pub async fn create_tables(pool: &Pool<Sqlite>) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_inventory_entries (
inventory_key TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
provider_family TEXT NOT NULL,
last_updated_at TEXT,
last_refresh_attempt_at TEXT,
last_refresh_error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_inventory_models (
inventory_key TEXT NOT NULL REFERENCES provider_inventory_entries(inventory_key) ON DELETE CASCADE,
ordinal INTEGER NOT NULL,
model_id TEXT NOT NULL,
name TEXT NOT NULL,
family TEXT,
context_limit INTEGER,
reasoning BOOLEAN,
recommended BOOLEAN,
PRIMARY KEY (inventory_key, ordinal)
)
"#,
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_provider_inventory_provider_id ON provider_inventory_entries(provider_id)",
)
.execute(pool)
.await?;
Ok(())
}
pub async fn create_tables_in_tx(tx: &mut Transaction<'_, Sqlite>) -> Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_inventory_entries (
inventory_key TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
provider_family TEXT NOT NULL,
last_updated_at TEXT,
last_refresh_attempt_at TEXT,
last_refresh_error TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(&mut **tx)
.await?;
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_inventory_models (
inventory_key TEXT NOT NULL REFERENCES provider_inventory_entries(inventory_key) ON DELETE CASCADE,
ordinal INTEGER NOT NULL,
model_id TEXT NOT NULL,
name TEXT NOT NULL,
family TEXT,
context_limit INTEGER,
reasoning BOOLEAN,
recommended BOOLEAN,
PRIMARY KEY (inventory_key, ordinal)
)
"#,
)
.execute(&mut **tx)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_provider_inventory_provider_id ON provider_inventory_entries(provider_id)",
)
.execute(&mut **tx)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn inventory_identity_hash_changes_with_secret_inputs() {
let left = InventoryIdentityInput::new("openai", "openai")
.with_public("host", "https://api.openai.com")
.with_secret("api_key", "secret-a")
.into_identity()
.unwrap();
let right = InventoryIdentityInput::new("openai", "openai")
.with_public("host", "https://api.openai.com")
.with_secret("api_key", "secret-b")
.into_identity()
.unwrap();
assert_ne!(left.inventory_key, right.inventory_key);
}
#[test]
fn configured_models_use_canonical_enrichment() {
let models =
configured_models_to_inventory("anthropic", &[ModelInfo::new("claude-sonnet-4-5", 0)]);
assert_eq!(models.len(), 1);
assert!(models[0].name.contains("Claude"));
}
#[test]
fn inventory_uses_configured_models_before_first_successful_refresh() {
let configured_models = [ModelInfo::new("claude-sonnet-4-5", 0)];
let snapshot = InventorySnapshot {
models: vec![],
last_updated_at: None,
last_refresh_attempt_at: Some(Utc::now()),
last_refresh_error: Some("auth failed".to_string()),
};
let models =
inventory_models_from_snapshot(Some(&snapshot), "anthropic", &configured_models);
assert_eq!(models.len(), 1);
assert_eq!(models[0].id, "claude-sonnet-4-5");
}
#[test]
fn inventory_preserves_empty_models_after_successful_refresh() {
let configured_models = [ModelInfo::new("claude-sonnet-4-5", 0)];
let snapshot = InventorySnapshot {
models: vec![],
last_updated_at: Some(Utc::now()),
last_refresh_attempt_at: Some(Utc::now()),
last_refresh_error: None,
};
let models =
inventory_models_from_snapshot(Some(&snapshot), "anthropic", &configured_models);
assert!(models.is_empty());
}
}

View file

@ -1,3 +1,4 @@
mod acp_tooling;
pub mod amp_acp;
pub mod anthropic;
pub mod api_client;
@ -28,6 +29,7 @@ pub mod gemini_oauth;
pub mod githubcopilot;
pub mod google;
mod init;
pub mod inventory;
pub mod kimicode;
pub mod litellm;
#[cfg(feature = "local-inference")]
@ -56,6 +58,6 @@ pub mod xai;
pub use init::{
cleanup_provider, create, create_with_default_model, create_with_named_model,
get_from_registry, providers, refresh_custom_providers,
get_from_registry, inventory_identity, providers, refresh_custom_providers,
};
pub use retry::{retry_operation, RetryConfig};

View file

@ -1,6 +1,7 @@
use super::api_client::{ApiClient, AuthMethod};
use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata};
use super::errors::ProviderError;
use super::inventory::InventoryIdentityInput;
use super::openai_compatible::handle_status_openai_compat;
use super::retry::{ProviderRetry, RetryConfig};
use super::utils::{ImageFormat, RequestLog};
@ -256,6 +257,22 @@ impl ProviderDef for OllamaProvider {
) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
}
fn supports_inventory_refresh() -> bool {
true
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
let config = crate::config::Config::global();
Ok(
InventoryIdentityInput::new(OLLAMA_PROVIDER_NAME, OLLAMA_PROVIDER_NAME).with_public(
"host",
config
.get_param::<String>("OLLAMA_HOST")
.unwrap_or_else(|_| OLLAMA_HOST.to_string()),
),
)
}
}
#[async_trait]

View file

@ -7,6 +7,7 @@ use super::formats::openai_responses::{
create_responses_request, get_responses_usage, responses_api_to_message,
responses_api_to_streaming_message, ResponsesApiResponse,
};
use super::inventory::{config_secret_value, InventoryIdentityInput};
use super::openai_compatible::{
handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
};
@ -425,6 +426,58 @@ impl ProviderDef for OpenAiProvider {
) -> BoxFuture<'static, Result<Self::Provider>> {
Box::pin(Self::from_env(model))
}
fn supports_inventory_refresh() -> bool {
true
}
fn inventory_configured() -> bool {
let config = crate::config::Config::global();
// If the host is explicitly set to something non-default, trust the user's
// custom setup (e.g. a local server that doesn't require an API key).
if let Ok(host) = config.get_param::<String>("OPENAI_HOST") {
if host != "https://api.openai.com" {
return true;
}
}
// Standard OpenAI endpoint requires an API key.
config
.get_secret::<serde_json::Value>("OPENAI_API_KEY")
.is_ok()
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
let config = crate::config::Config::global();
let mut identity =
InventoryIdentityInput::new(OPEN_AI_PROVIDER_NAME, OPEN_AI_PROVIDER_NAME)
.with_public(
"host",
config
.get_param::<String>("OPENAI_HOST")
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
)
.with_public(
"base_path",
config
.get_param::<String>("OPENAI_BASE_PATH")
.unwrap_or_else(|_| OPEN_AI_DEFAULT_BASE_PATH.to_string()),
);
if let Ok(organization) = config.get_param::<String>("OPENAI_ORGANIZATION") {
identity = identity.with_public("organization", organization);
}
if let Ok(project) = config.get_param::<String>("OPENAI_PROJECT") {
identity = identity.with_public("project", project);
}
if let Some(api_key) = config_secret_value(config, "OPENAI_API_KEY") {
identity = identity.with_secret("api_key", api_key);
}
if let Some(custom_headers) = config_secret_value(config, "OPENAI_CUSTOM_HEADERS") {
identity = identity.with_secret("custom_headers", custom_headers);
}
Ok(identity)
}
}
#[async_trait]

View file

@ -9,7 +9,9 @@ use crate::acp::{
use crate::config::search_path::SearchPaths;
use crate::config::{Config, GooseMode};
use crate::model::ModelConfig;
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
use crate::providers::base::{ProviderDef, ProviderMetadata};
use crate::providers::inventory::InventoryIdentityInput;
const PI_ACP_PROVIDER_NAME: &str = "pi-acp";
const PI_ACP_DOC_URL: &str = "https://github.com/anthropics/pi";
@ -36,6 +38,7 @@ impl ProviderDef for PiAcpProvider {
"Set in your goose config file (`~/.config/goose/config.yaml` on macOS/Linux):\n GOOSE_PROVIDER: pi-acp\n GOOSE_MODEL: current",
"Restart goose for changes to take effect",
])
.with_model_selection_hint("Use the Pi CLI to configure models")
}
fn from_env(
@ -70,4 +73,16 @@ impl ProviderDef for PiAcpProvider {
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
})
}
fn supports_inventory_refresh() -> bool {
false
}
fn inventory_identity() -> Result<InventoryIdentityInput> {
acp_inventory_identity(PI_ACP_PROVIDER_NAME, PI_ACP_BINARY)
}
fn inventory_configured() -> bool {
acp_adapter_installed(PI_ACP_BINARY)
}
}

View file

@ -1,4 +1,5 @@
use super::base::{ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderType};
use super::inventory::InventoryIdentityInput;
use crate::config::{DeclarativeProviderConfig, ExtensionConfig};
use crate::model::ModelConfig;
use anyhow::Result;
@ -14,12 +15,20 @@ pub type ProviderConstructor = Arc<
pub type ProviderCleanup = Arc<dyn Fn() -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type ProviderInventoryIdentityResolver =
Arc<dyn Fn() -> Result<InventoryIdentityInput> + Send + Sync>;
pub type ProviderInventoryConfiguredResolver = Arc<dyn Fn() -> bool + Send + Sync>;
#[derive(Clone)]
pub struct ProviderEntry {
metadata: ProviderMetadata,
pub(crate) constructor: ProviderConstructor,
pub(crate) inventory_identity: ProviderInventoryIdentityResolver,
pub(crate) inventory_configured: ProviderInventoryConfiguredResolver,
pub(crate) cleanup: Option<ProviderCleanup>,
provider_type: ProviderType,
supports_inventory_refresh: bool,
}
impl ProviderEntry {
@ -27,6 +36,22 @@ impl ProviderEntry {
&self.metadata
}
pub fn provider_type(&self) -> ProviderType {
self.provider_type
}
pub fn supports_inventory_refresh(&self) -> bool {
self.supports_inventory_refresh
}
pub fn inventory_identity(&self) -> Result<InventoryIdentityInput> {
(self.inventory_identity)()
}
pub fn inventory_configured(&self) -> bool {
(self.inventory_configured)()
}
fn normalize_model_config(&self, mut model: ModelConfig) -> ModelConfig {
model = model.with_canonical_limits(&self.metadata.name);
@ -92,24 +117,30 @@ impl ProviderRegistry {
Ok(Arc::new(provider) as Arc<dyn Provider>)
})
}),
inventory_identity: Arc::new(F::inventory_identity),
inventory_configured: Arc::new(F::inventory_configured),
cleanup: None,
provider_type: if preferred {
ProviderType::Preferred
} else {
ProviderType::Builtin
},
supports_inventory_refresh: F::supports_inventory_refresh(),
},
);
}
pub fn register_with_name<P, F>(
pub fn register_with_name<P, F, G>(
&mut self,
config: &DeclarativeProviderConfig,
provider_type: ProviderType,
supports_inventory_refresh: bool,
constructor: F,
inventory_identity: G,
) where
P: ProviderDef + 'static,
F: Fn(ModelConfig) -> Result<P::Provider> + Send + Sync + 'static,
G: Fn() -> Result<InventoryIdentityInput> + Send + Sync + 'static,
{
let base_metadata = P::metadata();
let description = config
@ -174,7 +205,9 @@ impl ProviderRegistry {
model_doc_link: base_metadata.model_doc_link,
config_keys,
setup_steps: vec![],
model_selection_hint: None,
};
let inventory_config_keys = custom_metadata.config_keys.clone();
self.entries.insert(
config.name.clone(),
@ -187,8 +220,16 @@ impl ProviderRegistry {
Ok(Arc::new(provider) as Arc<dyn Provider>)
})
}),
inventory_identity: Arc::new(inventory_identity),
inventory_configured: Arc::new(move || {
super::inventory::default_inventory_configured(
&inventory_config_keys,
crate::config::Config::global(),
)
}),
cleanup: None,
provider_type,
supports_inventory_refresh,
},
);
}

View file

@ -19,7 +19,7 @@ use std::sync::{Arc, LazyLock};
use tracing::{info, warn};
use utoipa::ToSchema;
pub const CURRENT_SCHEMA_VERSION: i32 = 10;
pub const CURRENT_SCHEMA_VERSION: i32 = 11;
pub const SESSIONS_FOLDER: &str = "sessions";
pub const DB_NAME: &str = "sessions.db";
@ -717,6 +717,8 @@ impl SessionStorage {
.execute(pool)
.await?;
crate::providers::inventory::create_tables(pool).await?;
Ok(())
}
@ -1060,6 +1062,9 @@ impl SessionStorage {
.execute(&mut **tx)
.await?;
}
11 => {
crate::providers::inventory::create_tables_in_tx(tx).await?;
}
_ => {
anyhow::bail!("Unknown migration version: {}", version);
}

View file

@ -374,6 +374,7 @@ mod tests {
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
model_selection_hint: None,
}
}
@ -542,6 +543,7 @@ mod tests {
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
model_selection_hint: None,
}
}
@ -867,6 +869,7 @@ mod tests {
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
model_selection_hint: None,
}
}

View file

@ -192,6 +192,7 @@ impl ProviderDef for MockCompactionProvider {
model_doc_link: "".to_string(),
config_keys: vec![],
setup_steps: vec![],
model_selection_hint: None,
}
}

View file

@ -6864,6 +6864,11 @@
"type": "string",
"description": "Link to the docs where models can be found"
},
"model_selection_hint": {
"type": "string",
"description": "Hint shown in the model picker when this provider manages its own model selection.",
"nullable": true
},
"name": {
"type": "string",
"description": "The unique identifier for this provider"

View file

@ -970,6 +970,10 @@ export type ProviderMetadata = {
* Link to the docs where models can be found
*/
model_doc_link: string;
/**
* Hint shown in the model picker when this provider manages its own model selection.
*/
model_selection_hint?: string | null;
/**
* The unique identifier for this provider
*/

View file

@ -36,9 +36,19 @@ const EXCEPTIONS = {
"Search-as-you-type filtering and draft-aware sidebar highlight logic.",
},
"src/app/AppShell.tsx": {
limit: 660,
limit: 730,
justification:
"Shell still coordinates ACP session loading, replay-buffer cleanup on load failure, project reassignment, and app-level chat routing. Includes gated [perf:load]/[perf:newtab] logging via perfLog (dev-only by default).",
"Shell still coordinates ACP session loading, replay-buffer cleanup on load failure, project reassignment, home-session restoration, app-level chat routing, and restored project-draft reuse. Includes gated [perf:load]/[perf:newtab] logging via perfLog (dev-only by default).",
},
"src/features/chat/hooks/useChatSessionController.ts": {
limit: 690,
justification:
"Controller now centralizes home-to-chat pending state transfer, workspace/project preparation, provider/model/persona handoff, Goose cross-provider model selection sequencing with rollback, and chat input orchestration pending a later decomposition pass.",
},
"src/features/chat/ui/AgentModelPicker.tsx": {
limit: 570,
justification:
"Agent-first picker currently keeps the full trigger, recommended-model view, searchable full-model view, and ACP/goose-specific labeling logic in one component pending later extraction.",
},
"src/features/chat/stores/__tests__/chatSessionStore.test.ts": {
limit: 540,
@ -56,7 +66,7 @@ const EXCEPTIONS = {
"Voice dictation send/stop guards, attachment handling, and mention/picker coordination still share one chat composer component.",
},
"src/features/chat/ui/__tests__/ChatInput.test.tsx": {
limit: 510,
limit: 520,
justification:
"Composer regression coverage spans personas, queueing, attachments, and voice-input edge cases in one interaction-heavy suite.",
},

View file

@ -0,0 +1,33 @@
#!/usr/bin/env bash
# Reset the provider inventory tables to empty, as if migration 12 just ran.
# This lets you test the first-use experience (cold inventory).
#
# Usage: ./scripts/reset-inventory.sh
set -euo pipefail
DB="${GOOSE_DB:-$HOME/.local/share/goose/sessions/sessions.db}"
if [ ! -f "$DB" ]; then
echo "Database not found at $DB"
echo "Set GOOSE_DB to override the path."
exit 1
fi
echo "Database: $DB"
echo ""
echo "Before:"
echo " provider_inventory_entries: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_entries;')"
echo " provider_inventory_models: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_models;')"
# ON DELETE CASCADE on provider_inventory_models means deleting entries clears both tables.
# Delete models first since CASCADE isn't reliable in all sqlite3 builds,
# then delete entries.
sqlite3 "$DB" "DELETE FROM provider_inventory_models; DELETE FROM provider_inventory_entries;"
echo ""
echo "After:"
echo " provider_inventory_entries: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_entries;')"
echo " provider_inventory_models: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_models;')"
echo ""
echo "Inventory tables are empty. Restart goose to test first-use flow."

View file

@ -1,7 +1,6 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { Sidebar } from "@/features/sidebar/ui/Sidebar";
import { StatusBar } from "@/features/status/ui/StatusBar";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
import { CreateProjectDialog } from "@/features/projects/ui/CreateProjectDialog";
import { archiveProject } from "@/features/projects/api/projects";
import type { ProjectInfo } from "@/features/projects/api/projects";
@ -9,7 +8,11 @@ import { SettingsModal } from "@/features/settings/ui/SettingsModal";
import type { SectionId } from "@/features/settings/ui/SettingsModal";
import { TopBar } from "./ui/TopBar";
import { useChatStore } from "@/features/chat/stores/chatStore";
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
import {
type ChatSession,
hasSessionStarted,
useChatSessionStore,
} from "@/features/chat/stores/chatSessionStore";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useProjectStore } from "@/features/projects/stores/projectStore";
import { findExistingDraft } from "@/features/chat/lib/newChat";
@ -37,6 +40,33 @@ const SIDEBAR_MIN_WIDTH = 180;
const SIDEBAR_MAX_WIDTH = 380;
const SIDEBAR_SNAP_COLLAPSE_THRESHOLD = 100;
const SIDEBAR_COLLAPSED_WIDTH = 48;
const HOME_SESSION_STORAGE_KEY = "goose:home-session-id";
function loadStoredHomeSessionId(): string | null {
if (typeof window === "undefined") {
return null;
}
try {
return window.localStorage.getItem(HOME_SESSION_STORAGE_KEY);
} catch {
return null;
}
}
function persistHomeSessionId(sessionId: string | null): void {
if (typeof window === "undefined") {
return;
}
try {
if (sessionId) {
window.localStorage.setItem(HOME_SESSION_STORAGE_KEY, sessionId);
return;
}
window.localStorage.removeItem(HOME_SESSION_STORAGE_KEY);
} catch {
// localStorage may be unavailable
}
}
export function AppShell({ children }: { children?: React.ReactNode }) {
const [sidebarCollapsed, setSidebarCollapsed] = useState(false);
@ -52,9 +82,9 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
null,
);
const [activeView, setActiveView] = useState<AppView>("home");
const [homeSelectedProvider, setHomeSelectedProvider] = useState<
string | undefined
>();
const [homeSessionId, setHomeSessionId] = useState<string | null>(() =>
loadStoredHomeSessionId(),
);
const chatStore = useChatStore();
const sessionStore = useChatSessionStore();
@ -64,6 +94,9 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
const pendingProjectCreatedRef = useRef<((projectId: string) => void) | null>(
null,
);
const homeSessionRequestRef = useRef<Promise<ChatSession | null> | null>(
null,
);
const loadSessionMessages = useCallback(async (sessionId: string) => {
const sid = sessionId.slice(0, 8);
@ -125,70 +158,146 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
}
}, [activeSessionId, activeView]);
const isHome = activeSessionId === null && activeView === "home";
const isHome = activeView === "home";
const activeSession = activeSessionId
? sessionStore.getSession(activeSessionId)
: undefined;
const modelName = activeSession?.modelName;
const tokenCount = activeSessionId
? chatStore.getSessionRuntime(activeSessionId).tokenState.totalTokens
: 0;
const modelName =
activeView === "chat" ? activeSession?.modelName : undefined;
const tokenCount =
activeView === "chat" && activeSessionId
? chatStore.getSessionRuntime(activeSessionId).tokenState.totalTokens
: 0;
const homeSession = homeSessionId
? sessionStore.getSession(homeSessionId)
: undefined;
const [pendingInitialMessage, setPendingInitialMessage] = useState<
string | undefined
>();
const [pendingInitialAttachments, setPendingInitialAttachments] = useState<
ChatAttachmentDraft[] | undefined
>();
const [homeSelectedPersonaId, setHomeSelectedPersonaId] = useState<
string | undefined
>();
useEffect(() => {
if (
!homeSessionId ||
!sessionStore.hasHydratedSessions ||
sessionStore.isLoading
) {
return;
}
if (
!homeSession ||
homeSession.archivedAt ||
hasSessionStarted(
homeSession,
chatStore.messagesBySession[homeSession.id],
)
) {
setHomeSessionId(null);
}
}, [
chatStore.messagesBySession,
homeSession,
homeSession?.archivedAt,
homeSession?.messageCount,
homeSessionId,
sessionStore.hasHydratedSessions,
sessionStore.isLoading,
]);
const cleanupEmptyDraft = useCallback(
(sessionId: string | null) => {
if (!sessionId) return;
const state = useChatSessionStore.getState();
const session = state.sessions.find((s) => s.id === sessionId);
if (!session?.draft) return;
const draft = useChatStore.getState().draftsBySession[sessionId] ?? "";
if (draft.length > 0) return; // has typed text — keep it
chatStore.cleanupSession(sessionId);
state.removeDraft(sessionId);
},
[chatStore],
);
useEffect(() => {
persistHomeSessionId(homeSessionId);
}, [homeSessionId]);
const ensureHomeSession = useCallback(async () => {
if (!sessionStore.hasHydratedSessions || sessionStore.isLoading) {
return null;
}
if (homeSessionRequestRef.current) {
return homeSessionRequestRef.current;
}
const request = (async () => {
if (
homeSession &&
!homeSession.archivedAt &&
homeSession.messageCount === 0
) {
const project = homeSession.projectId
? (projectStore.projects.find(
(candidate) => candidate.id === homeSession.projectId,
) ?? null)
: null;
const workingDir = await resolveSessionCwd(project);
await acpPrepareSession(
homeSession.id,
homeSession.providerId ?? agentStore.selectedProvider ?? "goose",
workingDir,
{
personaId: homeSession.personaId,
},
);
return homeSession;
}
const workingDir = await resolveSessionCwd(null);
const session = await sessionStore.createSession({
title: DEFAULT_CHAT_TITLE,
providerId: agentStore.selectedProvider ?? "goose",
workingDir,
});
setHomeSessionId(session.id);
return session;
})();
homeSessionRequestRef.current = request;
try {
return await request;
} finally {
if (homeSessionRequestRef.current === request) {
homeSessionRequestRef.current = null;
}
}
}, [
agentStore.selectedProvider,
homeSession,
projectStore.projects,
sessionStore.hasHydratedSessions,
sessionStore,
sessionStore.isLoading,
]);
useEffect(() => {
if (activeView !== "home") {
return;
}
void ensureHomeSession().catch((error) => {
console.error("Failed to ensure Home session:", error);
});
}, [activeView, ensureHomeSession]);
const createNewTab = useCallback(
(title = DEFAULT_CHAT_TITLE, project?: ProjectInfo) => {
async (title = DEFAULT_CHAT_TITLE, project?: ProjectInfo) => {
const tStart = performance.now();
perfLog(
`[perf:newtab] createNewTab start (project=${project?.id ?? "none"})`,
);
const agentId = agentStore.activeAgentId ?? undefined;
const providerId = project?.preferredProvider ?? homeSelectedProvider;
const personaId = homeSelectedPersonaId;
const providerId =
project?.preferredProvider ?? agentStore.selectedProvider ?? "goose";
const modelId = project?.preferredModel ?? undefined;
const sessionState = useChatSessionStore.getState();
const chatStoreState = useChatStore.getState();
const chatState = useChatStore.getState();
const existingDraft = findExistingDraft({
sessions: sessionState.sessions,
activeSessionId: sessionState.activeSessionId,
draftsBySession: chatStoreState.draftsBySession,
messagesBySession: chatStoreState.messagesBySession,
draftsBySession: chatState.draftsBySession,
messagesBySession: chatState.messagesBySession,
request: {
title,
projectId: project?.id,
agentId,
providerId,
personaId,
},
});
if (existingDraft) {
if (sessionState.activeSessionId !== existingDraft.id) {
cleanupEmptyDraft(sessionState.activeSessionId);
}
sessionState.setActiveSession(existingDraft.id);
sessionStore.setActiveSession(existingDraft.id);
setActiveView("chat");
chatStore.setActiveSession(existingDraft.id);
perfLog(
@ -196,46 +305,45 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
);
return existingDraft;
}
cleanupEmptyDraft(sessionState.activeSessionId);
const session = sessionStore.createDraftSession({
const workingDir = await resolveSessionCwd(project);
const session = await sessionStore.createSession({
title,
projectId: project?.id,
agentId,
providerId,
personaId,
workingDir,
modelId,
modelName: modelId,
});
sessionStore.setActiveSession(session.id);
setActiveView("chat");
chatStore.setActiveSession(session.id);
perfLog(
`[perf:newtab] ${session.id.slice(0, 8)} created draft in ${(performance.now() - tStart).toFixed(1)}ms`,
`[perf:newtab] ${session.id.slice(0, 8)} created session in ${(performance.now() - tStart).toFixed(1)}ms`,
);
return session;
},
[
agentStore.activeAgentId,
agentStore.selectedProvider,
chatStore,
sessionStore,
agentStore.activeAgentId,
homeSelectedPersonaId,
homeSelectedProvider,
cleanupEmptyDraft,
],
);
const handleStartChatFromProject = useCallback(
(project: ProjectInfo) => {
setHomeSelectedProvider(undefined);
createNewTab(DEFAULT_CHAT_TITLE, project);
void createNewTab(DEFAULT_CHAT_TITLE, project);
},
[createNewTab],
);
const handleNewChatInProject = useCallback(
(projectId: string) => {
setHomeSelectedProvider(undefined);
const project = projectStore.projects.find((p) => p.id === projectId);
if (project) {
createNewTab(DEFAULT_CHAT_TITLE, project);
void createNewTab(DEFAULT_CHAT_TITLE, project);
}
},
[createNewTab, projectStore.projects],
@ -255,12 +363,11 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
const clearActiveSession = useCallback(
(sessionId: string) => {
cleanupEmptyDraft(sessionId);
chatStore.cleanupSession(sessionId);
sessionStore.setActiveSession(null);
setActiveView("home");
},
[chatStore, sessionStore, cleanupEmptyDraft],
[chatStore, sessionStore],
);
const openSettings = useCallback((section: SectionId = "appearance") => {
setSettingsInitialSection(section);
@ -306,7 +413,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
sessionStore.updateSession(sessionId, { projectId });
const session = useChatSessionStore.getState().getSession(sessionId);
if (!session || session.draft) {
if (!session) {
return;
}
@ -362,41 +469,28 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
[],
);
const handleHomeStartChat = useCallback(
(
initialMessage?: string,
providerId?: string,
personaId?: string,
projectId?: string | null,
attachments?: ChatAttachmentDraft[],
) => {
setHomeSelectedProvider(providerId);
setHomeSelectedPersonaId(personaId);
setPendingInitialMessage(initialMessage);
setPendingInitialAttachments(attachments);
const selectedProject =
projectId != null
? projectStore.projects.find((project) => project.id === projectId)
: undefined;
createNewTab(
initialMessage?.slice(0, 40) || DEFAULT_CHAT_TITLE,
selectedProject,
);
const activateHomeSession = useCallback(
(sessionId: string) => {
if (homeSessionId === sessionId) {
setHomeSessionId(null);
}
sessionStore.setActiveSession(sessionId);
setActiveView("chat");
chatStore.setActiveSession(sessionId);
useChatStore.getState().markSessionRead(sessionId);
},
[createNewTab, projectStore.projects],
[chatStore, homeSessionId, sessionStore],
);
const handleSelectSession = useCallback(
(id: string) => {
cleanupEmptyDraft(useChatSessionStore.getState().activeSessionId);
sessionStore.setActiveSession(id);
setActiveView("chat");
chatStore.setActiveSession(id);
useChatStore.getState().markSessionRead(id);
loadSessionMessages(id);
},
[sessionStore, chatStore, loadSessionMessages, cleanupEmptyDraft],
[sessionStore, chatStore, loadSessionMessages],
);
const handleSelectSearchResult = useCallback(
@ -414,12 +508,11 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
const handleNavigate = useCallback(
(view: AppView) => {
if (view !== "chat") {
cleanupEmptyDraft(useChatSessionStore.getState().activeSessionId);
sessionStore.setActiveSession(null);
}
setActiveView(view);
},
[sessionStore, cleanupEmptyDraft],
[sessionStore],
);
const toggleSidebar = () => setSidebarCollapsed((prev) => !prev);
@ -496,20 +589,13 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
// Cmd+N opens new conversation screen
if (e.key === "n" && e.metaKey) {
e.preventDefault();
createNewTab();
sessionStore.setActiveSession(null);
setActiveView("home");
}
};
window.addEventListener("keydown", handler);
return () => window.removeEventListener("keydown", handler);
}, [clearActiveSession, createNewTab]);
const activeSessionPersonaId = activeSession?.personaId;
const handleInitialMessageConsumed = useCallback(() => {
setPendingInitialMessage(undefined);
setPendingInitialAttachments(undefined);
setHomeSelectedProvider(undefined);
setHomeSelectedPersonaId(undefined);
}, []);
}, [clearActiveSession, sessionStore]);
const editingProjectProp = useMemo(
() =>
@ -551,7 +637,10 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
onCollapse={toggleSidebar}
onNavigate={handleNavigate}
onNewChatInProject={handleNewChatInProject}
onNewChat={() => createNewTab()}
onNewChat={() => {
sessionStore.setActiveSession(null);
setActiveView("home");
}}
onCreateProject={() => openCreateProjectDialog()}
onEditProject={handleEditProject}
onArchiveProject={handleArchiveProject}
@ -582,15 +671,10 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
<AppShellContent
activeView={activeView}
activeSession={activeSession}
activeSessionPersonaId={activeSessionPersonaId}
homeSelectedProvider={homeSelectedProvider}
homeSelectedPersonaId={homeSelectedPersonaId}
pendingInitialMessage={pendingInitialMessage}
pendingInitialAttachments={pendingInitialAttachments}
homeSessionId={homeSessionId}
onArchiveChat={handleArchiveChat}
onCreateProject={openCreateProjectDialog}
onHomeStartChat={handleHomeStartChat}
onInitialMessageConsumed={handleInitialMessageConsumed}
onActivateHomeSession={activateHomeSession}
onRenameChat={handleRenameChat}
onSelectSession={handleSelectSession}
onSelectSearchResult={handleSelectSearchResult}

View file

@ -1,10 +1,17 @@
import { useEffect } from "react";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
import { useProviderInventoryStore } from "@/features/providers/stores/providerInventoryStore";
import { setNotificationHandler, getClient } from "@/shared/api/acpConnection";
import notificationHandler from "@/shared/api/acpNotificationHandler";
import { perfLog } from "@/shared/lib/perfLog";
const INVENTORY_POLL_DELAYS_MS = [250, 500, 750, 1000, 1500, 2000];
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => window.setTimeout(resolve, ms));
}
export function useAppStartup() {
useEffect(() => {
(async () => {
@ -22,6 +29,7 @@ export function useAppStartup() {
}
const store = useAgentStore.getState();
const inventoryStore = useProviderInventoryStore.getState();
const loadPersonas = async () => {
const t0 = performance.now();
store.setPersonasLoading(true);
@ -56,6 +64,76 @@ export function useAppStartup() {
}
};
const loadProviderInventory = async () => {
const t0 = performance.now();
inventoryStore.setLoading(true);
try {
const { getProviderInventory } = await import(
"@/features/providers/api/inventory"
);
const entries = await getProviderInventory();
inventoryStore.setEntries(entries);
perfLog(
`[perf:startup] loadProviderInventory done in ${(performance.now() - t0).toFixed(1)}ms (n=${entries.length})`,
);
return entries;
} catch (err) {
console.error("Failed to load provider inventory on startup:", err);
return [];
} finally {
inventoryStore.setLoading(false);
}
};
const refreshConfiguredProviderInventory = async (
initialEntries?: Awaited<ReturnType<typeof loadProviderInventory>>,
) => {
try {
const entries =
initialEntries && initialEntries.length > 0
? initialEntries
: await (async () => {
const { getProviderInventory } = await import(
"@/features/providers/api/inventory"
);
return getProviderInventory();
})();
const configuredProviderIds = entries
.filter((entry) => entry.configured)
.map((entry) => entry.providerId);
if (configuredProviderIds.length === 0) {
return;
}
const { getProviderInventory, refreshProviderInventory } =
await import("@/features/providers/api/inventory");
const refresh = await refreshProviderInventory(configuredProviderIds);
if (refresh.started.length === 0) {
return;
}
inventoryStore.mergeEntries(
await getProviderInventory(refresh.started),
);
for (const delayMs of INVENTORY_POLL_DELAYS_MS) {
await sleep(delayMs);
const refreshedEntries = await getProviderInventory(
refresh.started,
);
inventoryStore.mergeEntries(refreshedEntries);
if (refreshedEntries.every((entry) => !entry.refreshing)) {
return;
}
}
} catch (err) {
console.error(
"Failed to refresh provider inventory on startup:",
err,
);
}
};
const loadSessionState = async () => {
const t0 = performance.now();
perfLog("[perf:startup] loadSessionState start");
@ -68,11 +146,17 @@ export function useAppStartup() {
setActiveSession(null);
};
const inventoryLoad = loadProviderInventory();
await Promise.allSettled([
loadPersonas(),
loadProviders(),
inventoryLoad,
loadSessionState(),
]);
void inventoryLoad.then((entries) =>
refreshConfiguredProviderInventory(entries),
);
perfLog(
`[perf:startup] useAppStartup complete in ${(performance.now() - tStartup).toFixed(1)}ms`,
);

View file

@ -5,31 +5,19 @@ import { AgentsView } from "@/features/agents/ui/AgentsView";
import { ProjectsView } from "@/features/projects/ui/ProjectsView";
import { SessionHistoryView } from "@/features/sessions/ui/SessionHistoryView";
import type { ChatSession } from "@/features/chat/stores/chatSessionStore";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
import type { ProjectInfo } from "@/features/projects/api/projects";
import type { AppView } from "../AppShell";
interface AppShellContentProps {
activeView: AppView;
activeSession?: ChatSession;
activeSessionPersonaId?: string;
homeSelectedProvider?: string;
homeSelectedPersonaId?: string;
pendingInitialMessage?: string;
pendingInitialAttachments?: ChatAttachmentDraft[];
homeSessionId: string | null;
onArchiveChat: (sessionId: string) => Promise<void>;
onCreateProject: (options?: {
initialWorkingDir?: string | null;
onCreated?: (projectId: string) => void;
}) => void;
onHomeStartChat: (
initialMessage?: string,
providerId?: string,
personaId?: string,
projectId?: string | null,
attachments?: ChatAttachmentDraft[],
) => void;
onInitialMessageConsumed: () => void;
onActivateHomeSession: (sessionId: string) => void;
onRenameChat: (sessionId: string, nextTitle: string) => void;
onSelectSession: (sessionId: string) => void;
onSelectSearchResult: (
@ -43,15 +31,10 @@ interface AppShellContentProps {
export function AppShellContent({
activeView,
activeSession,
activeSessionPersonaId,
homeSelectedProvider,
homeSelectedPersonaId,
pendingInitialMessage,
pendingInitialAttachments,
homeSessionId,
onArchiveChat,
onCreateProject,
onHomeStartChat,
onInitialMessageConsumed,
onActivateHomeSession,
onRenameChat,
onSelectSession,
onSelectSearchResult,
@ -74,21 +57,24 @@ export function AppShellContent({
/>
);
case "chat":
case "home":
return activeSession ? (
<ChatView
key={activeSession.id}
sessionId={activeSession.id}
initialProvider={homeSelectedProvider}
initialPersonaId={activeSessionPersonaId ?? homeSelectedPersonaId}
initialMessage={pendingInitialMessage}
initialAttachments={pendingInitialAttachments}
onCreateProject={onCreateProject}
onInitialMessageConsumed={onInitialMessageConsumed}
/>
) : (
<HomeScreen
onStartChat={onHomeStartChat}
sessionId={homeSessionId}
onActivateSession={onActivateHomeSession}
onCreateProject={onCreateProject}
/>
);
case "home":
return (
<HomeScreen
sessionId={homeSessionId}
onActivateSession={onActivateHomeSession}
onCreateProject={onCreateProject}
/>
);

View file

@ -0,0 +1,125 @@
import { act, renderHook } from "@testing-library/react";
import { describe, expect, it, vi } from "vitest";
import { useAgentModelPickerState } from "../useAgentModelPickerState";
const mockUseProviderInventory = vi.fn();
vi.mock("@/features/providers/hooks/useProviderInventory", () => ({
useProviderInventory: () => mockUseProviderInventory(),
}));
describe("useAgentModelPickerState", () => {
it("switches to goose when the current provider is goose-backed", () => {
const onProviderSelected = vi.fn();
mockUseProviderInventory.mockReturnValue({
entries: new Map([
[
"anthropic",
{
providerId: "anthropic",
configured: true,
refreshing: false,
models: [],
},
],
]),
getEntry: (providerId: string) =>
providerId === "anthropic"
? {
providerId: "anthropic",
configured: true,
refreshing: false,
models: [],
}
: undefined,
configuredModelProviderEntries: [],
getModelsForAgent: () => [],
loading: false,
});
const { result } = renderHook(() =>
useAgentModelPickerState({
providers: [{ id: "anthropic", label: "Anthropic" }],
selectedProvider: "anthropic",
onProviderSelected,
}),
);
act(() => {
result.current.handleProviderChange("goose");
});
expect(onProviderSelected).toHaveBeenCalledWith("goose");
});
it("treats goose as a no-op only when goose is already selected", () => {
const onProviderSelected = vi.fn();
mockUseProviderInventory.mockReturnValue({
entries: new Map(),
getEntry: () => undefined,
configuredModelProviderEntries: [],
getModelsForAgent: () => [],
loading: false,
});
const { result } = renderHook(() =>
useAgentModelPickerState({
providers: [],
selectedProvider: "goose",
onProviderSelected,
}),
);
act(() => {
result.current.handleProviderChange("goose");
});
expect(onProviderSelected).not.toHaveBeenCalled();
});
it("passes the selected model provider through for goose model picks", () => {
const onModelSelected = vi.fn();
mockUseProviderInventory.mockReturnValue({
entries: new Map(),
getEntry: () => undefined,
configuredModelProviderEntries: [],
getModelsForAgent: () => [
{
id: "claude-sonnet-4",
name: "claude-sonnet-4",
displayName: "Claude Sonnet 4",
providerId: "anthropic",
providerName: "Anthropic",
recommended: true,
},
],
loading: false,
});
const { result } = renderHook(() =>
useAgentModelPickerState({
providers: [{ id: "goose", label: "Goose" }],
selectedProvider: "openai",
onProviderSelected: vi.fn(),
onModelSelected,
}),
);
act(() => {
result.current.handleModelChange("claude-sonnet-4");
});
expect(onModelSelected).toHaveBeenCalledWith({
id: "claude-sonnet-4",
name: "claude-sonnet-4",
displayName: "Claude Sonnet 4",
provider: undefined,
providerId: "anthropic",
providerName: "Anthropic",
recommended: true,
});
});
});

View file

@ -33,8 +33,6 @@ describe("useChat attachments", () => {
isLoading: false,
contextPanelOpenBySession: {},
activeWorkspaceBySession: {},
modelsBySession: {},
modelCacheByProvider: {},
});
useAgentStore.setState({
personas: [],

View file

@ -81,8 +81,6 @@ describe("useChat", () => {
isLoading: false,
contextPanelOpenBySession: {},
activeWorkspaceBySession: {},
modelsBySession: {},
modelCacheByProvider: {},
});
useAgentStore.setState({
personas: [
@ -354,7 +352,7 @@ describe("useChat", () => {
});
});
it("prepares draft sessions before applying a selected model on first send", async () => {
it("sends messages without an extra session preparation step", async () => {
useChatSessionStore.setState({
sessions: [
{
@ -366,28 +364,16 @@ describe("useChat", () => {
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messageCount: 0,
draft: true,
},
],
});
const { result } = renderHook(() =>
useChat("session-1", "openai", undefined, undefined, async () => "/tmp"),
);
const { result } = renderHook(() => useChat("session-1", "openai"));
await act(async () => {
await result.current.sendMessage("Hello");
});
expect(mockAcpPrepareSession).toHaveBeenCalledWith(
"session-1",
"openai",
"/tmp",
{
personaId: undefined,
},
);
expect(mockAcpSetModel).toHaveBeenCalledWith("session-1", "gpt-4.1");
expect(mockAcpSendMessage).toHaveBeenCalledWith("session-1", "Hello", {
systemPrompt: undefined,
personaId: undefined,
@ -396,6 +382,50 @@ describe("useChat", () => {
});
});
it("fires onMessageAccepted only after the message enters the session", async () => {
const onMessageAccepted = vi.fn();
const deferred = createDeferredPromise();
mockAcpSendMessage.mockReturnValue(deferred.promise);
const { result } = renderHook(() =>
useChat("session-1", undefined, undefined, undefined, {
onMessageAccepted,
}),
);
await act(async () => {
const sendPromise = result.current.sendMessage("Hello");
await Promise.resolve();
expect(onMessageAccepted).toHaveBeenCalledTimes(1);
expect(
useChatStore.getState().messagesBySession["session-1"],
).toHaveLength(1);
deferred.resolve();
await sendPromise;
});
});
it("awaits ensurePrepared before prompting", async () => {
const ensurePrepared = vi.fn().mockResolvedValue(undefined);
const { result } = renderHook(() =>
useChat("session-1", undefined, undefined, undefined, {
ensurePrepared,
}),
);
await act(async () => {
await result.current.sendMessage("Hello");
});
expect(ensurePrepared).toHaveBeenCalledTimes(1);
expect(ensurePrepared.mock.invocationCallOrder[0]).toBeLessThan(
mockAcpSendMessage.mock.invocationCallOrder[0],
);
});
it("appends an error message and removes the empty assistant placeholder when send fails", async () => {
mockAcpSendMessage.mockRejectedValue(
new Error("Working directory missing"),

View file

@ -0,0 +1,181 @@
import { act, renderHook, waitFor } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useProjectStore } from "@/features/projects/stores/projectStore";
import { useChatStore } from "../../stores/chatStore";
import { useChatSessionStore } from "../../stores/chatSessionStore";
const mockAcpPrepareSession = vi.fn();
const mockAcpSetModel = vi.fn();
const mockSetSelectedProvider = vi.fn();
const mockResolveSessionCwd = vi.fn();
vi.mock("@/shared/api/acp", () => ({
acpPrepareSession: (...args: unknown[]) => mockAcpPrepareSession(...args),
acpSetModel: (...args: unknown[]) => mockAcpSetModel(...args),
}));
vi.mock("../useChat", () => ({
useChat: () => ({
messages: [],
chatState: "idle",
tokenState: null,
sendMessage: vi.fn(),
stopStreaming: vi.fn(),
streamingMessageId: null,
}),
}));
vi.mock("../useMessageQueue", () => ({
useMessageQueue: () => ({
queuedMessage: null,
enqueue: vi.fn(),
}),
}));
vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
useProviderSelection: () => ({
providers: [
{ id: "goose", label: "Goose" },
{ id: "openai", label: "OpenAI" },
{ id: "anthropic", label: "Anthropic" },
],
providersLoading: false,
selectedProvider: "openai",
setSelectedProvider: (...args: unknown[]) =>
mockSetSelectedProvider(...args),
}),
}));
vi.mock("@/features/projects/lib/sessionCwdSelection", () => ({
resolveSessionCwd: (...args: unknown[]) => mockResolveSessionCwd(...args),
}));
vi.mock("../useAgentModelPickerState", () => ({
useAgentModelPickerState: ({
onModelSelected,
}: {
onModelSelected?: (model: {
id: string;
name: string;
displayName?: string;
providerId?: string;
}) => void;
}) => ({
selectedAgentId: "goose",
pickerAgents: [{ id: "goose", label: "Goose" }],
availableModels: [],
modelsLoading: false,
modelStatusMessage: null,
handleProviderChange: vi.fn(),
handleModelChange: (modelId: string) => {
if (modelId === "claude-sonnet-4") {
onModelSelected?.({
id: modelId,
name: modelId,
displayName: "Claude Sonnet 4",
providerId: "anthropic",
});
}
},
}),
}));
import { useChatSessionController } from "../useChatSessionController";
describe("useChatSessionController", () => {
beforeEach(() => {
vi.clearAllMocks();
mockAcpPrepareSession.mockResolvedValue(undefined);
mockAcpSetModel.mockResolvedValue(undefined);
mockResolveSessionCwd.mockResolvedValue("/tmp/project");
useAgentStore.setState({
personas: [],
personasLoading: false,
agents: [],
agentsLoading: false,
providers: [],
providersLoading: false,
selectedProvider: "openai",
activeAgentId: null,
isLoading: false,
personaEditorOpen: false,
editingPersona: null,
});
useProjectStore.setState({
projects: [],
loading: false,
activeProjectId: null,
});
useChatStore.setState({
messagesBySession: {},
sessionStateById: {},
draftsBySession: {},
queuedMessageBySession: {},
scrollTargetMessageBySession: {},
activeSessionId: null,
isConnected: true,
});
useChatSessionStore.setState({
sessions: [
{
id: "session-1",
title: "Chat",
providerId: "openai",
modelId: "gpt-4o",
modelName: "GPT-4o",
createdAt: "2026-04-20T00:00:00.000Z",
updatedAt: "2026-04-20T00:00:00.000Z",
messageCount: 0,
},
],
activeSessionId: null,
isLoading: false,
hasHydratedSessions: true,
contextPanelOpenBySession: {},
activeWorkspaceBySession: {},
});
});
it("prepares the selected model provider before setting a goose model", async () => {
const { result } = renderHook(() =>
useChatSessionController({ sessionId: "session-1" }),
);
act(() => {
result.current.handleModelChange("claude-sonnet-4");
});
await waitFor(() => {
expect(mockAcpPrepareSession).toHaveBeenCalledWith(
"session-1",
"anthropic",
"/tmp/project",
{ personaId: undefined },
);
});
await waitFor(() => {
expect(mockAcpSetModel).toHaveBeenCalledWith(
"session-1",
"claude-sonnet-4",
);
});
expect(mockAcpPrepareSession.mock.invocationCallOrder[0]).toBeLessThan(
mockAcpSetModel.mock.invocationCallOrder[0],
);
expect(mockSetSelectedProvider).toHaveBeenCalledWith("anthropic");
expect(
useChatSessionStore.getState().getSession("session-1"),
).toMatchObject({
providerId: "anthropic",
modelId: "claude-sonnet-4",
modelName: "Claude Sonnet 4",
});
});
});

View file

@ -0,0 +1,173 @@
import { useCallback, useMemo } from "react";
import type { AcpProvider } from "@/shared/api/acp";
import { useProviderInventory } from "@/features/providers/hooks/useProviderInventory";
import {
getCatalogEntry,
resolveAgentProviderCatalogIdStrict,
} from "@/features/providers/providerCatalog";
import type { ModelOption } from "../types";
interface UseAgentModelPickerStateOptions {
providers: AcpProvider[];
selectedProvider?: string;
onProviderSelected: (providerId: string) => void;
onModelSelected?: (model: ModelOption) => void;
}
const EMPTY_MODELS: ModelOption[] = [];
export function useAgentModelPickerState({
providers,
selectedProvider,
onProviderSelected,
onModelSelected,
}: UseAgentModelPickerStateOptions) {
const {
entries: providerInventoryEntries,
getEntry: getProviderInventoryEntry,
configuredModelProviderEntries,
getModelsForAgent,
loading: providerInventoryLoading,
} = useProviderInventory();
const selectedAgentId = selectedProvider
? (resolveAgentProviderCatalogIdStrict(selectedProvider) ?? "goose")
: "goose";
const selectedProviderInventory = getProviderInventoryEntry(selectedAgentId);
const pickerAgents = useMemo(() => {
const visible = new Map<string, { id: string; label: string }>();
visible.set("goose", {
id: "goose",
label: getCatalogEntry("goose")?.displayName ?? "Goose",
});
for (const provider of providers) {
const agentId = resolveAgentProviderCatalogIdStrict(provider.id);
if (!agentId || agentId === "goose") {
continue;
}
const inventoryEntry = providerInventoryEntries.get(agentId);
if (!inventoryEntry?.configured && agentId !== selectedAgentId) {
continue;
}
visible.set(agentId, {
id: agentId,
label: getCatalogEntry(agentId)?.displayName ?? provider.label,
});
}
if (!visible.has(selectedAgentId)) {
visible.set(selectedAgentId, {
id: selectedAgentId,
label: getCatalogEntry(selectedAgentId)?.displayName ?? selectedAgentId,
});
}
return [...visible.values()];
}, [providerInventoryEntries, providers, selectedAgentId]);
const availableModels = useMemo(
() => getModelsForAgent(selectedAgentId) ?? EMPTY_MODELS,
[getModelsForAgent, selectedAgentId],
);
const modelsLoading = useMemo(() => {
// Show loading only when we have no models to display yet.
// If cached models exist, show them immediately — a background refresh
// will update the list when it completes.
if (availableModels.length > 0) {
return false;
}
if (providerInventoryLoading) {
return true;
}
if (selectedAgentId === "goose") {
return (
configuredModelProviderEntries.length > 0 &&
configuredModelProviderEntries.some((entry) => entry.refreshing)
);
}
return selectedProviderInventory?.refreshing === true;
}, [
availableModels.length,
configuredModelProviderEntries,
providerInventoryLoading,
selectedAgentId,
selectedProviderInventory?.refreshing,
]);
const modelStatusMessage = useMemo(() => {
if (availableModels.length > 0) {
return null;
}
if (selectedAgentId === "goose") {
const entryWithHint = configuredModelProviderEntries.find(
(entry) => entry.modelSelectionHint || entry.lastRefreshError,
);
return (
entryWithHint?.modelSelectionHint ??
entryWithHint?.lastRefreshError ??
null
);
}
return (
selectedProviderInventory?.modelSelectionHint ??
selectedProviderInventory?.lastRefreshError ??
null
);
}, [
availableModels.length,
configuredModelProviderEntries,
selectedAgentId,
selectedProviderInventory?.modelSelectionHint,
selectedProviderInventory?.lastRefreshError,
]);
const handleProviderChange = useCallback(
(providerId: string) => {
if (providerId === (selectedProvider ?? "goose")) {
return;
}
onProviderSelected(providerId);
},
[onProviderSelected, selectedProvider],
);
const handleModelChange = useCallback(
(modelId: string) => {
const selectedModel = availableModels.find(
(model) => model.id === modelId,
);
onModelSelected?.({
id: modelId,
name: selectedModel?.name ?? modelId,
displayName: selectedModel?.displayName ?? modelId,
provider: selectedModel?.provider,
providerId: selectedModel?.providerId,
providerName: selectedModel?.providerName,
recommended: selectedModel?.recommended,
});
},
[availableModels, onModelSelected],
);
return {
selectedAgentId,
pickerAgents,
availableModels,
modelsLoading,
modelStatusMessage,
handleProviderChange,
handleModelChange,
};
}

View file

@ -14,8 +14,6 @@ import {
acpSendMessage,
acpCancelSession,
acpLoadSession,
acpPrepareSession,
acpSetModel,
} from "@/shared/api/acp";
import { getGooseSessionId } from "@/shared/api/acpSessionTracker";
import { useAgentStore } from "@/features/agents/stores/agentStore";
@ -109,7 +107,10 @@ export function useChat(
providerOverride?: string,
systemPromptOverride?: string,
personaInfo?: { id: string; name: string },
getWorkingDir?: () => Promise<string | undefined>,
options?: {
onMessageAccepted?: (sessionId: string) => void;
ensurePrepared?: () => Promise<void>;
},
) {
const store = useChatStore();
const abortRef = useRef<AbortController | null>(null);
@ -215,15 +216,8 @@ export function useChat(
store.setChatState(sessionId, "thinking");
store.setError(sessionId, null);
// Promote draft to real backend session before first send
const sessionStore = useChatSessionStore.getState();
const session = sessionStore.getSession(sessionId);
const wasDraft = !!session?.draft;
const selectedModelId = session?.modelId;
if (wasDraft) {
sessionStore.promoteDraft(sessionId);
}
// Immediately set the session/sidebar title from the user's message when
// the session still has the default placeholder. This gives instant
@ -231,20 +225,18 @@ export function useChat(
// A better backend-generated title will overwrite this if it arrives
// via the acp:session_info event.
if (session && isDefaultChatTitle(session.title)) {
sessionStore.updateSession(
sessionId,
{
title: getSessionTitleFromDraft(text, attachments),
updatedAt: new Date().toISOString(),
},
{ localOnly: wasDraft },
);
sessionStore.updateSession(sessionId, {
title: getSessionTitleFromDraft(text, attachments),
updatedAt: new Date().toISOString(),
});
} else {
sessionStore.updateSession(sessionId, {
updatedAt: new Date().toISOString(),
});
}
options?.onMessageAccepted?.(sessionId);
store.clearDraft(sessionId);
const abort = new AbortController();
@ -252,26 +244,7 @@ export function useChat(
streamingPersonaIdRef.current = effectivePersonaInfo?.id ?? null;
try {
if (wasDraft || selectedModelId) {
const workingDir = await getWorkingDir?.();
if (!workingDir) {
throw new Error("Missing session working directory");
}
const tPrep = performance.now();
await acpPrepareSession(sessionId, providerId, workingDir, {
personaId: effectivePersonaInfo?.id,
});
perfLog(
`[perf:send] ${sid} acpPrepareSession in ${(performance.now() - tPrep).toFixed(1)}ms (wasDraft=${wasDraft})`,
);
if (selectedModelId) {
const tModel = performance.now();
await acpSetModel(sessionId, selectedModelId);
perfLog(
`[perf:send] ${sid} acpSetModel(${selectedModelId}) in ${(performance.now() - tModel).toFixed(1)}ms`,
);
}
}
await options?.ensurePrepared?.();
store.setChatState(sessionId, "streaming");
// When images are present with no text, pass a single space so the ACP
@ -298,18 +271,6 @@ export function useChat(
store.setChatState(sessionId, "idle");
store.setStreamingMessageId(sessionId, null);
if (wasDraft) {
const promoted = sessionStore.getSession(sessionId);
if (promoted) {
sessionStore.updateSession(sessionId, {
title: promoted.title,
providerId: promoted.providerId,
personaId: promoted.personaId,
projectId: promoted.projectId,
});
}
}
} catch (err) {
if (err instanceof DOMException && err.name === "AbortError") {
store.setChatState(sessionId, "idle");
@ -351,7 +312,7 @@ export function useChat(
providerOverride,
systemPromptOverride,
resolvePersonaInfo,
getWorkingDir,
options,
],
);
@ -415,6 +376,12 @@ export function useChat(
store.setPendingAssistantProvider(sessionId, null);
}, [sessionId, store]);
const getWorkingDir = useCallback(
() =>
useChatSessionStore.getState().activeWorkspaceBySession[sessionId]?.path,
[sessionId],
);
const compactConversation = useCallback(async () => {
const currentChatState = useChatStore
.getState()
@ -457,7 +424,7 @@ export function useChat(
// layer does not currently forward history replacement events. Drop those
// transient chunks and refresh the session from replay instead.
clearReplayBuffer(sessionId);
const workingDir = await getWorkingDir?.();
const workingDir = getWorkingDir();
await acpLoadSession(sessionId, gooseSessionId, workingDir);
store.setSessionLoading(sessionId, false);

View file

@ -0,0 +1,682 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
import { useChat } from "./useChat";
import { useMessageQueue } from "./useMessageQueue";
import { useChatStore } from "../stores/chatStore";
import { useChatSessionStore } from "../stores/chatSessionStore";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
import { useProjectStore } from "@/features/projects/stores/projectStore";
import { useAgentModelPickerState } from "./useAgentModelPickerState";
import {
buildProjectSystemPrompt,
composeSystemPrompt,
getProjectArtifactRoots,
resolveProjectDefaultArtifactRoot,
} from "@/features/projects/lib/chatProjectContext";
import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection";
import { acpPrepareSession, acpSetModel } from "@/shared/api/acp";
interface UseChatSessionControllerOptions {
sessionId: string | null;
onMessageAccepted?: (sessionId: string) => void;
}
const PENDING_HOME_SESSION_ID = "__home_pending__";
export function useChatSessionController({
sessionId,
onMessageAccepted,
}: UseChatSessionControllerOptions) {
const stateSessionId = sessionId ?? PENDING_HOME_SESSION_ID;
const {
providers,
providersLoading,
selectedProvider: globalSelectedProvider,
setSelectedProvider: setGlobalSelectedProvider,
} = useProviderSelection();
const personas = useAgentStore((s) => s.personas);
const session = useChatSessionStore((s) =>
sessionId
? s.sessions.find((candidate) => candidate.id === sessionId)
: undefined,
);
const activeWorkspace = useChatSessionStore((s) =>
sessionId ? s.activeWorkspaceBySession[sessionId] : undefined,
);
const clearActiveWorkspace = useChatSessionStore(
(s) => s.clearActiveWorkspace,
);
const projects = useProjectStore((s) => s.projects);
const projectsLoading = useProjectStore((s) => s.loading);
const [pendingPersonaId, setPendingPersonaId] = useState<string | null>();
const [pendingProjectId, setPendingProjectId] = useState<string | null>();
const [pendingProviderId, setPendingProviderId] = useState<string>();
const [pendingModelSelection, setPendingModelSelection] = useState<{
id: string;
name: string;
providerId?: string;
} | null>();
const pendingDraftValue = useChatStore(
(s) => s.draftsBySession[PENDING_HOME_SESSION_ID] ?? "",
);
const pendingQueuedMessage = useChatStore(
(s) => s.queuedMessageBySession[PENDING_HOME_SESSION_ID] ?? null,
);
const effectiveProjectId =
pendingProjectId !== undefined
? pendingProjectId
: (session?.projectId ?? null);
const storedProject = useProjectStore((s) =>
effectiveProjectId
? s.projects.find((candidate) => candidate.id === effectiveProjectId)
: undefined,
);
const project = storedProject ?? null;
const selectedProvider =
pendingProviderId ??
session?.providerId ??
project?.preferredProvider ??
globalSelectedProvider;
const selectedPersonaId =
pendingPersonaId !== undefined
? pendingPersonaId
: (session?.personaId ?? null);
const selectedPersona = personas.find(
(persona) => persona.id === selectedPersonaId,
);
const projectArtifactRoots = useMemo(
() => getProjectArtifactRoots(project),
[project],
);
const projectDefaultArtifactRoot = useMemo(
() => resolveProjectDefaultArtifactRoot(project),
[project],
);
const projectMetadataPending = Boolean(
effectiveProjectId && !projectDefaultArtifactRoot && projectsLoading,
);
const allowedArtifactRoots = useMemo(
() => [
...new Set(
projectArtifactRoots.map((path) => path.trim()).filter(Boolean),
),
],
[projectArtifactRoots],
);
const availableProjects = useMemo(
() =>
[...projects]
.sort((a, b) => a.order - b.order || a.name.localeCompare(b.name))
.map((projectInfo) => ({
id: projectInfo.id,
name: projectInfo.name,
workingDirs: projectInfo.workingDirs,
color: projectInfo.color,
})),
[projects],
);
const projectSystemPrompt = useMemo(
() => buildProjectSystemPrompt(project),
[project],
);
const workingContextPrompt = useMemo(() => {
if (!activeWorkspace?.branch) return undefined;
return `<active-working-context>\nActive branch: ${activeWorkspace.branch}\nWorking directory: ${activeWorkspace.path}\n</active-working-context>`;
}, [activeWorkspace?.branch, activeWorkspace?.path]);
const effectiveSystemPrompt = useMemo(
() =>
composeSystemPrompt(
selectedPersona?.systemPrompt,
projectSystemPrompt,
workingContextPrompt,
),
[projectSystemPrompt, selectedPersona?.systemPrompt, workingContextPrompt],
);
const prepareCurrentSession = useCallback(
async (
providerId: string,
nextProject = project,
nextWorkspacePath = activeWorkspace?.path,
personaId = selectedPersonaId ?? undefined,
) => {
if (!sessionId) {
return;
}
const workingDir = await resolveSessionCwd(
nextProject,
nextWorkspacePath,
);
await acpPrepareSession(sessionId, providerId, workingDir, { personaId });
},
[activeWorkspace?.path, project, selectedPersonaId, sessionId],
);
const prevProjectIdRef = useRef(session?.projectId);
useEffect(() => {
if (!sessionId) {
return;
}
const previousProjectId = prevProjectIdRef.current;
prevProjectIdRef.current = session?.projectId;
if (
previousProjectId !== undefined &&
previousProjectId !== session?.projectId
) {
clearActiveWorkspace(sessionId);
}
}, [clearActiveWorkspace, session?.projectId, sessionId]);
const prevWorkspaceRef = useRef(activeWorkspace);
useEffect(() => {
const previousWorkspace = prevWorkspaceRef.current;
if (
!sessionId ||
!activeWorkspace ||
!selectedProvider ||
activeWorkspace === previousWorkspace
) {
return;
}
prevWorkspaceRef.current = activeWorkspace;
if (previousWorkspace?.path === activeWorkspace.path) {
return;
}
void prepareCurrentSession(selectedProvider).catch((error) => {
console.error("Failed to prepare ACP session:", error);
});
}, [activeWorkspace, prepareCurrentSession, selectedProvider, sessionId]);
const {
selectedAgentId,
pickerAgents,
availableModels,
modelsLoading,
modelStatusMessage,
handleProviderChange,
handleModelChange,
} = useAgentModelPickerState({
providers,
selectedProvider,
onProviderSelected: (providerId) => {
if (!sessionId) {
setGlobalSelectedProvider(providerId);
setPendingProviderId(providerId);
setPendingModelSelection(null);
return;
}
useChatSessionStore
.getState()
.switchSessionProvider(sessionId, providerId);
setGlobalSelectedProvider(providerId);
void prepareCurrentSession(providerId).catch((error) => {
console.error("Failed to update ACP session provider:", error);
});
},
onModelSelected: (model) => {
const modelId = model.id;
const modelName = model.displayName ?? model.name ?? model.id;
const nextProviderId = model.providerId ?? selectedProvider;
if (!sessionId) {
if (nextProviderId && nextProviderId !== selectedProvider) {
setPendingProviderId(nextProviderId);
setGlobalSelectedProvider(nextProviderId);
}
setPendingModelSelection({
id: modelId,
name: modelName,
providerId: nextProviderId,
});
return;
}
if (
!session ||
(modelId === session.modelId &&
(!nextProviderId || nextProviderId === session.providerId))
) {
return;
}
const previousProviderId = session.providerId;
const previousModelId = session.modelId;
const previousModelName = session.modelName;
const providerChanged =
Boolean(nextProviderId) && nextProviderId !== session.providerId;
if (providerChanged && nextProviderId) {
useChatSessionStore
.getState()
.switchSessionProvider(sessionId, nextProviderId);
setGlobalSelectedProvider(nextProviderId);
}
useChatSessionStore.getState().updateSession(sessionId, {
modelId,
modelName,
});
void (async () => {
try {
if (providerChanged && nextProviderId) {
await prepareCurrentSession(nextProviderId);
}
await acpSetModel(sessionId, modelId);
} catch (error) {
console.error("Failed to set model:", error);
if (providerChanged && previousProviderId) {
setGlobalSelectedProvider(previousProviderId);
}
useChatSessionStore.getState().updateSession(sessionId, {
providerId: previousProviderId,
modelId: previousModelId,
modelName: previousModelName,
});
void (async () => {
try {
if (providerChanged && previousProviderId) {
await prepareCurrentSession(previousProviderId);
}
if (previousModelId) {
await acpSetModel(sessionId, previousModelId);
}
} catch (rollbackError) {
console.error(
"Failed to restore previous provider/model after setModel failure:",
rollbackError,
);
}
})();
}
})();
},
});
const handleProjectChange = useCallback(
(projectId: string | null) => {
if (!sessionId) {
setPendingProjectId(projectId);
return;
}
const nextProject =
projectId == null
? null
: (useProjectStore
.getState()
.projects.find((candidate) => candidate.id === projectId) ??
null);
useChatSessionStore.getState().updateSession(sessionId, { projectId });
if (!selectedProvider) {
return;
}
void prepareCurrentSession(selectedProvider, nextProject).catch(
(error) => {
console.error(
"Failed to update ACP session working directory:",
error,
);
},
);
},
[prepareCurrentSession, selectedProvider, sessionId],
);
const handlePersonaChange = useCallback(
(personaId: string | null) => {
const persona = personas.find((candidate) => candidate.id === personaId);
if (persona?.provider) {
const matchingProvider = providers.find(
(provider) =>
provider.id === persona.provider ||
provider.label.toLowerCase().includes(persona.provider ?? ""),
);
if (matchingProvider) {
if (!sessionId) {
setPendingProviderId(matchingProvider.id);
setPendingModelSelection(null);
setGlobalSelectedProvider(matchingProvider.id);
} else {
handleProviderChange(matchingProvider.id);
}
}
}
const agentStore = useAgentStore.getState();
const matchingAgent = agentStore.agents.find(
(agent) => agent.personaId === personaId,
);
if (matchingAgent) {
agentStore.setActiveAgent(matchingAgent.id);
}
if (!sessionId) {
setPendingPersonaId(personaId);
return;
}
useChatSessionStore
.getState()
.updateSession(sessionId, { personaId: personaId ?? undefined });
},
[
handleProviderChange,
personas,
providers,
sessionId,
setGlobalSelectedProvider,
],
);
useEffect(() => {
if (
selectedPersonaId !== null &&
personas.length > 0 &&
!personas.find((persona) => persona.id === selectedPersonaId)
) {
if (sessionId) {
useChatSessionStore
.getState()
.updateSession(sessionId, { personaId: undefined });
} else {
setPendingPersonaId(undefined);
}
}
}, [personas, selectedPersonaId, sessionId]);
const personaInfo = selectedPersona
? { id: selectedPersona.id, name: selectedPersona.displayName }
: undefined;
const {
messages,
chatState,
tokenState,
sendMessage,
stopStreaming,
streamingMessageId,
} = useChat(
stateSessionId,
selectedProvider,
effectiveSystemPrompt,
personaInfo,
{
onMessageAccepted: sessionId ? onMessageAccepted : undefined,
ensurePrepared: selectedProvider
? () => prepareCurrentSession(selectedProvider)
: undefined,
},
);
const isLoadingHistory = useChatStore((s) =>
sessionId
? s.loadingSessionIds.has(sessionId) &&
(s.messagesBySession[sessionId]?.length ?? 0) === 0
: false,
);
const deferredSend = useRef<{
text: string;
attachments?: ChatAttachmentDraft[];
} | null>(null);
const queue = useMessageQueue(
stateSessionId,
sessionId ? chatState : "thinking",
sendMessage,
);
const chatStore = useChatStore();
const handleSend = useCallback(
(text: string, personaId?: string, attachments?: ChatAttachmentDraft[]) => {
if (!sessionId) {
if (!queue.queuedMessage) {
queue.enqueue(text, personaId, attachments);
}
return;
}
if (personaId && personaId !== selectedPersonaId) {
const nextPersona = personas.find(
(persona) => persona.id === personaId,
);
if (nextPersona) {
chatStore.addMessage(sessionId, {
id: crypto.randomUUID(),
role: "system",
created: Date.now(),
content: [
{
type: "systemNotification",
notificationType: "info",
text: `Switched to ${nextPersona.displayName}`,
},
],
metadata: { userVisible: true, agentVisible: false },
});
}
handlePersonaChange(personaId);
deferredSend.current = { text, attachments };
return;
}
if (chatState !== "idle" && !queue.queuedMessage) {
queue.enqueue(text, personaId, attachments);
return;
}
sendMessage(text, undefined, attachments);
},
[
chatState,
chatStore,
handlePersonaChange,
personas,
queue,
sessionId,
selectedPersonaId,
sendMessage,
],
);
useEffect(() => {
if (deferredSend.current && selectedPersona) {
const { text, attachments } = deferredSend.current;
deferredSend.current = null;
sendMessage(text, undefined, attachments);
}
}, [selectedPersona, sendMessage]);
const handleCreatePersona = useCallback(() => {
useAgentStore.getState().openPersonaEditor();
}, []);
const sessionDraftValue = useChatStore((s) =>
sessionId ? (s.draftsBySession[sessionId] ?? "") : "",
);
const draftValue = sessionId ? sessionDraftValue : pendingDraftValue;
const handleDraftChange = useCallback(
(text: string) => {
useChatStore.getState().setDraft(stateSessionId, text);
},
[stateSessionId],
);
const scrollTarget = useChatStore((s) =>
sessionId ? (s.scrollTargetMessageBySession[sessionId] ?? null) : null,
);
const handleScrollTargetHandled = useCallback(() => {
if (!sessionId) {
return;
}
useChatStore.getState().clearScrollTargetMessage(sessionId);
}, [sessionId]);
useEffect(() => {
if (!sessionId) {
return;
}
let cancelled = false;
void pendingDraftValue;
void pendingQueuedMessage;
const syncPendingHomeState = async () => {
const chatState = useChatStore.getState();
const pendingDraft =
chatState.draftsBySession[PENDING_HOME_SESSION_ID] ?? "";
if (pendingDraft && !chatState.draftsBySession[sessionId]) {
chatState.setDraft(sessionId, pendingDraft);
}
const hasPendingProvider = pendingProviderId !== undefined;
const hasPendingPersona = pendingPersonaId !== undefined;
const hasPendingProject = pendingProjectId !== undefined;
const hasPendingModel = pendingModelSelection !== undefined;
if (
hasPendingProvider ||
hasPendingPersona ||
hasPendingProject ||
hasPendingModel
) {
const nextProviderId = pendingProviderId ?? selectedProvider;
const nextPersonaId =
pendingPersonaId !== undefined
? (pendingPersonaId ?? undefined)
: session?.personaId;
const nextProjectId =
pendingProjectId !== undefined
? pendingProjectId
: session?.projectId;
const nextProject =
nextProjectId == null
? null
: (useProjectStore
.getState()
.projects.find((candidate) => candidate.id === nextProjectId) ??
null);
const patch: {
providerId?: string;
personaId?: string | undefined;
projectId?: string | null;
modelId?: string | undefined;
modelName?: string | undefined;
} = {};
if (hasPendingProvider) {
patch.providerId = nextProviderId;
patch.modelId = undefined;
patch.modelName = undefined;
}
if (hasPendingPersona) {
patch.personaId = nextPersonaId;
}
if (hasPendingProject) {
patch.projectId = nextProjectId ?? null;
}
if (hasPendingModel) {
patch.modelId = pendingModelSelection?.id;
patch.modelName = pendingModelSelection?.name;
}
useChatSessionStore.getState().updateSession(sessionId, patch);
try {
await prepareCurrentSession(
nextProviderId,
nextProject,
activeWorkspace?.path,
nextPersonaId,
);
if (cancelled) {
return;
}
if (pendingModelSelection?.id) {
await acpSetModel(sessionId, pendingModelSelection.id);
if (cancelled) {
return;
}
}
} catch (error) {
console.error("Failed to sync pending Home state:", error);
return;
}
setPendingProviderId(undefined);
setPendingPersonaId(undefined);
setPendingProjectId(undefined);
setPendingModelSelection(undefined);
}
const latestChatState = useChatStore.getState();
const latestPendingQueue =
latestChatState.queuedMessageBySession[PENDING_HOME_SESSION_ID] ?? null;
if (
latestPendingQueue &&
!latestChatState.queuedMessageBySession[sessionId]
) {
latestChatState.enqueueMessage(sessionId, latestPendingQueue);
}
useChatStore.getState().clearDraft(PENDING_HOME_SESSION_ID);
useChatStore.getState().dismissQueuedMessage(PENDING_HOME_SESSION_ID);
useChatStore.getState().cleanupSession(PENDING_HOME_SESSION_ID);
};
void syncPendingHomeState();
return () => {
cancelled = true;
};
}, [
activeWorkspace?.path,
pendingDraftValue,
pendingModelSelection,
pendingPersonaId,
pendingProjectId,
pendingProviderId,
pendingQueuedMessage,
prepareCurrentSession,
selectedProvider,
session?.personaId,
session?.projectId,
sessionId,
]);
return {
session,
project,
allowedArtifactRoots,
messages,
chatState,
tokenState,
stopStreaming,
streamingMessageId,
isLoadingHistory,
queue,
handleSend,
draftValue,
handleDraftChange,
scrollTarget,
handleScrollTargetHandled,
projectMetadataPending,
personas,
selectedPersonaId,
handlePersonaChange,
handleCreatePersona,
pickerAgents,
providersLoading,
selectedProvider: selectedAgentId,
handleProviderChange,
currentModelId:
pendingModelSelection !== undefined
? (pendingModelSelection?.id ?? null)
: (session?.modelId ?? null),
currentModelName:
pendingModelSelection !== undefined
? (pendingModelSelection?.name ?? null)
: session?.modelName,
availableModels,
modelsLoading,
modelStatusMessage,
handleModelChange,
selectedProjectId: effectiveProjectId,
availableProjects,
handleProjectChange,
};
}

View file

@ -2,170 +2,78 @@ import { describe, expect, it } from "vitest";
import { findExistingDraft } from "./newChat";
import type { ChatSession } from "../stores/chatSessionStore";
function makeDraft(overrides: Partial<ChatSession> = {}): ChatSession {
function makeSession(
id: string,
overrides: Partial<ChatSession> = {},
): ChatSession {
return {
id: "session-1",
id,
title: "New Chat",
createdAt: "2026-03-31T10:00:00.000Z",
updatedAt: "2026-03-31T10:00:00.000Z",
createdAt: "2026-04-01T00:00:00.000Z",
updatedAt: "2026-04-01T00:00:00.000Z",
messageCount: 0,
draft: true,
...overrides,
};
}
describe("findExistingDraft", () => {
it("reuses the active empty draft session", () => {
const activeDraft = makeDraft({ id: "active-draft" });
const result = findExistingDraft({
sessions: [activeDraft],
activeSessionId: activeDraft.id,
draftsBySession: {},
messagesBySession: {},
request: { title: "New Chat" },
it("reuses a matching project draft with content", () => {
const draft = makeSession("alpha-draft", {
projectId: "alpha",
providerId: "goose",
});
expect(result?.id).toBe(activeDraft.id);
expect(
findExistingDraft({
sessions: [draft],
activeSessionId: null,
draftsBySession: { "alpha-draft": "alpha draft" },
messagesBySession: {},
request: {
title: "New Chat",
projectId: "alpha",
},
}),
).toEqual(draft);
});
it("does not reuse non-draft sessions", () => {
const realSession = makeDraft({ id: "real", draft: undefined });
const result = findExistingDraft({
sessions: [realSession],
activeSessionId: realSession.id,
draftsBySession: {},
messagesBySession: {},
request: { title: "New Chat" },
it("does not reuse a draft from a different project", () => {
const draft = makeSession("alpha-draft", {
projectId: "alpha",
providerId: "goose",
});
expect(result).toBeUndefined();
expect(
findExistingDraft({
sessions: [draft],
activeSessionId: null,
draftsBySession: { "alpha-draft": "alpha draft" },
messagesBySession: {},
request: {
title: "New Chat",
projectId: "beta",
},
}),
).toBeUndefined();
});
it("does not reuse drafts that already have messages", () => {
const result = findExistingDraft({
sessions: [makeDraft({ id: "used-draft", messageCount: 1 })],
activeSessionId: "used-draft",
draftsBySession: {},
messagesBySession: {},
request: { title: "New Chat" },
it("does not reuse an abandoned empty draft", () => {
const draft = makeSession("alpha-draft", {
projectId: "alpha",
providerId: "goose",
});
expect(result).toBeUndefined();
});
it("does not reuse drafts with local in-memory messages", () => {
const result = findExistingDraft({
sessions: [makeDraft({ id: "streaming-draft" })],
activeSessionId: "streaming-draft",
draftsBySession: {},
messagesBySession: {
"streaming-draft": [
{
id: "msg-1",
role: "user",
created: Date.now(),
content: [{ type: "text", text: "hello" }],
},
],
},
request: { title: "New Chat" },
});
expect(result).toBeUndefined();
});
it("only reuses drafts for the same chat context", () => {
const projectDraft = makeDraft({
id: "project-draft",
projectId: "project-1",
});
const result = findExistingDraft({
sessions: [projectDraft],
activeSessionId: projectDraft.id,
draftsBySession: {},
messagesBySession: {},
request: { title: "New Chat", projectId: "project-2" },
});
expect(result).toBeUndefined();
});
it("does not reuse drafts when creating a titled chat", () => {
const result = findExistingDraft({
sessions: [makeDraft()],
activeSessionId: "session-1",
draftsBySession: {},
messagesBySession: {},
request: { title: "What day is it?" },
});
expect(result).toBeUndefined();
});
it("finds a non-active draft with content", () => {
const draftWithContent = makeDraft({ id: "background-draft" });
const result = findExistingDraft({
sessions: [draftWithContent],
activeSessionId: null,
draftsBySession: { "background-draft": "some typed text" },
messagesBySession: {},
request: { title: "New Chat" },
});
expect(result?.id).toBe("background-draft");
});
it("prefers active draft with content over non-active", () => {
const activeDraft = makeDraft({ id: "active-draft" });
const backgroundDraft = makeDraft({ id: "background-draft" });
const result = findExistingDraft({
sessions: [backgroundDraft, activeDraft],
activeSessionId: "active-draft",
draftsBySession: {
"active-draft": "active text",
"background-draft": "background text",
},
messagesBySession: {},
request: { title: "New Chat" },
});
expect(result?.id).toBe("active-draft");
});
it("finds an inactive empty draft when no draft has content", () => {
const inactiveDraft = makeDraft({
id: "inactive-draft",
updatedAt: "2026-03-31T12:00:00.000Z",
});
const result = findExistingDraft({
sessions: [inactiveDraft],
activeSessionId: null,
draftsBySession: {},
messagesBySession: {},
request: { title: "New Chat" },
});
expect(result?.id).toBe("inactive-draft");
});
it("prefers drafts with content over empty drafts", () => {
const emptyDraft = makeDraft({ id: "empty-draft" });
const contentDraft = makeDraft({ id: "content-draft" });
const result = findExistingDraft({
sessions: [emptyDraft, contentDraft],
activeSessionId: "empty-draft",
draftsBySession: { "content-draft": "has text" },
messagesBySession: {},
request: { title: "New Chat" },
});
expect(result?.id).toBe("content-draft");
expect(
findExistingDraft({
sessions: [draft],
activeSessionId: null,
draftsBySession: {},
messagesBySession: {},
request: {
title: "New Chat",
projectId: "alpha",
},
}),
).toBeUndefined();
});
});

View file

@ -5,9 +5,6 @@ import { DEFAULT_CHAT_TITLE } from "./sessionTitle";
interface NewChatRequest {
title: string;
projectId?: string;
agentId?: string;
providerId?: string;
personaId?: string;
}
interface FindExistingDraftArgs {
@ -22,24 +19,14 @@ function isMatchingContext(
session: ChatSession,
request: Omit<NewChatRequest, "title">,
): boolean {
return (
session.projectId === request.projectId &&
session.agentId === request.agentId &&
session.providerId === request.providerId &&
session.personaId === request.personaId
);
return session.projectId === request.projectId;
}
function isReusableDraft(
session: ChatSession,
localMessages: Message[] | undefined,
_localMessages: Message[] | undefined,
): boolean {
return (
!!session.draft &&
!session.archivedAt &&
session.messageCount === 0 &&
(localMessages?.length ?? 0) === 0
);
return !session.archivedAt && session.messageCount === 0;
}
export function findExistingDraft({
@ -64,18 +51,14 @@ export function findExistingDraft({
}
const withContent = candidates.filter(
(s) => (draftsBySession[s.id] ?? "").length > 0,
(session) => (draftsBySession[session.id] ?? "").length > 0,
);
if (withContent.length > 0) {
return withContent.find((s) => s.id === activeSessionId) ?? withContent[0];
return (
withContent.find((session) => session.id === activeSessionId) ??
withContent[0]
);
}
const active = candidates.find((s) => s.id === activeSessionId);
if (active) {
return active;
}
return candidates.sort(
(a, b) => new Date(b.updatedAt).getTime() - new Date(a.updatedAt).getTime(),
)[0];
return candidates.find((session) => session.id === activeSessionId);
}

View file

@ -1,9 +1,7 @@
import type { ModelOption } from "../types";
const ACP_SESSION_METADATA_STORAGE_KEY = "goose:acp-session-metadata";
const DRAFT_SESSION_STORAGE_KEY = "goose:chat-draft-sessions";
const LEGACY_SESSION_CACHE_STORAGE_KEY = "goose:chat-sessions";
const DRAFT_TEXT_STORAGE_KEY = "goose:chat-drafts";
export interface SessionMetadataOverlayRecord {
sessionId: string;
@ -22,7 +20,7 @@ export interface SessionMetadataOverlayRecord {
updatedAt: string;
}
export interface DraftSessionRecord {
interface LegacySessionRecord {
id: string;
acpSessionId?: string;
title: string;
@ -36,7 +34,7 @@ export interface DraftSessionRecord {
updatedAt: string;
archivedAt?: string;
messageCount: number;
draft?: true;
draft?: boolean;
userSetName?: boolean;
}
@ -65,32 +63,14 @@ function persistStorageArray<T>(storageKey: string, records: T[]): void {
}
}
function draftsWithText(): Set<string> {
if (typeof window === "undefined") return new Set();
try {
const stored = window.localStorage.getItem(DRAFT_TEXT_STORAGE_KEY);
if (!stored) return new Set();
const parsed = JSON.parse(stored);
return new Set(
Object.entries(parsed)
.filter(([, value]) => typeof value === "string" && value.length > 0)
.map(([sessionId]) => sessionId),
);
} catch {
return new Set();
}
}
function loadLegacySessions(): Array<
DraftSessionRecord & {
userSetName?: boolean;
}
> {
return parseStorageArray(LEGACY_SESSION_CACHE_STORAGE_KEY);
function loadLegacySessions(): LegacySessionRecord[] {
return parseStorageArray<LegacySessionRecord>(
LEGACY_SESSION_CACHE_STORAGE_KEY,
);
}
function recordFromLegacySession(
session: DraftSessionRecord & { userSetName?: boolean },
session: LegacySessionRecord,
): SessionMetadataOverlayRecord {
return {
sessionId: session.acpSessionId ?? session.id,
@ -138,49 +118,6 @@ export function persistSessionMetadataOverlay(
);
}
export function loadDraftSessionRecords(): DraftSessionRecord[] {
const records = parseStorageArray<DraftSessionRecord>(
DRAFT_SESSION_STORAGE_KEY,
);
const drafts = new Map(records.map((record) => [record.id, record]));
for (const session of loadLegacySessions()) {
if (!session.draft) continue;
if (drafts.has(session.id)) continue;
drafts.set(session.id, session);
}
return [...drafts.values()];
}
export function persistDraftSessionRecords(
records: DraftSessionRecord[],
): void {
const withText = draftsWithText();
persistStorageArray(
DRAFT_SESSION_STORAGE_KEY,
records.filter((record) => withText.has(record.id)),
);
}
export function migrateSessionMetadataOverlayId(
previousId: string,
nextId: string,
): void {
if (!previousId || !nextId || previousId === nextId) return;
const overlays = loadSessionMetadataOverlay();
const previous = overlays.get(previousId);
if (!previous) return;
const existing = overlays.get(nextId);
overlays.set(nextId, {
...previous,
...existing,
sessionId: nextId,
});
overlays.delete(previousId);
persistSessionMetadataOverlay(overlays.values());
}
export function upsertSessionMetadataOverlayRecord(
record: SessionMetadataOverlayRecord,
): void {
@ -195,21 +132,6 @@ export function removeSessionMetadataOverlayRecord(sessionId: string): void {
persistSessionMetadataOverlay(overlays.values());
}
export function persistDraftSessionRecord(record: DraftSessionRecord): void {
const drafts = loadDraftSessionRecords();
const nextDrafts = [
...drafts.filter((draft) => draft.id !== record.id),
record,
];
persistDraftSessionRecords(nextDrafts);
}
export function removeDraftSessionRecord(sessionId: string): void {
persistDraftSessionRecords(
loadDraftSessionRecords().filter((record) => record.id !== sessionId),
);
}
export function modelIdsMatch(
cached: ModelOption[] | undefined,
next: ModelOption[],

View file

@ -1,148 +1,97 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { AcpSessionInfo } from "@/shared/api/acp";
import { useChatSessionStore } from "../chatSessionStore";
import { useChatSessionStore, type ChatSession } from "../chatSessionStore";
const mockAcpCreateSession = vi.fn();
const mockAcpListSessions = vi.fn();
vi.mock("@/shared/api/acp", () => ({
acpListSessions: vi.fn(),
acpCreateSession: (...args: unknown[]) => mockAcpCreateSession(...args),
acpListSessions: (...args: unknown[]) => mockAcpListSessions(...args),
}));
import { acpListSessions } from "@/shared/api/acp";
const mockedAcpListSessions = vi.mocked(acpListSessions);
const LEGACY_SESSION_CACHE_KEY = "goose:chat-sessions";
const OVERLAY_CACHE_KEY = "goose:acp-session-metadata";
const DRAFT_SESSION_CACHE_KEY = "goose:chat-draft-sessions";
function resetStore() {
useChatSessionStore.setState({
sessions: [],
activeSessionId: null,
isLoading: false,
hasHydratedSessions: false,
contextPanelOpenBySession: {},
activeWorkspaceBySession: {},
modelsBySession: {},
modelCacheByProvider: {},
});
}
function makeSession(overrides: Partial<ChatSession> = {}): ChatSession {
return {
id: "session-1",
acpSessionId: "session-1",
title: "Test Session",
createdAt: "2026-04-01T00:00:00.000Z",
updatedAt: "2026-04-01T00:00:00.000Z",
messageCount: 0,
...overrides,
};
}
function seedSession(overrides: Partial<ChatSession> = {}): ChatSession {
const session = makeSession(overrides);
useChatSessionStore.getState().addSession(session);
return session;
}
describe("chatSessionStore", () => {
beforeEach(() => {
resetStore();
window.localStorage.removeItem(LEGACY_SESSION_CACHE_KEY);
window.localStorage.removeItem(OVERLAY_CACHE_KEY);
window.localStorage.removeItem(DRAFT_SESSION_CACHE_KEY);
vi.clearAllMocks();
});
afterEach(() => {
window.localStorage.removeItem(LEGACY_SESSION_CACHE_KEY);
window.localStorage.removeItem(OVERLAY_CACHE_KEY);
window.localStorage.removeItem(DRAFT_SESSION_CACHE_KEY);
});
describe("createDraftSession", () => {
it("creates a draft session with default title", () => {
const session = useChatSessionStore.getState().createDraftSession();
describe("createSession", () => {
it("creates a real ACP-backed session", async () => {
mockAcpCreateSession.mockResolvedValue({ sessionId: "acp-1" });
expect(session.title).toBe("New Chat");
expect(session.draft).toBe(true);
expect(session.messageCount).toBe(0);
expect(useChatSessionStore.getState().sessions).toContainEqual(session);
});
it("creates a draft session with custom options", () => {
const session = useChatSessionStore.getState().createDraftSession({
title: "My Custom Chat",
projectId: "proj-1",
const session = await useChatSessionStore.getState().createSession({
title: "New Chat",
providerId: "openai",
personaId: "persona-1",
modelId: "gpt-4.1",
modelName: "GPT-4.1",
workingDir: "/tmp/project",
});
expect(session.title).toBe("My Custom Chat");
expect(session.projectId).toBe("proj-1");
expect(session.providerId).toBe("openai");
expect(session.personaId).toBe("persona-1");
expect(session.draft).toBe(true);
});
});
describe("promoteDraft", () => {
it("removes draft flag from a draft session", () => {
const session = useChatSessionStore.getState().createDraftSession();
expect(session.draft).toBe(true);
useChatSessionStore.getState().promoteDraft(session.id);
const updated = useChatSessionStore
.getState()
.sessions.find((s) => s.id === session.id);
expect(updated?.draft).toBeUndefined();
});
it("does nothing for non-draft sessions", () => {
useChatSessionStore.setState({
sessions: [
{
id: "non-draft",
title: "Regular Session",
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messageCount: 5,
},
],
expect(mockAcpCreateSession).toHaveBeenCalledWith(
"openai",
"/tmp/project",
{
personaId: "persona-1",
modelId: "gpt-4.1",
},
);
expect(session).toMatchObject({
id: "acp-1",
acpSessionId: "acp-1",
title: "New Chat",
providerId: "openai",
personaId: "persona-1",
modelId: "gpt-4.1",
modelName: "GPT-4.1",
});
useChatSessionStore.getState().promoteDraft("non-draft");
const session = useChatSessionStore
.getState()
.sessions.find((s) => s.id === "non-draft");
expect(session?.draft).toBeUndefined();
});
});
describe("removeDraft", () => {
it("removes a draft session", () => {
const session = useChatSessionStore.getState().createDraftSession();
expect(useChatSessionStore.getState().sessions).toHaveLength(1);
useChatSessionStore.getState().removeDraft(session.id);
expect(useChatSessionStore.getState().sessions).toHaveLength(0);
});
it("clears activeSessionId if removing the active draft", () => {
const session = useChatSessionStore.getState().createDraftSession();
useChatSessionStore.getState().setActiveSession(session.id);
useChatSessionStore.getState().removeDraft(session.id);
expect(useChatSessionStore.getState().activeSessionId).toBeNull();
});
it("does not remove non-draft sessions", () => {
useChatSessionStore.setState({
sessions: [
{
id: "non-draft",
title: "Regular Session",
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messageCount: 5,
},
],
});
useChatSessionStore.getState().removeDraft("non-draft");
expect(useChatSessionStore.getState().sessions).toHaveLength(1);
expect(useChatSessionStore.getState().sessions).toContainEqual(session);
});
});
describe("loadSessions", () => {
it("loads sessions from ACP and maps them correctly", async () => {
mockedAcpListSessions.mockResolvedValue([
mockAcpListSessions.mockResolvedValue([
{
sessionId: "acp-1",
title: "ACP Session 1",
@ -161,64 +110,14 @@ describe("chatSessionStore", () => {
const sessions = useChatSessionStore.getState().sessions;
expect(sessions).toHaveLength(2);
expect(sessions[0].id).toBe("acp-2"); // Most recent first
expect(sessions[0].title).toBe("Untitled"); // null title becomes "Untitled"
expect(sessions[0].id).toBe("acp-2");
expect(sessions[0].title).toBe("Untitled");
expect(sessions[0].messageCount).toBe(7);
expect(sessions[1].id).toBe("acp-1");
expect(sessions[1].title).toBe("ACP Session 1");
expect(sessions[1].messageCount).toBe(4);
});
it("preserves local drafts alongside ACP sessions", async () => {
const draft = useChatSessionStore.getState().createDraftSession({
title: "My Draft",
});
mockedAcpListSessions.mockResolvedValue([
{
sessionId: "acp-1",
title: "ACP Session",
updatedAt: "2026-04-01",
messageCount: 3,
},
]);
await useChatSessionStore.getState().loadSessions();
const sessions = useChatSessionStore.getState().sessions;
expect(sessions).toHaveLength(2);
expect(sessions.find((s) => s.id === "acp-1")).toBeDefined();
expect(sessions.find((s) => s.id === draft.id)).toBeDefined();
});
it("migrates promoted draft metadata onto the resolved ACP session id", async () => {
const draft = useChatSessionStore.getState().createDraftSession({
title: "Project Draft",
projectId: "project-123",
providerId: "goose",
});
useChatSessionStore.getState().promoteDraft(draft.id);
useChatSessionStore.getState().setSessionAcpId(draft.id, "acp-1");
mockedAcpListSessions.mockResolvedValue([
{
sessionId: "acp-1",
title: "ACP Session",
updatedAt: "2026-04-02",
messageCount: 3,
},
]);
await useChatSessionStore.getState().loadSessions();
const session = useChatSessionStore.getState().sessions[0];
expect(session.id).toBe("acp-1");
expect(session.acpSessionId).toBe("acp-1");
expect(session.projectId).toBe("project-123");
expect(session.providerId).toBe("goose");
});
it("rehydrates cached project metadata for ACP sessions", async () => {
window.localStorage.setItem(
LEGACY_SESSION_CACHE_KEY,
@ -237,7 +136,7 @@ describe("chatSessionStore", () => {
]),
);
mockedAcpListSessions.mockResolvedValue([
mockAcpListSessions.mockResolvedValue([
{
sessionId: "acp-1",
title: null,
@ -259,21 +158,37 @@ describe("chatSessionStore", () => {
expect(session.userSetName).toBe(true);
});
it("drops stale non-draft sessions that are no longer in ACP", async () => {
useChatSessionStore.setState({
sessions: [
it("ignores legacy draft records while hydrating overlays", async () => {
window.localStorage.setItem(
LEGACY_SESSION_CACHE_KEY,
JSON.stringify([
{
id: "stale-session",
title: "Stale Session",
id: "cached-draft",
title: "Cached Draft",
draft: true,
createdAt: "2026-04-01",
updatedAt: "2026-04-01",
messageCount: 2,
messageCount: 0,
},
]),
);
mockAcpListSessions.mockResolvedValue([]);
await useChatSessionStore.getState().loadSessions();
expect(useChatSessionStore.getState().sessions).toEqual([]);
});
it("drops stale sessions that are no longer in ACP", async () => {
useChatSessionStore.setState({
sessions: [
makeSession({ id: "stale-session", title: "Stale Session" }),
],
activeSessionId: "stale-session",
});
mockedAcpListSessions.mockResolvedValue([
mockAcpListSessions.mockResolvedValue([
{
sessionId: "acp-1",
title: "ACP Session",
@ -292,7 +207,7 @@ describe("chatSessionStore", () => {
it("sets isLoading during fetch", async () => {
let resolvePromise: (value: AcpSessionInfo[]) => void = () => {};
mockedAcpListSessions.mockReturnValue(
mockAcpListSessions.mockReturnValue(
new Promise((resolve) => {
resolvePromise = resolve;
}),
@ -300,11 +215,13 @@ describe("chatSessionStore", () => {
const loadPromise = useChatSessionStore.getState().loadSessions();
expect(useChatSessionStore.getState().isLoading).toBe(true);
expect(useChatSessionStore.getState().hasHydratedSessions).toBe(false);
resolvePromise([]);
await loadPromise;
expect(useChatSessionStore.getState().isLoading).toBe(false);
expect(useChatSessionStore.getState().hasHydratedSessions).toBe(true);
});
it("falls back to cached sessions on error", async () => {
@ -315,56 +232,51 @@ describe("chatSessionStore", () => {
id: "cached-session",
title: "Cached Session",
projectId: "project-123",
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
createdAt: "2026-04-01T00:00:00.000Z",
updatedAt: "2026-04-01T00:00:00.000Z",
messageCount: 8,
},
{
id: "cached-draft",
title: "Cached Draft",
draft: true,
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
createdAt: "2026-04-01T00:00:00.000Z",
updatedAt: "2026-04-01T00:00:00.000Z",
messageCount: 0,
},
]),
);
mockedAcpListSessions.mockRejectedValue(new Error("Network error"));
mockAcpListSessions.mockRejectedValue(new Error("Network error"));
await useChatSessionStore.getState().loadSessions();
const sessions = useChatSessionStore.getState().sessions;
expect(sessions).toHaveLength(2);
expect(
sessions.find((session) => session.id === "cached-session"),
).toMatchObject({
expect(sessions).toHaveLength(1);
expect(sessions[0]).toMatchObject({
id: "cached-session",
projectId: "project-123",
});
expect(
sessions.find((session) => session.id === "cached-draft")?.draft,
).toBe(true);
expect(useChatSessionStore.getState().hasHydratedSessions).toBe(true);
});
});
describe("updateSession", () => {
it("updates session properties", () => {
const session = useChatSessionStore.getState().createDraftSession();
const session = seedSession();
useChatSessionStore.getState().updateSession(session.id, {
title: "Updated Title",
projectId: "new-project",
});
const updated = useChatSessionStore
.getState()
.sessions.find((s) => s.id === session.id);
const updated = useChatSessionStore.getState().getSession(session.id);
expect(updated?.title).toBe("Updated Title");
expect(updated?.projectId).toBe("new-project");
});
it("preserves updatedAt when not explicitly provided in patch", () => {
const session = useChatSessionStore.getState().createDraftSession();
const session = seedSession();
const originalUpdatedAt = session.updatedAt;
vi.useFakeTimers();
@ -376,14 +288,12 @@ describe("chatSessionStore", () => {
vi.useRealTimers();
const updated = useChatSessionStore
.getState()
.sessions.find((s) => s.id === session.id);
const updated = useChatSessionStore.getState().getSession(session.id);
expect(updated?.updatedAt).toBe(originalUpdatedAt);
});
it("updates updatedAt when explicitly provided in patch", () => {
const session = useChatSessionStore.getState().createDraftSession();
const session = seedSession();
const originalUpdatedAt = session.updatedAt;
vi.useFakeTimers();
@ -397,76 +307,43 @@ describe("chatSessionStore", () => {
vi.useRealTimers();
const updated = useChatSessionStore
.getState()
.sessions.find((s) => s.id === session.id);
const updated = useChatSessionStore.getState().getSession(session.id);
expect(updated?.updatedAt).not.toBe(originalUpdatedAt);
expect(updated?.updatedAt).toBe(newTimestamp);
});
});
describe("session models", () => {
it("stores models per session", () => {
const session = useChatSessionStore.getState().createDraftSession();
describe("provider switching", () => {
it("clears the selected model when switching providers", () => {
const session = seedSession({
providerId: "openai",
modelId: "gpt-4o",
modelName: "GPT-4o",
});
useChatSessionStore.getState().setSessionModels(session.id, [
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
{ id: "gpt-4o", name: "GPT-4o" },
]);
expect(
useChatSessionStore.getState().getSessionModels(session.id),
).toEqual([
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
{ id: "gpt-4o", name: "GPT-4o" },
]);
});
it("removes stored models when a draft session is removed", () => {
const session = useChatSessionStore.getState().createDraftSession();
useChatSessionStore
.getState()
.setSessionModels(session.id, [
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
]);
.switchSessionProvider(session.id, "anthropic");
useChatSessionStore.getState().removeDraft(session.id);
expect(
useChatSessionStore.getState().getSessionModels(session.id),
).toEqual([]);
});
it("removes stored models when a session is archived", async () => {
const session = useChatSessionStore.getState().createDraftSession();
useChatSessionStore
.getState()
.setSessionModels(session.id, [
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
]);
await useChatSessionStore.getState().archiveSession(session.id);
expect(
useChatSessionStore.getState().getSessionModels(session.id),
).toEqual([]);
const updated = useChatSessionStore.getState().getSession(session.id);
expect(updated?.providerId).toBe("anthropic");
expect(updated?.modelId).toBeUndefined();
expect(updated?.modelName).toBeUndefined();
});
});
describe("archiveSession", () => {
it("sets archivedAt on the session", async () => {
const session = useChatSessionStore.getState().createDraftSession();
const session = seedSession();
await useChatSessionStore.getState().archiveSession(session.id);
const archived = useChatSessionStore
.getState()
.sessions.find((s) => s.id === session.id);
const archived = useChatSessionStore.getState().getSession(session.id);
expect(archived?.archivedAt).toBeDefined();
});
it("clears activeSessionId if archiving the active session", async () => {
const session = useChatSessionStore.getState().createDraftSession();
const session = seedSession();
useChatSessionStore.getState().setActiveSession(session.id);
await useChatSessionStore.getState().archiveSession(session.id);
@ -478,13 +355,14 @@ describe("chatSessionStore", () => {
describe("addSession", () => {
it("prepends a new session to the list", () => {
const { addSession } = useChatSessionStore.getState();
addSession({
id: "imported-1",
title: "Imported Session",
createdAt: "2026-01-01T00:00:00Z",
updatedAt: "2026-01-01T00:00:00Z",
messageCount: 5,
});
addSession(
makeSession({
id: "imported-1",
title: "Imported Session",
messageCount: 5,
}),
);
const sessions = useChatSessionStore.getState().sessions;
expect(sessions[0].id).toBe("imported-1");
expect(sessions[0].title).toBe("Imported Session");
@ -493,22 +371,13 @@ describe("chatSessionStore", () => {
it("does not create a duplicate if session ID already exists", () => {
const { addSession } = useChatSessionStore.getState();
addSession({
id: "dup-1",
title: "First",
createdAt: "2026-01-01T00:00:00Z",
updatedAt: "2026-01-01T00:00:00Z",
messageCount: 1,
});
addSession({
id: "dup-1",
title: "Second",
createdAt: "2026-01-01T00:00:00Z",
updatedAt: "2026-01-01T00:00:00Z",
messageCount: 2,
});
addSession(makeSession({ id: "dup-1", title: "First", messageCount: 1 }));
addSession(
makeSession({ id: "dup-1", title: "Second", messageCount: 2 }),
);
const sessions = useChatSessionStore.getState().sessions;
const matches = sessions.filter((s) => s.id === "dup-1");
const matches = sessions.filter((session) => session.id === "dup-1");
expect(matches).toHaveLength(1);
expect(matches[0].title).toBe("Second");
});

View file

@ -1,23 +1,17 @@
import { create } from "zustand";
import { acpListSessions, type AcpSessionInfo } from "@/shared/api/acp";
import {
acpCreateSession,
acpListSessions,
type AcpSessionInfo,
} from "@/shared/api/acp";
import type { Session } from "@/shared/types/chat";
import { DEFAULT_CHAT_TITLE } from "@/features/chat/lib/sessionTitle";
import {
loadDraftSessionRecords,
loadSessionMetadataOverlay,
migrateSessionMetadataOverlayId,
modelIdsMatch,
persistSessionMetadataOverlay,
persistDraftSessionRecord,
persistDraftSessionRecords,
removeDraftSessionRecord,
upsertSessionMetadataOverlayRecord,
type DraftSessionRecord,
type SessionMetadataOverlayRecord,
} from "@/features/chat/lib/sessionMetadataOverlay";
import type { ModelOption } from "../types";
const EMPTY_MODELS: ModelOption[] = [];
export interface ChatSession {
id: string;
@ -33,7 +27,6 @@ export interface ChatSession {
updatedAt: string;
archivedAt?: string;
messageCount: number;
draft?: boolean;
userSetName?: boolean;
}
@ -42,14 +35,31 @@ export interface ActiveWorkspace {
branch: string | null;
}
export function hasSessionStarted(
session: Pick<ChatSession, "messageCount">,
localMessages?: ArrayLike<unknown>,
): boolean {
return session.messageCount > 0 || (localMessages?.length ?? 0) > 0;
}
export function getVisibleSessions<
T extends Pick<ChatSession, "id" | "messageCount">,
>(
sessions: T[],
messagesBySession: Record<string, ArrayLike<unknown> | undefined>,
): T[] {
return sessions.filter((session) =>
hasSessionStarted(session, messagesBySession[session.id]),
);
}
interface ChatSessionStoreState {
sessions: ChatSession[];
activeSessionId: string | null;
isLoading: boolean;
hasHydratedSessions: boolean;
contextPanelOpenBySession: Record<string, boolean>;
activeWorkspaceBySession: Record<string, ActiveWorkspace>;
modelsBySession: Record<string, ModelOption[]>;
modelCacheByProvider: Record<string, ModelOption[]>;
}
interface CreateSessionOpts {
@ -58,18 +68,17 @@ interface CreateSessionOpts {
agentId?: string;
providerId?: string;
personaId?: string;
workingDir?: string;
modelId?: string;
modelName?: string;
}
interface UpdateSessionOptions {
localOnly?: boolean;
persistOverlay?: boolean;
}
interface ChatSessionStoreActions {
createSession: (opts?: CreateSessionOpts) => Promise<ChatSession>;
createDraftSession: (opts?: CreateSessionOpts) => ChatSession;
promoteDraft: (id: string) => void;
removeDraft: (id: string) => void;
loadSessions: () => Promise<void>;
updateSession: (
id: string,
@ -79,74 +88,20 @@ interface ChatSessionStoreActions {
addSession: (session: ChatSession) => void;
archiveSession: (id: string) => Promise<void>;
unarchiveSession: (id: string) => Promise<void>;
setSessionAcpId: (id: string, acpSessionId: string) => void;
setActiveSession: (sessionId: string | null) => void;
setContextPanelOpen: (sessionId: string, open: boolean) => void;
setActiveWorkspace: (sessionId: string, context: ActiveWorkspace) => void;
clearActiveWorkspace: (sessionId: string) => void;
setSessionModels: (sessionId: string, models: ModelOption[]) => void;
switchSessionProvider: (
sessionId: string,
providerId: string,
models: ModelOption[],
) => void;
cacheModelsForProvider: (providerId: string, models: ModelOption[]) => void;
getCachedModels: (providerId: string) => ModelOption[];
switchSessionProvider: (sessionId: string, providerId: string) => void;
getSession: (id: string) => ChatSession | undefined;
getActiveSession: () => ChatSession | null;
getArchivedSessions: () => ChatSession[];
getSessionModels: (sessionId: string) => ModelOption[];
}
export type ChatSessionStore = ChatSessionStoreState & ChatSessionStoreActions;
const MODEL_CACHE_STORAGE_KEY = "goose:model-cache";
function loadModelCache(): Record<string, ModelOption[]> {
if (typeof window === "undefined") return {};
try {
const stored = window.localStorage.getItem(MODEL_CACHE_STORAGE_KEY);
if (!stored) return {};
const parsed = JSON.parse(stored);
return typeof parsed === "object" && parsed !== null
? (parsed as Record<string, ModelOption[]>)
: {};
} catch {
return {};
}
}
function persistModelCache(cache: Record<string, ModelOption[]>): void {
if (typeof window === "undefined") return;
try {
window.localStorage.setItem(MODEL_CACHE_STORAGE_KEY, JSON.stringify(cache));
} catch {
// localStorage may be unavailable
}
}
function draftSessionToRecord(session: ChatSession): DraftSessionRecord {
return {
id: session.id,
acpSessionId: session.acpSessionId,
title: session.title,
projectId: session.projectId,
agentId: session.agentId,
providerId: session.providerId,
personaId: session.personaId,
modelId: session.modelId,
modelName: session.modelName,
createdAt: session.createdAt,
updatedAt: session.updatedAt,
archivedAt: session.archivedAt,
messageCount: session.messageCount,
draft: true,
userSetName: session.userSetName,
};
}
function overlayKeyForSession(
session: Pick<ChatSession, "id" | "acpSessionId">,
) {
@ -225,40 +180,12 @@ function mergeAcpSessionWithOverlay(
};
}
function mergeDraftSessions(
currentDrafts: ChatSession[],
persistedDrafts: DraftSessionRecord[],
): ChatSession[] {
const draftsById = new Map<string, ChatSession>(
persistedDrafts.map((record) => [
record.id,
{
...record,
draft: true,
} satisfies ChatSession,
]),
);
for (const draft of currentDrafts) {
draftsById.set(draft.id, draft);
}
return [...draftsById.values()];
}
function persistDraftsFromSessions(sessions: ChatSession[]): void {
persistDraftSessionRecords(
sessions.filter((session) => session.draft).map(draftSessionToRecord),
);
}
function syncOverlaySnapshots(
sessions: ChatSession[],
existingOverlays = loadSessionMetadataOverlay(),
): void {
const overlays = new Map(existingOverlays);
for (const session of sessions) {
if (session.draft) continue;
overlays.set(
overlayKeyForSession(session),
buildOverlayRecord(
@ -299,81 +226,55 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
sessions: [],
activeSessionId: null,
isLoading: false,
hasHydratedSessions: false,
contextPanelOpenBySession: {},
activeWorkspaceBySession: {},
modelsBySession: {},
modelCacheByProvider: loadModelCache(),
createSession: async (_opts) => {
throw new Error(
"createSession not yet wired to ACP — use createDraftSession",
);
},
createDraftSession: (opts) => {
createSession: async (opts) => {
if (!opts?.workingDir) {
throw new Error("createSession requires a working directory");
}
const now = new Date().toISOString();
const providerId = opts.providerId ?? "goose";
const { sessionId } = await acpCreateSession(providerId, opts.workingDir, {
personaId: opts.personaId,
modelId: opts.modelId,
});
const chatSession: ChatSession = {
id: crypto.randomUUID(),
title: opts?.title ?? DEFAULT_CHAT_TITLE,
projectId: opts?.projectId,
agentId: opts?.agentId,
providerId: opts?.providerId,
personaId: opts?.personaId,
id: sessionId,
acpSessionId: sessionId,
title: opts.title ?? DEFAULT_CHAT_TITLE,
projectId: opts.projectId,
agentId: opts.agentId,
providerId,
personaId: opts.personaId,
modelId: opts.modelId,
modelName: opts.modelName,
createdAt: now,
updatedAt: now,
messageCount: 0,
draft: true,
};
set((state) => ({ sessions: [...state.sessions, chatSession] }));
persistDraftSessionRecord(draftSessionToRecord(chatSession));
set((state) => ({ sessions: [chatSession, ...state.sessions] }));
const existing = loadSessionMetadataOverlay().get(
overlayKeyForSession(chatSession),
);
upsertSessionMetadataOverlayRecord(
buildOverlayRecord(chatSession, existing),
);
return chatSession;
},
promoteDraft: (id) => {
const session = get().sessions.find((candidate) => candidate.id === id);
if (!session?.draft) return;
set((state) => ({
sessions: state.sessions.map((candidate) =>
candidate.id === id ? { ...candidate, draft: undefined } : candidate,
),
}));
persistDraftsFromSessions(get().sessions);
},
removeDraft: (id) => {
const session = get().sessions.find((candidate) => candidate.id === id);
if (!session?.draft) return;
const { [id]: _ignoredPanelState, ...remainingPanelState } =
get().contextPanelOpenBySession;
const { [id]: _ignoredContext, ...remainingContextState } =
get().activeWorkspaceBySession;
const remainingModels = { ...get().modelsBySession };
delete remainingModels[id];
set((state) => ({
sessions: state.sessions.filter((candidate) => candidate.id !== id),
activeSessionId:
state.activeSessionId === id ? null : state.activeSessionId,
contextPanelOpenBySession: remainingPanelState,
activeWorkspaceBySession: remainingContextState,
modelsBySession: remainingModels,
}));
removeDraftSessionRecord(id);
},
loadSessions: async () => {
set({ isLoading: true });
try {
const overlays = loadSessionMetadataOverlay();
const acpSessions = await acpListSessions();
const persistedDrafts = loadDraftSessionRecords();
const currentDrafts = get().sessions.filter((session) => session.draft);
const drafts = mergeDraftSessions(currentDrafts, persistedDrafts);
const mergedAcpSessions = sortByUpdatedAtDesc(
acpSessions.map((session) =>
mergeAcpSessionWithOverlay(session, overlays.get(session.sessionId)),
),
);
const merged = [...mergedAcpSessions, ...drafts];
const merged = mergedAcpSessions;
const activeSessionId = get().activeSessionId;
const activeSessionStillExists =
activeSessionId == null ||
@ -382,22 +283,16 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
sessions: merged,
activeSessionId: activeSessionStillExists ? activeSessionId : null,
});
persistDraftsFromSessions(merged);
syncOverlaySnapshots(mergedAcpSessions, overlays);
} catch (error) {
console.error("Failed to load sessions from ACP:", error);
const overlays = loadSessionMetadataOverlay();
const persistedDrafts = loadDraftSessionRecords();
const currentDrafts = get().sessions.filter((session) => session.draft);
const drafts = mergeDraftSessions(currentDrafts, persistedDrafts);
const fallbackSessions = sortByUpdatedAtDesc([
...[...overlays.values()].map(overlayToFallbackSession),
...drafts,
]);
const fallbackSessions = sortByUpdatedAtDesc(
[...overlays.values()].map(overlayToFallbackSession),
);
set({ sessions: fallbackSessions });
persistDraftsFromSessions(fallbackSessions);
} finally {
set({ isLoading: false });
set({ isLoading: false, hasHydratedSessions: true });
}
},
@ -415,23 +310,15 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
}));
const updatedSession = get().sessions.find((session) => session.id === id);
persistDraftsFromSessions(get().sessions);
if (
updatedSession &&
!updatedSession.draft &&
opts?.persistOverlay !== false
) {
if (updatedSession && opts?.persistOverlay !== false) {
const key = overlayKeyForSession(updatedSession);
const existing = loadSessionMetadataOverlay().get(key);
upsertSessionMetadataOverlayRecord(
buildOverlayRecord(updatedSession, existing),
);
}
if (opts?.localOnly) return;
if (updatedSession?.draft) return;
// TODO: wire non-draft updates to ACP when supported
// TODO: wire session updates to ACP when supported
},
addSession: (session) => {
@ -450,20 +337,15 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
}
return { sessions: [normalizedSession, ...state.sessions] };
});
persistDraftsFromSessions(get().sessions);
if (!normalizedSession.draft) {
const existing = loadSessionMetadataOverlay().get(
overlayKeyForSession(normalizedSession),
);
upsertSessionMetadataOverlayRecord(
buildOverlayRecord(normalizedSession, existing),
);
}
const existing = loadSessionMetadataOverlay().get(
overlayKeyForSession(normalizedSession),
);
upsertSessionMetadataOverlayRecord(
buildOverlayRecord(normalizedSession, existing),
);
},
archiveSession: async (id) => {
const remainingModels = { ...get().modelsBySession };
delete remainingModels[id];
set((state) => ({
sessions: state.sessions.map((session) =>
session.id === id
@ -472,11 +354,9 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
),
activeSessionId:
state.activeSessionId === id ? null : state.activeSessionId,
modelsBySession: remainingModels,
}));
persistDraftsFromSessions(get().sessions);
const session = get().sessions.find((candidate) => candidate.id === id);
if (session && !session.draft) {
if (session) {
const existing = loadSessionMetadataOverlay().get(
overlayKeyForSession(session),
);
@ -490,9 +370,8 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
session.id === id ? { ...session, archivedAt: undefined } : session,
),
}));
persistDraftsFromSessions(get().sessions);
const session = get().sessions.find((candidate) => candidate.id === id);
if (session && !session.draft) {
if (session) {
const existing = loadSessionMetadataOverlay().get(
overlayKeyForSession(session),
);
@ -500,36 +379,6 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
}
},
setSessionAcpId: (id, acpSessionId) => {
if (!acpSessionId) return;
const session = get().sessions.find((candidate) => candidate.id === id);
if (!session || session.acpSessionId === acpSessionId) return;
set((state) => ({
sessions: state.sessions.map((candidate) =>
candidate.id === id ? { ...candidate, acpSessionId } : candidate,
),
}));
if (!session.draft) {
migrateSessionMetadataOverlayId(
overlayKeyForSession(session),
acpSessionId,
);
const updatedSession = get().sessions.find(
(candidate) => candidate.id === id,
);
if (updatedSession) {
const existing = loadSessionMetadataOverlay().get(acpSessionId);
upsertSessionMetadataOverlayRecord(
buildOverlayRecord(updatedSession, existing),
);
}
} else {
persistDraftsFromSessions(get().sessions);
}
},
setActiveSession: (sessionId) => {
if (get().activeSessionId === sessionId) return;
set({ activeSessionId: sessionId });
@ -560,41 +409,24 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
});
},
setSessionModels: (sessionId, models) => {
switchSessionProvider: (sessionId, providerId) => {
set((state) => ({
modelsBySession: {
...state.modelsBySession,
[sessionId]: models,
},
}));
},
switchSessionProvider: (sessionId, providerId, models) => {
set((state) => ({
modelsBySession: {
...state.modelsBySession,
[sessionId]: models,
},
sessions: state.sessions.map((session) =>
session.id === sessionId
? {
...session,
providerId,
modelId: models.length > 0 ? models[0].id : undefined,
modelName:
models.length > 0
? (models[0].displayName ?? models[0].name)
: undefined,
modelId: undefined,
modelName: undefined,
updatedAt: session.updatedAt,
}
: session,
),
}));
persistDraftsFromSessions(get().sessions);
const session = get().sessions.find(
(candidate) => candidate.id === sessionId,
);
if (session && !session.draft) {
if (session) {
const existing = loadSessionMetadataOverlay().get(
overlayKeyForSession(session),
);
@ -602,25 +434,6 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
}
},
cacheModelsForProvider: (providerId, models) => {
if (models.length === 0) return;
const existing = get().modelCacheByProvider[providerId];
if (modelIdsMatch(existing, models)) {
return;
}
set((state) => {
const updated = {
...state.modelCacheByProvider,
[providerId]: models,
};
persistModelCache(updated);
return { modelCacheByProvider: updated };
});
},
getCachedModels: (providerId) =>
get().modelCacheByProvider[providerId] ?? EMPTY_MODELS,
getSession: (id) => get().sessions.find((session) => session.id === id),
getActiveSession: () => {
@ -631,7 +444,4 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
getArchivedSessions: () =>
get().sessions.filter((session) => !!session.archivedAt),
getSessionModels: (sessionId) =>
get().modelsBySession[sessionId] ?? EMPTY_MODELS,
}));

View file

@ -1,6 +1,62 @@
import type { AcpProvider } from "@/shared/api/acp";
import type { Persona } from "@/shared/types/agents";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
export interface ModelOption {
id: string;
name: string;
displayName?: string;
provider?: string;
providerId?: string;
providerName?: string;
/** Whether this model should appear in the compact recommended picker. */
recommended?: boolean;
}
export interface ProjectOption {
id: string;
name: string;
workingDirs: string[];
color?: string | null;
}
export interface ChatInputProps {
onSend: (
text: string,
personaId?: string,
attachments?: ChatAttachmentDraft[],
) => void;
onStop?: () => void;
isStreaming?: boolean;
disabled?: boolean;
queuedMessage?: { text: string } | null;
onDismissQueue?: () => void;
initialValue?: string;
onDraftChange?: (text: string) => void;
className?: string;
personas?: Persona[];
selectedPersonaId?: string | null;
onPersonaChange?: (personaId: string | null) => void;
onCreatePersona?: () => void;
providers?: AcpProvider[];
providersLoading?: boolean;
selectedProvider?: string;
onProviderChange?: (providerId: string) => void;
currentModelId?: string | null;
currentModel?: string;
availableModels?: ModelOption[];
modelsLoading?: boolean;
modelStatusMessage?: string | null;
onModelChange?: (modelId: string) => void;
selectedProjectId?: string | null;
availableProjects?: ProjectOption[];
onProjectChange?: (projectId: string | null) => void;
onCreateProject?: (options?: {
onCreated?: (projectId: string) => void;
}) => void;
contextTokens?: number;
contextLimit?: number;
onCompactContext?: () => void | Promise<void>;
canCompactContext?: boolean;
isCompactingContext?: boolean;
}

View file

@ -1,15 +1,19 @@
import { useEffect, useMemo, useState, type ReactNode } from "react";
import { useEffect, useMemo, useRef, useState, type ReactNode } from "react";
import {
IconCheck,
IconChevronDown,
IconChevronRight,
IconChevronLeft,
IconSearch,
} from "@tabler/icons-react";
import { useTranslation } from "react-i18next";
import type { AcpProvider } from "@/shared/api/acp";
import { getProviderInventory } from "@/features/providers/api/inventory";
import { useProviderInventoryStore } from "@/features/providers/stores/providerInventoryStore";
import { cn } from "@/shared/lib/cn";
import { Button } from "@/shared/ui/button";
import { Popover, PopoverContent, PopoverTrigger } from "@/shared/ui/popover";
import { ScrollArea } from "@/shared/ui/scroll-area";
import { Spinner } from "@/shared/ui/spinner";
import {
formatProviderLabel,
getProviderIcon,
@ -23,66 +27,42 @@ interface AgentModelPickerProps {
currentModelId?: string | null;
currentModelName?: string | null;
availableModels: ModelOption[];
modelsLoading?: boolean;
modelStatusMessage?: string | null;
onModelChange?: (modelId: string) => void;
loading?: boolean;
isCompact?: boolean;
}
interface ModelGroup {
provider: string;
models: ModelOption[];
hasSelectedModel: boolean;
}
const MODEL_PROVIDER_MATCHERS: Array<[string, RegExp]> = [
["Anthropic", /claude|anthropic/i],
["OpenAI", /(^|[\s-])(gpt|o1|o3|o4)([\s-]|$)|openai/i],
["Google", /gemini|google/i],
["Mistral", /mistral/i],
["Meta", /llama|meta/i],
["DeepSeek", /deepseek/i],
["Qwen", /qwen/i],
["Cohere", /cohere|command/i],
];
function getModelDisplayName(model: ModelOption) {
return model.displayName ?? model.name;
}
function inferModelProvider(model: ModelOption) {
if (model.provider) {
return model.provider;
function getGooseModelProviderLabel(model: ModelOption) {
if (model.providerName) {
return model.providerName;
}
const candidate = `${model.id} ${model.name}`;
for (const [provider, pattern] of MODEL_PROVIDER_MATCHERS) {
if (pattern.test(candidate)) {
return provider;
}
if (model.providerId) {
return formatProviderLabel(model.providerId);
}
return "Other";
return null;
}
function groupModels(models: ModelOption[], currentModelId: string | null) {
const grouped = new Map<string, ModelOption[]>();
function sortModels(models: ModelOption[], currentModelId: string | null) {
return [...models].sort((left, right) => {
if (left.id === currentModelId) return -1;
if (right.id === currentModelId) return 1;
for (const model of models) {
const provider = inferModelProvider(model);
const existing = grouped.get(provider) ?? [];
existing.push(model);
grouped.set(provider, existing);
}
const leftProvider = getGooseModelProviderLabel(left) ?? "";
const rightProvider = getGooseModelProviderLabel(right) ?? "";
if (leftProvider !== rightProvider) {
return leftProvider.localeCompare(rightProvider);
}
const groups = Array.from(grouped.entries())
.map(([provider, groupedModels]) => ({
provider,
models: groupedModels,
hasSelectedModel: groupedModels.some((m) => m.id === currentModelId),
}))
.sort((left, right) => left.provider.localeCompare(right.provider));
return groups;
return getModelDisplayName(left).localeCompare(getModelDisplayName(right));
});
}
function PickerItem({
@ -116,6 +96,219 @@ function PickerItem({
);
}
// ── Model list views ────────────────────────────────────────────────
type ModelView = "recommended" | "all";
function RecommendedModelList({
models,
currentModelId,
selectedAgentId,
onModelSelect,
onShowAll,
t,
}: {
models: ModelOption[];
currentModelId: string | null;
selectedAgentId: string;
onModelSelect: (id: string) => void;
onShowAll: () => void;
t: (key: string) => string;
}) {
const recommended = useMemo(() => {
const rec = models.filter((m) => m.recommended);
// If the current model isn't in the recommended list, prepend it
// so the user can always see what's selected.
if (
currentModelId &&
rec.length > 0 &&
!rec.some((m) => m.id === currentModelId)
) {
const current = models.find((m) => m.id === currentModelId);
if (current) {
return [current, ...rec];
}
}
// Fall back to full list if no recommendations exist (e.g. ACP agents).
return rec.length > 0 ? rec : models;
}, [models, currentModelId]);
const sorted = useMemo(
() => sortModels(recommended, currentModelId),
[recommended, currentModelId],
);
const hasMore = models.length > recommended.length;
return (
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
{t("toolbar.model")}
</div>
<ScrollArea className="min-h-0 min-w-0 flex-1">
<div className="space-y-0.5 p-1">
{sorted.map((model) => {
const providerLabel = getGooseModelProviderLabel(model);
return (
<PickerItem
key={`${model.providerId ?? "model"}:${model.id}`}
onClick={() => onModelSelect(model.id)}
selected={model.id === currentModelId}
className="justify-between"
>
<div className="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
{selectedAgentId === "goose" && model.providerId ? (
<span
className="shrink-0 text-muted-foreground"
title={providerLabel ?? undefined}
>
{getProviderIcon(model.providerId, "size-3.5")}
</span>
) : null}
<div className="min-w-0 flex-1 truncate">
{getModelDisplayName(model)}
</div>
</div>
{model.id === currentModelId ? (
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
) : null}
</PickerItem>
);
})}
</div>
</ScrollArea>
{hasMore ? (
<div className="shrink-0 border-t px-1 py-1">
<button
type="button"
onClick={onShowAll}
className="flex w-full items-center gap-1.5 rounded-sm px-2 py-1.5 text-xs text-muted-foreground transition-colors hover:bg-muted hover:text-foreground"
>
<IconSearch className="size-3.5" />
<span>{t("toolbar.showAllModels")}</span>
</button>
</div>
) : null}
</div>
);
}
function AllModelsList({
models,
currentModelId,
selectedAgentId,
onModelSelect,
onBack,
t,
}: {
models: ModelOption[];
currentModelId: string | null;
selectedAgentId: string;
onModelSelect: (id: string) => void;
onBack: () => void;
t: (key: string) => string;
}) {
const [query, setQuery] = useState("");
const inputRef = useRef<HTMLInputElement>(null);
useEffect(() => {
// Auto-focus search on mount.
inputRef.current?.focus();
}, []);
const filtered = useMemo(() => {
if (!query.trim()) {
return sortModels(models, currentModelId);
}
const q = query.toLowerCase();
const matches = models.filter(
(m) =>
m.name.toLowerCase().includes(q) ||
m.id.toLowerCase().includes(q) ||
m.displayName?.toLowerCase().includes(q) ||
m.providerName?.toLowerCase().includes(q) ||
m.providerId?.toLowerCase().includes(q),
);
return sortModels(matches, currentModelId);
}, [models, query, currentModelId]);
return (
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
<div className="flex shrink-0 items-center gap-1 px-1 py-1">
<button
type="button"
onClick={onBack}
className="flex shrink-0 items-center rounded-sm p-1 text-muted-foreground transition-colors hover:bg-muted hover:text-foreground"
aria-label={t("toolbar.model")}
>
<IconChevronLeft className="size-4" />
</button>
<div className="relative min-w-0 flex-1">
<IconSearch className="pointer-events-none absolute left-2 top-1/2 size-3.5 -translate-y-1/2 text-muted-foreground" />
<input
ref={inputRef}
type="text"
value={query}
onChange={(e) => setQuery(e.target.value)}
placeholder={t("toolbar.searchModels")}
className="h-7 w-full rounded-sm border bg-transparent pl-7 pr-2 text-sm outline-none placeholder:text-muted-foreground focus:ring-1 focus:ring-ring"
/>
</div>
</div>
{filtered.length > 0 ? (
<ScrollArea className="min-h-0 min-w-0 flex-1">
<div className="space-y-0.5 p-1">
{filtered.map((model) => {
const providerLabel = getGooseModelProviderLabel(model);
const displayName = getModelDisplayName(model);
// Show the raw model_id as secondary text when it differs from name
const showModelId =
model.id !== model.name && model.id !== displayName;
return (
<PickerItem
key={`${model.providerId ?? "model"}:${model.id}`}
onClick={() => onModelSelect(model.id)}
selected={model.id === currentModelId}
className="justify-between"
>
<div className="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
{selectedAgentId === "goose" && model.providerId ? (
<span
className="shrink-0 text-muted-foreground"
title={providerLabel ?? undefined}
>
{getProviderIcon(model.providerId, "size-3.5")}
</span>
) : null}
<div className="min-w-0 flex-1 overflow-hidden">
<div className="truncate">{displayName}</div>
{showModelId ? (
<div className="truncate text-xs text-muted-foreground">
{model.id}
</div>
) : null}
</div>
</div>
{model.id === currentModelId ? (
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
) : null}
</PickerItem>
);
})}
</div>
</ScrollArea>
) : (
<div className="px-3 py-4 text-center text-sm text-muted-foreground">
{t("toolbar.noSearchResults")}
</div>
)}
</div>
);
}
// ── Main component ──────────────────────────────────────────────────
export function AgentModelPicker({
agents,
selectedAgentId,
@ -123,54 +316,31 @@ export function AgentModelPicker({
currentModelId = null,
currentModelName = null,
availableModels,
modelsLoading = false,
modelStatusMessage = null,
onModelChange,
loading = false,
isCompact = false,
}: AgentModelPickerProps) {
const { t } = useTranslation("chat");
const [open, setOpen] = useState(false);
const [expandedGroups, setExpandedGroups] = useState<Set<string>>(new Set());
const [modelView, setModelView] = useState<ModelView>("recommended");
const mergeInventoryEntries = useProviderInventoryStore(
(s) => s.mergeEntries,
);
const selectedAgentLabel =
agents.find((agent) => agent.id === selectedAgentId)?.label ??
formatProviderLabel(selectedAgentId);
const groupedModels = useMemo<ModelGroup[]>(
() => groupModels(availableModels, currentModelId),
[availableModels, currentModelId],
);
const hasModelInfo = currentModelName !== null || availableModels.length > 0;
const triggerModelLabel = hasModelInfo
? (currentModelName ?? t("toolbar.loading"))
const hasSelectedModel = currentModelName !== null || currentModelId !== null;
const triggerModelLabel = hasSelectedModel
? (currentModelName ?? currentModelId)
: null;
useEffect(() => {
if (open) {
const selected = groupedModels.find((g) => g.hasSelectedModel);
if (selected) {
setExpandedGroups(new Set([selected.provider]));
}
}
}, [open, groupedModels]);
const toggleGroup = (provider: string) => {
setExpandedGroups((prev) => {
const next = new Set(prev);
if (next.has(provider)) {
next.delete(provider);
} else {
next.add(provider);
}
return next;
});
};
const isGroupExpanded = (group: ModelGroup) => {
return expandedGroups.has(group.provider);
};
const handleAgentSelect = (agentId: string) => {
if (agentId !== selectedAgentId) {
onAgentChange(agentId);
setModelView("recommended");
}
};
@ -179,6 +349,42 @@ export function AgentModelPicker({
setOpen(false);
};
// Reset to recommended view when popover closes.
useEffect(() => {
if (!open) {
setModelView("recommended");
}
}, [open]);
useEffect(() => {
if (!open) {
return;
}
let cancelled = false;
const syncInventory = async () => {
try {
const entries = await getProviderInventory();
if (cancelled) {
return;
}
mergeInventoryEntries(entries);
} catch (error) {
console.error("Failed to sync provider inventory from picker:", error);
}
};
void syncInventory();
return () => {
cancelled = true;
};
}, [open, mergeInventoryEntries]);
// When in "all" view, expand the popover to full width for the search experience.
const isAllView = modelView === "all";
return (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild>
@ -251,133 +457,97 @@ export function AgentModelPicker({
}
}}
>
<div className="grid h-full grid-cols-[minmax(0,1fr)_minmax(0,1fr)] gap-1 overflow-hidden">
<div
data-col="agent"
className="flex min-h-0 min-w-0 overflow-hidden p-1"
>
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
{t("toolbar.agent")}
</div>
<ScrollArea className="min-h-0 min-w-0 flex-1">
<div className="p-1 space-y-0.5">
{agents.map((agent) => {
const isSelected = agent.id === selectedAgentId;
return (
<PickerItem
key={agent.id}
onClick={() => handleAgentSelect(agent.id)}
selected={isSelected}
>
<span className="shrink-0">
{getProviderIcon(agent.id, "size-4")}
</span>
<span className="min-w-0 flex-1 truncate">
{agent.label}
</span>
{isSelected ? (
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
) : null}
</PickerItem>
);
})}
<div
className={cn(
"grid h-full gap-1 overflow-hidden",
isAllView
? "grid-cols-1"
: "grid-cols-[minmax(0,1fr)_minmax(0,1fr)]",
)}
>
{/* Agent column — hidden in "all models" search view */}
{!isAllView ? (
<div
data-col="agent"
className="flex min-h-0 min-w-0 overflow-hidden p-1"
>
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
{t("toolbar.agent")}
</div>
</ScrollArea>
</div>
</div>
<div
data-col="model"
className="flex min-h-0 min-w-0 overflow-hidden p-1"
>
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
{t("toolbar.model")}
</div>
{groupedModels.length > 0 ? (
<ScrollArea className="min-h-0 min-w-0 flex-1">
<div className="p-1 space-y-0.5">
{groupedModels.map((group) => {
const expanded = isGroupExpanded(group);
<div className="space-y-0.5 p-1">
{agents.map((agent) => {
const isSelected = agent.id === selectedAgentId;
return (
<div key={group.provider}>
<button
type="button"
onClick={() => toggleGroup(group.provider)}
className={cn(
"flex min-w-0 w-full items-center gap-1.5 rounded-sm px-2 py-1.5 text-left text-sm font-medium transition-colors",
"hover:bg-muted focus:bg-muted focus:outline-none",
)}
>
<IconChevronRight
className={cn(
"size-3.5 shrink-0 text-muted-foreground transition-transform",
expanded && "rotate-90",
)}
/>
<span className="min-w-0 flex-1 truncate">
{group.provider}
</span>
<span className="text-xs text-muted-foreground">
{group.models.length}
</span>
</button>
{expanded ? (
<div className="overflow-hidden pb-1">
{group.models.map((model) => {
const modelName = getModelDisplayName(model);
return (
<PickerItem
key={model.id}
onClick={() => handleModelSelect(model.id)}
selected={model.id === currentModelId}
className="justify-between pl-6"
>
<div className="min-w-0 flex-1 truncate">
{modelName}
</div>
{model.id === currentModelId ? (
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
) : null}
</PickerItem>
);
})}
</div>
) : group.hasSelectedModel ? (
<div className="overflow-hidden pb-1">
{group.models
.filter((m) => m.id === currentModelId)
.map((model) => (
<PickerItem
key={model.id}
onClick={() => handleModelSelect(model.id)}
selected
className="justify-between pl-6"
>
<div className="min-w-0 flex-1 truncate">
{getModelDisplayName(model)}
</div>
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
</PickerItem>
))}
</div>
<PickerItem
key={agent.id}
onClick={() => handleAgentSelect(agent.id)}
selected={isSelected}
>
<span className="shrink-0">
{getProviderIcon(agent.id, "size-4")}
</span>
<span className="min-w-0 flex-1 truncate">
{agent.label}
</span>
{isSelected ? (
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
) : null}
</div>
</PickerItem>
);
})}
</div>
</ScrollArea>
</div>
</div>
) : null}
{/* Model column */}
<div
data-col="model"
className="flex min-h-0 min-w-0 overflow-hidden p-1"
>
{modelsLoading ? (
<div className="flex min-h-0 flex-1 items-center gap-2 px-2 py-2 text-sm text-muted-foreground">
<Spinner className="size-4" />
<span>{t("toolbar.loadingModels")}</span>
</div>
) : availableModels.length > 0 ? (
modelView === "recommended" ? (
<RecommendedModelList
models={availableModels}
currentModelId={currentModelId}
selectedAgentId={selectedAgentId}
onModelSelect={handleModelSelect}
onShowAll={() => setModelView("all")}
t={t}
/>
) : (
<AllModelsList
models={availableModels}
currentModelId={currentModelId}
selectedAgentId={selectedAgentId}
onModelSelect={handleModelSelect}
onBack={() => setModelView("recommended")}
t={t}
/>
)
) : (
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
{t("toolbar.model")}
</div>
<div className="px-2 py-2">
<div className="text-sm text-muted-foreground">
{currentModelName ? currentModelName : t("toolbar.loading")}
{modelStatusMessage ??
currentModelName ??
t("toolbar.noModelsAvailable")}
</div>
</div>
)}
</div>
</div>
)}
</div>
</div>
</PopoverContent>

View file

@ -2,8 +2,6 @@ import { useState, useRef, useCallback, useEffect, useMemo } from "react";
import { open } from "@tauri-apps/plugin-dialog";
import { X } from "lucide-react";
import { useTranslation } from "react-i18next";
import type { AcpProvider } from "@/shared/api/acp";
import type { Persona } from "@/shared/types/agents";
import { cn } from "@/shared/lib/cn";
import { Badge } from "@/shared/ui/badge";
import { Button } from "@/shared/ui/button";
@ -14,61 +12,14 @@ import { ChatInputToolbar } from "./ChatInputToolbar";
import { formatProviderLabel } from "@/shared/ui/icons/ProviderIcons";
import { TooltipProvider } from "@/shared/ui/tooltip";
import { PersonaAvatar } from "./PersonaPicker";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
import { useAttachmentDropTarget } from "../hooks/useAttachmentDropTarget";
import {
normalizeDialogSelection,
useChatInputAttachments,
} from "../hooks/useChatInputAttachments";
import type { ModelOption } from "../types";
import { ChatInputAttachments } from "./ChatInputAttachments";
import { useVoiceDictation } from "../hooks/useVoiceDictation";
export interface ProjectOption {
id: string;
name: string;
workingDirs: string[];
color?: string | null;
}
interface ChatInputProps {
onSend: (
text: string,
personaId?: string,
attachments?: ChatAttachmentDraft[],
) => void;
onStop?: () => void;
isStreaming?: boolean;
disabled?: boolean;
queuedMessage?: { text: string } | null;
onDismissQueue?: () => void;
initialValue?: string;
onDraftChange?: (text: string) => void;
className?: string;
personas?: Persona[];
selectedPersonaId?: string | null;
onPersonaChange?: (personaId: string | null) => void;
onCreatePersona?: () => void;
providers?: AcpProvider[];
providersLoading?: boolean;
selectedProvider?: string;
onProviderChange?: (providerId: string) => void;
currentModelId?: string | null;
currentModel?: string;
availableModels?: ModelOption[];
onModelChange?: (modelId: string) => void;
selectedProjectId?: string | null;
availableProjects?: ProjectOption[];
onProjectChange?: (projectId: string | null) => void;
onCreateProject?: (options?: {
onCreated?: (projectId: string) => void;
}) => void;
contextTokens?: number;
contextLimit?: number;
onCompactContext?: () => void | Promise<void>;
canCompactContext?: boolean;
isCompactingContext?: boolean;
}
import type { ChatInputProps } from "../types";
export function ChatInput({
onSend,
@ -91,6 +42,8 @@ export function ChatInput({
currentModelId = null,
currentModel,
availableModels = [],
modelsLoading = false,
modelStatusMessage = null,
onModelChange,
selectedProjectId = null,
availableProjects = [],
@ -104,6 +57,9 @@ export function ChatInput({
}: ChatInputProps) {
const { t } = useTranslation("chat");
const [text, setTextRaw] = useState(initialValue);
useEffect(() => {
setTextRaw(initialValue);
}, [initialValue]);
const setText = useCallback(
(value: string) => {
setTextRaw(value);
@ -351,8 +307,18 @@ export function ChatInput({
providers.find((provider) => provider.id === selectedProvider)?.label ??
formatProviderLabel(selectedProvider);
const agentDisplayName = activePersona?.displayName ?? providerDisplayName;
const resolvedCurrentModel =
currentModel ?? availableModels[0]?.displayName ?? availableModels[0]?.name;
const resolvedCurrentModel = useMemo(() => {
if (currentModel) {
return currentModel;
}
if (!currentModelId) {
return undefined;
}
const selectedModel = availableModels.find(
(model) => model.id === currentModelId,
);
return selectedModel?.displayName ?? selectedModel?.name ?? currentModelId;
}, [availableModels, currentModel, currentModelId]);
const effectivePlaceholder = t("input.placeholder", {
agent: agentDisplayName,
});
@ -472,6 +438,8 @@ export function ChatInput({
currentModelId={currentModelId}
currentModel={resolvedCurrentModel}
availableModels={availableModels}
modelsLoading={modelsLoading}
modelStatusMessage={modelStatusMessage}
onModelChange={onModelChange}
selectedProjectId={selectedProjectId}
availableProjects={availableProjects}

View file

@ -16,7 +16,7 @@ import { cn } from "@/shared/lib/cn";
import { ChatInputSelector } from "./ChatInputSelector";
import { ContextRing } from "./ContextRing";
import { PersonaPicker } from "./PersonaPicker";
import type { ProjectOption } from "./ChatInput";
import type { ProjectOption } from "../types";
import { Button } from "@/shared/ui/button";
import {
DropdownMenu,
@ -30,11 +30,7 @@ import { Tooltip, TooltipTrigger, TooltipContent } from "@/shared/ui/tooltip";
import { AgentModelPicker } from "./AgentModelPicker";
import type { ModelOption } from "../types";
import { formatProviderLabel } from "@/shared/ui/icons/ProviderIcons";
import { useAgentProviderStatus } from "@/features/providers/hooks/useAgentProviderStatus";
import {
getCatalogEntry,
resolveAgentProviderCatalogIdStrict,
} from "@/features/providers/providerCatalog";
import { getCatalogEntry } from "@/features/providers/providerCatalog";
const NO_PROJECT_VALUE = "__no_project__";
const CREATE_PROJECT_VALUE = "__create_project__";
@ -67,6 +63,8 @@ interface ChatInputToolbarProps {
currentModelId?: string | null;
currentModel?: string;
availableModels: ModelOption[];
modelsLoading?: boolean;
modelStatusMessage?: string | null;
onModelChange?: (modelId: string) => void;
// Project
selectedProjectId: string | null;
@ -111,6 +109,8 @@ export function ChatInputToolbar({
currentModelId,
currentModel,
availableModels,
modelsLoading = false,
modelStatusMessage = null,
onModelChange,
selectedProjectId,
availableProjects,
@ -137,27 +137,22 @@ export function ChatInputToolbar({
}: ChatInputToolbarProps) {
const { t } = useTranslation("chat");
const { formatNumber } = useLocaleFormatting();
const { readyAgentIds } = useAgentProviderStatus();
const [isContextPopoverOpen, setIsContextPopoverOpen] = useState(false);
const agentProviders = useMemo(() => {
const seen = new Set<string>();
const connected: AcpProvider[] = [];
for (const p of providers) {
const catalogId = resolveAgentProviderCatalogIdStrict(p.id);
if (
catalogId === null ||
!readyAgentIds.has(catalogId) ||
seen.has(catalogId)
)
const available: AcpProvider[] = [];
for (const provider of providers) {
if (seen.has(provider.id)) {
continue;
seen.add(catalogId);
connected.push({
id: p.id,
label: getCatalogEntry(catalogId)?.displayName ?? p.label,
}
seen.add(provider.id);
available.push({
id: provider.id,
label: getCatalogEntry(provider.id)?.displayName ?? provider.label,
});
}
if (connected.length > 0) return connected;
if (available.length > 0) return available;
return [
{
id: selectedProvider,
@ -166,7 +161,7 @@ export function ChatInputToolbar({
formatProviderLabel(selectedProvider),
},
];
}, [providers, readyAgentIds, selectedProvider]);
}, [providers, selectedProvider]);
const selectedProject = availableProjects.find(
(project) => project.id === selectedProjectId,
);
@ -220,6 +215,8 @@ export function ChatInputToolbar({
currentModelId={currentModelId}
currentModelName={currentModel ?? null}
availableModels={availableModels}
modelsLoading={modelsLoading}
modelStatusMessage={modelStatusMessage}
onModelChange={onModelChange}
loading={providersLoading}
isCompact={isCompact}

View file

@ -1,507 +1,97 @@
import { useState, useEffect, useRef, useCallback, useMemo } from "react";
import { useState, useEffect, useRef } from "react";
import { useTranslation } from "react-i18next";
import { AnimatePresence } from "motion/react";
import { MessageTimeline } from "./MessageTimeline";
import { ChatInput } from "./ChatInput";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
import { LoadingGoose } from "./LoadingGoose";
import { ChatLoadingSkeleton } from "./ChatLoadingSkeleton";
import { useChat } from "../hooks/useChat";
import { useMessageQueue } from "../hooks/useMessageQueue";
import { useChatStore } from "../stores/chatStore";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
import { useChatSessionStore } from "../stores/chatSessionStore";
import { useProjectStore } from "@/features/projects/stores/projectStore";
import { acpPrepareSession, acpSetModel } from "@/shared/api/acp";
import {
buildProjectSystemPrompt,
composeSystemPrompt,
defaultGlobalArtifactRoot,
getProjectArtifactRoots,
resolveProjectDefaultArtifactRoot,
} from "@/features/projects/lib/chatProjectContext";
import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection";
import { defaultGlobalArtifactRoot } from "@/features/projects/lib/chatProjectContext";
import { ArtifactPolicyProvider } from "../hooks/ArtifactPolicyContext";
import type { ModelOption } from "../types";
import { ChatContextPanel } from "./ChatContextPanel";
import { perfLog } from "@/shared/lib/perfLog";
const EMPTY_MODELS: ModelOption[] = [];
import { useChatSessionController } from "../hooks/useChatSessionController";
interface ChatViewProps {
sessionId: string;
initialProvider?: string;
initialPersonaId?: string;
initialMessage?: string;
initialAttachments?: ChatAttachmentDraft[];
onInitialMessageConsumed?: () => void;
onCreateProject?: (options?: {
onCreated?: (projectId: string) => void;
}) => void;
}
export function ChatView({
sessionId,
initialProvider,
initialPersonaId,
initialMessage,
initialAttachments,
onInitialMessageConsumed,
onCreateProject,
}: ChatViewProps) {
export function ChatView({ sessionId, onCreateProject }: ChatViewProps) {
const { t } = useTranslation("chat");
const activeSessionId = sessionId;
const mountStart = useRef(performance.now());
const isContextPanelOpen = useChatSessionStore(
(s) => s.contextPanelOpenBySession[sessionId] ?? false,
);
const setContextPanelOpen = useChatSessionStore((s) => s.setContextPanelOpen);
const [globalArtifactRoot, setGlobalArtifactRoot] = useState<string | null>(
null,
);
const controller = useChatSessionController({ sessionId });
const contextPanelLabel = isContextPanelOpen
? t("context.closePanel")
: t("context.openPanel");
const allowedArtifactRoots = [
...controller.allowedArtifactRoots,
...(globalArtifactRoot ? [globalArtifactRoot] : []),
];
useEffect(() => {
const ms = (performance.now() - mountStart.current).toFixed(1);
perfLog(`[perf:chatview] ${sessionId.slice(0, 8)} mounted in ${ms}ms`);
}, [sessionId]);
const isContextPanelOpen = useChatSessionStore(
(s) => s.contextPanelOpenBySession[activeSessionId] ?? false,
);
const setContextPanelOpen = useChatSessionStore((s) => s.setContextPanelOpen);
const activeWorkspace = useChatSessionStore(
(s) => s.activeWorkspaceBySession[activeSessionId],
);
const clearActiveWorkspace = useChatSessionStore(
(s) => s.clearActiveWorkspace,
);
const {
providers,
providersLoading,
selectedProvider: globalSelectedProvider,
setSelectedProvider: setGlobalSelectedProvider,
} = useProviderSelection();
const personas = useAgentStore((s) => s.personas);
const [selectedPersonaId, setSelectedPersonaId] = useState<string | null>(
initialPersonaId ?? null,
);
const session = useChatSessionStore((s) =>
s.sessions.find((candidate) => candidate.id === activeSessionId),
);
const availableModels = useChatSessionStore(
(s) => s.modelsBySession[activeSessionId] ?? EMPTY_MODELS,
);
const projects = useProjectStore((s) => s.projects);
const projectsLoading = useProjectStore((s) => s.loading);
const storedProject = useProjectStore((s) =>
session?.projectId
? s.projects.find((candidate) => candidate.id === session.projectId)
: undefined,
);
const [globalArtifactRoot, setGlobalArtifactRoot] = useState<string | null>(
null,
);
const project = storedProject ?? null;
const contextPanelLabel = isContextPanelOpen
? t("context.closePanel")
: t("context.openPanel");
const availableProjects = useMemo(
() =>
[...projects]
.sort((a, b) => a.order - b.order || a.name.localeCompare(b.name))
.map((projectInfo) => ({
id: projectInfo.id,
name: projectInfo.name,
workingDirs: projectInfo.workingDirs,
color: projectInfo.color,
})),
[projects],
);
const selectedProvider =
session?.providerId ??
initialProvider ??
project?.preferredProvider ??
globalSelectedProvider;
const selectedPersona = personas.find((p) => p.id === selectedPersonaId);
const projectArtifactRoots = useMemo(
() => getProjectArtifactRoots(project),
[project],
);
const projectDefaultArtifactRoot = useMemo(
() => resolveProjectDefaultArtifactRoot(project),
[project],
);
const projectMetadataPending = Boolean(
session?.projectId && !projectDefaultArtifactRoot && projectsLoading,
);
const allowedArtifactRoots = useMemo(() => {
const roots = [
...projectArtifactRoots.map((path) => path.trim()).filter(Boolean),
];
if (globalArtifactRoot) {
roots.push(globalArtifactRoot);
}
return [...new Set(roots)];
}, [globalArtifactRoot, projectArtifactRoots]);
const projectSystemPrompt = useMemo(
() => buildProjectSystemPrompt(project),
[project],
);
const workingContextPrompt = useMemo(() => {
if (!activeWorkspace?.branch) return undefined;
return `<active-working-context>\nActive branch: ${activeWorkspace.branch}\nWorking directory: ${activeWorkspace.path}\n</active-working-context>`;
}, [activeWorkspace?.branch, activeWorkspace?.path]);
const effectiveSystemPrompt = useMemo(
() =>
composeSystemPrompt(
selectedPersona?.systemPrompt,
projectSystemPrompt,
workingContextPrompt,
),
[selectedPersona?.systemPrompt, projectSystemPrompt, workingContextPrompt],
);
useEffect(() => {
let cancelled = false;
defaultGlobalArtifactRoot()
.then((artifactRoot) => {
if (cancelled) return;
setGlobalArtifactRoot(artifactRoot);
if (!cancelled) {
setGlobalArtifactRoot(artifactRoot);
}
})
.catch(() => {
if (cancelled) return;
setGlobalArtifactRoot(null);
if (!cancelled) {
setGlobalArtifactRoot(null);
}
});
return () => {
cancelled = true;
};
}, []);
const prevProjectIdRef = useRef(session?.projectId);
useEffect(() => {
const prevProjectId = prevProjectIdRef.current;
prevProjectIdRef.current = session?.projectId;
if (prevProjectId !== undefined && prevProjectId !== session?.projectId) {
clearActiveWorkspace(activeSessionId);
}
}, [session?.projectId, activeSessionId, clearActiveWorkspace]);
const prevWorkspaceRef = useRef(activeWorkspace);
useEffect(() => {
const prev = prevWorkspaceRef.current;
if (
!activeWorkspace ||
!selectedProvider ||
session?.draft ||
activeWorkspace === prev
) {
return;
}
prevWorkspaceRef.current = activeWorkspace;
if (prev && prev.path === activeWorkspace.path) return;
async function prepareWorkspaceSession() {
const workingDir = await resolveSessionCwd(project, activeWorkspace.path);
if (!workingDir) {
return;
}
await acpPrepareSession(activeSessionId, selectedProvider, workingDir, {
personaId: selectedPersonaId ?? undefined,
});
}
void prepareWorkspaceSession().catch((error) => {
console.error("Failed to prepare ACP session:", error);
});
}, [
activeWorkspace,
activeSessionId,
project,
selectedProvider,
selectedPersonaId,
session?.draft,
]);
const handleProviderChange = useCallback(
(providerId: string) => {
if (providerId === selectedProvider) {
return;
}
const sessionStore = useChatSessionStore.getState();
const cached = sessionStore.getCachedModels(providerId);
sessionStore.switchSessionProvider(activeSessionId, providerId, cached);
setGlobalSelectedProvider(providerId);
},
[activeSessionId, selectedProvider, setGlobalSelectedProvider],
);
const handleProjectChange = useCallback(
(projectId: string | null) => {
const nextProject =
projectId == null
? null
: (useProjectStore
.getState()
.projects.find((candidate) => candidate.id === projectId) ??
null);
useChatSessionStore
.getState()
.updateSession(activeSessionId, { projectId });
if (!session?.draft && selectedProvider) {
async function updateProjectSessionCwd() {
const workingDir = await resolveSessionCwd(
nextProject,
activeWorkspace?.path,
);
if (!workingDir) {
return;
}
await acpPrepareSession(
activeSessionId,
selectedProvider,
workingDir,
{
personaId: selectedPersonaId ?? undefined,
},
);
}
void updateProjectSessionCwd().catch((error) => {
console.error(
"Failed to update ACP session working directory:",
error,
);
});
}
},
[
activeSessionId,
activeWorkspace?.path,
selectedPersonaId,
selectedProvider,
session?.draft,
],
);
const handleModelChange = useCallback(
(modelId: string) => {
if (!activeSessionId || modelId === session?.modelId) {
return;
}
const previousModelId = session?.modelId;
const previousModelName = session?.modelName;
const models = useChatSessionStore
.getState()
.getSessionModels(activeSessionId);
const selected = models.find((m) => m.id === modelId);
useChatSessionStore.getState().updateSession(activeSessionId, {
modelId,
modelName: selected?.displayName ?? selected?.name ?? modelId,
});
if (session?.draft) {
return;
}
acpSetModel(activeSessionId, modelId).catch((error) => {
console.error("Failed to set model:", error);
useChatSessionStore.getState().updateSession(activeSessionId, {
modelId: previousModelId,
modelName: previousModelName,
});
});
},
[activeSessionId, session?.draft, session?.modelId, session?.modelName],
);
// When persona changes, update the provider to match persona's default
const handlePersonaChange = useCallback(
(personaId: string | null) => {
setSelectedPersonaId(personaId);
const persona = personas.find((p) => p.id === personaId);
if (persona?.provider) {
const matchingProvider = providers.find(
(p) =>
p.id === persona.provider ||
p.label.toLowerCase().includes(persona.provider ?? ""),
);
if (matchingProvider) {
handleProviderChange(matchingProvider.id);
}
}
const agentStore = useAgentStore.getState();
const matchingAgent = agentStore.agents.find(
(a) => a.personaId === personaId,
);
if (matchingAgent) {
agentStore.setActiveAgent(matchingAgent.id);
}
useChatSessionStore
.getState()
.updateSession(activeSessionId, { personaId: personaId ?? undefined });
},
[personas, providers, activeSessionId, handleProviderChange],
);
// Validate persona still exists — fall back to default if deleted
useEffect(() => {
if (
selectedPersonaId !== null &&
personas.length > 0 &&
!personas.find((p) => p.id === selectedPersonaId)
) {
// Selected persona was deleted — reset to no persona
setSelectedPersonaId(null);
}
}, [personas, selectedPersonaId]);
const personaInfo = selectedPersona
? { id: selectedPersona.id, name: selectedPersona.displayName }
: undefined;
const resolveCurrentSessionCwd = useCallback(
() => resolveSessionCwd(project, activeWorkspace?.path),
[project, activeWorkspace?.path],
);
const {
messages,
chatState,
tokenState,
sendMessage,
compactConversation,
stopStreaming,
streamingMessageId,
} = useChat(
activeSessionId,
selectedProvider,
effectiveSystemPrompt,
personaInfo,
resolveCurrentSessionCwd,
);
const isLoadingHistory = useChatStore(
(s) =>
s.loadingSessionIds.has(activeSessionId) &&
(s.messagesBySession[activeSessionId]?.length ?? 0) === 0,
);
const deferredSend = useRef<{
text: string;
attachments?: ChatAttachmentDraft[];
} | null>(null);
const queue = useMessageQueue(activeSessionId, chatState, sendMessage);
const chatStore = useChatStore();
const handleSend = useCallback(
(text: string, personaId?: string, attachments?: ChatAttachmentDraft[]) => {
if (personaId && personaId !== selectedPersonaId) {
const newPersona = personas.find((p) => p.id === personaId);
if (newPersona) {
// Inject a system notification about the persona switch
chatStore.addMessage(activeSessionId, {
id: crypto.randomUUID(),
role: "system",
created: Date.now(),
content: [
{
type: "systemNotification",
notificationType: "info",
text: `Switched to ${newPersona.displayName}`,
},
],
metadata: { userVisible: true, agentVisible: false },
});
}
handlePersonaChange(personaId);
// Defer the send until after persona state updates
deferredSend.current = { text, attachments };
return;
}
// Queue if agent is busy and no message already queued
if (chatState !== "idle" && !queue.queuedMessage) {
queue.enqueue(text, personaId, attachments);
return;
}
sendMessage(text, undefined, attachments);
},
[
sendMessage,
selectedPersonaId,
handlePersonaChange,
personas,
chatStore,
activeSessionId,
chatState,
queue,
],
);
useEffect(() => {
if (deferredSend.current && selectedPersona) {
const { text, attachments } = deferredSend.current;
deferredSend.current = null;
sendMessage(text, undefined, attachments);
}
}, [sendMessage, selectedPersona]);
const initialMessageSent = useRef(false);
useEffect(() => {
if (
(initialMessage || initialAttachments?.length) &&
!initialMessageSent.current
) {
initialMessageSent.current = true;
handleSend(initialMessage ?? "", undefined, initialAttachments);
onInitialMessageConsumed?.();
}
}, [
initialAttachments,
initialMessage,
handleSend,
onInitialMessageConsumed,
]);
const isStreaming = chatState === "streaming";
const isCompacting = chatState === "compacting";
const showIndicator =
chatState === "thinking" ||
chatState === "streaming" ||
chatState === "waiting" ||
chatState === "compacting";
const handleCreatePersona = useCallback(() => {
useAgentStore.getState().openPersonaEditor();
}, []);
const draftValue = useChatStore(
(s) => s.draftsBySession[activeSessionId] ?? "",
);
const scrollTarget = useChatStore(
(s) => s.scrollTargetMessageBySession[activeSessionId] ?? null,
);
const handleDraftChange = useCallback(
(text: string) => {
useChatStore.getState().setDraft(activeSessionId, text);
},
[activeSessionId],
);
const handleScrollTargetHandled = useCallback(() => {
useChatStore.getState().clearScrollTargetMessage(activeSessionId);
}, [activeSessionId]);
controller.chatState === "thinking" ||
controller.chatState === "streaming" ||
controller.chatState === "waiting" ||
controller.chatState === "compacting";
return (
<ArtifactPolicyProvider
messages={messages}
messages={controller.messages}
allowedRoots={allowedArtifactRoots}
>
<div className="relative flex h-full min-w-0">
<div className="flex min-w-0 flex-1 flex-col pr-1">
{isLoadingHistory ? (
{controller.isLoadingHistory ? (
<ChatLoadingSkeleton />
) : (
<MessageTimeline
messages={messages}
streamingMessageId={streamingMessageId}
scrollTargetMessageId={scrollTarget?.messageId ?? null}
scrollTargetQuery={scrollTarget?.query ?? null}
onScrollTargetHandled={handleScrollTargetHandled}
messages={controller.messages}
streamingMessageId={controller.streamingMessageId}
scrollTargetMessageId={controller.scrollTarget?.messageId ?? null}
scrollTargetQuery={controller.scrollTarget?.query ?? null}
onScrollTargetHandled={controller.handleScrollTargetHandled}
/>
)}
<AnimatePresence initial={false}>
{showIndicator && !isLoadingHistory ? (
{showIndicator && !controller.isLoadingHistory ? (
<LoadingGoose
key="loading-indicator"
chatState={
chatState as
controller.chatState as
| "thinking"
| "streaming"
| "waiting"
@ -512,54 +102,52 @@ export function ChatView({
</AnimatePresence>
<ChatInput
onSend={handleSend}
disabled={projectMetadataPending}
queuedMessage={queue.queuedMessage}
onDismissQueue={queue.dismiss}
initialValue={draftValue}
onDraftChange={handleDraftChange}
onStop={stopStreaming}
isStreaming={isStreaming || chatState === "thinking"}
personas={personas}
selectedPersonaId={selectedPersonaId}
onPersonaChange={handlePersonaChange}
onCreatePersona={handleCreatePersona}
providers={providers}
providersLoading={providersLoading}
selectedProvider={selectedProvider}
onProviderChange={handleProviderChange}
currentModelId={session?.modelId ?? null}
currentModel={session?.modelName}
availableModels={availableModels}
onModelChange={handleModelChange}
selectedProjectId={session?.projectId ?? null}
availableProjects={availableProjects}
onProjectChange={handleProjectChange}
onSend={controller.handleSend}
disabled={controller.projectMetadataPending}
queuedMessage={controller.queue.queuedMessage}
onDismissQueue={controller.queue.dismiss}
initialValue={controller.draftValue}
onDraftChange={controller.handleDraftChange}
onStop={controller.stopStreaming}
isStreaming={
controller.chatState === "streaming" ||
controller.chatState === "thinking"
}
personas={controller.personas}
selectedPersonaId={controller.selectedPersonaId}
onPersonaChange={controller.handlePersonaChange}
onCreatePersona={controller.handleCreatePersona}
providers={controller.pickerAgents}
providersLoading={controller.providersLoading}
selectedProvider={controller.selectedProvider}
onProviderChange={controller.handleProviderChange}
currentModelId={controller.currentModelId}
currentModel={controller.currentModelName ?? undefined}
availableModels={controller.availableModels}
modelsLoading={controller.modelsLoading}
modelStatusMessage={controller.modelStatusMessage}
onModelChange={controller.handleModelChange}
selectedProjectId={controller.selectedProjectId}
availableProjects={controller.availableProjects}
onProjectChange={controller.handleProjectChange}
onCreateProject={(options) =>
onCreateProject?.({
onCreated: (projectId) => {
handleProjectChange(projectId);
controller.handleProjectChange(projectId);
options?.onCreated?.(projectId);
},
})
}
contextTokens={tokenState.accumulatedTotal}
contextLimit={tokenState.contextLimit}
onCompactContext={compactConversation}
canCompactContext={
chatState === "idle" &&
tokenState.accumulatedTotal > 0 &&
!projectMetadataPending
}
isCompactingContext={isCompacting}
contextTokens={controller.tokenState.accumulatedTotal}
contextLimit={controller.tokenState.contextLimit}
/>
</div>
<ChatContextPanel
activeSessionId={activeSessionId}
activeSessionId={sessionId}
isOpen={isContextPanelOpen}
label={contextPanelLabel}
project={project}
project={controller.project}
setOpen={setContextPanelOpen}
/>
</div>

View file

@ -59,7 +59,6 @@ describe("AgentModelPicker", () => {
screen.getByRole("button", { name: /choose agent and model/i }),
);
await user.click(screen.getByRole("button", { name: /OpenAI/i }));
await user.click(screen.getByRole("button", { name: "GPT-4o" }));
expect(onModelChange).toHaveBeenCalledWith("gpt-4o");
@ -149,4 +148,49 @@ describe("AgentModelPicker", () => {
expect(trigger).toHaveTextContent("Goose");
expect(trigger).not.toHaveTextContent("·");
});
it("shows a loading state while models are refreshing", async () => {
const user = userEvent.setup();
render(
<AgentModelPicker
agents={AGENTS}
selectedAgentId="goose"
onAgentChange={vi.fn()}
currentModelId={null}
currentModelName={null}
availableModels={[]}
modelsLoading
onModelChange={vi.fn()}
/>,
);
await user.click(
screen.getByRole("button", { name: /choose agent and model/i }),
);
expect(screen.getByText("Loading models...")).toBeInTheDocument();
});
it("shows an empty-state message when no inventory models are available", async () => {
const user = userEvent.setup();
render(
<AgentModelPicker
agents={AGENTS}
selectedAgentId="goose"
onAgentChange={vi.fn()}
currentModelId={null}
currentModelName={null}
availableModels={[]}
onModelChange={vi.fn()}
/>,
);
await user.click(
screen.getByRole("button", { name: /choose agent and model/i }),
);
expect(screen.getByText("No models available")).toBeInTheDocument();
});
});

View file

@ -149,7 +149,7 @@ describe("ChatInput", () => {
).toHaveTextContent("GPT-4o");
});
it("shows default model name in model picker", () => {
it("shows provider label when no current model is selected", () => {
render(
<ChatInput
onSend={vi.fn()}
@ -159,7 +159,7 @@ describe("ChatInput", () => {
);
expect(
screen.getByRole("button", { name: /choose agent and model/i }),
).toHaveTextContent("Claude Sonnet 4");
).toHaveTextContent("Goose");
});
it("shows default provider label", () => {
@ -176,6 +176,18 @@ describe("ChatInput", () => {
expect(providerButton).toHaveTextContent("Goose");
});
it("resets the textarea when initialValue changes", () => {
const { rerender } = render(
<ChatInput onSend={vi.fn()} initialValue="alpha draft" />,
);
expect(screen.getByRole("textbox")).toHaveValue("alpha draft");
rerender(<ChatInput onSend={vi.fn()} initialValue="" />);
expect(screen.getByRole("textbox")).toHaveValue("");
});
it("opens the agent and model picker", async () => {
const user = userEvent.setup();

View file

@ -5,6 +5,67 @@ import { HomeScreen } from "./HomeScreen";
const setSelectedProvider = vi.fn();
const setSelectedProviderWithoutPersist = vi.fn();
const mockController = {
handleSend: vi.fn(),
projectMetadataPending: false,
queue: { queuedMessage: null, dismiss: vi.fn() },
stopStreaming: vi.fn(),
chatState: "idle" as const,
personas: [
{
id: "builtin-solo",
displayName: "Solo",
systemPrompt: "You are Solo.",
provider: "openai",
description: null,
avatar: null,
createdBy: null,
source: "custom",
extensions: [],
metadata: null,
sortOrder: 0,
isDefault: false,
},
{
id: "builtin-goose",
displayName: "Goosey",
systemPrompt: "You are Goosey.",
isBuiltin: true,
description: null,
avatar: null,
createdBy: null,
source: "custom",
extensions: [],
metadata: null,
sortOrder: 1,
isDefault: false,
createdAt: "",
updatedAt: "",
},
],
draftValue: "",
handleDraftChange: vi.fn(),
selectedPersonaId: null,
handlePersonaChange: vi.fn(),
handleCreatePersona: vi.fn(),
pickerAgents: [
{ id: "goose", label: "Goose" },
{ id: "claude-acp", label: "Claude Code" },
],
providersLoading: false,
selectedProvider: "goose",
handleProviderChange: setSelectedProvider,
currentModelId: null,
currentModelName: null,
availableModels: [],
modelsLoading: false,
modelStatusMessage: null,
handleModelChange: vi.fn(),
selectedProjectId: null,
availableProjects: [],
handleProjectChange: vi.fn(),
tokenState: { accumulatedTotal: 0, contextLimit: 0 },
};
vi.mock("@/shared/api/acp", () => ({
discoverAcpProviders: vi.fn().mockResolvedValue([
@ -21,6 +82,10 @@ vi.mock("@/features/providers/hooks/useAgentProviderStatus", () => ({
}),
}));
vi.mock("@/features/chat/hooks/useChatSessionController", () => ({
useChatSessionController: () => mockController,
}));
vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
useProviderSelection: () => ({
providers: [
@ -34,7 +99,6 @@ vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
}),
}));
// HomeScreen now reads personas from the agent store, not from ACP providers
vi.mock("@/features/agents/stores/agentStore", async (importOriginal) => {
const actual =
await importOriginal<
@ -88,6 +152,9 @@ vi.mock("@/features/agents/stores/agentStore", async (importOriginal) => {
});
describe("HomeScreen", () => {
const renderHome = () =>
render(<HomeScreen sessionId="home-session" onActivateSession={vi.fn()} />);
beforeEach(() => {
vi.useFakeTimers();
vi.setSystemTime(new Date(2026, 2, 29, 14, 30, 0)); // 2:30 PM
@ -98,32 +165,32 @@ describe("HomeScreen", () => {
});
it("renders the clock", () => {
render(<HomeScreen />);
renderHome();
expect(screen.getByText("2:30")).toBeInTheDocument();
expect(screen.getByText("PM")).toBeInTheDocument();
});
it("shows afternoon greeting at 2:30 PM", () => {
render(<HomeScreen />);
renderHome();
expect(screen.getByText("Good afternoon")).toBeInTheDocument();
});
it("renders the chat input placeholder with default agent name when no persona selected", () => {
render(<HomeScreen />);
renderHome();
expect(
screen.getByPlaceholderText("Message Goose, @ to mention personas"),
).toBeInTheDocument();
});
it("renders the assistant chooser affordance", () => {
render(<HomeScreen />);
renderHome();
expect(
screen.getByRole("button", { name: /choose assistant/i }),
).toBeInTheDocument();
});
it("renders the provider and project controls on the home screen", () => {
render(<HomeScreen />);
renderHome();
expect(
screen.getByRole("button", { name: /choose agent and model/i }),
).toBeInTheDocument();
@ -132,23 +199,17 @@ describe("HomeScreen", () => {
).toBeInTheDocument();
});
it("reverts to the stored provider when a persona override is cleared", async () => {
it("forwards persona selection through the shared session controller", async () => {
vi.useRealTimers();
const user = userEvent.setup();
render(<HomeScreen />);
renderHome();
await user.click(screen.getByRole("button", { name: /choose assistant/i }));
await user.click(screen.getByRole("menuitem", { name: /solo/i }));
expect(setSelectedProviderWithoutPersist).toHaveBeenLastCalledWith(
"openai",
expect(mockController.handlePersonaChange).toHaveBeenLastCalledWith(
"builtin-solo",
);
await user.click(
screen.getByRole("button", { name: /clear active assistant/i }),
);
expect(setSelectedProviderWithoutPersist).toHaveBeenLastCalledWith("goose");
});
});

View file

@ -1,17 +1,8 @@
import { useState, useEffect, useCallback } from "react";
import { useState, useEffect } from "react";
import { useTranslation } from "react-i18next";
import {
getStoredProvider,
useAgentStore,
} from "@/features/agents/stores/agentStore";
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
import { ChatInput } from "@/features/chat/ui/ChatInput";
import { useChatStore } from "@/features/chat/stores/chatStore";
import type { ChatAttachmentDraft } from "@/shared/types/messages";
import { useProjectStore } from "@/features/projects/stores/projectStore";
import { useLocaleFormatting } from "@/shared/i18n";
const HOME_DRAFT_KEY = "home";
import { useChatSessionController } from "@/features/chat/hooks/useChatSessionController";
function HomeClock() {
const [time, setTime] = useState(new Date());
@ -48,124 +39,94 @@ function getGreetingKey(hour: number): "morning" | "afternoon" | "evening" {
}
interface HomeScreenProps {
onStartChat?: (
initialMessage?: string,
providerId?: string,
personaId?: string,
projectId?: string | null,
attachments?: ChatAttachmentDraft[],
) => void;
sessionId: string | null;
onActivateSession: (sessionId: string) => void;
onCreateProject?: (options?: {
onCreated?: (projectId: string) => void;
}) => void;
}
export function HomeScreen({ onStartChat, onCreateProject }: HomeScreenProps) {
function HomeComposer({
sessionId,
onActivateSession,
onCreateProject,
}: {
sessionId: string | null;
onActivateSession: (sessionId: string) => void;
onCreateProject?: HomeScreenProps["onCreateProject"];
}) {
const controller = useChatSessionController({
sessionId,
onMessageAccepted: onActivateSession,
});
return (
<ChatInput
onSend={controller.handleSend}
disabled={controller.projectMetadataPending}
queuedMessage={controller.queue.queuedMessage}
onDismissQueue={controller.queue.dismiss}
initialValue={controller.draftValue}
onDraftChange={controller.handleDraftChange}
onStop={controller.stopStreaming}
isStreaming={
controller.chatState === "streaming" ||
controller.chatState === "thinking"
}
personas={controller.personas}
selectedPersonaId={controller.selectedPersonaId}
onPersonaChange={controller.handlePersonaChange}
onCreatePersona={controller.handleCreatePersona}
providers={controller.pickerAgents}
providersLoading={controller.providersLoading}
selectedProvider={controller.selectedProvider}
onProviderChange={controller.handleProviderChange}
currentModelId={controller.currentModelId}
currentModel={controller.currentModelName ?? undefined}
availableModels={controller.availableModels}
modelsLoading={controller.modelsLoading}
modelStatusMessage={controller.modelStatusMessage}
onModelChange={controller.handleModelChange}
selectedProjectId={controller.selectedProjectId}
availableProjects={controller.availableProjects}
onProjectChange={controller.handleProjectChange}
onCreateProject={(options) =>
onCreateProject?.({
onCreated: (projectId) => {
controller.handleProjectChange(projectId);
options?.onCreated?.(projectId);
},
})
}
contextTokens={controller.tokenState.accumulatedTotal}
contextLimit={controller.tokenState.contextLimit}
/>
);
}
export function HomeScreen({
sessionId,
onActivateSession,
onCreateProject,
}: HomeScreenProps) {
const { t } = useTranslation("home");
const [hour] = useState(() => new Date().getHours());
const greeting = t(`greeting.${getGreetingKey(hour)}`);
const personas = useAgentStore((s) => s.personas);
const {
providers,
providersLoading,
selectedProvider,
setSelectedProvider,
setSelectedProviderWithoutPersist,
} = useProviderSelection();
const projects = useProjectStore((s) => s.projects);
const [selectedPersonaId, setSelectedPersonaId] = useState<string | null>(
null,
);
const [selectedProjectId, setSelectedProjectId] = useState<string | null>(
null,
);
const handlePersonaChange = useCallback(
(personaId: string | null) => {
setSelectedPersonaId(personaId);
const persona = personaId
? personas.find((candidate) => candidate.id === personaId)
: null;
const nextProvider = persona?.provider ?? getStoredProvider(providers);
setSelectedProviderWithoutPersist(nextProvider);
},
[personas, providers, setSelectedProviderWithoutPersist],
);
const handleCreatePersona = useCallback(() => {
useAgentStore.getState().openPersonaEditor();
}, []);
const homeDraft = useChatStore(
(s) => s.draftsBySession[HOME_DRAFT_KEY] ?? "",
);
const handleDraftChange = useCallback((text: string) => {
useChatStore.getState().setDraft(HOME_DRAFT_KEY, text);
}, []);
const handleSend = useCallback(
(
message: string,
personaId?: string,
attachments?: ChatAttachmentDraft[],
) => {
const effectivePersonaId = personaId ?? selectedPersonaId ?? undefined;
useChatStore.getState().clearDraft(HOME_DRAFT_KEY);
onStartChat?.(
message,
selectedProvider,
effectivePersonaId,
selectedProjectId,
attachments,
);
},
[onStartChat, selectedPersonaId, selectedProjectId, selectedProvider],
);
return (
<div className="h-full w-full overflow-y-auto">
<div className="relative flex min-h-full flex-col items-center justify-center px-6 pb-4">
<div className="flex w-full max-w-[600px] flex-col antialiased">
{/* Clock */}
<HomeClock />
{/* Greeting */}
<p className="mb-6 pl-4 text-xl font-normal font-display text-muted-foreground">
{greeting}
</p>
{/* Chat input */}
<ChatInput
onSend={handleSend}
initialValue={homeDraft}
onDraftChange={handleDraftChange}
personas={personas}
selectedPersonaId={selectedPersonaId}
onPersonaChange={handlePersonaChange}
onCreatePersona={handleCreatePersona}
providers={providers}
providersLoading={providersLoading}
selectedProvider={selectedProvider}
onProviderChange={setSelectedProvider}
selectedProjectId={selectedProjectId}
availableProjects={projects.map((project) => ({
id: project.id,
name: project.name,
workingDirs: project.workingDirs,
color: project.color,
}))}
onProjectChange={setSelectedProjectId}
onCreateProject={(options) =>
onCreateProject?.({
onCreated: (projectId) => {
setSelectedProjectId(projectId);
options?.onCreated?.(projectId);
},
})
}
<HomeComposer
sessionId={sessionId}
onActivateSession={onActivateSession}
onCreateProject={onCreateProject}
/>
</div>
</div>

View file

@ -0,0 +1,32 @@
import type {
ProviderInventoryEntryDto,
RefreshProviderInventoryResponse,
} from "@aaif/goose-sdk";
import { getClient } from "@/shared/api/acpConnection";
import { perfLog } from "@/shared/lib/perfLog";
export async function getProviderInventory(
providerIds: string[] = [],
): Promise<ProviderInventoryEntryDto[]> {
const client = await getClient();
const t0 = performance.now();
const response = await client.goose.GooseProvidersInventory({ providerIds });
perfLog(
`[perf:inventory] getProviderInventory done in ${(performance.now() - t0).toFixed(1)}ms (n=${response.entries.length})`,
);
return response.entries;
}
export async function refreshProviderInventory(
providerIds: string[] = [],
): Promise<RefreshProviderInventoryResponse> {
const client = await getClient();
const t0 = performance.now();
const response = await client.goose.GooseProvidersInventoryRefresh({
providerIds,
});
perfLog(
`[perf:inventory] refreshProviderInventory done in ${(performance.now() - t0).toFixed(1)}ms started=[${response.started.join(",")}]`,
);
return response;
}

View file

@ -0,0 +1,83 @@
import { useCallback, useMemo } from "react";
import { useProviderInventoryStore } from "../stores/providerInventoryStore";
import type { ModelOption } from "@/features/chat/types";
import type {
ProviderInventoryEntryDto,
ProviderInventoryModelDto,
} from "@aaif/goose-sdk";
import { getModelProviders } from "../providerCatalog";
const MODEL_PROVIDER_IDS = new Set(getModelProviders().map((p) => p.id));
function inventoryModelToOption(
model: ProviderInventoryModelDto,
provider?: Pick<ProviderInventoryEntryDto, "providerId" | "providerName">,
): ModelOption {
return {
id: model.id,
name: model.name,
displayName: model.name !== model.id ? model.name : undefined,
provider: model.family ?? undefined,
providerId: provider?.providerId,
providerName: provider?.providerName,
recommended: model.recommended ?? false,
};
}
export function useProviderInventory() {
const entries = useProviderInventoryStore((s) => s.entries);
const loading = useProviderInventoryStore((s) => s.loading);
const getEntry = useCallback(
(providerId: string) => entries.get(providerId),
[entries],
);
const getModelsForProvider = useCallback(
(providerId: string): ModelOption[] => {
const entry = entries.get(providerId);
if (!entry) return [];
return entry.models.map((model) => inventoryModelToOption(model, entry));
},
[entries],
);
const configuredModelProviderEntries = useMemo(
() =>
[...entries.values()].filter(
(entry) => entry.configured && MODEL_PROVIDER_IDS.has(entry.providerId),
),
[entries],
);
const getModelsForAgent = useCallback(
(agentId: string): ModelOption[] => {
if (agentId !== "goose") {
return getModelsForProvider(agentId);
}
return configuredModelProviderEntries.flatMap((entry) =>
entry.models.map((model) => inventoryModelToOption(model, entry)),
);
},
[configuredModelProviderEntries, getModelsForProvider],
);
const configuredProviderIds = useMemo(
() =>
[...entries.values()]
.filter((e) => e.configured)
.map((e) => e.providerId),
[entries],
);
return {
entries,
loading,
getEntry,
configuredModelProviderEntries,
getModelsForAgent,
getModelsForProvider,
configuredProviderIds,
};
}

View file

@ -0,0 +1,47 @@
import { create } from "zustand";
import type { ProviderInventoryEntryDto } from "@aaif/goose-sdk";
import { perfLog } from "@/shared/lib/perfLog";
export interface ProviderInventoryState {
entries: Map<string, ProviderInventoryEntryDto>;
loading: boolean;
}
interface ProviderInventoryActions {
setEntries: (entries: ProviderInventoryEntryDto[]) => void;
mergeEntries: (entries: ProviderInventoryEntryDto[]) => void;
setLoading: (loading: boolean) => void;
}
export type ProviderInventoryStore = ProviderInventoryState &
ProviderInventoryActions;
export const useProviderInventoryStore = create<ProviderInventoryStore>(
(set) => ({
entries: new Map(),
loading: false,
setEntries: (entries) => {
const map = new Map<string, ProviderInventoryEntryDto>();
for (const entry of entries) {
map.set(entry.providerId, entry);
}
set({ entries: map });
perfLog(
`[perf:inventory] setEntries n=${entries.length} providers=[${entries.map((e) => e.providerId).join(",")}]`,
);
},
mergeEntries: (entries) => {
set((state) => {
const map = new Map(state.entries);
for (const entry of entries) {
map.set(entry.providerId, entry);
}
return { entries: map };
});
},
setLoading: (loading) => set({ loading }),
}),
);

View file

@ -8,7 +8,11 @@ import { Button } from "@/shared/ui/button";
import { SessionCard } from "./SessionCard";
import { groupSessionsByDate } from "../lib/groupSessionsByDate";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
import {
getVisibleSessions,
useChatSessionStore,
} from "@/features/chat/stores/chatSessionStore";
import { useChatStore } from "@/features/chat/stores/chatStore";
import { useProjectStore } from "@/features/projects/stores/projectStore";
import {
acpDuplicateSession,
@ -38,10 +42,14 @@ export function SessionHistoryView({
}: SessionHistoryViewProps) {
const { t, i18n } = useTranslation(["sessions", "common"]);
const sessions = useChatSessionStore((s) => s.sessions);
const messagesBySession = useChatStore((s) => s.messagesBySession);
const loadSessions = useChatSessionStore((s) => s.loadSessions);
const activeSessions = useMemo(
() => sessions.filter((session) => !session.draft && !session.archivedAt),
[sessions],
() =>
getVisibleSessions(sessions, messagesBySession).filter(
(session) => !session.archivedAt,
),
[messagesBySession, sessions],
);
const fileInputRef = useRef<HTMLInputElement>(null);

View file

@ -11,7 +11,10 @@ import { cn } from "@/shared/lib/cn";
import type { AppView } from "@/app/AppShell";
import type { ProjectInfo } from "@/features/projects/api/projects";
import { useChatStore } from "@/features/chat/stores/chatStore";
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
import {
getVisibleSessions,
useChatSessionStore,
} from "@/features/chat/stores/chatSessionStore";
import { isSessionRunning } from "@/features/chat/lib/sessionActivity";
import { useAgentStore } from "@/features/agents/stores/agentStore";
import { useProjectStore } from "@/features/projects/stores/projectStore";
@ -92,8 +95,12 @@ export function Sidebar({
const chatStore = useChatStore();
const { sessions } = useChatSessionStore();
const activeSessions = sessions.filter(
(session) => !session.draft && !session.archivedAt,
const visibleSessions = getVisibleSessions(
sessions,
chatStore.messagesBySession,
);
const activeSessions = visibleSessions.filter(
(session) => !session.archivedAt,
);
useEffect(() => {
@ -137,8 +144,8 @@ export function Sidebar({
};
const byProject: Record<string, SessionItem[]> = {};
const standalone: SessionItem[] = [];
for (const session of sessions) {
if (session.draft || session.archivedAt) continue;
for (const session of visibleSessions) {
if (session.archivedAt) continue;
const runtime = chatStore.getSessionRuntime(session.id);
const item: SessionItem = {
id: session.id,
@ -191,7 +198,7 @@ export function Sidebar({
useEffect(() => {
if (!activeSessionId) return;
const activeSession = sessions.find((s) => s.id === activeSessionId);
const activeSession = visibleSessions.find((s) => s.id === activeSessionId);
const projectId = activeSession?.projectId;
if (projectId) {
setExpandedProjects((prev) => {
@ -199,7 +206,7 @@ export function Sidebar({
return { ...prev, [projectId]: true };
});
}
}, [activeSessionId, sessions]);
}, [activeSessionId, visibleSessions]);
useEffect(() => {
try {
@ -254,17 +261,13 @@ export function Sidebar({
updateActiveRect,
} = useSidebarHighlight(navRef);
const activeDraft = activeSessionId
? sessions.find((s) => s.id === activeSessionId && s.draft)
: undefined;
const activeProjectId = activeDraft?.projectId ?? null;
const activeProjectId =
activeSessionId && activeView === "chat"
? (sessions.find((s) => s.id === activeSessionId)?.projectId ?? null)
: null;
useEffect(() => {
if (activeDraft) {
if (!activeProjectId) updateActiveRect(null);
return;
}
if (activeSessionId) return;
if (activeSessionId && activeView === "chat") return;
if (activeView === "home") {
updateActiveRect(homeRef.current);
} else if (activeView && navItemRefs.current[activeView]) {
@ -272,13 +275,7 @@ export function Sidebar({
} else {
updateActiveRect(null);
}
}, [
activeSessionId,
activeDraft,
activeProjectId,
activeView,
updateActiveRect,
]);
}, [activeSessionId, activeView, updateActiveRect]);
const activeSessionRefCallback = useCallback(
(el: HTMLElement | null) => {

View file

@ -9,12 +9,12 @@ const mockSessions: Array<{
updatedAt: string;
messageCount: number;
projectId?: string;
draft?: boolean;
archivedAt?: string;
}> = [];
vi.mock("@/features/chat/stores/chatStore", () => ({
useChatStore: () => ({
messagesBySession: {},
getSessionRuntime: () => ({
chatState: "idle",
hasUnread: false,
@ -23,6 +23,8 @@ vi.mock("@/features/chat/stores/chatStore", () => ({
}));
vi.mock("@/features/chat/stores/chatSessionStore", () => ({
getVisibleSessions: (sessions: typeof mockSessions) =>
sessions.filter((session) => session.messageCount > 0),
useChatSessionStore: () => ({
sessions: mockSessions,
}),
@ -65,6 +67,40 @@ describe("Sidebar", () => {
mockSessions.splice(0, mockSessions.length);
});
it("hides zero-message sessions from recents", () => {
mockSessions.splice(
0,
mockSessions.length,
{
id: "home-session",
title: "New Chat",
updatedAt: "2026-04-09T12:00:00.000Z",
messageCount: 0,
},
{
id: "session-1",
title: "Recovered Session",
updatedAt: "2026-04-09T12:01:00.000Z",
messageCount: 3,
},
);
render(
<Sidebar
collapsed={false}
onCollapse={vi.fn()}
onNavigate={vi.fn()}
onSelectSession={vi.fn()}
projects={[]}
/>,
);
expect(screen.queryByText("New Chat")).not.toBeInTheDocument();
expect(screen.getByText("Recovered Session")).toBeInTheDocument();
mockSessions.splice(0, mockSessions.length);
});
it("renders a home button in the sidebar header and navigates home", async () => {
const user = userEvent.setup();
const onNavigate = vi.fn();

View file

@ -16,7 +16,7 @@ vi.mock("../acpConnection", () => ({
}));
describe("dictation SDK wiring", () => {
let client: any;
let client: { goose: Record<string, ReturnType<typeof vi.fn>> };
beforeEach(() => {
client = {
goose: {
@ -33,7 +33,9 @@ describe("dictation SDK wiring", () => {
GooseDictationTranscribe: vi.fn().mockResolvedValue({ text: "hello" }),
},
};
vi.mocked(getClient).mockResolvedValue(client);
vi.mocked(getClient).mockResolvedValue(
client as unknown as Awaited<ReturnType<typeof getClient>>,
);
});
it("getDictationConfig calls GooseDictationConfig and returns providers map", async () => {
@ -46,7 +48,7 @@ describe("dictation SDK wiring", () => {
const result = await transcribeDictation({
audio: "base64==",
mimeType: "audio/webm",
provider: "openai" as any,
provider: "openai",
});
expect(client.goose.GooseDictationTranscribe).toHaveBeenCalledWith({
audio: "base64==",
@ -58,7 +60,7 @@ describe("dictation SDK wiring", () => {
it("saveDictationModelSelection calls GooseDictationModelSelect", async () => {
client.goose.GooseDictationModelSelect = vi.fn().mockResolvedValue({});
await saveDictationModelSelection("local" as any, "tiny");
await saveDictationModelSelection("local", "tiny");
expect(client.goose.GooseDictationModelSelect).toHaveBeenCalledWith({
provider: "local",
modelId: "tiny",

View file

@ -1,6 +1,10 @@
import type { ContentBlock } from "@agentclientprotocol/sdk";
import * as directAcp from "./acpApi";
import * as sessionTracker from "./acpSessionTracker";
import {
getCatalogEntry,
resolveAgentProviderCatalogId,
} from "@/features/providers/providerCatalog";
import {
setActiveMessageId,
clearActiveMessageId,
@ -25,9 +29,31 @@ export interface AcpPrepareSessionOptions {
personaId?: string;
}
export interface AcpCreateSessionOptions extends AcpPrepareSessionOptions {
modelId?: string | null;
}
/** Discover ACP providers installed on the system. */
export async function discoverAcpProviders(): Promise<AcpProvider[]> {
return directAcp.listProviders();
const providers = await directAcp.listProviders();
const seen = new Set<string>();
return providers
.map((provider) => {
const catalogId = resolveAgentProviderCatalogId(
provider.id,
provider.label,
);
if (!catalogId || seen.has(catalogId)) {
return null;
}
seen.add(catalogId);
return {
id: catalogId,
label: getCatalogEntry(catalogId)?.displayName ?? provider.label,
};
})
.filter((provider): provider is AcpProvider => provider !== null);
}
/** Send a message to an ACP agent. Response streams via Tauri events. */
@ -79,13 +105,13 @@ export async function acpPrepareSession(
providerId: string,
workingDir: string,
options: AcpPrepareSessionOptions = {},
): Promise<void> {
): Promise<string> {
const sid = sessionId.slice(0, 8);
const t0 = performance.now();
perfLog(
`[perf:prepare] ${sid} acpPrepareSession start (provider=${providerId})`,
);
await sessionTracker.prepareSession(
const gooseSessionId = await sessionTracker.prepareSession(
sessionId,
providerId,
workingDir,
@ -94,6 +120,31 @@ export async function acpPrepareSession(
perfLog(
`[perf:prepare] ${sid} acpPrepareSession done in ${(performance.now() - t0).toFixed(1)}ms`,
);
return gooseSessionId;
}
export async function acpCreateSession(
providerId: string,
workingDir: string,
options: AcpCreateSessionOptions = {},
): Promise<{ sessionId: string }> {
const localSessionId = crypto.randomUUID();
const gooseSessionId = await acpPrepareSession(
localSessionId,
providerId,
workingDir,
options,
);
sessionTracker.registerSession(
gooseSessionId,
gooseSessionId,
providerId,
workingDir,
);
if (options.modelId) {
await directAcp.setModel(gooseSessionId, options.modelId);
}
return { sessionId: gooseSessionId };
}
export async function acpSetModel(

View file

@ -455,7 +455,6 @@ function handleShared(sessionId: string, update: SessionUpdate): void {
currentModelId;
const sessionStore = useChatSessionStore.getState();
sessionStore.setSessionModels(sessionId, availableModels);
sessionStore.updateSession(
sessionId,
{ modelId: currentModelId, modelName: currentModelName },

View file

@ -115,8 +115,10 @@ export async function prepareSession(
`[perf:prepare] ${sid} tracker setProvider(${providerId}) in ${(performance.now() - tProv).toFixed(1)}ms (goose_sid=${gooseSid})`,
);
prepared.set(key, { gooseSessionId, providerId, workingDir });
prepared.set(sessionId, { gooseSessionId, providerId, workingDir });
const entry = { gooseSessionId, providerId, workingDir };
prepared.set(key, entry);
prepared.set(sessionId, entry);
prepared.set(gooseSessionId, entry);
gooseToLocal.set(gooseSessionId, sessionId);
notifySessionRegistered(sessionId, gooseSessionId);
@ -161,6 +163,7 @@ export function registerSession(
}
prepared.set(sessionId, entry);
prepared.set(gooseSessionId, entry);
gooseToLocal.set(gooseSessionId, sessionId);
notifySessionRegistered(sessionId, gooseSessionId);

View file

@ -163,7 +163,14 @@
"createProject": "Create project",
"generalChatWithoutProject": "General chat without project context",
"loading": "Loading...",
"loadingModels": "Loading models...",
"model": "Model",
"allModels": "All models",
"searchModels": "Search models...",
"recommended": "Recommended",
"noModelsAvailable": "No models available",
"noSearchResults": "No matching models",
"showAllModels": "Browse all models",
"noProject": "No project",
"selectModel": "Select model",
"selectProject": "Select project",

View file

@ -163,7 +163,14 @@
"createProject": "Crear proyecto",
"generalChatWithoutProject": "Chat general sin contexto de proyecto",
"loading": "Cargando...",
"loadingModels": "Cargando modelos...",
"model": "Modelo",
"allModels": "Todos los modelos",
"searchModels": "Buscar modelos...",
"recommended": "Recomendados",
"noModelsAvailable": "No hay modelos disponibles",
"noSearchResults": "Sin resultados",
"showAllModels": "Ver todos los modelos",
"noProject": "Sin proyecto",
"selectModel": "Seleccionar modelo",
"selectProject": "Seleccionar proyecto",

View file

@ -30,6 +30,137 @@ export function buildInitScript(options?: {
const PERSONAS = ${personas};
const SKILLS = ${skills};
const PROJECTS = ${projects};
const ACP_SESSIONS = [];
function nowIso() {
return new Date().toISOString();
}
function buildSession(sessionId, providerId = "goose") {
return {
sessionId,
title: "New Chat",
updatedAt: nowIso(),
messageCount: 0,
providerId,
modelId: null,
};
}
function findSession(sessionId) {
return ACP_SESSIONS.find((session) => session.sessionId === sessionId) ?? null;
}
function jsonRpcResult(id, result) {
return { jsonrpc: "2.0", id, result };
}
function handleAcpRequest(message) {
switch (message.method) {
case "initialize":
return jsonRpcResult(message.id, {
protocolVersion: "0.1.0",
agentCapabilities: {
loadSession: {},
listSessions: {},
},
agentInfo: {
name: "mock-goose",
version: "0.0.0",
},
authMethods: [],
});
case "session/list":
return jsonRpcResult(message.id, {
sessions: ACP_SESSIONS.map((session) => ({
sessionId: session.sessionId,
title: session.title,
updatedAt: session.updatedAt,
_meta: {
messageCount: session.messageCount,
},
})),
});
case "session/new": {
const providerId = message.params?.meta?.provider ?? "goose";
const sessionId = "session-" + Math.random().toString(36).slice(2, 10);
ACP_SESSIONS.unshift(buildSession(sessionId, providerId));
return jsonRpcResult(message.id, { sessionId });
}
case "session/load":
return jsonRpcResult(message.id, {});
case "session/set_config_option": {
const session = findSession(message.params?.sessionId);
if (session) {
if (message.params?.configId === "provider") {
session.providerId = message.params?.value ?? session.providerId;
session.modelId = null;
}
if (message.params?.configId === "model") {
session.modelId = message.params?.value ?? null;
}
session.updatedAt = nowIso();
}
return jsonRpcResult(message.id, {});
}
case "session/prompt": {
const session = findSession(message.params?.sessionId);
if (session) {
session.messageCount += 1;
session.updatedAt = nowIso();
}
return jsonRpcResult(message.id, { stopReason: "end_turn" });
}
case "_goose/providers/list":
return jsonRpcResult(message.id, { providers: [] });
case "_goose/providers/inventory":
return jsonRpcResult(message.id, { entries: [] });
case "_goose/providers/inventory/refresh":
return jsonRpcResult(message.id, { started: [], skipped: [] });
case "_goose/working_dir/update":
case "goose/working_dir/update":
return jsonRpcResult(message.id, {});
default:
return jsonRpcResult(message.id, {});
}
}
class MockWebSocket extends EventTarget {
constructor(url) {
super();
this.url = url;
this.readyState = 0;
queueMicrotask(() => {
this.readyState = 1;
this.dispatchEvent(new Event("open"));
});
}
send(raw) {
const message = JSON.parse(raw);
const response =
message && typeof message === "object" && "id" in message
? handleAcpRequest(message)
: null;
if (!response) {
return;
}
queueMicrotask(() => {
this.dispatchEvent(
new MessageEvent("message", {
data: JSON.stringify(response),
}),
);
});
}
close() {
this.readyState = 3;
this.dispatchEvent(new CloseEvent("close"));
}
}
window.WebSocket = MockWebSocket;
window.__TAURI_INTERNALS__ = {
invoke(cmd, args) {
@ -95,7 +226,16 @@ export function buildInitScript(options?: {
// ---- Sessions / Misc ----
case "list_sessions":
return Promise.resolve([]);
return Promise.resolve(
ACP_SESSIONS.map((session) => ({
sessionId: session.sessionId,
title: session.title,
updatedAt: session.updatedAt,
messageCount: session.messageCount,
})),
);
case "get_goose_serve_url":
return Promise.resolve("ws://mock-goose");
case "create_session":
return Promise.resolve({
id: "session-" + Math.random().toString(36).slice(2, 10),
@ -130,6 +270,16 @@ export function buildInitScript(options?: {
return Promise.resolve("/tmp/home");
case "path_exists":
return Promise.resolve(false);
case "resolve_path": {
const parts = args?.request?.parts ?? [];
const path = parts
.filter((part) => typeof part === "string" && part.length > 0)
.join("/");
const normalizedPath = path.startsWith("~/")
? "/tmp/home/" + path.slice(2)
: path;
return Promise.resolve({ path: normalizedPath });
}
// ---- Fallback ----
default:

View file

@ -31,8 +31,8 @@ import type {
GetExtensionsResponse,
GetProviderDetailsRequest,
GetProviderDetailsResponse,
GetProviderModelsRequest,
GetProviderModelsResponse,
GetProviderInventoryRequest,
GetProviderInventoryResponse,
GetSessionExtensionsRequest,
GetSessionExtensionsResponse,
GetToolsRequest,
@ -45,6 +45,8 @@ import type {
ReadConfigResponse,
ReadResourceRequest,
ReadResourceResponse,
RefreshProviderInventoryRequest,
RefreshProviderInventoryResponse,
RemoveConfigRequest,
RemoveExtensionRequest,
RemoveSecretRequest,
@ -62,13 +64,14 @@ import {
zExportSessionResponse,
zGetExtensionsResponse,
zGetProviderDetailsResponse,
zGetProviderModelsResponse,
zGetProviderInventoryResponse,
zGetSessionExtensionsResponse,
zGetToolsResponse,
zImportSessionResponse,
zListProvidersResponse,
zReadConfigResponse,
zReadResourceResponse,
zRefreshProviderInventoryResponse,
} from './zod.gen.js';
export class GooseExtClient {
@ -132,11 +135,25 @@ export class GooseExtClient {
return zGetProviderDetailsResponse.parse(raw) as GetProviderDetailsResponse;
}
async GooseProvidersModels(
params: GetProviderModelsRequest,
): Promise<GetProviderModelsResponse> {
const raw = await this.conn.extMethod("_goose/providers/models", params);
return zGetProviderModelsResponse.parse(raw) as GetProviderModelsResponse;
async GooseProvidersInventory(
params: GetProviderInventoryRequest,
): Promise<GetProviderInventoryResponse> {
const raw = await this.conn.extMethod("_goose/providers/inventory", params);
return zGetProviderInventoryResponse.parse(
raw,
) as GetProviderInventoryResponse;
}
async GooseProvidersInventoryRefresh(
params: RefreshProviderInventoryRequest,
): Promise<RefreshProviderInventoryResponse> {
const raw = await this.conn.extMethod(
"_goose/providers/inventory/refresh",
params,
);
return zRefreshProviderInventoryResponse.parse(
raw,
) as RefreshProviderInventoryResponse;
}
async GooseConfigRead(

View file

@ -1,6 +1,6 @@
// This file is auto-generated by @hey-api/openapi-ts
export type { AddExtensionRequest, ArchiveSessionRequest, CheckSecretRequest, CheckSecretResponse, DeleteSessionRequest, DictationConfigRequest, DictationConfigResponse, DictationDownloadProgress, DictationLocalModelStatus, DictationModelCancelRequest, DictationModelDeleteRequest, DictationModelDownloadProgressRequest, DictationModelDownloadProgressResponse, DictationModelDownloadRequest, DictationModelOption, DictationModelSelectRequest, DictationModelsListRequest, DictationModelsListResponse, DictationProviderStatusEntry, DictationTranscribeRequest, DictationTranscribeResponse, EmptyResponse, ExportSessionRequest, ExportSessionResponse, ExtRequest, ExtResponse, GetExtensionsRequest, GetExtensionsResponse, GetProviderDetailsRequest, GetProviderDetailsResponse, GetProviderModelsRequest, GetProviderModelsResponse, GetSessionExtensionsRequest, GetSessionExtensionsResponse, GetToolsRequest, GetToolsResponse, ImportSessionRequest, ImportSessionResponse, ListProvidersRequest, ListProvidersResponse, ModelEntry, ProviderConfigKey, ProviderDetailEntry, ProviderListEntry, ReadConfigRequest, ReadConfigResponse, ReadResourceRequest, ReadResourceResponse, RemoveConfigRequest, RemoveExtensionRequest, RemoveSecretRequest, UnarchiveSessionRequest, UpdateWorkingDirRequest, UpsertConfigRequest, UpsertSecretRequest } from './types.gen.js';
export type { AddExtensionRequest, ArchiveSessionRequest, CheckSecretRequest, CheckSecretResponse, DeleteSessionRequest, DictationConfigRequest, DictationConfigResponse, DictationDownloadProgress, DictationLocalModelStatus, DictationModelCancelRequest, DictationModelDeleteRequest, DictationModelDownloadProgressRequest, DictationModelDownloadProgressResponse, DictationModelDownloadRequest, DictationModelOption, DictationModelSelectRequest, DictationModelsListRequest, DictationModelsListResponse, DictationProviderStatusEntry, DictationTranscribeRequest, DictationTranscribeResponse, EmptyResponse, ExportSessionRequest, ExportSessionResponse, ExtRequest, ExtResponse, GetExtensionsRequest, GetExtensionsResponse, GetProviderDetailsRequest, GetProviderDetailsResponse, GetProviderInventoryRequest, GetProviderInventoryResponse, GetSessionExtensionsRequest, GetSessionExtensionsResponse, GetToolsRequest, GetToolsResponse, ImportSessionRequest, ImportSessionResponse, ListProvidersRequest, ListProvidersResponse, ModelEntry, ProviderConfigKey, ProviderDetailEntry, ProviderInventoryEntryDto, ProviderInventoryModelDto, ProviderListEntry, ReadConfigRequest, ReadConfigResponse, ReadResourceRequest, ReadResourceResponse, RefreshProviderInventoryRequest, RefreshProviderInventoryResponse, RefreshProviderInventorySkipDto, RefreshProviderInventorySkipReasonDto, RemoveConfigRequest, RemoveExtensionRequest, RemoveSecretRequest, UnarchiveSessionRequest, UpdateWorkingDirRequest, UpsertConfigRequest, UpsertSecretRequest } from './types.gen.js';
export const GOOSE_EXT_METHODS = [
{
@ -54,9 +54,14 @@ export const GOOSE_EXT_METHODS = [
responseType: "GetProviderDetailsResponse",
},
{
method: "_goose/providers/models",
requestType: "GetProviderModelsRequest",
responseType: "GetProviderModelsResponse",
method: "_goose/providers/inventory",
requestType: "GetProviderInventoryRequest",
responseType: "GetProviderInventoryResponse",
},
{
method: "_goose/providers/inventory/refresh",
requestType: "RefreshProviderInventoryRequest",
responseType: "RefreshProviderInventoryResponse",
},
{
method: "_goose/config/read",

View file

@ -165,19 +165,133 @@ export type ModelEntry = {
};
/**
* Fetch the full list of models available for a specific provider.
* Read per-provider inventory. Always returns immediately from stored state.
*/
export type GetProviderModelsRequest = {
providerName: string;
export type GetProviderInventoryRequest = {
/**
* Only return entries for these providers. Empty means all.
*/
providerIds?: Array<string>;
};
/**
* Provider models response.
* Provider inventory response.
*/
export type GetProviderModelsResponse = {
models: Array<string>;
export type GetProviderInventoryResponse = {
entries: Array<ProviderInventoryEntryDto>;
};
/**
* Provider inventory entry.
*/
export type ProviderInventoryEntryDto = {
/**
* Provider identifier.
*/
providerId: string;
/**
* Human-readable provider name.
*/
providerName: string;
/**
* Whether Goose has enough configuration to use this provider.
*/
configured: boolean;
/**
* Whether this provider supports background inventory refresh.
*/
supportsRefresh: boolean;
/**
* Whether a refresh is currently in flight.
*/
refreshing: boolean;
/**
* The list of available models.
*/
models: Array<ProviderInventoryModelDto>;
/**
* When this entry was last successfully refreshed (ISO 8601).
*/
lastUpdatedAt?: string | null;
/**
* When a refresh was most recently attempted (ISO 8601).
*/
lastRefreshAttemptAt?: string | null;
/**
* The last refresh failure message, if any.
*/
lastRefreshError?: string | null;
/**
* Whether we believe this data may be outdated.
*/
stale: boolean;
/**
* Guidance message shown when this provider manages its own model selection externally.
*/
modelSelectionHint?: string | null;
};
/**
* A single model in provider inventory.
*/
export type ProviderInventoryModelDto = {
/**
* Model identifier as the provider knows it.
*/
id: string;
/**
* Human-readable display name.
*/
name: string;
/**
* Model family for grouping in UI.
*/
family?: string | null;
/**
* Context window size in tokens.
*/
contextLimit?: number | null;
/**
* Whether the model supports reasoning/extended thinking.
*/
reasoning?: boolean | null;
/**
* Whether this model should appear in the compact recommended picker.
*/
recommended?: boolean;
};
/**
* Trigger a background refresh of provider inventories.
*/
export type RefreshProviderInventoryRequest = {
/**
* Which providers to refresh. Empty means all known providers.
*/
providerIds?: Array<string>;
};
/**
* Refresh acknowledgement.
*/
export type RefreshProviderInventoryResponse = {
/**
* Which providers will be refreshed.
*/
started: Array<string>;
/**
* Which providers were skipped and why.
*/
skipped?: Array<RefreshProviderInventorySkipDto>;
};
export type RefreshProviderInventorySkipDto = {
providerId: string;
reason: RefreshProviderInventorySkipReasonDto;
};
export type RefreshProviderInventorySkipReasonDto = 'unknown_provider' | 'not_configured' | 'does_not_support_refresh' | 'already_refreshing';
/**
* Read a single non-secret config value.
*/
@ -421,14 +535,14 @@ export type DictationModelSelectRequest = {
export type ExtRequest = {
id: string;
method: string;
params?: AddExtensionRequest | RemoveExtensionRequest | GetToolsRequest | ReadResourceRequest | UpdateWorkingDirRequest | DeleteSessionRequest | GetExtensionsRequest | GetSessionExtensionsRequest | ListProvidersRequest | GetProviderDetailsRequest | GetProviderModelsRequest | ReadConfigRequest | UpsertConfigRequest | RemoveConfigRequest | CheckSecretRequest | UpsertSecretRequest | RemoveSecretRequest | ExportSessionRequest | ImportSessionRequest | ArchiveSessionRequest | UnarchiveSessionRequest | DictationTranscribeRequest | DictationConfigRequest | DictationModelsListRequest | DictationModelDownloadRequest | DictationModelDownloadProgressRequest | DictationModelCancelRequest | DictationModelDeleteRequest | DictationModelSelectRequest | {
params?: AddExtensionRequest | RemoveExtensionRequest | GetToolsRequest | ReadResourceRequest | UpdateWorkingDirRequest | DeleteSessionRequest | GetExtensionsRequest | GetSessionExtensionsRequest | ListProvidersRequest | GetProviderDetailsRequest | GetProviderInventoryRequest | RefreshProviderInventoryRequest | ReadConfigRequest | UpsertConfigRequest | RemoveConfigRequest | CheckSecretRequest | UpsertSecretRequest | RemoveSecretRequest | ExportSessionRequest | ImportSessionRequest | ArchiveSessionRequest | UnarchiveSessionRequest | DictationTranscribeRequest | DictationConfigRequest | DictationModelsListRequest | DictationModelDownloadRequest | DictationModelDownloadProgressRequest | DictationModelCancelRequest | DictationModelDeleteRequest | DictationModelSelectRequest | {
[key: string]: unknown;
} | null;
};
export type ExtResponse = {
id: string;
result?: EmptyResponse | GetToolsResponse | ReadResourceResponse | GetExtensionsResponse | GetSessionExtensionsResponse | ListProvidersResponse | GetProviderDetailsResponse | GetProviderModelsResponse | ReadConfigResponse | CheckSecretResponse | ExportSessionResponse | ImportSessionResponse | DictationTranscribeResponse | DictationConfigResponse | DictationModelsListResponse | DictationModelDownloadProgressResponse | unknown;
result?: EmptyResponse | GetToolsResponse | ReadResourceResponse | GetExtensionsResponse | GetSessionExtensionsResponse | ListProvidersResponse | GetProviderDetailsResponse | GetProviderInventoryResponse | RefreshProviderInventoryResponse | ReadConfigResponse | CheckSecretResponse | ExportSessionResponse | ImportSessionResponse | DictationTranscribeResponse | DictationConfigResponse | DictationModelsListResponse | DictationModelDownloadProgressResponse | unknown;
} | {
error: {
code: number;

View file

@ -149,17 +149,94 @@ export const zGetProviderDetailsResponse = z.object({
});
/**
* Fetch the full list of models available for a specific provider.
* Read per-provider inventory. Always returns immediately from stored state.
*/
export const zGetProviderModelsRequest = z.object({
providerName: z.string()
export const zGetProviderInventoryRequest = z.object({
providerIds: z.array(z.string()).optional().default([])
});
/**
* Provider models response.
* A single model in provider inventory.
*/
export const zGetProviderModelsResponse = z.object({
models: z.array(z.string())
export const zProviderInventoryModelDto = z.object({
id: z.string(),
name: z.string(),
family: z.union([
z.string(),
z.null()
]).optional(),
contextLimit: z.union([
z.number().int().gte(0),
z.null()
]).optional(),
reasoning: z.union([
z.boolean(),
z.null()
]).optional(),
recommended: z.boolean().optional().default(false)
});
/**
* Provider inventory entry.
*/
export const zProviderInventoryEntryDto = z.object({
providerId: z.string(),
providerName: z.string(),
configured: z.boolean(),
supportsRefresh: z.boolean(),
refreshing: z.boolean(),
models: z.array(zProviderInventoryModelDto),
lastUpdatedAt: z.union([
z.string(),
z.null()
]).optional(),
lastRefreshAttemptAt: z.union([
z.string(),
z.null()
]).optional(),
lastRefreshError: z.union([
z.string(),
z.null()
]).optional(),
stale: z.boolean(),
modelSelectionHint: z.union([
z.string(),
z.null()
]).optional()
});
/**
* Provider inventory response.
*/
export const zGetProviderInventoryResponse = z.object({
entries: z.array(zProviderInventoryEntryDto)
});
/**
* Trigger a background refresh of provider inventories.
*/
export const zRefreshProviderInventoryRequest = z.object({
providerIds: z.array(z.string()).optional().default([])
});
export const zRefreshProviderInventorySkipReasonDto = z.enum([
'unknown_provider',
'not_configured',
'does_not_support_refresh',
'already_refreshing'
]);
export const zRefreshProviderInventorySkipDto = z.object({
providerId: z.string(),
reason: zRefreshProviderInventorySkipReasonDto
});
/**
* Refresh acknowledgement.
*/
export const zRefreshProviderInventoryResponse = z.object({
started: z.array(z.string()),
skipped: z.array(zRefreshProviderInventorySkipDto).optional().default([])
});
/**
@ -426,7 +503,8 @@ export const zExtRequest = z.object({
zGetSessionExtensionsRequest,
zListProvidersRequest,
zGetProviderDetailsRequest,
zGetProviderModelsRequest,
zGetProviderInventoryRequest,
zRefreshProviderInventoryRequest,
zReadConfigRequest,
zUpsertConfigRequest,
zRemoveConfigRequest,
@ -465,7 +543,8 @@ export const zExtResponse = z.union([
zGetSessionExtensionsResponse,
zListProvidersResponse,
zGetProviderDetailsResponse,
zGetProviderModelsResponse,
zGetProviderInventoryResponse,
zRefreshProviderInventoryResponse,
zReadConfigResponse,
zCheckSecretResponse,
zExportSessionResponse,