mirror of
https://github.com/block/goose.git
synced 2026-04-28 03:29:36 +00:00
overhaul provider inventory and agent/model selection (#8652)
Some checks are pending
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Cargo Deny / deny (push) Waiting to run
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Goose 2 CI / Lint & Format (push) Waiting to run
Goose 2 CI / Unit Tests (push) Waiting to run
Goose 2 CI / Desktop Build & E2E (push) Waiting to run
Goose 2 CI / Rust Lint (push) Waiting to run
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Some checks are pending
Canary / Prepare Version (push) Waiting to run
Canary / build-cli (push) Blocked by required conditions
Canary / Upload Install Script (push) Blocked by required conditions
Canary / bundle-desktop (push) Blocked by required conditions
Canary / bundle-desktop-intel (push) Blocked by required conditions
Canary / bundle-desktop-linux (push) Blocked by required conditions
Canary / bundle-desktop-windows (push) Blocked by required conditions
Canary / Release (push) Blocked by required conditions
Cargo Deny / deny (push) Waiting to run
Unused Dependencies / machete (push) Waiting to run
CI / changes (push) Waiting to run
CI / Check Rust Code Format (push) Blocked by required conditions
CI / Build and Test Rust Project (push) Blocked by required conditions
CI / Build Rust Project on Windows (push) Waiting to run
CI / Lint Rust Code (push) Blocked by required conditions
CI / Check Generated Schemas are Up-to-Date (push) Blocked by required conditions
CI / Test and Lint Electron Desktop App (push) Blocked by required conditions
Goose 2 CI / Lint & Format (push) Waiting to run
Goose 2 CI / Unit Tests (push) Waiting to run
Goose 2 CI / Desktop Build & E2E (push) Waiting to run
Goose 2 CI / Rust Lint (push) Waiting to run
Live Provider Tests / check-fork (push) Waiting to run
Live Provider Tests / changes (push) Blocked by required conditions
Live Provider Tests / Build Binary (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (push) Blocked by required conditions
Live Provider Tests / Smoke Tests (Code Execution) (push) Blocked by required conditions
Live Provider Tests / Compaction Tests (push) Blocked by required conditions
Live Provider Tests / goose server HTTP integration tests (push) Blocked by required conditions
Publish Docker Image / docker (push) Waiting to run
Scorecard supply-chain security / Scorecard analysis (push) Waiting to run
Signed-off-by: Bradley Axen <baxen@squareup.com>
This commit is contained in:
parent
3d582943fd
commit
8eda6fdabc
70 changed files with 5321 additions and 2123 deletions
|
|
@ -51,9 +51,14 @@
|
|||
"responseType": "GetProviderDetailsResponse"
|
||||
},
|
||||
{
|
||||
"method": "_goose/providers/models",
|
||||
"requestType": "GetProviderModelsRequest",
|
||||
"responseType": "GetProviderModelsResponse"
|
||||
"method": "_goose/providers/inventory",
|
||||
"requestType": "GetProviderInventoryRequest",
|
||||
"responseType": "GetProviderInventoryResponse"
|
||||
},
|
||||
{
|
||||
"method": "_goose/providers/inventory/refresh",
|
||||
"requestType": "RefreshProviderInventoryRequest",
|
||||
"responseType": "RefreshProviderInventoryResponse"
|
||||
},
|
||||
{
|
||||
"method": "_goose/config/read",
|
||||
|
|
|
|||
|
|
@ -362,36 +362,224 @@
|
|||
"contextLimit"
|
||||
]
|
||||
},
|
||||
"GetProviderModelsRequest": {
|
||||
"GetProviderInventoryRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"providerName": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"providerName"
|
||||
],
|
||||
"description": "Fetch the full list of models available for a specific provider.",
|
||||
"x-side": "agent",
|
||||
"x-method": "_goose/providers/models"
|
||||
},
|
||||
"GetProviderModelsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"models": {
|
||||
"providerIds": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Only return entries for these providers. Empty means all.",
|
||||
"default": []
|
||||
}
|
||||
},
|
||||
"description": "Read per-provider inventory. Always returns immediately from stored state.",
|
||||
"x-side": "agent",
|
||||
"x-method": "_goose/providers/inventory"
|
||||
},
|
||||
"GetProviderInventoryResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entries": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/$defs/ProviderInventoryEntryDto"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"models"
|
||||
"entries"
|
||||
],
|
||||
"description": "Provider models response.",
|
||||
"description": "Provider inventory response.",
|
||||
"x-side": "agent",
|
||||
"x-method": "_goose/providers/models"
|
||||
"x-method": "_goose/providers/inventory"
|
||||
},
|
||||
"ProviderInventoryEntryDto": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"providerId": {
|
||||
"type": "string",
|
||||
"description": "Provider identifier."
|
||||
},
|
||||
"providerName": {
|
||||
"type": "string",
|
||||
"description": "Human-readable provider name."
|
||||
},
|
||||
"configured": {
|
||||
"type": "boolean",
|
||||
"description": "Whether Goose has enough configuration to use this provider."
|
||||
},
|
||||
"supportsRefresh": {
|
||||
"type": "boolean",
|
||||
"description": "Whether this provider supports background inventory refresh."
|
||||
},
|
||||
"refreshing": {
|
||||
"type": "boolean",
|
||||
"description": "Whether a refresh is currently in flight."
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/$defs/ProviderInventoryModelDto"
|
||||
},
|
||||
"description": "The list of available models."
|
||||
},
|
||||
"lastUpdatedAt": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"description": "When this entry was last successfully refreshed (ISO 8601)."
|
||||
},
|
||||
"lastRefreshAttemptAt": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"description": "When a refresh was most recently attempted (ISO 8601)."
|
||||
},
|
||||
"lastRefreshError": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"description": "The last refresh failure message, if any."
|
||||
},
|
||||
"stale": {
|
||||
"type": "boolean",
|
||||
"description": "Whether we believe this data may be outdated."
|
||||
},
|
||||
"modelSelectionHint": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"description": "Guidance message shown when this provider manages its own model selection externally."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"providerId",
|
||||
"providerName",
|
||||
"configured",
|
||||
"supportsRefresh",
|
||||
"refreshing",
|
||||
"models",
|
||||
"stale"
|
||||
],
|
||||
"description": "Provider inventory entry."
|
||||
},
|
||||
"ProviderInventoryModelDto": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "Model identifier as the provider knows it."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Human-readable display name."
|
||||
},
|
||||
"family": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"description": "Model family for grouping in UI."
|
||||
},
|
||||
"contextLimit": {
|
||||
"type": [
|
||||
"integer",
|
||||
"null"
|
||||
],
|
||||
"format": "uint",
|
||||
"minimum": 0,
|
||||
"description": "Context window size in tokens."
|
||||
},
|
||||
"reasoning": {
|
||||
"type": [
|
||||
"boolean",
|
||||
"null"
|
||||
],
|
||||
"description": "Whether the model supports reasoning/extended thinking."
|
||||
},
|
||||
"recommended": {
|
||||
"type": "boolean",
|
||||
"description": "Whether this model should appear in the compact recommended picker.",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"name"
|
||||
],
|
||||
"description": "A single model in provider inventory."
|
||||
},
|
||||
"RefreshProviderInventoryRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"providerIds": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Which providers to refresh. Empty means all known providers.",
|
||||
"default": []
|
||||
}
|
||||
},
|
||||
"description": "Trigger a background refresh of provider inventories.",
|
||||
"x-side": "agent",
|
||||
"x-method": "_goose/providers/inventory/refresh"
|
||||
},
|
||||
"RefreshProviderInventoryResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"started": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Which providers will be refreshed."
|
||||
},
|
||||
"skipped": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/$defs/RefreshProviderInventorySkipDto"
|
||||
},
|
||||
"description": "Which providers were skipped and why.",
|
||||
"default": []
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"started"
|
||||
],
|
||||
"description": "Refresh acknowledgement.",
|
||||
"x-side": "agent",
|
||||
"x-method": "_goose/providers/inventory/refresh"
|
||||
},
|
||||
"RefreshProviderInventorySkipDto": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"providerId": {
|
||||
"type": "string"
|
||||
},
|
||||
"reason": {
|
||||
"$ref": "#/$defs/RefreshProviderInventorySkipReasonDto"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"providerId",
|
||||
"reason"
|
||||
]
|
||||
},
|
||||
"RefreshProviderInventorySkipReasonDto": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"unknown_provider",
|
||||
"not_configured",
|
||||
"does_not_support_refresh",
|
||||
"already_refreshing"
|
||||
]
|
||||
},
|
||||
"ReadConfigRequest": {
|
||||
"type": "object",
|
||||
|
|
@ -1035,11 +1223,20 @@
|
|||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/$defs/GetProviderModelsRequest"
|
||||
"$ref": "#/$defs/GetProviderInventoryRequest"
|
||||
}
|
||||
],
|
||||
"description": "Params for _goose/providers/models",
|
||||
"title": "GetProviderModelsRequest"
|
||||
"description": "Params for _goose/providers/inventory",
|
||||
"title": "GetProviderInventoryRequest"
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/$defs/RefreshProviderInventoryRequest"
|
||||
}
|
||||
],
|
||||
"description": "Params for _goose/providers/inventory/refresh",
|
||||
"title": "RefreshProviderInventoryRequest"
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
|
|
@ -1292,10 +1489,18 @@
|
|||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/$defs/GetProviderModelsResponse"
|
||||
"$ref": "#/$defs/GetProviderInventoryResponse"
|
||||
}
|
||||
],
|
||||
"title": "GetProviderModelsResponse"
|
||||
"title": "GetProviderInventoryResponse"
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/$defs/RefreshProviderInventoryResponse"
|
||||
}
|
||||
],
|
||||
"title": "RefreshProviderInventoryResponse"
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ use goose::mcp_utils::ToolResult;
|
|||
use goose::permission::permission_confirmation::PrincipalType;
|
||||
use goose::permission::{Permission, PermissionConfirmation};
|
||||
use goose::providers::base::Provider;
|
||||
use goose::providers::inventory::{
|
||||
ProviderInventoryEntry, ProviderInventoryService, RefreshSkipReason,
|
||||
};
|
||||
use goose::session::session_manager::SessionType;
|
||||
use goose::session::{EnabledExtensionsState, Session, SessionManager};
|
||||
use goose_acp_macros::custom_methods;
|
||||
|
|
@ -125,6 +128,8 @@ struct AgentSetupRequest {
|
|||
/// Pre-resolved provider name + model config (from config, no network).
|
||||
/// When present the spawn skips re-deriving these from config.
|
||||
resolved_provider: Option<(String, goose::model::ModelConfig)>,
|
||||
/// Pre-instantiated provider reused from synchronous session initialization.
|
||||
prebuilt_provider: Option<Arc<dyn Provider>>,
|
||||
}
|
||||
|
||||
pub struct GooseAcpAgent {
|
||||
|
|
@ -139,6 +144,7 @@ pub struct GooseAcpAgent {
|
|||
permission_manager: Arc<PermissionManager>,
|
||||
goose_mode: GooseMode,
|
||||
disable_session_naming: bool,
|
||||
provider_inventory: ProviderInventoryService,
|
||||
}
|
||||
|
||||
/// Shorten a session/thread id for perf log correlation.
|
||||
|
|
@ -415,19 +421,50 @@ fn builtin_to_extension_config(name: &str) -> ExtensionConfig {
|
|||
}
|
||||
}
|
||||
|
||||
async fn build_model_state(provider: &dyn Provider) -> Result<SessionModelState, sacp::Error> {
|
||||
let models = provider
|
||||
.fetch_recommended_models()
|
||||
.await
|
||||
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
|
||||
let current_model = &provider.get_model_config().model_name;
|
||||
Ok(SessionModelState::new(
|
||||
ModelId::new(current_model.as_str()),
|
||||
models
|
||||
.iter()
|
||||
.map(|name| ModelInfo::new(ModelId::new(&**name), &**name))
|
||||
fn inventory_entry_to_dto(entry: ProviderInventoryEntry) -> ProviderInventoryEntryDto {
|
||||
let stale = ProviderInventoryService::is_stale(&entry);
|
||||
ProviderInventoryEntryDto {
|
||||
provider_id: entry.provider_id,
|
||||
provider_name: entry.provider_name,
|
||||
configured: entry.configured,
|
||||
supports_refresh: entry.supports_refresh,
|
||||
refreshing: entry.refreshing,
|
||||
models: entry
|
||||
.models
|
||||
.into_iter()
|
||||
.map(|m| ProviderInventoryModelDto {
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
family: m.family,
|
||||
context_limit: m.context_limit,
|
||||
reasoning: m.reasoning,
|
||||
recommended: m.recommended,
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
last_updated_at: entry.last_updated_at.map(|t| t.to_rfc3339()),
|
||||
last_refresh_attempt_at: entry.last_refresh_attempt_at.map(|t| t.to_rfc3339()),
|
||||
last_refresh_error: entry.last_refresh_error,
|
||||
stale,
|
||||
model_selection_hint: entry.model_selection_hint,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_model_state(current_model: &str, inventory: &ProviderInventoryEntry) -> SessionModelState {
|
||||
let mut available_models = inventory
|
||||
.models
|
||||
.iter()
|
||||
.map(|model| ModelInfo::new(ModelId::new(model.id.as_str()), model.name.as_str()))
|
||||
.collect::<Vec<_>>();
|
||||
if !available_models
|
||||
.iter()
|
||||
.any(|model| model.model_id.0.as_ref() == current_model)
|
||||
{
|
||||
available_models.insert(
|
||||
0,
|
||||
ModelInfo::new(ModelId::new(current_model), current_model),
|
||||
);
|
||||
}
|
||||
SessionModelState::new(ModelId::new(current_model), available_models)
|
||||
}
|
||||
|
||||
async fn list_provider_entries(current_provider: Option<&str>) -> Vec<ProviderListEntry> {
|
||||
|
|
@ -546,31 +583,25 @@ fn build_mode_state(current_mode: GooseMode) -> Result<SessionModeState, sacp::E
|
|||
))
|
||||
}
|
||||
|
||||
/// Build model state and config options eagerly from the canonical registry.
|
||||
///
|
||||
/// TODO: This trades speed for correctness — the canonical registry may not perfectly
|
||||
/// match what the provider API returns (new models not yet in the registry, deprecated
|
||||
/// models still listed, or locally-installed models for providers like Ollama). Consider
|
||||
/// whether to reconcile with a live API call in the background.
|
||||
async fn build_eager_config(
|
||||
resolved: &Result<(String, goose::model::ModelConfig), String>,
|
||||
fn should_refresh_inventory_for_session_init(entry: &ProviderInventoryEntry) -> bool {
|
||||
entry.configured
|
||||
&& entry.supports_refresh
|
||||
&& (entry.last_updated_at.is_none() || ProviderInventoryService::is_stale(entry))
|
||||
}
|
||||
|
||||
async fn build_eager_config_from_inventory(
|
||||
provider_name: &str,
|
||||
current_model: &str,
|
||||
inventory: &ProviderInventoryEntry,
|
||||
mode_state: &SessionModeState,
|
||||
goose_session: &Session,
|
||||
) -> (Option<SessionModelState>, Option<Vec<SessionConfigOption>>) {
|
||||
let Ok((ref provider_name, ref mc)) = resolved else {
|
||||
return (None, None);
|
||||
};
|
||||
let recommended = goose::providers::canonical::recommended_models_from_registry(provider_name);
|
||||
let available: Vec<ModelInfo> = recommended
|
||||
.iter()
|
||||
.map(|name| ModelInfo::new(ModelId::new(&**name), &**name))
|
||||
.collect();
|
||||
let ms = SessionModelState::new(ModelId::new(mc.model_name.as_str()), available);
|
||||
) -> (SessionModelState, Vec<SessionConfigOption>) {
|
||||
let ms = build_model_state(current_model, inventory);
|
||||
let provider_selection = session_provider_selection(goose_session);
|
||||
let provider_options = build_provider_options(Some(provider_name.as_str())).await;
|
||||
let provider_options = build_provider_options(Some(provider_name)).await;
|
||||
let config_options =
|
||||
build_config_options(mode_state, &ms, provider_selection, provider_options);
|
||||
(Some(ms), Some(config_options))
|
||||
(ms, config_options)
|
||||
}
|
||||
|
||||
fn build_config_options(
|
||||
|
|
@ -651,6 +682,7 @@ impl GooseAcpAgent {
|
|||
session_manager.storage().clone(),
|
||||
));
|
||||
let permission_manager = Arc::new(PermissionManager::new(config_dir.clone()));
|
||||
let provider_inventory = ProviderInventoryService::new(session_manager.storage().clone());
|
||||
|
||||
Ok(Self {
|
||||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
|
|
@ -664,6 +696,7 @@ impl GooseAcpAgent {
|
|||
permission_manager,
|
||||
goose_mode,
|
||||
disable_session_naming,
|
||||
provider_inventory,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -680,6 +713,125 @@ impl GooseAcpAgent {
|
|||
(self.provider_factory)(provider_name.to_string(), model_config, extensions).await
|
||||
}
|
||||
|
||||
async fn prepare_session_init_config(
|
||||
&self,
|
||||
resolved: &Result<(String, goose::model::ModelConfig), String>,
|
||||
mode_state: &SessionModeState,
|
||||
goose_session: &Session,
|
||||
) -> (
|
||||
Option<SessionModelState>,
|
||||
Option<Vec<SessionConfigOption>>,
|
||||
Option<Arc<dyn Provider>>,
|
||||
) {
|
||||
let Ok((provider_name, model_config)) = resolved else {
|
||||
return (None, None, None);
|
||||
};
|
||||
|
||||
let Some(mut inventory) = self
|
||||
.provider_inventory
|
||||
.entry_for_provider(provider_name)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
else {
|
||||
return (None, None, None);
|
||||
};
|
||||
|
||||
let mut prebuilt_provider = None;
|
||||
if should_refresh_inventory_for_session_init(&inventory) {
|
||||
match self.load_config() {
|
||||
Ok(config) => {
|
||||
let ext_state = EnabledExtensionsState::extensions_or_default(
|
||||
Some(&goose_session.extension_data),
|
||||
&config,
|
||||
);
|
||||
match self
|
||||
.create_provider(provider_name, model_config.clone(), ext_state)
|
||||
.await
|
||||
{
|
||||
Ok(provider) => {
|
||||
let provider_id = provider_name.clone();
|
||||
prebuilt_provider = Some(provider.clone());
|
||||
match self
|
||||
.provider_inventory
|
||||
.plan_refresh(std::slice::from_ref(&provider_id))
|
||||
.await
|
||||
{
|
||||
Ok(plan) if plan.started.iter().any(|id| id == &provider_id) => {
|
||||
match provider.fetch_recommended_models().await {
|
||||
Ok(models) => {
|
||||
if let Err(error) = self
|
||||
.provider_inventory
|
||||
.store_refreshed_models(&provider_id, &models)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
provider = %provider_id,
|
||||
error = %error,
|
||||
"failed to store refreshed provider inventory during session init"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
if let Err(store_error) = self
|
||||
.provider_inventory
|
||||
.store_refresh_error(
|
||||
&provider_id,
|
||||
error.to_string(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!(
|
||||
provider = %provider_id,
|
||||
error = %store_error,
|
||||
"failed to store provider inventory refresh error during session init"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(error) => warn!(
|
||||
provider = %provider_id,
|
||||
error = %error,
|
||||
"failed to plan provider inventory refresh during session init"
|
||||
),
|
||||
}
|
||||
|
||||
if let Ok(Some(refreshed_inventory)) = self
|
||||
.provider_inventory
|
||||
.entry_for_provider(provider_name)
|
||||
.await
|
||||
{
|
||||
inventory = refreshed_inventory;
|
||||
}
|
||||
}
|
||||
Err(error) => warn!(
|
||||
provider = %provider_name,
|
||||
error = %error,
|
||||
"failed to initialize provider during synchronous inventory refresh"
|
||||
),
|
||||
}
|
||||
}
|
||||
Err(error) => warn!(
|
||||
provider = %provider_name,
|
||||
error = %error,
|
||||
"failed to load config during synchronous inventory refresh"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
let (model_state, config_options) = build_eager_config_from_inventory(
|
||||
provider_name,
|
||||
model_config.model_name.as_str(),
|
||||
&inventory,
|
||||
mode_state,
|
||||
goose_session,
|
||||
)
|
||||
.await;
|
||||
(Some(model_state), Some(config_options), prebuilt_provider)
|
||||
}
|
||||
|
||||
fn spawn_agent_setup(
|
||||
&self,
|
||||
cx: &ConnectionTo<Client>,
|
||||
|
|
@ -691,6 +843,7 @@ impl GooseAcpAgent {
|
|||
goose_session,
|
||||
mcp_servers,
|
||||
resolved_provider,
|
||||
prebuilt_provider,
|
||||
} = req;
|
||||
|
||||
let goose_mode = goose_session.goose_mode;
|
||||
|
|
@ -845,9 +998,12 @@ impl GooseAcpAgent {
|
|||
Some(&goose_session.extension_data),
|
||||
&config,
|
||||
);
|
||||
let provider = provider_factory(provider_name.to_string(), model_config, ext_state)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let provider = match prebuilt_provider {
|
||||
Some(provider) => provider,
|
||||
None => provider_factory(provider_name.to_string(), model_config, ext_state)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?,
|
||||
};
|
||||
agent
|
||||
.update_provider(provider.clone(), &goose_session.id)
|
||||
.await
|
||||
|
|
@ -1416,9 +1572,10 @@ impl GooseAcpAgent {
|
|||
.as_ref()
|
||||
.ok()
|
||||
.map(|(_, mc)| build_usage_update(&goose_session, mc.context_limit()));
|
||||
let (model_state, config_options) =
|
||||
build_eager_config(&resolved, &mode_state, &goose_session).await;
|
||||
let session_id = SessionId::new(thread_id.clone());
|
||||
let (model_state, config_options, prebuilt_provider) = self
|
||||
.prepare_session_init_config(&resolved, &mode_state, &goose_session)
|
||||
.await;
|
||||
|
||||
self.spawn_agent_setup(
|
||||
cx,
|
||||
|
|
@ -1428,6 +1585,7 @@ impl GooseAcpAgent {
|
|||
goose_session,
|
||||
mcp_servers: args.mcp_servers,
|
||||
resolved_provider: resolved.as_ref().ok().cloned(),
|
||||
prebuilt_provider,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
@ -1798,8 +1956,9 @@ impl GooseAcpAgent {
|
|||
.as_ref()
|
||||
.map(|mc| build_usage_update(&goose_session, mc.context_limit()))
|
||||
});
|
||||
let (model_state, config_options) =
|
||||
build_eager_config(&resolved, &mode_state, &goose_session).await;
|
||||
let (model_state, config_options, prebuilt_provider) = self
|
||||
.prepare_session_init_config(&resolved, &mode_state, &goose_session)
|
||||
.await;
|
||||
|
||||
self.spawn_agent_setup(
|
||||
cx,
|
||||
|
|
@ -1809,6 +1968,7 @@ impl GooseAcpAgent {
|
|||
goose_session,
|
||||
mcp_servers: args.mcp_servers,
|
||||
resolved_provider: None,
|
||||
prebuilt_provider,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
@ -2116,10 +2276,21 @@ impl GooseAcpAgent {
|
|||
let provider = agent.provider().await.map_err(|e| {
|
||||
sacp::Error::internal_error().data(format!("Failed to get provider: {}", e))
|
||||
})?;
|
||||
let provider_name = provider.get_name().to_string();
|
||||
let current_model = provider.get_model_config().model_name.clone();
|
||||
let goose_mode = agent.goose_mode().await;
|
||||
let model_state = build_model_state(&*provider).await?;
|
||||
let inventory = self
|
||||
.provider_inventory
|
||||
.entry_for_provider(&provider_name)
|
||||
.await
|
||||
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
|
||||
let Some(inventory) = inventory else {
|
||||
return Err(sacp::Error::internal_error()
|
||||
.data(format!("Unknown provider inventory: {}", provider_name)));
|
||||
};
|
||||
let model_state = build_model_state(current_model.as_str(), &inventory);
|
||||
let mode_state = build_mode_state(goose_mode)?;
|
||||
let provider_options = build_provider_options(Some(provider.get_name())).await;
|
||||
let provider_options = build_provider_options(Some(&provider_name)).await;
|
||||
let config_options = build_config_options(
|
||||
&mode_state,
|
||||
&model_state,
|
||||
|
|
@ -2399,8 +2570,9 @@ impl GooseAcpAgent {
|
|||
|
||||
let mode_state = build_mode_state(self.goose_mode)?;
|
||||
let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await;
|
||||
let (model_state, config_options) =
|
||||
build_eager_config(&resolved, &mode_state, &goose_session).await;
|
||||
let (model_state, config_options, prebuilt_provider) = self
|
||||
.prepare_session_init_config(&resolved, &mode_state, &goose_session)
|
||||
.await;
|
||||
|
||||
self.spawn_agent_setup(
|
||||
cx,
|
||||
|
|
@ -2410,6 +2582,7 @@ impl GooseAcpAgent {
|
|||
goose_session,
|
||||
mcp_servers: args.mcp_servers,
|
||||
resolved_provider: resolved.ok(),
|
||||
prebuilt_provider,
|
||||
},
|
||||
);
|
||||
|
||||
|
|
@ -2677,58 +2850,82 @@ impl GooseAcpAgent {
|
|||
Ok(GetProviderDetailsResponse { providers: entries })
|
||||
}
|
||||
|
||||
#[custom_method(GetProviderModelsRequest)]
|
||||
async fn on_get_provider_models(
|
||||
#[custom_method(GetProviderInventoryRequest)]
|
||||
async fn on_get_provider_inventory(
|
||||
&self,
|
||||
req: GetProviderModelsRequest,
|
||||
) -> Result<GetProviderModelsResponse, sacp::Error> {
|
||||
let config = self.load_config().ok();
|
||||
let all = goose::providers::providers().await;
|
||||
req: GetProviderInventoryRequest,
|
||||
) -> Result<GetProviderInventoryResponse, sacp::Error> {
|
||||
let entries = self
|
||||
.provider_inventory
|
||||
.entries(&req.provider_ids)
|
||||
.await
|
||||
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
|
||||
Ok(GetProviderInventoryResponse {
|
||||
entries: entries.into_iter().map(inventory_entry_to_dto).collect(),
|
||||
})
|
||||
}
|
||||
|
||||
let Some((metadata, _provider_type)) =
|
||||
all.into_iter().find(|(m, _)| m.name == req.provider_name)
|
||||
else {
|
||||
return Err(sacp::Error::invalid_params()
|
||||
.data(format!("Unknown provider: {}", req.provider_name)));
|
||||
};
|
||||
|
||||
let is_configured = config
|
||||
.as_ref()
|
||||
.map(|c| {
|
||||
metadata.config_keys.iter().all(|k| {
|
||||
if !k.required {
|
||||
return true;
|
||||
}
|
||||
if k.secret {
|
||||
c.get_secret::<String>(&k.name).is_ok()
|
||||
} else {
|
||||
c.get_param::<String>(&k.name).is_ok()
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_configured {
|
||||
return Err(sacp::Error::invalid_params().data(format!(
|
||||
"Provider '{}' is not configured",
|
||||
req.provider_name
|
||||
)));
|
||||
#[custom_method(RefreshProviderInventoryRequest)]
|
||||
async fn on_refresh_provider_inventory(
|
||||
&self,
|
||||
req: RefreshProviderInventoryRequest,
|
||||
) -> Result<RefreshProviderInventoryResponse, sacp::Error> {
|
||||
let refresh_plan = self
|
||||
.provider_inventory
|
||||
.plan_refresh(&req.provider_ids)
|
||||
.await;
|
||||
let refresh_plan =
|
||||
refresh_plan.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
|
||||
for provider_id in &refresh_plan.started {
|
||||
let provider_inventory = self.provider_inventory.clone();
|
||||
let provider_factory = Arc::clone(&self.provider_factory);
|
||||
let provider_id = provider_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let result = async {
|
||||
let metadata = goose::providers::get_from_registry(&provider_id).await?;
|
||||
let model_config =
|
||||
goose::model::ModelConfig::new(&metadata.metadata().default_model)?
|
||||
.with_canonical_limits(&provider_id);
|
||||
let provider =
|
||||
provider_factory(provider_id.clone(), model_config, Vec::new()).await?;
|
||||
let models = provider.fetch_recommended_models().await?;
|
||||
provider_inventory
|
||||
.store_refreshed_models(&provider_id, &models)
|
||||
.await
|
||||
}
|
||||
.await;
|
||||
if let Err(error) = result {
|
||||
let _ = provider_inventory
|
||||
.store_refresh_error(&provider_id, error.to_string())
|
||||
.await;
|
||||
warn!(provider = %provider_id, error = %error, "provider inventory refresh failed");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let model_config = goose::model::ModelConfig::new(&metadata.default_model)
|
||||
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?
|
||||
.with_canonical_limits(&req.provider_name);
|
||||
|
||||
let provider = (self.provider_factory)(req.provider_name.clone(), model_config, Vec::new())
|
||||
.await
|
||||
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
|
||||
|
||||
let models = provider
|
||||
.fetch_recommended_models()
|
||||
.await
|
||||
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
|
||||
|
||||
Ok(GetProviderModelsResponse { models })
|
||||
Ok(RefreshProviderInventoryResponse {
|
||||
started: refresh_plan.started,
|
||||
skipped: refresh_plan
|
||||
.skipped
|
||||
.into_iter()
|
||||
.map(|entry| RefreshProviderInventorySkipDto {
|
||||
provider_id: entry.provider_id,
|
||||
reason: match entry.reason {
|
||||
RefreshSkipReason::UnknownProvider => {
|
||||
RefreshProviderInventorySkipReasonDto::UnknownProvider
|
||||
}
|
||||
RefreshSkipReason::NotConfigured => {
|
||||
RefreshProviderInventorySkipReasonDto::NotConfigured
|
||||
}
|
||||
RefreshSkipReason::DoesNotSupportRefresh => {
|
||||
RefreshProviderInventorySkipReasonDto::DoesNotSupportRefresh
|
||||
}
|
||||
RefreshSkipReason::AlreadyRefreshing => {
|
||||
RefreshProviderInventorySkipReasonDto::AlreadyRefreshing
|
||||
}
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
#[custom_method(ReadConfigRequest)]
|
||||
|
|
@ -3450,11 +3647,77 @@ impl HandleDispatchFrom<Client> for GooseAcpHandler {
|
|||
return Ok(());
|
||||
}
|
||||
}
|
||||
// Respond immediately using the current provider inventory snapshot.
|
||||
let t_tail = std::time::Instant::now();
|
||||
let (notification, config_options) = agent.build_config_update(&session_id).await?;
|
||||
cx.send_notification(notification)?;
|
||||
responder.respond(SetSessionConfigOptionResponse::new(config_options))?;
|
||||
debug!(target: "perf", sid = %sid, ms = t_tail.elapsed().as_millis() as u64, "perf: set_config_option notification_and_respond");
|
||||
debug!(target: "perf", sid = %sid, ms = t_tail.elapsed().as_millis() as u64, "perf: set_config_option inventory_respond");
|
||||
|
||||
let maybe_refresh = if config_id == "provider" {
|
||||
let provider_id = value_id.0.to_string();
|
||||
agent
|
||||
.provider_inventory
|
||||
.plan_refresh(std::slice::from_ref(&provider_id))
|
||||
.await
|
||||
.ok()
|
||||
.filter(|plan| plan.started.iter().any(|id| id == &provider_id))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if maybe_refresh.is_some() {
|
||||
let agent_bg = agent.clone();
|
||||
let cx_bg = cx.clone();
|
||||
let session_id_bg = session_id.clone();
|
||||
let sid_bg = sid.clone();
|
||||
tokio::spawn(async move {
|
||||
let t_bg = std::time::Instant::now();
|
||||
let refreshed = async {
|
||||
let session_agent =
|
||||
agent_bg.get_session_agent(&session_id_bg.0, None).await?;
|
||||
let provider = session_agent
|
||||
.provider()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e.to_string()))?;
|
||||
let provider_name = provider.get_name().to_string();
|
||||
let models = provider
|
||||
.fetch_recommended_models()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e.to_string()))?;
|
||||
agent_bg
|
||||
.provider_inventory
|
||||
.store_refreshed_models(&provider_name, &models)
|
||||
.await?;
|
||||
agent_bg
|
||||
.build_config_update(&session_id_bg)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e.to_string()))
|
||||
}
|
||||
.await;
|
||||
|
||||
match refreshed {
|
||||
Ok((fresh_notification, _)) => {
|
||||
let _ = cx_bg.send_notification(fresh_notification);
|
||||
debug!(target: "perf", sid = %sid_bg, ms = t_bg.elapsed().as_millis() as u64, "perf: set_config_option background_refresh done");
|
||||
}
|
||||
Err(e) => {
|
||||
if let Ok(session_agent) =
|
||||
agent_bg.get_session_agent(&session_id_bg.0, None).await
|
||||
{
|
||||
if let Ok(provider) = session_agent.provider().await {
|
||||
let provider_name = provider.get_name().to_string();
|
||||
let _ = agent_bg
|
||||
.provider_inventory
|
||||
.store_refresh_error(&provider_name, e.to_string())
|
||||
.await;
|
||||
}
|
||||
}
|
||||
debug!(target: "perf", sid = %sid_bg, error = %e, ms = t_bg.elapsed().as_millis() as u64, "perf: set_config_option background_refresh failed");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
debug!(target: "perf", sid = %sid, ms = t_handler.elapsed().as_millis() as u64, config_id = %config_id, "perf: set_config_option done");
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -3597,7 +3860,6 @@ pub async fn run(builtins: Vec<String>) -> Result<()> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use goose::conversation::message::{ToolRequest, ToolResponse};
|
||||
use goose::providers::errors::ProviderError;
|
||||
use rmcp::model::{CallToolRequestParams, Content as RmcpContent};
|
||||
use sacp::schema::{
|
||||
EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio,
|
||||
|
|
@ -3787,61 +4049,48 @@ print(\"hello, world\")
|
|||
assert_eq!(outcome_to_confirmation(&input), expected);
|
||||
}
|
||||
|
||||
struct MockModelProvider {
|
||||
models: Result<Vec<String>, ProviderError>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for MockModelProvider {
|
||||
fn get_name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_model_config: &goose::model::ModelConfig,
|
||||
_session_id: &str,
|
||||
_system: &str,
|
||||
_messages: &[goose::conversation::message::Message],
|
||||
_tools: &[rmcp::model::Tool],
|
||||
) -> Result<goose::providers::base::MessageStream, ProviderError> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> goose::model::ModelConfig {
|
||||
goose::model::ModelConfig::new_or_fail("unused")
|
||||
}
|
||||
|
||||
async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> {
|
||||
self.models.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(
|
||||
Ok(vec!["model-a".into(), "model-b".into()])
|
||||
=> Ok(SessionModelState::new(
|
||||
vec!["model-a".into(), "model-b".into()]
|
||||
=> SessionModelState::new(
|
||||
ModelId::new("unused"),
|
||||
vec![ModelInfo::new(ModelId::new("model-a"), "model-a"),
|
||||
vec![ModelInfo::new(ModelId::new("unused"), "unused"),
|
||||
ModelInfo::new(ModelId::new("model-a"), "model-a"),
|
||||
ModelInfo::new(ModelId::new("model-b"), "model-b")],
|
||||
))
|
||||
)
|
||||
; "returns current and available models"
|
||||
)]
|
||||
#[test_case(
|
||||
Ok(vec![])
|
||||
=> Ok(SessionModelState::new(ModelId::new("unused"), vec![]))
|
||||
vec![]
|
||||
=> SessionModelState::new(
|
||||
ModelId::new("unused"),
|
||||
vec![ModelInfo::new(ModelId::new("unused"), "unused")],
|
||||
)
|
||||
; "empty model list"
|
||||
)]
|
||||
#[test_case(
|
||||
Err(ProviderError::ExecutionError("fail".into()))
|
||||
=> Err(sacp::Error::internal_error().data("Execution error: fail".to_string()))
|
||||
; "fetch error propagates"
|
||||
)]
|
||||
#[tokio::test]
|
||||
async fn test_build_model_state(
|
||||
models: Result<Vec<String>, ProviderError>,
|
||||
) -> Result<SessionModelState, sacp::Error> {
|
||||
let provider = MockModelProvider { models };
|
||||
build_model_state(&provider).await
|
||||
fn test_build_model_state(models: Vec<String>) -> SessionModelState {
|
||||
let inventory = ProviderInventoryEntry {
|
||||
provider_id: "mock".to_string(),
|
||||
provider_name: "Mock".to_string(),
|
||||
configured: true,
|
||||
supports_refresh: true,
|
||||
refreshing: false,
|
||||
models: models
|
||||
.into_iter()
|
||||
.map(|id| goose::providers::inventory::InventoryModel {
|
||||
name: id.clone(),
|
||||
id,
|
||||
family: None,
|
||||
context_limit: None,
|
||||
reasoning: None,
|
||||
recommended: false,
|
||||
})
|
||||
.collect(),
|
||||
last_updated_at: None,
|
||||
last_refresh_attempt_at: None,
|
||||
last_refresh_error: None,
|
||||
model_selection_hint: None,
|
||||
};
|
||||
build_model_state("unused", &inventory)
|
||||
}
|
||||
|
||||
fn json_object(pairs: Vec<(&str, serde_json::Value)>) -> rmcp::model::JsonObject {
|
||||
|
|
|
|||
|
|
@ -257,20 +257,6 @@ pub struct GetProviderDetailsResponse {
|
|||
pub providers: Vec<ProviderDetailEntry>,
|
||||
}
|
||||
|
||||
/// Fetch the full list of models available for a specific provider.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)]
|
||||
#[request(method = "_goose/providers/models", response = GetProviderModelsResponse)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetProviderModelsRequest {
|
||||
pub provider_name: String,
|
||||
}
|
||||
|
||||
/// Provider models response.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
|
||||
pub struct GetProviderModelsResponse {
|
||||
pub models: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProviderDetailEntry {
|
||||
|
|
@ -370,6 +356,120 @@ pub struct DictationConfigResponse {
|
|||
pub providers: HashMap<String, DictationProviderStatusEntry>,
|
||||
}
|
||||
|
||||
/// Read per-provider inventory. Always returns immediately from stored state.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)]
|
||||
#[request(
|
||||
method = "_goose/providers/inventory",
|
||||
response = GetProviderInventoryResponse
|
||||
)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetProviderInventoryRequest {
|
||||
/// Only return entries for these providers. Empty means all.
|
||||
#[serde(default)]
|
||||
pub provider_ids: Vec<String>,
|
||||
}
|
||||
|
||||
/// Provider inventory response.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
|
||||
pub struct GetProviderInventoryResponse {
|
||||
pub entries: Vec<ProviderInventoryEntryDto>,
|
||||
}
|
||||
|
||||
/// Trigger a background refresh of provider inventories.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcRequest)]
|
||||
#[request(
|
||||
method = "_goose/providers/inventory/refresh",
|
||||
response = RefreshProviderInventoryResponse
|
||||
)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RefreshProviderInventoryRequest {
|
||||
/// Which providers to refresh. Empty means all known providers.
|
||||
#[serde(default)]
|
||||
pub provider_ids: Vec<String>,
|
||||
}
|
||||
|
||||
/// Refresh acknowledgement.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RefreshProviderInventoryResponse {
|
||||
/// Which providers will be refreshed.
|
||||
pub started: Vec<String>,
|
||||
/// Which providers were skipped and why.
|
||||
#[serde(default)]
|
||||
pub skipped: Vec<RefreshProviderInventorySkipDto>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RefreshProviderInventorySkipDto {
|
||||
pub provider_id: String,
|
||||
pub reason: RefreshProviderInventorySkipReasonDto,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RefreshProviderInventorySkipReasonDto {
|
||||
#[default]
|
||||
UnknownProvider,
|
||||
NotConfigured,
|
||||
DoesNotSupportRefresh,
|
||||
AlreadyRefreshing,
|
||||
}
|
||||
|
||||
/// A single model in provider inventory.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProviderInventoryModelDto {
|
||||
/// Model identifier as the provider knows it.
|
||||
pub id: String,
|
||||
/// Human-readable display name.
|
||||
pub name: String,
|
||||
/// Model family for grouping in UI.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub family: Option<String>,
|
||||
/// Context window size in tokens.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context_limit: Option<usize>,
|
||||
/// Whether the model supports reasoning/extended thinking.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<bool>,
|
||||
/// Whether this model should appear in the compact recommended picker.
|
||||
#[serde(default)]
|
||||
pub recommended: bool,
|
||||
}
|
||||
|
||||
/// Provider inventory entry.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProviderInventoryEntryDto {
|
||||
/// Provider identifier.
|
||||
pub provider_id: String,
|
||||
/// Human-readable provider name.
|
||||
pub provider_name: String,
|
||||
/// Whether Goose has enough configuration to use this provider.
|
||||
pub configured: bool,
|
||||
/// Whether this provider supports background inventory refresh.
|
||||
pub supports_refresh: bool,
|
||||
/// Whether a refresh is currently in flight.
|
||||
pub refreshing: bool,
|
||||
/// The list of available models.
|
||||
pub models: Vec<ProviderInventoryModelDto>,
|
||||
/// When this entry was last successfully refreshed (ISO 8601).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub last_updated_at: Option<String>,
|
||||
/// When a refresh was most recently attempted (ISO 8601).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub last_refresh_attempt_at: Option<String>,
|
||||
/// The last refresh failure message, if any.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub last_refresh_error: Option<String>,
|
||||
/// Whether we believe this data may be outdated.
|
||||
pub stale: bool,
|
||||
/// Guidance message shown when this provider manages its own model selection externally.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_selection_hint: Option<String>,
|
||||
}
|
||||
|
||||
/// Empty success response for operations that return no data.
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, JsonRpcResponse)]
|
||||
pub struct EmptyResponse {}
|
||||
|
|
|
|||
|
|
@ -416,6 +416,11 @@ mod tests {
|
|||
use std::sync::Arc;
|
||||
|
||||
let tmp_dir = tempfile::tempdir().unwrap();
|
||||
let temp_root = tmp_dir.path().display().to_string();
|
||||
let _guard = env_lock::lock_env([
|
||||
("HOME", Some(temp_root.as_str())),
|
||||
("GOOSE_PATH_ROOT", Some(temp_root.as_str())),
|
||||
]);
|
||||
let session_manager = Arc::new(SessionManager::new(tmp_dir.path().to_path_buf()));
|
||||
let session = session_manager
|
||||
.create_session(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use crate::config::paths::Paths;
|
|||
use crate::config::Config;
|
||||
use crate::providers::anthropic::AnthropicProvider;
|
||||
use crate::providers::base::{ModelInfo, ProviderType};
|
||||
use crate::providers::inventory::declarative_inventory_identity;
|
||||
use crate::providers::ollama::OllamaProvider;
|
||||
use crate::providers::openai::OpenAiProvider;
|
||||
use anyhow::Result;
|
||||
|
|
@ -460,38 +461,59 @@ pub fn register_declarative_provider(
|
|||
match config.engine {
|
||||
ProviderEngine::OpenAI => {
|
||||
let captured = config.clone();
|
||||
registry.register_with_name::<OpenAiProvider, _>(
|
||||
let identity_config = config.clone();
|
||||
registry.register_with_name::<OpenAiProvider, _, _>(
|
||||
&config,
|
||||
provider_type,
|
||||
config.dynamic_models.unwrap_or(false),
|
||||
move |model| {
|
||||
let mut cfg = captured.clone();
|
||||
resolve_config(&mut cfg)?;
|
||||
OpenAiProvider::from_custom_config(model, cfg)
|
||||
},
|
||||
move || {
|
||||
let mut cfg = identity_config.clone();
|
||||
resolve_config(&mut cfg)?;
|
||||
declarative_inventory_identity(&cfg)
|
||||
},
|
||||
);
|
||||
}
|
||||
ProviderEngine::Ollama => {
|
||||
let captured = config.clone();
|
||||
registry.register_with_name::<OllamaProvider, _>(
|
||||
let identity_config = config.clone();
|
||||
registry.register_with_name::<OllamaProvider, _, _>(
|
||||
&config,
|
||||
provider_type,
|
||||
config.dynamic_models.unwrap_or(false),
|
||||
move |model| {
|
||||
let mut cfg = captured.clone();
|
||||
resolve_config(&mut cfg)?;
|
||||
OllamaProvider::from_custom_config(model, cfg)
|
||||
},
|
||||
move || {
|
||||
let mut cfg = identity_config.clone();
|
||||
resolve_config(&mut cfg)?;
|
||||
declarative_inventory_identity(&cfg)
|
||||
},
|
||||
);
|
||||
}
|
||||
ProviderEngine::Anthropic => {
|
||||
let captured = config.clone();
|
||||
registry.register_with_name::<AnthropicProvider, _>(
|
||||
let identity_config = config.clone();
|
||||
registry.register_with_name::<AnthropicProvider, _, _>(
|
||||
&config,
|
||||
provider_type,
|
||||
config.dynamic_models.unwrap_or(false),
|
||||
move |model| {
|
||||
let mut cfg = captured.clone();
|
||||
resolve_config(&mut cfg)?;
|
||||
AnthropicProvider::from_custom_config(model, cfg)
|
||||
},
|
||||
move || {
|
||||
let mut cfg = identity_config.clone();
|
||||
resolve_config(&mut cfg)?;
|
||||
declarative_inventory_identity(&cfg)
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
18
crates/goose/src/providers/acp_tooling.rs
Normal file
18
crates/goose/src/providers/acp_tooling.rs
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
use crate::config::search_path::SearchPaths;
|
||||
use crate::providers::inventory::InventoryIdentityInput;
|
||||
use anyhow::Result;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn acp_adapter_installed(command: &str) -> bool {
|
||||
resolve_acp_command(command).is_ok()
|
||||
}
|
||||
|
||||
pub fn acp_inventory_identity(provider_id: &str, command: &str) -> Result<InventoryIdentityInput> {
|
||||
let resolved_command = resolve_acp_command(command)?;
|
||||
Ok(InventoryIdentityInput::new(provider_id, provider_id)
|
||||
.with_public("command", resolved_command.display().to_string()))
|
||||
}
|
||||
|
||||
fn resolve_acp_command(command: &str) -> Result<PathBuf> {
|
||||
SearchPaths::builder().with_npm().resolve(command)
|
||||
}
|
||||
|
|
@ -9,7 +9,9 @@ use crate::acp::{
|
|||
use crate::config::search_path::SearchPaths;
|
||||
use crate::config::{Config, GooseMode};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
|
||||
use crate::providers::base::{ProviderDef, ProviderMetadata};
|
||||
use crate::providers::inventory::InventoryIdentityInput;
|
||||
|
||||
const AMP_ACP_PROVIDER_NAME: &str = "amp-acp";
|
||||
const AMP_ACP_DOC_URL: &str = "https://ampcode.com";
|
||||
|
|
@ -37,6 +39,7 @@ impl ProviderDef for AmpAcpProvider {
|
|||
"Set in your goose config file (`~/.config/goose/config.yaml` on macOS/Linux):\n GOOSE_PROVIDER: amp-acp\n GOOSE_MODEL: current",
|
||||
"Restart goose for changes to take effect",
|
||||
])
|
||||
.with_model_selection_hint("Use the Amp CLI to configure models")
|
||||
}
|
||||
|
||||
fn from_env(
|
||||
|
|
@ -49,10 +52,12 @@ impl ProviderDef for AmpAcpProvider {
|
|||
let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto);
|
||||
|
||||
let mode_mapping = HashMap::from([
|
||||
(GooseMode::Auto, "auto".to_string()),
|
||||
(GooseMode::Approve, "approve".to_string()),
|
||||
(GooseMode::SmartApprove, "smart-approve".to_string()),
|
||||
(GooseMode::Chat, "chat".to_string()),
|
||||
// "bypass" skips confirmations, closest to autonomous mode.
|
||||
(GooseMode::Auto, "bypass".to_string()),
|
||||
// "default" prompts before risky actions.
|
||||
(GooseMode::Approve, "default".to_string()),
|
||||
(GooseMode::SmartApprove, "default".to_string()),
|
||||
(GooseMode::Chat, "default".to_string()),
|
||||
]);
|
||||
|
||||
let provider_config = AcpProviderConfig {
|
||||
|
|
@ -71,4 +76,16 @@ impl ProviderDef for AmpAcpProvider {
|
|||
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
acp_inventory_identity(AMP_ACP_PROVIDER_NAME, AMP_ACP_BINARY)
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool {
|
||||
acp_adapter_installed(AMP_ACP_BINARY)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ use super::errors::ProviderError;
|
|||
use super::formats::anthropic::{
|
||||
create_request, response_to_streaming_message, thinking_type, ThinkingType,
|
||||
};
|
||||
use super::inventory::{config_secret_value, serialize_string_map, InventoryIdentityInput};
|
||||
use super::openai_compatible::handle_status_openai_compat;
|
||||
use super::openai_compatible::map_http_error_to_provider_error;
|
||||
use super::retry::ProviderRetry;
|
||||
|
|
@ -235,6 +236,33 @@ impl ProviderDef for AnthropicProvider {
|
|||
) -> BoxFuture<'static, Result<Self::Provider>> {
|
||||
Box::pin(Self::from_env(model))
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
let config = crate::config::Config::global();
|
||||
let mut identity =
|
||||
InventoryIdentityInput::new(ANTHROPIC_PROVIDER_NAME, ANTHROPIC_PROVIDER_NAME)
|
||||
.with_public(
|
||||
"host",
|
||||
config
|
||||
.get_param::<String>("ANTHROPIC_HOST")
|
||||
.unwrap_or_else(|_| "https://api.anthropic.com".to_string()),
|
||||
);
|
||||
|
||||
if let Some(api_key) = config_secret_value(config, "ANTHROPIC_API_KEY") {
|
||||
identity = identity.with_secret("api_key", api_key);
|
||||
}
|
||||
if let Ok(headers) = config
|
||||
.get_secret::<std::collections::HashMap<String, String>>("ANTHROPIC_CUSTOM_HEADERS")
|
||||
{
|
||||
identity = identity.with_secret("headers", serialize_string_map(&headers)?);
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
use super::canonical::{map_to_canonical_model, CanonicalModelRegistry};
|
||||
use super::errors::ProviderError;
|
||||
use super::inventory::{default_inventory_identity, InventoryIdentityInput};
|
||||
use super::retry::RetryConfig;
|
||||
use crate::config::base::ConfigValue;
|
||||
use crate::config::{ExtensionConfig, GooseMode};
|
||||
use crate::config::{Config, ExtensionConfig, GooseMode};
|
||||
use crate::conversation::message::{Message, MessageContent};
|
||||
use crate::conversation::Conversation;
|
||||
use crate::model::ModelConfig;
|
||||
|
|
@ -179,6 +180,9 @@ pub struct ProviderMetadata {
|
|||
/// step-by-step instructions for set up providers eg: api key
|
||||
#[serde(default)]
|
||||
pub setup_steps: Vec<String>,
|
||||
/// Hint shown in the model picker when this provider manages its own model selection.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub model_selection_hint: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderMetadata {
|
||||
|
|
@ -212,6 +216,7 @@ impl ProviderMetadata {
|
|||
model_doc_link: model_doc_link.to_string(),
|
||||
config_keys,
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -233,6 +238,7 @@ impl ProviderMetadata {
|
|||
model_doc_link: model_doc_link.to_string(),
|
||||
config_keys,
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -246,6 +252,7 @@ impl ProviderMetadata {
|
|||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -253,6 +260,11 @@ impl ProviderMetadata {
|
|||
self.setup_steps = steps.into_iter().map(|s| s.to_string()).collect();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_model_selection_hint(mut self, hint: &str) -> Self {
|
||||
self.model_selection_hint = Some(hint.to_string());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration key metadata for provider setup
|
||||
|
|
@ -492,6 +504,34 @@ pub trait ProviderDef: Send + Sync {
|
|||
) -> BoxFuture<'static, Result<Self::Provider>>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
fn supports_inventory_refresh() -> bool
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
false
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let metadata = Self::metadata();
|
||||
Ok(default_inventory_identity(
|
||||
&metadata.name,
|
||||
&metadata.name,
|
||||
&metadata.config_keys,
|
||||
Config::global(),
|
||||
))
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let metadata = Self::metadata();
|
||||
super::inventory::default_inventory_configured(&metadata.config_keys, Config::global())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
|
|
@ -588,7 +628,7 @@ pub trait Provider: Send + Sync {
|
|||
false
|
||||
}
|
||||
|
||||
/// Fetch models filtered by canonical registry and usability
|
||||
/// Fetch inventory models filtered by canonical registry and usability.
|
||||
async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> {
|
||||
let all_models = self.fetch_supported_models().await?;
|
||||
|
||||
|
|
@ -637,15 +677,15 @@ pub trait Provider: Send + Sync {
|
|||
(None, None) => a.0.cmp(&b.0),
|
||||
});
|
||||
|
||||
let recommended_models: Vec<String> = models_with_dates
|
||||
let inventory_models: Vec<String> = models_with_dates
|
||||
.into_iter()
|
||||
.map(|(name, _)| name)
|
||||
.collect();
|
||||
|
||||
if recommended_models.is_empty() {
|
||||
if inventory_models.is_empty() {
|
||||
Ok(all_models)
|
||||
} else {
|
||||
Ok(recommended_models)
|
||||
Ok(inventory_models)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ use crate::acp::{
|
|||
use crate::config::search_path::SearchPaths;
|
||||
use crate::config::{Config, GooseMode};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
|
||||
use crate::providers::base::{ProviderDef, ProviderMetadata};
|
||||
use crate::providers::inventory::InventoryIdentityInput;
|
||||
|
||||
const CLAUDE_ACP_PROVIDER_NAME: &str = "claude-acp";
|
||||
const CLAUDE_ACP_DOC_URL: &str = "https://github.com/zed-industries/claude-agent-acp";
|
||||
|
|
@ -78,4 +80,16 @@ impl ProviderDef for ClaudeAcpProvider {
|
|||
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
acp_inventory_identity(CLAUDE_ACP_PROVIDER_NAME, CLAUDE_ACP_BINARY)
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool {
|
||||
acp_adapter_installed(CLAUDE_ACP_BINARY)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ use crate::acp::{
|
|||
use crate::config::search_path::SearchPaths;
|
||||
use crate::config::{Config, GooseMode};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
|
||||
use crate::providers::base::{ProviderDef, ProviderMetadata};
|
||||
use crate::providers::inventory::InventoryIdentityInput;
|
||||
|
||||
const CODEX_ACP_PROVIDER_NAME: &str = "codex-acp";
|
||||
const CODEX_ACP_DOC_URL: &str = "https://github.com/zed-industries/codex-acp";
|
||||
|
|
@ -98,6 +100,18 @@ impl ProviderDef for CodexAcpProvider {
|
|||
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
acp_inventory_identity(CODEX_ACP_PROVIDER_NAME, CODEX_ACP_PROVIDER_NAME)
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool {
|
||||
acp_adapter_installed(CODEX_ACP_PROVIDER_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
// Codex sandbox scope determines what needs approval: operations within the
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ use crate::acp::{
|
|||
use crate::config::search_path::SearchPaths;
|
||||
use crate::config::{Config, GooseMode};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
|
||||
use crate::providers::base::{ProviderDef, ProviderMetadata};
|
||||
use crate::providers::inventory::InventoryIdentityInput;
|
||||
|
||||
const COPILOT_ACP_PROVIDER_NAME: &str = "copilot-acp";
|
||||
const COPILOT_ACP_DOC_URL: &str = "https://github.com/github/copilot-cli";
|
||||
|
|
@ -84,4 +86,16 @@ impl ProviderDef for CopilotAcpProvider {
|
|||
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
acp_inventory_identity(COPILOT_ACP_PROVIDER_NAME, COPILOT_ACP_BINARY)
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool {
|
||||
acp_adapter_installed(COPILOT_ACP_BINARY)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -340,6 +340,10 @@ impl ProviderDef for DatabricksProvider {
|
|||
) -> BoxFuture<'static, Result<Self::Provider>> {
|
||||
Box::pin(Self::from_env(model))
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
|
|
@ -1100,12 +1100,16 @@ mod tests {
|
|||
fn test_create_request_enabled_thinking_with_budget() -> anyhow::Result<()> {
|
||||
let _guard = env_lock::lock_env([
|
||||
("CLAUDE_THINKING_TYPE", None::<&str>),
|
||||
("CLAUDE_THINKING_ENABLED", Some("1")),
|
||||
("CLAUDE_THINKING_ENABLED", None::<&str>),
|
||||
("CLAUDE_THINKING_BUDGET", Some("10000")),
|
||||
]);
|
||||
|
||||
let mut model_config = ModelConfig::new_or_fail("databricks-claude-3-7-sonnet");
|
||||
model_config.max_tokens = Some(4096);
|
||||
model_config = model_config.with_request_params(Some(std::collections::HashMap::from([(
|
||||
"thinking_type".to_string(),
|
||||
json!("enabled"),
|
||||
)])));
|
||||
|
||||
let request = create_request(&model_config, "system", &[], &[], &ImageFormat::OpenAi)?;
|
||||
|
||||
|
|
|
|||
|
|
@ -149,6 +149,10 @@ pub async fn get_from_registry(name: &str) -> Result<ProviderEntry> {
|
|||
.cloned()
|
||||
}
|
||||
|
||||
pub async fn inventory_identity(name: &str) -> Result<super::inventory::InventoryIdentityInput> {
|
||||
get_from_registry(name).await?.inventory_identity()
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
name: &str,
|
||||
model: ModelConfig,
|
||||
|
|
|
|||
970
crates/goose/src/providers/inventory/mod.rs
Normal file
970
crates/goose/src/providers/inventory/mod.rs
Normal file
|
|
@ -0,0 +1,970 @@
|
|||
use super::base::{ConfigKey, ModelInfo};
|
||||
use super::canonical::{map_provider_name, map_to_canonical_model, CanonicalModelRegistry};
|
||||
use crate::config::declarative_providers::{DeclarativeProviderConfig, ProviderEngine};
|
||||
use crate::config::Config;
|
||||
use crate::session::session_manager::SessionStorage;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{Pool, Row, Sqlite, Transaction};
|
||||
use std::collections::{BTreeMap, HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
const STALE_AFTER_HOURS: i64 = 24;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProviderInventoryEntry {
|
||||
pub provider_id: String,
|
||||
pub provider_name: String,
|
||||
pub configured: bool,
|
||||
pub supports_refresh: bool,
|
||||
pub refreshing: bool,
|
||||
pub models: Vec<InventoryModel>,
|
||||
pub last_updated_at: Option<DateTime<Utc>>,
|
||||
pub last_refresh_attempt_at: Option<DateTime<Utc>>,
|
||||
pub last_refresh_error: Option<String>,
|
||||
pub model_selection_hint: Option<String>,
|
||||
}
|
||||
|
||||
/// Families whose latest model should be surfaced in the compact picker.
|
||||
/// Each entry is matched against the `family` field of enriched models.
|
||||
const RECOMMENDED_FAMILIES: &[&str] = &[
|
||||
"claude-opus",
|
||||
"claude-sonnet",
|
||||
"gpt",
|
||||
"gpt-mini",
|
||||
"glm",
|
||||
"gemini-pro",
|
||||
"gemini-flash",
|
||||
"gemma",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InventoryModel {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub family: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context_limit: Option<usize>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<bool>,
|
||||
/// Whether this model should appear in the compact recommended picker.
|
||||
pub recommended: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InventoryIdentity {
|
||||
pub provider_id: String,
|
||||
pub provider_family: String,
|
||||
pub inventory_key: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct InventoryIdentityInput {
|
||||
pub provider_id: String,
|
||||
pub provider_family: String,
|
||||
pub public_inputs: BTreeMap<String, String>,
|
||||
pub secret_inputs: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
impl InventoryIdentityInput {
|
||||
pub fn new(
|
||||
provider_id: impl Into<String>,
|
||||
provider_family: impl Into<String>,
|
||||
) -> InventoryIdentityInput {
|
||||
InventoryIdentityInput {
|
||||
provider_id: provider_id.into(),
|
||||
provider_family: provider_family.into(),
|
||||
public_inputs: BTreeMap::new(),
|
||||
secret_inputs: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_public(
|
||||
mut self,
|
||||
key: impl Into<String>,
|
||||
value: impl Into<String>,
|
||||
) -> InventoryIdentityInput {
|
||||
self.public_inputs.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_secret(
|
||||
mut self,
|
||||
key: impl Into<String>,
|
||||
value: impl Into<String>,
|
||||
) -> InventoryIdentityInput {
|
||||
self.secret_inputs.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn into_identity(self) -> Result<InventoryIdentity> {
|
||||
let InventoryIdentityInput {
|
||||
provider_id,
|
||||
provider_family,
|
||||
public_inputs,
|
||||
secret_inputs,
|
||||
} = self;
|
||||
let payload = serde_json::json!({
|
||||
"provider_family": provider_family,
|
||||
"public_inputs": public_inputs,
|
||||
"secret_inputs": secret_inputs,
|
||||
});
|
||||
let digest = Sha256::digest(serde_json::to_vec(&payload)?);
|
||||
Ok(InventoryIdentity {
|
||||
provider_id,
|
||||
provider_family,
|
||||
inventory_key: format!("{digest:x}"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RefreshSkipReason {
|
||||
UnknownProvider,
|
||||
NotConfigured,
|
||||
DoesNotSupportRefresh,
|
||||
AlreadyRefreshing,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RefreshSkip {
|
||||
pub provider_id: String,
|
||||
pub reason: RefreshSkipReason,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RefreshPlan {
|
||||
pub started: Vec<String>,
|
||||
pub skipped: Vec<RefreshSkip>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProviderInventoryService {
|
||||
storage: Arc<SessionStorage>,
|
||||
refreshing_keys: Arc<RwLock<HashSet<String>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct InventorySnapshot {
|
||||
models: Vec<InventoryModel>,
|
||||
last_updated_at: Option<DateTime<Utc>>,
|
||||
last_refresh_attempt_at: Option<DateTime<Utc>>,
|
||||
last_refresh_error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ProviderDescriptor {
|
||||
provider_id: String,
|
||||
provider_name: String,
|
||||
identity: InventoryIdentity,
|
||||
configured: bool,
|
||||
supports_refresh: bool,
|
||||
static_models: Vec<ModelInfo>,
|
||||
model_selection_hint: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderInventoryService {
|
||||
pub fn new(storage: Arc<SessionStorage>) -> ProviderInventoryService {
|
||||
ProviderInventoryService {
|
||||
storage,
|
||||
refreshing_keys: Arc::new(RwLock::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn entry_for_provider(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
) -> Result<Option<ProviderInventoryEntry>> {
|
||||
let Some(descriptor) = self.describe_provider(provider_id).await? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let snapshot = self.read_snapshot(&descriptor.identity).await?;
|
||||
let refreshing = self
|
||||
.refreshing_keys
|
||||
.read()
|
||||
.await
|
||||
.contains(&descriptor.identity.inventory_key);
|
||||
let models = inventory_models_from_snapshot(
|
||||
snapshot.as_ref(),
|
||||
&descriptor.identity.provider_family,
|
||||
&descriptor.static_models,
|
||||
);
|
||||
|
||||
Ok(Some(ProviderInventoryEntry {
|
||||
provider_id: descriptor.provider_id,
|
||||
provider_name: descriptor.provider_name,
|
||||
configured: descriptor.configured,
|
||||
supports_refresh: descriptor.supports_refresh,
|
||||
refreshing,
|
||||
models,
|
||||
last_updated_at: snapshot
|
||||
.as_ref()
|
||||
.and_then(|snapshot| snapshot.last_updated_at),
|
||||
last_refresh_attempt_at: snapshot
|
||||
.as_ref()
|
||||
.and_then(|snapshot| snapshot.last_refresh_attempt_at),
|
||||
last_refresh_error: snapshot.and_then(|snapshot| snapshot.last_refresh_error),
|
||||
model_selection_hint: descriptor.model_selection_hint,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn entries(&self, provider_ids: &[String]) -> Result<Vec<ProviderInventoryEntry>> {
|
||||
let ids = self.resolve_provider_ids(provider_ids).await;
|
||||
let mut entries = Vec::with_capacity(ids.len());
|
||||
for provider_id in ids {
|
||||
if let Some(entry) = self.entry_for_provider(&provider_id).await? {
|
||||
entries.push(entry);
|
||||
}
|
||||
}
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
pub async fn plan_refresh(&self, provider_ids: &[String]) -> Result<RefreshPlan> {
|
||||
let ids = self.resolve_provider_ids(provider_ids).await;
|
||||
let mut plan = RefreshPlan::default();
|
||||
|
||||
for provider_id in ids {
|
||||
let Some(descriptor) = self.describe_provider(&provider_id).await? else {
|
||||
plan.skipped.push(RefreshSkip {
|
||||
provider_id,
|
||||
reason: RefreshSkipReason::UnknownProvider,
|
||||
});
|
||||
continue;
|
||||
};
|
||||
|
||||
if !descriptor.supports_refresh {
|
||||
plan.skipped.push(RefreshSkip {
|
||||
provider_id: descriptor.provider_id,
|
||||
reason: RefreshSkipReason::DoesNotSupportRefresh,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if !descriptor.configured {
|
||||
plan.skipped.push(RefreshSkip {
|
||||
provider_id: descriptor.provider_id,
|
||||
reason: RefreshSkipReason::NotConfigured,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut refreshing_keys = self.refreshing_keys.write().await;
|
||||
if refreshing_keys.contains(&descriptor.identity.inventory_key) {
|
||||
plan.skipped.push(RefreshSkip {
|
||||
provider_id: descriptor.provider_id,
|
||||
reason: RefreshSkipReason::AlreadyRefreshing,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
refreshing_keys.insert(descriptor.identity.inventory_key.clone());
|
||||
drop(refreshing_keys);
|
||||
|
||||
self.mark_refresh_started(&descriptor.identity).await?;
|
||||
plan.started.push(descriptor.provider_id);
|
||||
}
|
||||
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
pub async fn store_refreshed_models(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
model_ids: &[String],
|
||||
) -> Result<()> {
|
||||
let descriptor = self.require_provider(provider_id).await?;
|
||||
let models =
|
||||
enrich_model_ids_with_canonical(&descriptor.identity.provider_family, model_ids);
|
||||
let now = Utc::now();
|
||||
let pool = self.storage.pool().await?;
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_inventory_entries (
|
||||
inventory_key,
|
||||
provider_id,
|
||||
provider_family,
|
||||
last_updated_at,
|
||||
last_refresh_attempt_at,
|
||||
last_refresh_error,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, NULL, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(inventory_key) DO UPDATE SET
|
||||
provider_id = excluded.provider_id,
|
||||
provider_family = excluded.provider_family,
|
||||
last_updated_at = excluded.last_updated_at,
|
||||
last_refresh_attempt_at = excluded.last_refresh_attempt_at,
|
||||
last_refresh_error = NULL,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&descriptor.identity.inventory_key)
|
||||
.bind(&descriptor.identity.provider_id)
|
||||
.bind(&descriptor.identity.provider_family)
|
||||
.bind(now.to_rfc3339())
|
||||
.bind(now.to_rfc3339())
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query("DELETE FROM provider_inventory_models WHERE inventory_key = ?")
|
||||
.bind(&descriptor.identity.inventory_key)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
for (ordinal, model) in models.iter().enumerate() {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_inventory_models (
|
||||
inventory_key,
|
||||
ordinal,
|
||||
model_id,
|
||||
name,
|
||||
family,
|
||||
context_limit,
|
||||
reasoning,
|
||||
recommended
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&descriptor.identity.inventory_key)
|
||||
.bind(i64::try_from(ordinal)?)
|
||||
.bind(&model.id)
|
||||
.bind(&model.name)
|
||||
.bind(&model.family)
|
||||
.bind(model.context_limit.map(i64::try_from).transpose()?)
|
||||
.bind(model.reasoning)
|
||||
.bind(model.recommended)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
self.refreshing_keys
|
||||
.write()
|
||||
.await
|
||||
.remove(&descriptor.identity.inventory_key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn store_refresh_error(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
error: impl Into<String>,
|
||||
) -> Result<()> {
|
||||
let descriptor = self.require_provider(provider_id).await?;
|
||||
let error = error.into();
|
||||
let existing = self.read_snapshot(&descriptor.identity).await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_inventory_entries (
|
||||
inventory_key,
|
||||
provider_id,
|
||||
provider_family,
|
||||
last_updated_at,
|
||||
last_refresh_attempt_at,
|
||||
last_refresh_error,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(inventory_key) DO UPDATE SET
|
||||
provider_id = excluded.provider_id,
|
||||
provider_family = excluded.provider_family,
|
||||
last_updated_at = excluded.last_updated_at,
|
||||
last_refresh_attempt_at = excluded.last_refresh_attempt_at,
|
||||
last_refresh_error = excluded.last_refresh_error,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&descriptor.identity.inventory_key)
|
||||
.bind(&descriptor.identity.provider_id)
|
||||
.bind(&descriptor.identity.provider_family)
|
||||
.bind(existing.and_then(|snapshot| snapshot.last_updated_at.map(|time| time.to_rfc3339())))
|
||||
.bind(Utc::now().to_rfc3339())
|
||||
.bind(error)
|
||||
.execute(self.storage.pool().await?)
|
||||
.await?;
|
||||
|
||||
self.refreshing_keys
|
||||
.write()
|
||||
.await
|
||||
.remove(&descriptor.identity.inventory_key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_stale(entry: &ProviderInventoryEntry) -> bool {
|
||||
let Some(last_updated_at) = entry.last_updated_at else {
|
||||
return false;
|
||||
};
|
||||
entry.supports_refresh && Utc::now() - last_updated_at > Duration::hours(STALE_AFTER_HOURS)
|
||||
}
|
||||
|
||||
async fn describe_provider(&self, provider_id: &str) -> Result<Option<ProviderDescriptor>> {
|
||||
let entry = match crate::providers::get_from_registry(provider_id).await {
|
||||
Ok(entry) => entry,
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
let metadata = entry.metadata().clone();
|
||||
let identity = crate::providers::inventory_identity(provider_id)
|
||||
.await
|
||||
.unwrap_or_else(|_| fallback_inventory_identity(provider_id))
|
||||
.into_identity()?;
|
||||
|
||||
Ok(Some(ProviderDescriptor {
|
||||
provider_id: metadata.name.clone(),
|
||||
provider_name: metadata.display_name.clone(),
|
||||
identity,
|
||||
configured: entry.inventory_configured(),
|
||||
supports_refresh: entry.supports_inventory_refresh(),
|
||||
static_models: metadata.known_models,
|
||||
model_selection_hint: metadata.model_selection_hint,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn require_provider(&self, provider_id: &str) -> Result<ProviderDescriptor> {
|
||||
self.describe_provider(provider_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("Unknown provider: {}", provider_id))
|
||||
}
|
||||
|
||||
async fn mark_refresh_started(&self, identity: &InventoryIdentity) -> Result<()> {
|
||||
let existing = self.read_snapshot(identity).await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_inventory_entries (
|
||||
inventory_key,
|
||||
provider_id,
|
||||
provider_family,
|
||||
last_updated_at,
|
||||
last_refresh_attempt_at,
|
||||
last_refresh_error,
|
||||
updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, NULL, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(inventory_key) DO UPDATE SET
|
||||
provider_id = excluded.provider_id,
|
||||
provider_family = excluded.provider_family,
|
||||
last_updated_at = excluded.last_updated_at,
|
||||
last_refresh_attempt_at = excluded.last_refresh_attempt_at,
|
||||
last_refresh_error = NULL,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&identity.inventory_key)
|
||||
.bind(&identity.provider_id)
|
||||
.bind(&identity.provider_family)
|
||||
.bind(existing.and_then(|snapshot| snapshot.last_updated_at.map(|time| time.to_rfc3339())))
|
||||
.bind(Utc::now().to_rfc3339())
|
||||
.execute(self.storage.pool().await?)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_snapshot(
|
||||
&self,
|
||||
identity: &InventoryIdentity,
|
||||
) -> Result<Option<InventorySnapshot>> {
|
||||
let pool = self.storage.pool().await?;
|
||||
let entry = sqlx::query(
|
||||
r#"
|
||||
SELECT last_updated_at, last_refresh_attempt_at, last_refresh_error
|
||||
FROM provider_inventory_entries
|
||||
WHERE inventory_key = ?
|
||||
"#,
|
||||
)
|
||||
.bind(&identity.inventory_key)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let Some(entry) = entry else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let last_updated_at = parse_optional_datetime(entry.try_get("last_updated_at")?)?;
|
||||
let last_refresh_attempt_at =
|
||||
parse_optional_datetime(entry.try_get("last_refresh_attempt_at")?)?;
|
||||
let last_refresh_error = entry.try_get("last_refresh_error")?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT model_id, name, family, context_limit, reasoning, recommended
|
||||
FROM provider_inventory_models
|
||||
WHERE inventory_key = ?
|
||||
ORDER BY ordinal
|
||||
"#,
|
||||
)
|
||||
.bind(&identity.inventory_key)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let models = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
Ok(InventoryModel {
|
||||
id: row.try_get("model_id")?,
|
||||
name: row.try_get("name")?,
|
||||
family: row.try_get("family")?,
|
||||
context_limit: row
|
||||
.try_get::<Option<i64>, _>("context_limit")?
|
||||
.map(usize::try_from)
|
||||
.transpose()?,
|
||||
reasoning: row.try_get("reasoning")?,
|
||||
recommended: row
|
||||
.try_get::<Option<bool>, _>("recommended")?
|
||||
.unwrap_or(false),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, anyhow::Error>>()?;
|
||||
|
||||
Ok(Some(InventorySnapshot {
|
||||
models,
|
||||
last_updated_at,
|
||||
last_refresh_attempt_at,
|
||||
last_refresh_error,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn resolve_provider_ids(&self, provider_ids: &[String]) -> Vec<String> {
|
||||
let mut ids = if provider_ids.is_empty() {
|
||||
crate::providers::providers()
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(metadata, _)| metadata.name)
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
provider_ids.to_vec()
|
||||
};
|
||||
ids.sort();
|
||||
ids.dedup();
|
||||
ids
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_inventory_identity(
|
||||
provider_id: &str,
|
||||
provider_family: &str,
|
||||
config_keys: &[ConfigKey],
|
||||
config: &Config,
|
||||
) -> InventoryIdentityInput {
|
||||
let mut identity = InventoryIdentityInput::new(provider_id, provider_family);
|
||||
|
||||
for key in config_keys {
|
||||
if key.secret {
|
||||
if let Some(value) = config_secret_value(config, &key.name) {
|
||||
identity.secret_inputs.insert(key.name.clone(), value);
|
||||
}
|
||||
} else if let Some(value) = config_param_value(config, &key.name) {
|
||||
identity.public_inputs.insert(key.name.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
identity
|
||||
}
|
||||
|
||||
pub fn default_inventory_configured(config_keys: &[ConfigKey], config: &Config) -> bool {
|
||||
config_keys.iter().all(|key| {
|
||||
if !key.required {
|
||||
return true;
|
||||
}
|
||||
if key.default.is_some() {
|
||||
return true;
|
||||
}
|
||||
if key.secret {
|
||||
config.get_secret::<serde_json::Value>(&key.name).is_ok()
|
||||
} else {
|
||||
config.get_param::<serde_json::Value>(&key.name).is_ok()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn declarative_inventory_identity(
|
||||
config: &DeclarativeProviderConfig,
|
||||
) -> Result<InventoryIdentityInput> {
|
||||
let global = Config::global();
|
||||
let mut identity = InventoryIdentityInput::new(
|
||||
config.name.clone(),
|
||||
config
|
||||
.catalog_provider_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| match config.engine {
|
||||
ProviderEngine::OpenAI => "openai".to_string(),
|
||||
ProviderEngine::Anthropic => "anthropic".to_string(),
|
||||
ProviderEngine::Ollama => "ollama".to_string(),
|
||||
}),
|
||||
);
|
||||
|
||||
identity
|
||||
.public_inputs
|
||||
.insert("base_url".to_string(), config.base_url.clone());
|
||||
|
||||
if let Some(base_path) = &config.base_path {
|
||||
identity
|
||||
.public_inputs
|
||||
.insert("base_path".to_string(), base_path.clone());
|
||||
}
|
||||
if let Some(catalog_provider_id) = &config.catalog_provider_id {
|
||||
identity.public_inputs.insert(
|
||||
"catalog_provider_id".to_string(),
|
||||
catalog_provider_id.clone(),
|
||||
);
|
||||
}
|
||||
if let Some(dynamic_models) = config.dynamic_models {
|
||||
identity
|
||||
.public_inputs
|
||||
.insert("dynamic_models".to_string(), dynamic_models.to_string());
|
||||
}
|
||||
identity.public_inputs.insert(
|
||||
"skip_canonical_filtering".to_string(),
|
||||
config.skip_canonical_filtering.to_string(),
|
||||
);
|
||||
if !config.models.is_empty() {
|
||||
identity.public_inputs.insert(
|
||||
"models".to_string(),
|
||||
serde_json::to_string(
|
||||
&config
|
||||
.models
|
||||
.iter()
|
||||
.map(|model| &model.name)
|
||||
.collect::<Vec<_>>(),
|
||||
)?,
|
||||
);
|
||||
}
|
||||
if let Some(headers) = &config.headers {
|
||||
identity
|
||||
.public_inputs
|
||||
.insert("headers".to_string(), serialize_string_map(headers)?);
|
||||
}
|
||||
if config.requires_auth && !config.api_key_env.is_empty() {
|
||||
if let Some(value) = config_secret_value(global, &config.api_key_env) {
|
||||
identity
|
||||
.secret_inputs
|
||||
.insert(config.api_key_env.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
|
||||
pub fn config_param_value(config: &Config, key: &str) -> Option<String> {
|
||||
config
|
||||
.get_param::<serde_json::Value>(key)
|
||||
.ok()
|
||||
.and_then(|value| normalize_json_value(&value))
|
||||
}
|
||||
|
||||
pub fn config_secret_value(config: &Config, key: &str) -> Option<String> {
|
||||
config
|
||||
.get_secret::<serde_json::Value>(key)
|
||||
.ok()
|
||||
.and_then(|value| normalize_json_value(&value))
|
||||
}
|
||||
|
||||
pub fn serialize_string_map(map: &HashMap<String, String>) -> Result<String> {
|
||||
let ordered = map
|
||||
.iter()
|
||||
.map(|(key, value)| (key.clone(), value.clone()))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
Ok(serde_json::to_string(&ordered)?)
|
||||
}
|
||||
|
||||
fn parse_optional_datetime(value: Option<String>) -> Result<Option<DateTime<Utc>>> {
|
||||
value
|
||||
.map(|value| value.parse::<DateTime<Utc>>())
|
||||
.transpose()
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn normalize_json_value(value: &serde_json::Value) -> Option<String> {
|
||||
match value {
|
||||
serde_json::Value::Null => None,
|
||||
serde_json::Value::String(value) if value.is_empty() => None,
|
||||
serde_json::Value::String(value) => Some(value.clone()),
|
||||
other => serde_json::to_string(other).ok(),
|
||||
}
|
||||
}
|
||||
|
||||
fn fallback_inventory_identity(provider_id: &str) -> InventoryIdentityInput {
|
||||
InventoryIdentityInput::new(
|
||||
provider_id.to_string(),
|
||||
map_provider_name(provider_id).to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn enrich_model_ids_with_canonical(
|
||||
provider_family: &str,
|
||||
model_ids: &[String],
|
||||
) -> Vec<InventoryModel> {
|
||||
let mut models: Vec<InventoryModel> = Vec::new();
|
||||
let mut seen_names: HashSet<String> = HashSet::new();
|
||||
|
||||
for id in model_ids {
|
||||
let model = enriched_model(provider_family, id, None);
|
||||
if !seen_names.insert(model.name.clone()) {
|
||||
continue;
|
||||
}
|
||||
models.push(model);
|
||||
}
|
||||
|
||||
// For databricks, prefer goose- prefixed model_ids when there are duplicates.
|
||||
// Re-scan: if a later model_id with "goose-" prefix maps to the same display name,
|
||||
// swap it in.
|
||||
if provider_family == "databricks" {
|
||||
let mut name_to_idx: HashMap<String, usize> = HashMap::new();
|
||||
for (idx, model) in models.iter().enumerate() {
|
||||
name_to_idx.insert(model.name.clone(), idx);
|
||||
}
|
||||
for id in model_ids {
|
||||
if !id.starts_with("goose-") {
|
||||
continue;
|
||||
}
|
||||
let candidate = enriched_model(provider_family, id, None);
|
||||
if let Some(&idx) = name_to_idx.get(&candidate.name) {
|
||||
if !models[idx].id.starts_with("goose-") {
|
||||
models[idx].id = candidate.id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark the latest model per recommended family.
|
||||
let mut seen_recommended_families: HashSet<String> = HashSet::new();
|
||||
for model in &mut models {
|
||||
if let Some(family) = &model.family {
|
||||
if RECOMMENDED_FAMILIES.contains(&family.as_str())
|
||||
&& seen_recommended_families.insert(family.clone())
|
||||
{
|
||||
model.recommended = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
models
|
||||
}
|
||||
|
||||
fn configured_models_to_inventory(
|
||||
provider_family: &str,
|
||||
models: &[ModelInfo],
|
||||
) -> Vec<InventoryModel> {
|
||||
let mut result: Vec<InventoryModel> = Vec::new();
|
||||
let mut seen_names: HashSet<String> = HashSet::new();
|
||||
for model in models {
|
||||
let enriched = enriched_model(provider_family, &model.name, Some(model.context_limit));
|
||||
if seen_names.insert(enriched.name.clone()) {
|
||||
result.push(enriched);
|
||||
}
|
||||
}
|
||||
|
||||
let mut seen_recommended_families: HashSet<String> = HashSet::new();
|
||||
for model in &mut result {
|
||||
if let Some(family) = &model.family {
|
||||
if RECOMMENDED_FAMILIES.contains(&family.as_str())
|
||||
&& seen_recommended_families.insert(family.clone())
|
||||
{
|
||||
model.recommended = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn inventory_models_from_snapshot(
|
||||
snapshot: Option<&InventorySnapshot>,
|
||||
provider_family: &str,
|
||||
configured_models: &[ModelInfo],
|
||||
) -> Vec<InventoryModel> {
|
||||
match snapshot {
|
||||
Some(snapshot) if !snapshot.models.is_empty() || snapshot.last_updated_at.is_some() => {
|
||||
snapshot.models.clone()
|
||||
}
|
||||
_ => configured_models_to_inventory(provider_family, configured_models),
|
||||
}
|
||||
}
|
||||
|
||||
fn enriched_model(
|
||||
provider_family: &str,
|
||||
model_id: &str,
|
||||
fallback_context_limit: Option<usize>,
|
||||
) -> InventoryModel {
|
||||
let registry = CanonicalModelRegistry::bundled().ok();
|
||||
let canonical = registry.as_ref().and_then(|registry| {
|
||||
let canonical_id = map_to_canonical_model(provider_family, model_id, registry)?;
|
||||
let (provider, model) = canonical_id.split_once('/')?;
|
||||
registry.get(provider, model).cloned()
|
||||
});
|
||||
|
||||
InventoryModel {
|
||||
id: model_id.to_string(),
|
||||
name: canonical
|
||||
.as_ref()
|
||||
.map(|model| model.name.clone())
|
||||
.unwrap_or_else(|| model_id.to_string()),
|
||||
family: canonical.as_ref().and_then(|model| model.family.clone()),
|
||||
context_limit: canonical
|
||||
.as_ref()
|
||||
.map(|model| model.limit.context)
|
||||
.or(fallback_context_limit),
|
||||
reasoning: canonical.as_ref().and_then(|model| model.reasoning),
|
||||
recommended: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_tables(pool: &Pool<Sqlite>) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS provider_inventory_entries (
|
||||
inventory_key TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
provider_family TEXT NOT NULL,
|
||||
last_updated_at TEXT,
|
||||
last_refresh_attempt_at TEXT,
|
||||
last_refresh_error TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS provider_inventory_models (
|
||||
inventory_key TEXT NOT NULL REFERENCES provider_inventory_entries(inventory_key) ON DELETE CASCADE,
|
||||
ordinal INTEGER NOT NULL,
|
||||
model_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
family TEXT,
|
||||
context_limit INTEGER,
|
||||
reasoning BOOLEAN,
|
||||
recommended BOOLEAN,
|
||||
PRIMARY KEY (inventory_key, ordinal)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
"CREATE INDEX IF NOT EXISTS idx_provider_inventory_provider_id ON provider_inventory_entries(provider_id)",
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_tables_in_tx(tx: &mut Transaction<'_, Sqlite>) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS provider_inventory_entries (
|
||||
inventory_key TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
provider_family TEXT NOT NULL,
|
||||
last_updated_at TEXT,
|
||||
last_refresh_attempt_at TEXT,
|
||||
last_refresh_error TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS provider_inventory_models (
|
||||
inventory_key TEXT NOT NULL REFERENCES provider_inventory_entries(inventory_key) ON DELETE CASCADE,
|
||||
ordinal INTEGER NOT NULL,
|
||||
model_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
family TEXT,
|
||||
context_limit INTEGER,
|
||||
reasoning BOOLEAN,
|
||||
recommended BOOLEAN,
|
||||
PRIMARY KEY (inventory_key, ordinal)
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
"CREATE INDEX IF NOT EXISTS idx_provider_inventory_provider_id ON provider_inventory_entries(provider_id)",
|
||||
)
|
||||
.execute(&mut **tx)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn inventory_identity_hash_changes_with_secret_inputs() {
|
||||
let left = InventoryIdentityInput::new("openai", "openai")
|
||||
.with_public("host", "https://api.openai.com")
|
||||
.with_secret("api_key", "secret-a")
|
||||
.into_identity()
|
||||
.unwrap();
|
||||
let right = InventoryIdentityInput::new("openai", "openai")
|
||||
.with_public("host", "https://api.openai.com")
|
||||
.with_secret("api_key", "secret-b")
|
||||
.into_identity()
|
||||
.unwrap();
|
||||
|
||||
assert_ne!(left.inventory_key, right.inventory_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn configured_models_use_canonical_enrichment() {
|
||||
let models =
|
||||
configured_models_to_inventory("anthropic", &[ModelInfo::new("claude-sonnet-4-5", 0)]);
|
||||
|
||||
assert_eq!(models.len(), 1);
|
||||
assert!(models[0].name.contains("Claude"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inventory_uses_configured_models_before_first_successful_refresh() {
|
||||
let configured_models = [ModelInfo::new("claude-sonnet-4-5", 0)];
|
||||
let snapshot = InventorySnapshot {
|
||||
models: vec![],
|
||||
last_updated_at: None,
|
||||
last_refresh_attempt_at: Some(Utc::now()),
|
||||
last_refresh_error: Some("auth failed".to_string()),
|
||||
};
|
||||
|
||||
let models =
|
||||
inventory_models_from_snapshot(Some(&snapshot), "anthropic", &configured_models);
|
||||
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].id, "claude-sonnet-4-5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inventory_preserves_empty_models_after_successful_refresh() {
|
||||
let configured_models = [ModelInfo::new("claude-sonnet-4-5", 0)];
|
||||
let snapshot = InventorySnapshot {
|
||||
models: vec![],
|
||||
last_updated_at: Some(Utc::now()),
|
||||
last_refresh_attempt_at: Some(Utc::now()),
|
||||
last_refresh_error: None,
|
||||
};
|
||||
|
||||
let models =
|
||||
inventory_models_from_snapshot(Some(&snapshot), "anthropic", &configured_models);
|
||||
|
||||
assert!(models.is_empty());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
mod acp_tooling;
|
||||
pub mod amp_acp;
|
||||
pub mod anthropic;
|
||||
pub mod api_client;
|
||||
|
|
@ -28,6 +29,7 @@ pub mod gemini_oauth;
|
|||
pub mod githubcopilot;
|
||||
pub mod google;
|
||||
mod init;
|
||||
pub mod inventory;
|
||||
pub mod kimicode;
|
||||
pub mod litellm;
|
||||
#[cfg(feature = "local-inference")]
|
||||
|
|
@ -56,6 +58,6 @@ pub mod xai;
|
|||
|
||||
pub use init::{
|
||||
cleanup_provider, create, create_with_default_model, create_with_named_model,
|
||||
get_from_registry, providers, refresh_custom_providers,
|
||||
get_from_registry, inventory_identity, providers, refresh_custom_providers,
|
||||
};
|
||||
pub use retry::{retry_operation, RetryConfig};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use super::api_client::{ApiClient, AuthMethod};
|
||||
use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata};
|
||||
use super::errors::ProviderError;
|
||||
use super::inventory::InventoryIdentityInput;
|
||||
use super::openai_compatible::handle_status_openai_compat;
|
||||
use super::retry::{ProviderRetry, RetryConfig};
|
||||
use super::utils::{ImageFormat, RequestLog};
|
||||
|
|
@ -256,6 +257,22 @@ impl ProviderDef for OllamaProvider {
|
|||
) -> BoxFuture<'static, Result<Self::Provider>> {
|
||||
Box::pin(Self::from_env(model))
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
let config = crate::config::Config::global();
|
||||
Ok(
|
||||
InventoryIdentityInput::new(OLLAMA_PROVIDER_NAME, OLLAMA_PROVIDER_NAME).with_public(
|
||||
"host",
|
||||
config
|
||||
.get_param::<String>("OLLAMA_HOST")
|
||||
.unwrap_or_else(|_| OLLAMA_HOST.to_string()),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ use super::formats::openai_responses::{
|
|||
create_responses_request, get_responses_usage, responses_api_to_message,
|
||||
responses_api_to_streaming_message, ResponsesApiResponse,
|
||||
};
|
||||
use super::inventory::{config_secret_value, InventoryIdentityInput};
|
||||
use super::openai_compatible::{
|
||||
handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat,
|
||||
};
|
||||
|
|
@ -425,6 +426,58 @@ impl ProviderDef for OpenAiProvider {
|
|||
) -> BoxFuture<'static, Result<Self::Provider>> {
|
||||
Box::pin(Self::from_env(model))
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool {
|
||||
let config = crate::config::Config::global();
|
||||
// If the host is explicitly set to something non-default, trust the user's
|
||||
// custom setup (e.g. a local server that doesn't require an API key).
|
||||
if let Ok(host) = config.get_param::<String>("OPENAI_HOST") {
|
||||
if host != "https://api.openai.com" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// Standard OpenAI endpoint requires an API key.
|
||||
config
|
||||
.get_secret::<serde_json::Value>("OPENAI_API_KEY")
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
let config = crate::config::Config::global();
|
||||
let mut identity =
|
||||
InventoryIdentityInput::new(OPEN_AI_PROVIDER_NAME, OPEN_AI_PROVIDER_NAME)
|
||||
.with_public(
|
||||
"host",
|
||||
config
|
||||
.get_param::<String>("OPENAI_HOST")
|
||||
.unwrap_or_else(|_| "https://api.openai.com".to_string()),
|
||||
)
|
||||
.with_public(
|
||||
"base_path",
|
||||
config
|
||||
.get_param::<String>("OPENAI_BASE_PATH")
|
||||
.unwrap_or_else(|_| OPEN_AI_DEFAULT_BASE_PATH.to_string()),
|
||||
);
|
||||
|
||||
if let Ok(organization) = config.get_param::<String>("OPENAI_ORGANIZATION") {
|
||||
identity = identity.with_public("organization", organization);
|
||||
}
|
||||
if let Ok(project) = config.get_param::<String>("OPENAI_PROJECT") {
|
||||
identity = identity.with_public("project", project);
|
||||
}
|
||||
if let Some(api_key) = config_secret_value(config, "OPENAI_API_KEY") {
|
||||
identity = identity.with_secret("api_key", api_key);
|
||||
}
|
||||
if let Some(custom_headers) = config_secret_value(config, "OPENAI_CUSTOM_HEADERS") {
|
||||
identity = identity.with_secret("custom_headers", custom_headers);
|
||||
}
|
||||
|
||||
Ok(identity)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ use crate::acp::{
|
|||
use crate::config::search_path::SearchPaths;
|
||||
use crate::config::{Config, GooseMode};
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::acp_tooling::{acp_adapter_installed, acp_inventory_identity};
|
||||
use crate::providers::base::{ProviderDef, ProviderMetadata};
|
||||
use crate::providers::inventory::InventoryIdentityInput;
|
||||
|
||||
const PI_ACP_PROVIDER_NAME: &str = "pi-acp";
|
||||
const PI_ACP_DOC_URL: &str = "https://github.com/anthropics/pi";
|
||||
|
|
@ -36,6 +38,7 @@ impl ProviderDef for PiAcpProvider {
|
|||
"Set in your goose config file (`~/.config/goose/config.yaml` on macOS/Linux):\n GOOSE_PROVIDER: pi-acp\n GOOSE_MODEL: current",
|
||||
"Restart goose for changes to take effect",
|
||||
])
|
||||
.with_model_selection_hint("Use the Pi CLI to configure models")
|
||||
}
|
||||
|
||||
fn from_env(
|
||||
|
|
@ -70,4 +73,16 @@ impl ProviderDef for PiAcpProvider {
|
|||
AcpProvider::connect(metadata.name, model, goose_mode, provider_config).await
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_inventory_refresh() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn inventory_identity() -> Result<InventoryIdentityInput> {
|
||||
acp_inventory_identity(PI_ACP_PROVIDER_NAME, PI_ACP_BINARY)
|
||||
}
|
||||
|
||||
fn inventory_configured() -> bool {
|
||||
acp_adapter_installed(PI_ACP_BINARY)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use super::base::{ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderType};
|
||||
use super::inventory::InventoryIdentityInput;
|
||||
use crate::config::{DeclarativeProviderConfig, ExtensionConfig};
|
||||
use crate::model::ModelConfig;
|
||||
use anyhow::Result;
|
||||
|
|
@ -14,12 +15,20 @@ pub type ProviderConstructor = Arc<
|
|||
|
||||
pub type ProviderCleanup = Arc<dyn Fn() -> BoxFuture<'static, Result<()>> + Send + Sync>;
|
||||
|
||||
pub type ProviderInventoryIdentityResolver =
|
||||
Arc<dyn Fn() -> Result<InventoryIdentityInput> + Send + Sync>;
|
||||
|
||||
pub type ProviderInventoryConfiguredResolver = Arc<dyn Fn() -> bool + Send + Sync>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProviderEntry {
|
||||
metadata: ProviderMetadata,
|
||||
pub(crate) constructor: ProviderConstructor,
|
||||
pub(crate) inventory_identity: ProviderInventoryIdentityResolver,
|
||||
pub(crate) inventory_configured: ProviderInventoryConfiguredResolver,
|
||||
pub(crate) cleanup: Option<ProviderCleanup>,
|
||||
provider_type: ProviderType,
|
||||
supports_inventory_refresh: bool,
|
||||
}
|
||||
|
||||
impl ProviderEntry {
|
||||
|
|
@ -27,6 +36,22 @@ impl ProviderEntry {
|
|||
&self.metadata
|
||||
}
|
||||
|
||||
pub fn provider_type(&self) -> ProviderType {
|
||||
self.provider_type
|
||||
}
|
||||
|
||||
pub fn supports_inventory_refresh(&self) -> bool {
|
||||
self.supports_inventory_refresh
|
||||
}
|
||||
|
||||
pub fn inventory_identity(&self) -> Result<InventoryIdentityInput> {
|
||||
(self.inventory_identity)()
|
||||
}
|
||||
|
||||
pub fn inventory_configured(&self) -> bool {
|
||||
(self.inventory_configured)()
|
||||
}
|
||||
|
||||
fn normalize_model_config(&self, mut model: ModelConfig) -> ModelConfig {
|
||||
model = model.with_canonical_limits(&self.metadata.name);
|
||||
|
||||
|
|
@ -92,24 +117,30 @@ impl ProviderRegistry {
|
|||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
})
|
||||
}),
|
||||
inventory_identity: Arc::new(F::inventory_identity),
|
||||
inventory_configured: Arc::new(F::inventory_configured),
|
||||
cleanup: None,
|
||||
provider_type: if preferred {
|
||||
ProviderType::Preferred
|
||||
} else {
|
||||
ProviderType::Builtin
|
||||
},
|
||||
supports_inventory_refresh: F::supports_inventory_refresh(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn register_with_name<P, F>(
|
||||
pub fn register_with_name<P, F, G>(
|
||||
&mut self,
|
||||
config: &DeclarativeProviderConfig,
|
||||
provider_type: ProviderType,
|
||||
supports_inventory_refresh: bool,
|
||||
constructor: F,
|
||||
inventory_identity: G,
|
||||
) where
|
||||
P: ProviderDef + 'static,
|
||||
F: Fn(ModelConfig) -> Result<P::Provider> + Send + Sync + 'static,
|
||||
G: Fn() -> Result<InventoryIdentityInput> + Send + Sync + 'static,
|
||||
{
|
||||
let base_metadata = P::metadata();
|
||||
let description = config
|
||||
|
|
@ -174,7 +205,9 @@ impl ProviderRegistry {
|
|||
model_doc_link: base_metadata.model_doc_link,
|
||||
config_keys,
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
};
|
||||
let inventory_config_keys = custom_metadata.config_keys.clone();
|
||||
|
||||
self.entries.insert(
|
||||
config.name.clone(),
|
||||
|
|
@ -187,8 +220,16 @@ impl ProviderRegistry {
|
|||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
})
|
||||
}),
|
||||
inventory_identity: Arc::new(inventory_identity),
|
||||
inventory_configured: Arc::new(move || {
|
||||
super::inventory::default_inventory_configured(
|
||||
&inventory_config_keys,
|
||||
crate::config::Config::global(),
|
||||
)
|
||||
}),
|
||||
cleanup: None,
|
||||
provider_type,
|
||||
supports_inventory_refresh,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ use std::sync::{Arc, LazyLock};
|
|||
use tracing::{info, warn};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
pub const CURRENT_SCHEMA_VERSION: i32 = 10;
|
||||
pub const CURRENT_SCHEMA_VERSION: i32 = 11;
|
||||
pub const SESSIONS_FOLDER: &str = "sessions";
|
||||
pub const DB_NAME: &str = "sessions.db";
|
||||
|
||||
|
|
@ -717,6 +717,8 @@ impl SessionStorage {
|
|||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
crate::providers::inventory::create_tables(pool).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -1060,6 +1062,9 @@ impl SessionStorage {
|
|||
.execute(&mut **tx)
|
||||
.await?;
|
||||
}
|
||||
11 => {
|
||||
crate::providers::inventory::create_tables_in_tx(tx).await?;
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!("Unknown migration version: {}", version);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ mod tests {
|
|||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -542,6 +543,7 @@ mod tests {
|
|||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -867,6 +869,7 @@ mod tests {
|
|||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ impl ProviderDef for MockCompactionProvider {
|
|||
model_doc_link: "".to_string(),
|
||||
config_keys: vec![],
|
||||
setup_steps: vec![],
|
||||
model_selection_hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6864,6 +6864,11 @@
|
|||
"type": "string",
|
||||
"description": "Link to the docs where models can be found"
|
||||
},
|
||||
"model_selection_hint": {
|
||||
"type": "string",
|
||||
"description": "Hint shown in the model picker when this provider manages its own model selection.",
|
||||
"nullable": true
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "The unique identifier for this provider"
|
||||
|
|
|
|||
|
|
@ -970,6 +970,10 @@ export type ProviderMetadata = {
|
|||
* Link to the docs where models can be found
|
||||
*/
|
||||
model_doc_link: string;
|
||||
/**
|
||||
* Hint shown in the model picker when this provider manages its own model selection.
|
||||
*/
|
||||
model_selection_hint?: string | null;
|
||||
/**
|
||||
* The unique identifier for this provider
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -36,9 +36,19 @@ const EXCEPTIONS = {
|
|||
"Search-as-you-type filtering and draft-aware sidebar highlight logic.",
|
||||
},
|
||||
"src/app/AppShell.tsx": {
|
||||
limit: 660,
|
||||
limit: 730,
|
||||
justification:
|
||||
"Shell still coordinates ACP session loading, replay-buffer cleanup on load failure, project reassignment, and app-level chat routing. Includes gated [perf:load]/[perf:newtab] logging via perfLog (dev-only by default).",
|
||||
"Shell still coordinates ACP session loading, replay-buffer cleanup on load failure, project reassignment, home-session restoration, app-level chat routing, and restored project-draft reuse. Includes gated [perf:load]/[perf:newtab] logging via perfLog (dev-only by default).",
|
||||
},
|
||||
"src/features/chat/hooks/useChatSessionController.ts": {
|
||||
limit: 690,
|
||||
justification:
|
||||
"Controller now centralizes home-to-chat pending state transfer, workspace/project preparation, provider/model/persona handoff, Goose cross-provider model selection sequencing with rollback, and chat input orchestration pending a later decomposition pass.",
|
||||
},
|
||||
"src/features/chat/ui/AgentModelPicker.tsx": {
|
||||
limit: 570,
|
||||
justification:
|
||||
"Agent-first picker currently keeps the full trigger, recommended-model view, searchable full-model view, and ACP/goose-specific labeling logic in one component pending later extraction.",
|
||||
},
|
||||
"src/features/chat/stores/__tests__/chatSessionStore.test.ts": {
|
||||
limit: 540,
|
||||
|
|
@ -56,7 +66,7 @@ const EXCEPTIONS = {
|
|||
"Voice dictation send/stop guards, attachment handling, and mention/picker coordination still share one chat composer component.",
|
||||
},
|
||||
"src/features/chat/ui/__tests__/ChatInput.test.tsx": {
|
||||
limit: 510,
|
||||
limit: 520,
|
||||
justification:
|
||||
"Composer regression coverage spans personas, queueing, attachments, and voice-input edge cases in one interaction-heavy suite.",
|
||||
},
|
||||
|
|
|
|||
33
ui/goose2/scripts/reset-inventory.sh
Executable file
33
ui/goose2/scripts/reset-inventory.sh
Executable file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env bash
|
||||
# Reset the provider inventory tables to empty, as if migration 12 just ran.
|
||||
# This lets you test the first-use experience (cold inventory).
|
||||
#
|
||||
# Usage: ./scripts/reset-inventory.sh
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
DB="${GOOSE_DB:-$HOME/.local/share/goose/sessions/sessions.db}"
|
||||
|
||||
if [ ! -f "$DB" ]; then
|
||||
echo "Database not found at $DB"
|
||||
echo "Set GOOSE_DB to override the path."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Database: $DB"
|
||||
echo ""
|
||||
echo "Before:"
|
||||
echo " provider_inventory_entries: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_entries;')"
|
||||
echo " provider_inventory_models: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_models;')"
|
||||
|
||||
# ON DELETE CASCADE on provider_inventory_models means deleting entries clears both tables.
|
||||
# Delete models first since CASCADE isn't reliable in all sqlite3 builds,
|
||||
# then delete entries.
|
||||
sqlite3 "$DB" "DELETE FROM provider_inventory_models; DELETE FROM provider_inventory_entries;"
|
||||
|
||||
echo ""
|
||||
echo "After:"
|
||||
echo " provider_inventory_entries: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_entries;')"
|
||||
echo " provider_inventory_models: $(sqlite3 "$DB" 'SELECT COUNT(*) FROM provider_inventory_models;')"
|
||||
echo ""
|
||||
echo "Inventory tables are empty. Restart goose to test first-use flow."
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { Sidebar } from "@/features/sidebar/ui/Sidebar";
|
||||
import { StatusBar } from "@/features/status/ui/StatusBar";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
import { CreateProjectDialog } from "@/features/projects/ui/CreateProjectDialog";
|
||||
import { archiveProject } from "@/features/projects/api/projects";
|
||||
import type { ProjectInfo } from "@/features/projects/api/projects";
|
||||
|
|
@ -9,7 +8,11 @@ import { SettingsModal } from "@/features/settings/ui/SettingsModal";
|
|||
import type { SectionId } from "@/features/settings/ui/SettingsModal";
|
||||
import { TopBar } from "./ui/TopBar";
|
||||
import { useChatStore } from "@/features/chat/stores/chatStore";
|
||||
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
|
||||
import {
|
||||
type ChatSession,
|
||||
hasSessionStarted,
|
||||
useChatSessionStore,
|
||||
} from "@/features/chat/stores/chatSessionStore";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import { findExistingDraft } from "@/features/chat/lib/newChat";
|
||||
|
|
@ -37,6 +40,33 @@ const SIDEBAR_MIN_WIDTH = 180;
|
|||
const SIDEBAR_MAX_WIDTH = 380;
|
||||
const SIDEBAR_SNAP_COLLAPSE_THRESHOLD = 100;
|
||||
const SIDEBAR_COLLAPSED_WIDTH = 48;
|
||||
const HOME_SESSION_STORAGE_KEY = "goose:home-session-id";
|
||||
|
||||
function loadStoredHomeSessionId(): string | null {
|
||||
if (typeof window === "undefined") {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return window.localStorage.getItem(HOME_SESSION_STORAGE_KEY);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function persistHomeSessionId(sessionId: string | null): void {
|
||||
if (typeof window === "undefined") {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (sessionId) {
|
||||
window.localStorage.setItem(HOME_SESSION_STORAGE_KEY, sessionId);
|
||||
return;
|
||||
}
|
||||
window.localStorage.removeItem(HOME_SESSION_STORAGE_KEY);
|
||||
} catch {
|
||||
// localStorage may be unavailable
|
||||
}
|
||||
}
|
||||
|
||||
export function AppShell({ children }: { children?: React.ReactNode }) {
|
||||
const [sidebarCollapsed, setSidebarCollapsed] = useState(false);
|
||||
|
|
@ -52,9 +82,9 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
null,
|
||||
);
|
||||
const [activeView, setActiveView] = useState<AppView>("home");
|
||||
const [homeSelectedProvider, setHomeSelectedProvider] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [homeSessionId, setHomeSessionId] = useState<string | null>(() =>
|
||||
loadStoredHomeSessionId(),
|
||||
);
|
||||
|
||||
const chatStore = useChatStore();
|
||||
const sessionStore = useChatSessionStore();
|
||||
|
|
@ -64,6 +94,9 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
const pendingProjectCreatedRef = useRef<((projectId: string) => void) | null>(
|
||||
null,
|
||||
);
|
||||
const homeSessionRequestRef = useRef<Promise<ChatSession | null> | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const loadSessionMessages = useCallback(async (sessionId: string) => {
|
||||
const sid = sessionId.slice(0, 8);
|
||||
|
|
@ -125,70 +158,146 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
}
|
||||
}, [activeSessionId, activeView]);
|
||||
|
||||
const isHome = activeSessionId === null && activeView === "home";
|
||||
const isHome = activeView === "home";
|
||||
|
||||
const activeSession = activeSessionId
|
||||
? sessionStore.getSession(activeSessionId)
|
||||
: undefined;
|
||||
const modelName = activeSession?.modelName;
|
||||
const tokenCount = activeSessionId
|
||||
? chatStore.getSessionRuntime(activeSessionId).tokenState.totalTokens
|
||||
: 0;
|
||||
const modelName =
|
||||
activeView === "chat" ? activeSession?.modelName : undefined;
|
||||
const tokenCount =
|
||||
activeView === "chat" && activeSessionId
|
||||
? chatStore.getSessionRuntime(activeSessionId).tokenState.totalTokens
|
||||
: 0;
|
||||
const homeSession = homeSessionId
|
||||
? sessionStore.getSession(homeSessionId)
|
||||
: undefined;
|
||||
|
||||
const [pendingInitialMessage, setPendingInitialMessage] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [pendingInitialAttachments, setPendingInitialAttachments] = useState<
|
||||
ChatAttachmentDraft[] | undefined
|
||||
>();
|
||||
const [homeSelectedPersonaId, setHomeSelectedPersonaId] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
useEffect(() => {
|
||||
if (
|
||||
!homeSessionId ||
|
||||
!sessionStore.hasHydratedSessions ||
|
||||
sessionStore.isLoading
|
||||
) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
!homeSession ||
|
||||
homeSession.archivedAt ||
|
||||
hasSessionStarted(
|
||||
homeSession,
|
||||
chatStore.messagesBySession[homeSession.id],
|
||||
)
|
||||
) {
|
||||
setHomeSessionId(null);
|
||||
}
|
||||
}, [
|
||||
chatStore.messagesBySession,
|
||||
homeSession,
|
||||
homeSession?.archivedAt,
|
||||
homeSession?.messageCount,
|
||||
homeSessionId,
|
||||
sessionStore.hasHydratedSessions,
|
||||
sessionStore.isLoading,
|
||||
]);
|
||||
|
||||
const cleanupEmptyDraft = useCallback(
|
||||
(sessionId: string | null) => {
|
||||
if (!sessionId) return;
|
||||
const state = useChatSessionStore.getState();
|
||||
const session = state.sessions.find((s) => s.id === sessionId);
|
||||
if (!session?.draft) return;
|
||||
const draft = useChatStore.getState().draftsBySession[sessionId] ?? "";
|
||||
if (draft.length > 0) return; // has typed text — keep it
|
||||
chatStore.cleanupSession(sessionId);
|
||||
state.removeDraft(sessionId);
|
||||
},
|
||||
[chatStore],
|
||||
);
|
||||
useEffect(() => {
|
||||
persistHomeSessionId(homeSessionId);
|
||||
}, [homeSessionId]);
|
||||
|
||||
const ensureHomeSession = useCallback(async () => {
|
||||
if (!sessionStore.hasHydratedSessions || sessionStore.isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (homeSessionRequestRef.current) {
|
||||
return homeSessionRequestRef.current;
|
||||
}
|
||||
|
||||
const request = (async () => {
|
||||
if (
|
||||
homeSession &&
|
||||
!homeSession.archivedAt &&
|
||||
homeSession.messageCount === 0
|
||||
) {
|
||||
const project = homeSession.projectId
|
||||
? (projectStore.projects.find(
|
||||
(candidate) => candidate.id === homeSession.projectId,
|
||||
) ?? null)
|
||||
: null;
|
||||
const workingDir = await resolveSessionCwd(project);
|
||||
await acpPrepareSession(
|
||||
homeSession.id,
|
||||
homeSession.providerId ?? agentStore.selectedProvider ?? "goose",
|
||||
workingDir,
|
||||
{
|
||||
personaId: homeSession.personaId,
|
||||
},
|
||||
);
|
||||
return homeSession;
|
||||
}
|
||||
|
||||
const workingDir = await resolveSessionCwd(null);
|
||||
const session = await sessionStore.createSession({
|
||||
title: DEFAULT_CHAT_TITLE,
|
||||
providerId: agentStore.selectedProvider ?? "goose",
|
||||
workingDir,
|
||||
});
|
||||
setHomeSessionId(session.id);
|
||||
return session;
|
||||
})();
|
||||
|
||||
homeSessionRequestRef.current = request;
|
||||
try {
|
||||
return await request;
|
||||
} finally {
|
||||
if (homeSessionRequestRef.current === request) {
|
||||
homeSessionRequestRef.current = null;
|
||||
}
|
||||
}
|
||||
}, [
|
||||
agentStore.selectedProvider,
|
||||
homeSession,
|
||||
projectStore.projects,
|
||||
sessionStore.hasHydratedSessions,
|
||||
sessionStore,
|
||||
sessionStore.isLoading,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (activeView !== "home") {
|
||||
return;
|
||||
}
|
||||
void ensureHomeSession().catch((error) => {
|
||||
console.error("Failed to ensure Home session:", error);
|
||||
});
|
||||
}, [activeView, ensureHomeSession]);
|
||||
|
||||
const createNewTab = useCallback(
|
||||
(title = DEFAULT_CHAT_TITLE, project?: ProjectInfo) => {
|
||||
async (title = DEFAULT_CHAT_TITLE, project?: ProjectInfo) => {
|
||||
const tStart = performance.now();
|
||||
perfLog(
|
||||
`[perf:newtab] createNewTab start (project=${project?.id ?? "none"})`,
|
||||
);
|
||||
const agentId = agentStore.activeAgentId ?? undefined;
|
||||
const providerId = project?.preferredProvider ?? homeSelectedProvider;
|
||||
const personaId = homeSelectedPersonaId;
|
||||
const providerId =
|
||||
project?.preferredProvider ?? agentStore.selectedProvider ?? "goose";
|
||||
const modelId = project?.preferredModel ?? undefined;
|
||||
const sessionState = useChatSessionStore.getState();
|
||||
const chatStoreState = useChatStore.getState();
|
||||
const chatState = useChatStore.getState();
|
||||
const existingDraft = findExistingDraft({
|
||||
sessions: sessionState.sessions,
|
||||
activeSessionId: sessionState.activeSessionId,
|
||||
draftsBySession: chatStoreState.draftsBySession,
|
||||
messagesBySession: chatStoreState.messagesBySession,
|
||||
draftsBySession: chatState.draftsBySession,
|
||||
messagesBySession: chatState.messagesBySession,
|
||||
request: {
|
||||
title,
|
||||
projectId: project?.id,
|
||||
agentId,
|
||||
providerId,
|
||||
personaId,
|
||||
},
|
||||
});
|
||||
|
||||
if (existingDraft) {
|
||||
if (sessionState.activeSessionId !== existingDraft.id) {
|
||||
cleanupEmptyDraft(sessionState.activeSessionId);
|
||||
}
|
||||
sessionState.setActiveSession(existingDraft.id);
|
||||
sessionStore.setActiveSession(existingDraft.id);
|
||||
setActiveView("chat");
|
||||
chatStore.setActiveSession(existingDraft.id);
|
||||
perfLog(
|
||||
|
|
@ -196,46 +305,45 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
);
|
||||
return existingDraft;
|
||||
}
|
||||
cleanupEmptyDraft(sessionState.activeSessionId);
|
||||
const session = sessionStore.createDraftSession({
|
||||
|
||||
const workingDir = await resolveSessionCwd(project);
|
||||
const session = await sessionStore.createSession({
|
||||
title,
|
||||
projectId: project?.id,
|
||||
agentId,
|
||||
providerId,
|
||||
personaId,
|
||||
workingDir,
|
||||
modelId,
|
||||
modelName: modelId,
|
||||
});
|
||||
sessionStore.setActiveSession(session.id);
|
||||
setActiveView("chat");
|
||||
chatStore.setActiveSession(session.id);
|
||||
perfLog(
|
||||
`[perf:newtab] ${session.id.slice(0, 8)} created draft in ${(performance.now() - tStart).toFixed(1)}ms`,
|
||||
`[perf:newtab] ${session.id.slice(0, 8)} created session in ${(performance.now() - tStart).toFixed(1)}ms`,
|
||||
);
|
||||
return session;
|
||||
},
|
||||
[
|
||||
agentStore.activeAgentId,
|
||||
agentStore.selectedProvider,
|
||||
chatStore,
|
||||
sessionStore,
|
||||
agentStore.activeAgentId,
|
||||
homeSelectedPersonaId,
|
||||
homeSelectedProvider,
|
||||
cleanupEmptyDraft,
|
||||
],
|
||||
);
|
||||
|
||||
const handleStartChatFromProject = useCallback(
|
||||
(project: ProjectInfo) => {
|
||||
setHomeSelectedProvider(undefined);
|
||||
createNewTab(DEFAULT_CHAT_TITLE, project);
|
||||
void createNewTab(DEFAULT_CHAT_TITLE, project);
|
||||
},
|
||||
[createNewTab],
|
||||
);
|
||||
|
||||
const handleNewChatInProject = useCallback(
|
||||
(projectId: string) => {
|
||||
setHomeSelectedProvider(undefined);
|
||||
const project = projectStore.projects.find((p) => p.id === projectId);
|
||||
if (project) {
|
||||
createNewTab(DEFAULT_CHAT_TITLE, project);
|
||||
void createNewTab(DEFAULT_CHAT_TITLE, project);
|
||||
}
|
||||
},
|
||||
[createNewTab, projectStore.projects],
|
||||
|
|
@ -255,12 +363,11 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
|
||||
const clearActiveSession = useCallback(
|
||||
(sessionId: string) => {
|
||||
cleanupEmptyDraft(sessionId);
|
||||
chatStore.cleanupSession(sessionId);
|
||||
sessionStore.setActiveSession(null);
|
||||
setActiveView("home");
|
||||
},
|
||||
[chatStore, sessionStore, cleanupEmptyDraft],
|
||||
[chatStore, sessionStore],
|
||||
);
|
||||
const openSettings = useCallback((section: SectionId = "appearance") => {
|
||||
setSettingsInitialSection(section);
|
||||
|
|
@ -306,7 +413,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
sessionStore.updateSession(sessionId, { projectId });
|
||||
|
||||
const session = useChatSessionStore.getState().getSession(sessionId);
|
||||
if (!session || session.draft) {
|
||||
if (!session) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -362,41 +469,28 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
[],
|
||||
);
|
||||
|
||||
const handleHomeStartChat = useCallback(
|
||||
(
|
||||
initialMessage?: string,
|
||||
providerId?: string,
|
||||
personaId?: string,
|
||||
projectId?: string | null,
|
||||
attachments?: ChatAttachmentDraft[],
|
||||
) => {
|
||||
setHomeSelectedProvider(providerId);
|
||||
setHomeSelectedPersonaId(personaId);
|
||||
setPendingInitialMessage(initialMessage);
|
||||
setPendingInitialAttachments(attachments);
|
||||
const selectedProject =
|
||||
projectId != null
|
||||
? projectStore.projects.find((project) => project.id === projectId)
|
||||
: undefined;
|
||||
|
||||
createNewTab(
|
||||
initialMessage?.slice(0, 40) || DEFAULT_CHAT_TITLE,
|
||||
selectedProject,
|
||||
);
|
||||
const activateHomeSession = useCallback(
|
||||
(sessionId: string) => {
|
||||
if (homeSessionId === sessionId) {
|
||||
setHomeSessionId(null);
|
||||
}
|
||||
sessionStore.setActiveSession(sessionId);
|
||||
setActiveView("chat");
|
||||
chatStore.setActiveSession(sessionId);
|
||||
useChatStore.getState().markSessionRead(sessionId);
|
||||
},
|
||||
[createNewTab, projectStore.projects],
|
||||
[chatStore, homeSessionId, sessionStore],
|
||||
);
|
||||
|
||||
const handleSelectSession = useCallback(
|
||||
(id: string) => {
|
||||
cleanupEmptyDraft(useChatSessionStore.getState().activeSessionId);
|
||||
sessionStore.setActiveSession(id);
|
||||
setActiveView("chat");
|
||||
chatStore.setActiveSession(id);
|
||||
useChatStore.getState().markSessionRead(id);
|
||||
loadSessionMessages(id);
|
||||
},
|
||||
[sessionStore, chatStore, loadSessionMessages, cleanupEmptyDraft],
|
||||
[sessionStore, chatStore, loadSessionMessages],
|
||||
);
|
||||
|
||||
const handleSelectSearchResult = useCallback(
|
||||
|
|
@ -414,12 +508,11 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
const handleNavigate = useCallback(
|
||||
(view: AppView) => {
|
||||
if (view !== "chat") {
|
||||
cleanupEmptyDraft(useChatSessionStore.getState().activeSessionId);
|
||||
sessionStore.setActiveSession(null);
|
||||
}
|
||||
setActiveView(view);
|
||||
},
|
||||
[sessionStore, cleanupEmptyDraft],
|
||||
[sessionStore],
|
||||
);
|
||||
|
||||
const toggleSidebar = () => setSidebarCollapsed((prev) => !prev);
|
||||
|
|
@ -496,20 +589,13 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
// Cmd+N opens new conversation screen
|
||||
if (e.key === "n" && e.metaKey) {
|
||||
e.preventDefault();
|
||||
createNewTab();
|
||||
sessionStore.setActiveSession(null);
|
||||
setActiveView("home");
|
||||
}
|
||||
};
|
||||
window.addEventListener("keydown", handler);
|
||||
return () => window.removeEventListener("keydown", handler);
|
||||
}, [clearActiveSession, createNewTab]);
|
||||
|
||||
const activeSessionPersonaId = activeSession?.personaId;
|
||||
const handleInitialMessageConsumed = useCallback(() => {
|
||||
setPendingInitialMessage(undefined);
|
||||
setPendingInitialAttachments(undefined);
|
||||
setHomeSelectedProvider(undefined);
|
||||
setHomeSelectedPersonaId(undefined);
|
||||
}, []);
|
||||
}, [clearActiveSession, sessionStore]);
|
||||
|
||||
const editingProjectProp = useMemo(
|
||||
() =>
|
||||
|
|
@ -551,7 +637,10 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
onCollapse={toggleSidebar}
|
||||
onNavigate={handleNavigate}
|
||||
onNewChatInProject={handleNewChatInProject}
|
||||
onNewChat={() => createNewTab()}
|
||||
onNewChat={() => {
|
||||
sessionStore.setActiveSession(null);
|
||||
setActiveView("home");
|
||||
}}
|
||||
onCreateProject={() => openCreateProjectDialog()}
|
||||
onEditProject={handleEditProject}
|
||||
onArchiveProject={handleArchiveProject}
|
||||
|
|
@ -582,15 +671,10 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
<AppShellContent
|
||||
activeView={activeView}
|
||||
activeSession={activeSession}
|
||||
activeSessionPersonaId={activeSessionPersonaId}
|
||||
homeSelectedProvider={homeSelectedProvider}
|
||||
homeSelectedPersonaId={homeSelectedPersonaId}
|
||||
pendingInitialMessage={pendingInitialMessage}
|
||||
pendingInitialAttachments={pendingInitialAttachments}
|
||||
homeSessionId={homeSessionId}
|
||||
onArchiveChat={handleArchiveChat}
|
||||
onCreateProject={openCreateProjectDialog}
|
||||
onHomeStartChat={handleHomeStartChat}
|
||||
onInitialMessageConsumed={handleInitialMessageConsumed}
|
||||
onActivateHomeSession={activateHomeSession}
|
||||
onRenameChat={handleRenameChat}
|
||||
onSelectSession={handleSelectSession}
|
||||
onSelectSearchResult={handleSelectSearchResult}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
import { useEffect } from "react";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
|
||||
import { useProviderInventoryStore } from "@/features/providers/stores/providerInventoryStore";
|
||||
import { setNotificationHandler, getClient } from "@/shared/api/acpConnection";
|
||||
import notificationHandler from "@/shared/api/acpNotificationHandler";
|
||||
import { perfLog } from "@/shared/lib/perfLog";
|
||||
|
||||
const INVENTORY_POLL_DELAYS_MS = [250, 500, 750, 1000, 1500, 2000];
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => window.setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export function useAppStartup() {
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
|
|
@ -22,6 +29,7 @@ export function useAppStartup() {
|
|||
}
|
||||
|
||||
const store = useAgentStore.getState();
|
||||
const inventoryStore = useProviderInventoryStore.getState();
|
||||
const loadPersonas = async () => {
|
||||
const t0 = performance.now();
|
||||
store.setPersonasLoading(true);
|
||||
|
|
@ -56,6 +64,76 @@ export function useAppStartup() {
|
|||
}
|
||||
};
|
||||
|
||||
const loadProviderInventory = async () => {
|
||||
const t0 = performance.now();
|
||||
inventoryStore.setLoading(true);
|
||||
try {
|
||||
const { getProviderInventory } = await import(
|
||||
"@/features/providers/api/inventory"
|
||||
);
|
||||
const entries = await getProviderInventory();
|
||||
inventoryStore.setEntries(entries);
|
||||
perfLog(
|
||||
`[perf:startup] loadProviderInventory done in ${(performance.now() - t0).toFixed(1)}ms (n=${entries.length})`,
|
||||
);
|
||||
return entries;
|
||||
} catch (err) {
|
||||
console.error("Failed to load provider inventory on startup:", err);
|
||||
return [];
|
||||
} finally {
|
||||
inventoryStore.setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const refreshConfiguredProviderInventory = async (
|
||||
initialEntries?: Awaited<ReturnType<typeof loadProviderInventory>>,
|
||||
) => {
|
||||
try {
|
||||
const entries =
|
||||
initialEntries && initialEntries.length > 0
|
||||
? initialEntries
|
||||
: await (async () => {
|
||||
const { getProviderInventory } = await import(
|
||||
"@/features/providers/api/inventory"
|
||||
);
|
||||
return getProviderInventory();
|
||||
})();
|
||||
const configuredProviderIds = entries
|
||||
.filter((entry) => entry.configured)
|
||||
.map((entry) => entry.providerId);
|
||||
if (configuredProviderIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { getProviderInventory, refreshProviderInventory } =
|
||||
await import("@/features/providers/api/inventory");
|
||||
const refresh = await refreshProviderInventory(configuredProviderIds);
|
||||
if (refresh.started.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
inventoryStore.mergeEntries(
|
||||
await getProviderInventory(refresh.started),
|
||||
);
|
||||
|
||||
for (const delayMs of INVENTORY_POLL_DELAYS_MS) {
|
||||
await sleep(delayMs);
|
||||
const refreshedEntries = await getProviderInventory(
|
||||
refresh.started,
|
||||
);
|
||||
inventoryStore.mergeEntries(refreshedEntries);
|
||||
if (refreshedEntries.every((entry) => !entry.refreshing)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(
|
||||
"Failed to refresh provider inventory on startup:",
|
||||
err,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const loadSessionState = async () => {
|
||||
const t0 = performance.now();
|
||||
perfLog("[perf:startup] loadSessionState start");
|
||||
|
|
@ -68,11 +146,17 @@ export function useAppStartup() {
|
|||
setActiveSession(null);
|
||||
};
|
||||
|
||||
const inventoryLoad = loadProviderInventory();
|
||||
|
||||
await Promise.allSettled([
|
||||
loadPersonas(),
|
||||
loadProviders(),
|
||||
inventoryLoad,
|
||||
loadSessionState(),
|
||||
]);
|
||||
void inventoryLoad.then((entries) =>
|
||||
refreshConfiguredProviderInventory(entries),
|
||||
);
|
||||
perfLog(
|
||||
`[perf:startup] useAppStartup complete in ${(performance.now() - tStartup).toFixed(1)}ms`,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -5,31 +5,19 @@ import { AgentsView } from "@/features/agents/ui/AgentsView";
|
|||
import { ProjectsView } from "@/features/projects/ui/ProjectsView";
|
||||
import { SessionHistoryView } from "@/features/sessions/ui/SessionHistoryView";
|
||||
import type { ChatSession } from "@/features/chat/stores/chatSessionStore";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
import type { ProjectInfo } from "@/features/projects/api/projects";
|
||||
import type { AppView } from "../AppShell";
|
||||
|
||||
interface AppShellContentProps {
|
||||
activeView: AppView;
|
||||
activeSession?: ChatSession;
|
||||
activeSessionPersonaId?: string;
|
||||
homeSelectedProvider?: string;
|
||||
homeSelectedPersonaId?: string;
|
||||
pendingInitialMessage?: string;
|
||||
pendingInitialAttachments?: ChatAttachmentDraft[];
|
||||
homeSessionId: string | null;
|
||||
onArchiveChat: (sessionId: string) => Promise<void>;
|
||||
onCreateProject: (options?: {
|
||||
initialWorkingDir?: string | null;
|
||||
onCreated?: (projectId: string) => void;
|
||||
}) => void;
|
||||
onHomeStartChat: (
|
||||
initialMessage?: string,
|
||||
providerId?: string,
|
||||
personaId?: string,
|
||||
projectId?: string | null,
|
||||
attachments?: ChatAttachmentDraft[],
|
||||
) => void;
|
||||
onInitialMessageConsumed: () => void;
|
||||
onActivateHomeSession: (sessionId: string) => void;
|
||||
onRenameChat: (sessionId: string, nextTitle: string) => void;
|
||||
onSelectSession: (sessionId: string) => void;
|
||||
onSelectSearchResult: (
|
||||
|
|
@ -43,15 +31,10 @@ interface AppShellContentProps {
|
|||
export function AppShellContent({
|
||||
activeView,
|
||||
activeSession,
|
||||
activeSessionPersonaId,
|
||||
homeSelectedProvider,
|
||||
homeSelectedPersonaId,
|
||||
pendingInitialMessage,
|
||||
pendingInitialAttachments,
|
||||
homeSessionId,
|
||||
onArchiveChat,
|
||||
onCreateProject,
|
||||
onHomeStartChat,
|
||||
onInitialMessageConsumed,
|
||||
onActivateHomeSession,
|
||||
onRenameChat,
|
||||
onSelectSession,
|
||||
onSelectSearchResult,
|
||||
|
|
@ -74,21 +57,24 @@ export function AppShellContent({
|
|||
/>
|
||||
);
|
||||
case "chat":
|
||||
case "home":
|
||||
return activeSession ? (
|
||||
<ChatView
|
||||
key={activeSession.id}
|
||||
sessionId={activeSession.id}
|
||||
initialProvider={homeSelectedProvider}
|
||||
initialPersonaId={activeSessionPersonaId ?? homeSelectedPersonaId}
|
||||
initialMessage={pendingInitialMessage}
|
||||
initialAttachments={pendingInitialAttachments}
|
||||
onCreateProject={onCreateProject}
|
||||
onInitialMessageConsumed={onInitialMessageConsumed}
|
||||
/>
|
||||
) : (
|
||||
<HomeScreen
|
||||
onStartChat={onHomeStartChat}
|
||||
sessionId={homeSessionId}
|
||||
onActivateSession={onActivateHomeSession}
|
||||
onCreateProject={onCreateProject}
|
||||
/>
|
||||
);
|
||||
case "home":
|
||||
return (
|
||||
<HomeScreen
|
||||
sessionId={homeSessionId}
|
||||
onActivateSession={onActivateHomeSession}
|
||||
onCreateProject={onCreateProject}
|
||||
/>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,125 @@
|
|||
import { act, renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { useAgentModelPickerState } from "../useAgentModelPickerState";
|
||||
|
||||
const mockUseProviderInventory = vi.fn();
|
||||
|
||||
vi.mock("@/features/providers/hooks/useProviderInventory", () => ({
|
||||
useProviderInventory: () => mockUseProviderInventory(),
|
||||
}));
|
||||
|
||||
describe("useAgentModelPickerState", () => {
|
||||
it("switches to goose when the current provider is goose-backed", () => {
|
||||
const onProviderSelected = vi.fn();
|
||||
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
entries: new Map([
|
||||
[
|
||||
"anthropic",
|
||||
{
|
||||
providerId: "anthropic",
|
||||
configured: true,
|
||||
refreshing: false,
|
||||
models: [],
|
||||
},
|
||||
],
|
||||
]),
|
||||
getEntry: (providerId: string) =>
|
||||
providerId === "anthropic"
|
||||
? {
|
||||
providerId: "anthropic",
|
||||
configured: true,
|
||||
refreshing: false,
|
||||
models: [],
|
||||
}
|
||||
: undefined,
|
||||
configuredModelProviderEntries: [],
|
||||
getModelsForAgent: () => [],
|
||||
loading: false,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useAgentModelPickerState({
|
||||
providers: [{ id: "anthropic", label: "Anthropic" }],
|
||||
selectedProvider: "anthropic",
|
||||
onProviderSelected,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProviderChange("goose");
|
||||
});
|
||||
|
||||
expect(onProviderSelected).toHaveBeenCalledWith("goose");
|
||||
});
|
||||
|
||||
it("treats goose as a no-op only when goose is already selected", () => {
|
||||
const onProviderSelected = vi.fn();
|
||||
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
entries: new Map(),
|
||||
getEntry: () => undefined,
|
||||
configuredModelProviderEntries: [],
|
||||
getModelsForAgent: () => [],
|
||||
loading: false,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useAgentModelPickerState({
|
||||
providers: [],
|
||||
selectedProvider: "goose",
|
||||
onProviderSelected,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProviderChange("goose");
|
||||
});
|
||||
|
||||
expect(onProviderSelected).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("passes the selected model provider through for goose model picks", () => {
|
||||
const onModelSelected = vi.fn();
|
||||
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
entries: new Map(),
|
||||
getEntry: () => undefined,
|
||||
configuredModelProviderEntries: [],
|
||||
getModelsForAgent: () => [
|
||||
{
|
||||
id: "claude-sonnet-4",
|
||||
name: "claude-sonnet-4",
|
||||
displayName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
providerName: "Anthropic",
|
||||
recommended: true,
|
||||
},
|
||||
],
|
||||
loading: false,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useAgentModelPickerState({
|
||||
providers: [{ id: "goose", label: "Goose" }],
|
||||
selectedProvider: "openai",
|
||||
onProviderSelected: vi.fn(),
|
||||
onModelSelected,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChange("claude-sonnet-4");
|
||||
});
|
||||
|
||||
expect(onModelSelected).toHaveBeenCalledWith({
|
||||
id: "claude-sonnet-4",
|
||||
name: "claude-sonnet-4",
|
||||
displayName: "Claude Sonnet 4",
|
||||
provider: undefined,
|
||||
providerId: "anthropic",
|
||||
providerName: "Anthropic",
|
||||
recommended: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -33,8 +33,6 @@ describe("useChat attachments", () => {
|
|||
isLoading: false,
|
||||
contextPanelOpenBySession: {},
|
||||
activeWorkspaceBySession: {},
|
||||
modelsBySession: {},
|
||||
modelCacheByProvider: {},
|
||||
});
|
||||
useAgentStore.setState({
|
||||
personas: [],
|
||||
|
|
|
|||
|
|
@ -81,8 +81,6 @@ describe("useChat", () => {
|
|||
isLoading: false,
|
||||
contextPanelOpenBySession: {},
|
||||
activeWorkspaceBySession: {},
|
||||
modelsBySession: {},
|
||||
modelCacheByProvider: {},
|
||||
});
|
||||
useAgentStore.setState({
|
||||
personas: [
|
||||
|
|
@ -354,7 +352,7 @@ describe("useChat", () => {
|
|||
});
|
||||
});
|
||||
|
||||
it("prepares draft sessions before applying a selected model on first send", async () => {
|
||||
it("sends messages without an extra session preparation step", async () => {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
{
|
||||
|
|
@ -366,28 +364,16 @@ describe("useChat", () => {
|
|||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
messageCount: 0,
|
||||
draft: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChat("session-1", "openai", undefined, undefined, async () => "/tmp"),
|
||||
);
|
||||
const { result } = renderHook(() => useChat("session-1", "openai"));
|
||||
|
||||
await act(async () => {
|
||||
await result.current.sendMessage("Hello");
|
||||
});
|
||||
|
||||
expect(mockAcpPrepareSession).toHaveBeenCalledWith(
|
||||
"session-1",
|
||||
"openai",
|
||||
"/tmp",
|
||||
{
|
||||
personaId: undefined,
|
||||
},
|
||||
);
|
||||
expect(mockAcpSetModel).toHaveBeenCalledWith("session-1", "gpt-4.1");
|
||||
expect(mockAcpSendMessage).toHaveBeenCalledWith("session-1", "Hello", {
|
||||
systemPrompt: undefined,
|
||||
personaId: undefined,
|
||||
|
|
@ -396,6 +382,50 @@ describe("useChat", () => {
|
|||
});
|
||||
});
|
||||
|
||||
it("fires onMessageAccepted only after the message enters the session", async () => {
|
||||
const onMessageAccepted = vi.fn();
|
||||
const deferred = createDeferredPromise();
|
||||
mockAcpSendMessage.mockReturnValue(deferred.promise);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChat("session-1", undefined, undefined, undefined, {
|
||||
onMessageAccepted,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
const sendPromise = result.current.sendMessage("Hello");
|
||||
await Promise.resolve();
|
||||
|
||||
expect(onMessageAccepted).toHaveBeenCalledTimes(1);
|
||||
expect(
|
||||
useChatStore.getState().messagesBySession["session-1"],
|
||||
).toHaveLength(1);
|
||||
|
||||
deferred.resolve();
|
||||
await sendPromise;
|
||||
});
|
||||
});
|
||||
|
||||
it("awaits ensurePrepared before prompting", async () => {
|
||||
const ensurePrepared = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChat("session-1", undefined, undefined, undefined, {
|
||||
ensurePrepared,
|
||||
}),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.sendMessage("Hello");
|
||||
});
|
||||
|
||||
expect(ensurePrepared).toHaveBeenCalledTimes(1);
|
||||
expect(ensurePrepared.mock.invocationCallOrder[0]).toBeLessThan(
|
||||
mockAcpSendMessage.mock.invocationCallOrder[0],
|
||||
);
|
||||
});
|
||||
|
||||
it("appends an error message and removes the empty assistant placeholder when send fails", async () => {
|
||||
mockAcpSendMessage.mockRejectedValue(
|
||||
new Error("Working directory missing"),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
import { act, renderHook, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import { useChatStore } from "../../stores/chatStore";
|
||||
import { useChatSessionStore } from "../../stores/chatSessionStore";
|
||||
|
||||
const mockAcpPrepareSession = vi.fn();
|
||||
const mockAcpSetModel = vi.fn();
|
||||
const mockSetSelectedProvider = vi.fn();
|
||||
const mockResolveSessionCwd = vi.fn();
|
||||
|
||||
vi.mock("@/shared/api/acp", () => ({
|
||||
acpPrepareSession: (...args: unknown[]) => mockAcpPrepareSession(...args),
|
||||
acpSetModel: (...args: unknown[]) => mockAcpSetModel(...args),
|
||||
}));
|
||||
|
||||
vi.mock("../useChat", () => ({
|
||||
useChat: () => ({
|
||||
messages: [],
|
||||
chatState: "idle",
|
||||
tokenState: null,
|
||||
sendMessage: vi.fn(),
|
||||
stopStreaming: vi.fn(),
|
||||
streamingMessageId: null,
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("../useMessageQueue", () => ({
|
||||
useMessageQueue: () => ({
|
||||
queuedMessage: null,
|
||||
enqueue: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
|
||||
useProviderSelection: () => ({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "openai", label: "OpenAI" },
|
||||
{ id: "anthropic", label: "Anthropic" },
|
||||
],
|
||||
providersLoading: false,
|
||||
selectedProvider: "openai",
|
||||
setSelectedProvider: (...args: unknown[]) =>
|
||||
mockSetSelectedProvider(...args),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("@/features/projects/lib/sessionCwdSelection", () => ({
|
||||
resolveSessionCwd: (...args: unknown[]) => mockResolveSessionCwd(...args),
|
||||
}));
|
||||
|
||||
vi.mock("../useAgentModelPickerState", () => ({
|
||||
useAgentModelPickerState: ({
|
||||
onModelSelected,
|
||||
}: {
|
||||
onModelSelected?: (model: {
|
||||
id: string;
|
||||
name: string;
|
||||
displayName?: string;
|
||||
providerId?: string;
|
||||
}) => void;
|
||||
}) => ({
|
||||
selectedAgentId: "goose",
|
||||
pickerAgents: [{ id: "goose", label: "Goose" }],
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null,
|
||||
handleProviderChange: vi.fn(),
|
||||
handleModelChange: (modelId: string) => {
|
||||
if (modelId === "claude-sonnet-4") {
|
||||
onModelSelected?.({
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
displayName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
});
|
||||
}
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
import { useChatSessionController } from "../useChatSessionController";
|
||||
|
||||
describe("useChatSessionController", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockAcpPrepareSession.mockResolvedValue(undefined);
|
||||
mockAcpSetModel.mockResolvedValue(undefined);
|
||||
mockResolveSessionCwd.mockResolvedValue("/tmp/project");
|
||||
|
||||
useAgentStore.setState({
|
||||
personas: [],
|
||||
personasLoading: false,
|
||||
agents: [],
|
||||
agentsLoading: false,
|
||||
providers: [],
|
||||
providersLoading: false,
|
||||
selectedProvider: "openai",
|
||||
activeAgentId: null,
|
||||
isLoading: false,
|
||||
personaEditorOpen: false,
|
||||
editingPersona: null,
|
||||
});
|
||||
|
||||
useProjectStore.setState({
|
||||
projects: [],
|
||||
loading: false,
|
||||
activeProjectId: null,
|
||||
});
|
||||
|
||||
useChatStore.setState({
|
||||
messagesBySession: {},
|
||||
sessionStateById: {},
|
||||
draftsBySession: {},
|
||||
queuedMessageBySession: {},
|
||||
scrollTargetMessageBySession: {},
|
||||
activeSessionId: null,
|
||||
isConnected: true,
|
||||
});
|
||||
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
{
|
||||
id: "session-1",
|
||||
title: "Chat",
|
||||
providerId: "openai",
|
||||
modelId: "gpt-4o",
|
||||
modelName: "GPT-4o",
|
||||
createdAt: "2026-04-20T00:00:00.000Z",
|
||||
updatedAt: "2026-04-20T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
],
|
||||
activeSessionId: null,
|
||||
isLoading: false,
|
||||
hasHydratedSessions: true,
|
||||
contextPanelOpenBySession: {},
|
||||
activeWorkspaceBySession: {},
|
||||
});
|
||||
});
|
||||
|
||||
it("prepares the selected model provider before setting a goose model", async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useChatSessionController({ sessionId: "session-1" }),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChange("claude-sonnet-4");
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockAcpPrepareSession).toHaveBeenCalledWith(
|
||||
"session-1",
|
||||
"anthropic",
|
||||
"/tmp/project",
|
||||
{ personaId: undefined },
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockAcpSetModel).toHaveBeenCalledWith(
|
||||
"session-1",
|
||||
"claude-sonnet-4",
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockAcpPrepareSession.mock.invocationCallOrder[0]).toBeLessThan(
|
||||
mockAcpSetModel.mock.invocationCallOrder[0],
|
||||
);
|
||||
expect(mockSetSelectedProvider).toHaveBeenCalledWith("anthropic");
|
||||
expect(
|
||||
useChatSessionStore.getState().getSession("session-1"),
|
||||
).toMatchObject({
|
||||
providerId: "anthropic",
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
});
|
||||
});
|
||||
});
|
||||
173
ui/goose2/src/features/chat/hooks/useAgentModelPickerState.ts
Normal file
173
ui/goose2/src/features/chat/hooks/useAgentModelPickerState.ts
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
import { useCallback, useMemo } from "react";
|
||||
import type { AcpProvider } from "@/shared/api/acp";
|
||||
import { useProviderInventory } from "@/features/providers/hooks/useProviderInventory";
|
||||
import {
|
||||
getCatalogEntry,
|
||||
resolveAgentProviderCatalogIdStrict,
|
||||
} from "@/features/providers/providerCatalog";
|
||||
import type { ModelOption } from "../types";
|
||||
|
||||
interface UseAgentModelPickerStateOptions {
|
||||
providers: AcpProvider[];
|
||||
selectedProvider?: string;
|
||||
onProviderSelected: (providerId: string) => void;
|
||||
onModelSelected?: (model: ModelOption) => void;
|
||||
}
|
||||
|
||||
const EMPTY_MODELS: ModelOption[] = [];
|
||||
|
||||
export function useAgentModelPickerState({
|
||||
providers,
|
||||
selectedProvider,
|
||||
onProviderSelected,
|
||||
onModelSelected,
|
||||
}: UseAgentModelPickerStateOptions) {
|
||||
const {
|
||||
entries: providerInventoryEntries,
|
||||
getEntry: getProviderInventoryEntry,
|
||||
configuredModelProviderEntries,
|
||||
getModelsForAgent,
|
||||
loading: providerInventoryLoading,
|
||||
} = useProviderInventory();
|
||||
|
||||
const selectedAgentId = selectedProvider
|
||||
? (resolveAgentProviderCatalogIdStrict(selectedProvider) ?? "goose")
|
||||
: "goose";
|
||||
const selectedProviderInventory = getProviderInventoryEntry(selectedAgentId);
|
||||
|
||||
const pickerAgents = useMemo(() => {
|
||||
const visible = new Map<string, { id: string; label: string }>();
|
||||
|
||||
visible.set("goose", {
|
||||
id: "goose",
|
||||
label: getCatalogEntry("goose")?.displayName ?? "Goose",
|
||||
});
|
||||
|
||||
for (const provider of providers) {
|
||||
const agentId = resolveAgentProviderCatalogIdStrict(provider.id);
|
||||
if (!agentId || agentId === "goose") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const inventoryEntry = providerInventoryEntries.get(agentId);
|
||||
if (!inventoryEntry?.configured && agentId !== selectedAgentId) {
|
||||
continue;
|
||||
}
|
||||
|
||||
visible.set(agentId, {
|
||||
id: agentId,
|
||||
label: getCatalogEntry(agentId)?.displayName ?? provider.label,
|
||||
});
|
||||
}
|
||||
|
||||
if (!visible.has(selectedAgentId)) {
|
||||
visible.set(selectedAgentId, {
|
||||
id: selectedAgentId,
|
||||
label: getCatalogEntry(selectedAgentId)?.displayName ?? selectedAgentId,
|
||||
});
|
||||
}
|
||||
|
||||
return [...visible.values()];
|
||||
}, [providerInventoryEntries, providers, selectedAgentId]);
|
||||
|
||||
const availableModels = useMemo(
|
||||
() => getModelsForAgent(selectedAgentId) ?? EMPTY_MODELS,
|
||||
[getModelsForAgent, selectedAgentId],
|
||||
);
|
||||
|
||||
const modelsLoading = useMemo(() => {
|
||||
// Show loading only when we have no models to display yet.
|
||||
// If cached models exist, show them immediately — a background refresh
|
||||
// will update the list when it completes.
|
||||
if (availableModels.length > 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (providerInventoryLoading) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (selectedAgentId === "goose") {
|
||||
return (
|
||||
configuredModelProviderEntries.length > 0 &&
|
||||
configuredModelProviderEntries.some((entry) => entry.refreshing)
|
||||
);
|
||||
}
|
||||
|
||||
return selectedProviderInventory?.refreshing === true;
|
||||
}, [
|
||||
availableModels.length,
|
||||
configuredModelProviderEntries,
|
||||
providerInventoryLoading,
|
||||
selectedAgentId,
|
||||
selectedProviderInventory?.refreshing,
|
||||
]);
|
||||
|
||||
const modelStatusMessage = useMemo(() => {
|
||||
if (availableModels.length > 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (selectedAgentId === "goose") {
|
||||
const entryWithHint = configuredModelProviderEntries.find(
|
||||
(entry) => entry.modelSelectionHint || entry.lastRefreshError,
|
||||
);
|
||||
return (
|
||||
entryWithHint?.modelSelectionHint ??
|
||||
entryWithHint?.lastRefreshError ??
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
selectedProviderInventory?.modelSelectionHint ??
|
||||
selectedProviderInventory?.lastRefreshError ??
|
||||
null
|
||||
);
|
||||
}, [
|
||||
availableModels.length,
|
||||
configuredModelProviderEntries,
|
||||
selectedAgentId,
|
||||
selectedProviderInventory?.modelSelectionHint,
|
||||
selectedProviderInventory?.lastRefreshError,
|
||||
]);
|
||||
|
||||
const handleProviderChange = useCallback(
|
||||
(providerId: string) => {
|
||||
if (providerId === (selectedProvider ?? "goose")) {
|
||||
return;
|
||||
}
|
||||
|
||||
onProviderSelected(providerId);
|
||||
},
|
||||
[onProviderSelected, selectedProvider],
|
||||
);
|
||||
|
||||
const handleModelChange = useCallback(
|
||||
(modelId: string) => {
|
||||
const selectedModel = availableModels.find(
|
||||
(model) => model.id === modelId,
|
||||
);
|
||||
onModelSelected?.({
|
||||
id: modelId,
|
||||
name: selectedModel?.name ?? modelId,
|
||||
displayName: selectedModel?.displayName ?? modelId,
|
||||
provider: selectedModel?.provider,
|
||||
providerId: selectedModel?.providerId,
|
||||
providerName: selectedModel?.providerName,
|
||||
recommended: selectedModel?.recommended,
|
||||
});
|
||||
},
|
||||
[availableModels, onModelSelected],
|
||||
);
|
||||
|
||||
return {
|
||||
selectedAgentId,
|
||||
pickerAgents,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleProviderChange,
|
||||
handleModelChange,
|
||||
};
|
||||
}
|
||||
|
|
@ -14,8 +14,6 @@ import {
|
|||
acpSendMessage,
|
||||
acpCancelSession,
|
||||
acpLoadSession,
|
||||
acpPrepareSession,
|
||||
acpSetModel,
|
||||
} from "@/shared/api/acp";
|
||||
import { getGooseSessionId } from "@/shared/api/acpSessionTracker";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
|
|
@ -109,7 +107,10 @@ export function useChat(
|
|||
providerOverride?: string,
|
||||
systemPromptOverride?: string,
|
||||
personaInfo?: { id: string; name: string },
|
||||
getWorkingDir?: () => Promise<string | undefined>,
|
||||
options?: {
|
||||
onMessageAccepted?: (sessionId: string) => void;
|
||||
ensurePrepared?: () => Promise<void>;
|
||||
},
|
||||
) {
|
||||
const store = useChatStore();
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
|
|
@ -215,15 +216,8 @@ export function useChat(
|
|||
store.setChatState(sessionId, "thinking");
|
||||
store.setError(sessionId, null);
|
||||
|
||||
// Promote draft to real backend session before first send
|
||||
const sessionStore = useChatSessionStore.getState();
|
||||
const session = sessionStore.getSession(sessionId);
|
||||
const wasDraft = !!session?.draft;
|
||||
const selectedModelId = session?.modelId;
|
||||
|
||||
if (wasDraft) {
|
||||
sessionStore.promoteDraft(sessionId);
|
||||
}
|
||||
|
||||
// Immediately set the session/sidebar title from the user's message when
|
||||
// the session still has the default placeholder. This gives instant
|
||||
|
|
@ -231,20 +225,18 @@ export function useChat(
|
|||
// A better backend-generated title will overwrite this if it arrives
|
||||
// via the acp:session_info event.
|
||||
if (session && isDefaultChatTitle(session.title)) {
|
||||
sessionStore.updateSession(
|
||||
sessionId,
|
||||
{
|
||||
title: getSessionTitleFromDraft(text, attachments),
|
||||
updatedAt: new Date().toISOString(),
|
||||
},
|
||||
{ localOnly: wasDraft },
|
||||
);
|
||||
sessionStore.updateSession(sessionId, {
|
||||
title: getSessionTitleFromDraft(text, attachments),
|
||||
updatedAt: new Date().toISOString(),
|
||||
});
|
||||
} else {
|
||||
sessionStore.updateSession(sessionId, {
|
||||
updatedAt: new Date().toISOString(),
|
||||
});
|
||||
}
|
||||
|
||||
options?.onMessageAccepted?.(sessionId);
|
||||
|
||||
store.clearDraft(sessionId);
|
||||
|
||||
const abort = new AbortController();
|
||||
|
|
@ -252,26 +244,7 @@ export function useChat(
|
|||
streamingPersonaIdRef.current = effectivePersonaInfo?.id ?? null;
|
||||
|
||||
try {
|
||||
if (wasDraft || selectedModelId) {
|
||||
const workingDir = await getWorkingDir?.();
|
||||
if (!workingDir) {
|
||||
throw new Error("Missing session working directory");
|
||||
}
|
||||
const tPrep = performance.now();
|
||||
await acpPrepareSession(sessionId, providerId, workingDir, {
|
||||
personaId: effectivePersonaInfo?.id,
|
||||
});
|
||||
perfLog(
|
||||
`[perf:send] ${sid} acpPrepareSession in ${(performance.now() - tPrep).toFixed(1)}ms (wasDraft=${wasDraft})`,
|
||||
);
|
||||
if (selectedModelId) {
|
||||
const tModel = performance.now();
|
||||
await acpSetModel(sessionId, selectedModelId);
|
||||
perfLog(
|
||||
`[perf:send] ${sid} acpSetModel(${selectedModelId}) in ${(performance.now() - tModel).toFixed(1)}ms`,
|
||||
);
|
||||
}
|
||||
}
|
||||
await options?.ensurePrepared?.();
|
||||
|
||||
store.setChatState(sessionId, "streaming");
|
||||
// When images are present with no text, pass a single space so the ACP
|
||||
|
|
@ -298,18 +271,6 @@ export function useChat(
|
|||
|
||||
store.setChatState(sessionId, "idle");
|
||||
store.setStreamingMessageId(sessionId, null);
|
||||
|
||||
if (wasDraft) {
|
||||
const promoted = sessionStore.getSession(sessionId);
|
||||
if (promoted) {
|
||||
sessionStore.updateSession(sessionId, {
|
||||
title: promoted.title,
|
||||
providerId: promoted.providerId,
|
||||
personaId: promoted.personaId,
|
||||
projectId: promoted.projectId,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") {
|
||||
store.setChatState(sessionId, "idle");
|
||||
|
|
@ -351,7 +312,7 @@ export function useChat(
|
|||
providerOverride,
|
||||
systemPromptOverride,
|
||||
resolvePersonaInfo,
|
||||
getWorkingDir,
|
||||
options,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
@ -415,6 +376,12 @@ export function useChat(
|
|||
store.setPendingAssistantProvider(sessionId, null);
|
||||
}, [sessionId, store]);
|
||||
|
||||
const getWorkingDir = useCallback(
|
||||
() =>
|
||||
useChatSessionStore.getState().activeWorkspaceBySession[sessionId]?.path,
|
||||
[sessionId],
|
||||
);
|
||||
|
||||
const compactConversation = useCallback(async () => {
|
||||
const currentChatState = useChatStore
|
||||
.getState()
|
||||
|
|
@ -457,7 +424,7 @@ export function useChat(
|
|||
// layer does not currently forward history replacement events. Drop those
|
||||
// transient chunks and refresh the session from replay instead.
|
||||
clearReplayBuffer(sessionId);
|
||||
const workingDir = await getWorkingDir?.();
|
||||
const workingDir = getWorkingDir();
|
||||
await acpLoadSession(sessionId, gooseSessionId, workingDir);
|
||||
|
||||
store.setSessionLoading(sessionId, false);
|
||||
|
|
|
|||
682
ui/goose2/src/features/chat/hooks/useChatSessionController.ts
Normal file
682
ui/goose2/src/features/chat/hooks/useChatSessionController.ts
Normal file
|
|
@ -0,0 +1,682 @@
|
|||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
import { useChat } from "./useChat";
|
||||
import { useMessageQueue } from "./useMessageQueue";
|
||||
import { useChatStore } from "../stores/chatStore";
|
||||
import { useChatSessionStore } from "../stores/chatSessionStore";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import { useAgentModelPickerState } from "./useAgentModelPickerState";
|
||||
import {
|
||||
buildProjectSystemPrompt,
|
||||
composeSystemPrompt,
|
||||
getProjectArtifactRoots,
|
||||
resolveProjectDefaultArtifactRoot,
|
||||
} from "@/features/projects/lib/chatProjectContext";
|
||||
import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection";
|
||||
import { acpPrepareSession, acpSetModel } from "@/shared/api/acp";
|
||||
|
||||
interface UseChatSessionControllerOptions {
|
||||
sessionId: string | null;
|
||||
onMessageAccepted?: (sessionId: string) => void;
|
||||
}
|
||||
|
||||
const PENDING_HOME_SESSION_ID = "__home_pending__";
|
||||
|
||||
export function useChatSessionController({
|
||||
sessionId,
|
||||
onMessageAccepted,
|
||||
}: UseChatSessionControllerOptions) {
|
||||
const stateSessionId = sessionId ?? PENDING_HOME_SESSION_ID;
|
||||
const {
|
||||
providers,
|
||||
providersLoading,
|
||||
selectedProvider: globalSelectedProvider,
|
||||
setSelectedProvider: setGlobalSelectedProvider,
|
||||
} = useProviderSelection();
|
||||
const personas = useAgentStore((s) => s.personas);
|
||||
const session = useChatSessionStore((s) =>
|
||||
sessionId
|
||||
? s.sessions.find((candidate) => candidate.id === sessionId)
|
||||
: undefined,
|
||||
);
|
||||
const activeWorkspace = useChatSessionStore((s) =>
|
||||
sessionId ? s.activeWorkspaceBySession[sessionId] : undefined,
|
||||
);
|
||||
const clearActiveWorkspace = useChatSessionStore(
|
||||
(s) => s.clearActiveWorkspace,
|
||||
);
|
||||
const projects = useProjectStore((s) => s.projects);
|
||||
const projectsLoading = useProjectStore((s) => s.loading);
|
||||
const [pendingPersonaId, setPendingPersonaId] = useState<string | null>();
|
||||
const [pendingProjectId, setPendingProjectId] = useState<string | null>();
|
||||
const [pendingProviderId, setPendingProviderId] = useState<string>();
|
||||
const [pendingModelSelection, setPendingModelSelection] = useState<{
|
||||
id: string;
|
||||
name: string;
|
||||
providerId?: string;
|
||||
} | null>();
|
||||
const pendingDraftValue = useChatStore(
|
||||
(s) => s.draftsBySession[PENDING_HOME_SESSION_ID] ?? "",
|
||||
);
|
||||
const pendingQueuedMessage = useChatStore(
|
||||
(s) => s.queuedMessageBySession[PENDING_HOME_SESSION_ID] ?? null,
|
||||
);
|
||||
const effectiveProjectId =
|
||||
pendingProjectId !== undefined
|
||||
? pendingProjectId
|
||||
: (session?.projectId ?? null);
|
||||
const storedProject = useProjectStore((s) =>
|
||||
effectiveProjectId
|
||||
? s.projects.find((candidate) => candidate.id === effectiveProjectId)
|
||||
: undefined,
|
||||
);
|
||||
const project = storedProject ?? null;
|
||||
const selectedProvider =
|
||||
pendingProviderId ??
|
||||
session?.providerId ??
|
||||
project?.preferredProvider ??
|
||||
globalSelectedProvider;
|
||||
const selectedPersonaId =
|
||||
pendingPersonaId !== undefined
|
||||
? pendingPersonaId
|
||||
: (session?.personaId ?? null);
|
||||
const selectedPersona = personas.find(
|
||||
(persona) => persona.id === selectedPersonaId,
|
||||
);
|
||||
const projectArtifactRoots = useMemo(
|
||||
() => getProjectArtifactRoots(project),
|
||||
[project],
|
||||
);
|
||||
const projectDefaultArtifactRoot = useMemo(
|
||||
() => resolveProjectDefaultArtifactRoot(project),
|
||||
[project],
|
||||
);
|
||||
const projectMetadataPending = Boolean(
|
||||
effectiveProjectId && !projectDefaultArtifactRoot && projectsLoading,
|
||||
);
|
||||
const allowedArtifactRoots = useMemo(
|
||||
() => [
|
||||
...new Set(
|
||||
projectArtifactRoots.map((path) => path.trim()).filter(Boolean),
|
||||
),
|
||||
],
|
||||
[projectArtifactRoots],
|
||||
);
|
||||
const availableProjects = useMemo(
|
||||
() =>
|
||||
[...projects]
|
||||
.sort((a, b) => a.order - b.order || a.name.localeCompare(b.name))
|
||||
.map((projectInfo) => ({
|
||||
id: projectInfo.id,
|
||||
name: projectInfo.name,
|
||||
workingDirs: projectInfo.workingDirs,
|
||||
color: projectInfo.color,
|
||||
})),
|
||||
[projects],
|
||||
);
|
||||
const projectSystemPrompt = useMemo(
|
||||
() => buildProjectSystemPrompt(project),
|
||||
[project],
|
||||
);
|
||||
const workingContextPrompt = useMemo(() => {
|
||||
if (!activeWorkspace?.branch) return undefined;
|
||||
return `<active-working-context>\nActive branch: ${activeWorkspace.branch}\nWorking directory: ${activeWorkspace.path}\n</active-working-context>`;
|
||||
}, [activeWorkspace?.branch, activeWorkspace?.path]);
|
||||
const effectiveSystemPrompt = useMemo(
|
||||
() =>
|
||||
composeSystemPrompt(
|
||||
selectedPersona?.systemPrompt,
|
||||
projectSystemPrompt,
|
||||
workingContextPrompt,
|
||||
),
|
||||
[projectSystemPrompt, selectedPersona?.systemPrompt, workingContextPrompt],
|
||||
);
|
||||
|
||||
const prepareCurrentSession = useCallback(
|
||||
async (
|
||||
providerId: string,
|
||||
nextProject = project,
|
||||
nextWorkspacePath = activeWorkspace?.path,
|
||||
personaId = selectedPersonaId ?? undefined,
|
||||
) => {
|
||||
if (!sessionId) {
|
||||
return;
|
||||
}
|
||||
const workingDir = await resolveSessionCwd(
|
||||
nextProject,
|
||||
nextWorkspacePath,
|
||||
);
|
||||
await acpPrepareSession(sessionId, providerId, workingDir, { personaId });
|
||||
},
|
||||
[activeWorkspace?.path, project, selectedPersonaId, sessionId],
|
||||
);
|
||||
|
||||
const prevProjectIdRef = useRef(session?.projectId);
|
||||
useEffect(() => {
|
||||
if (!sessionId) {
|
||||
return;
|
||||
}
|
||||
const previousProjectId = prevProjectIdRef.current;
|
||||
prevProjectIdRef.current = session?.projectId;
|
||||
if (
|
||||
previousProjectId !== undefined &&
|
||||
previousProjectId !== session?.projectId
|
||||
) {
|
||||
clearActiveWorkspace(sessionId);
|
||||
}
|
||||
}, [clearActiveWorkspace, session?.projectId, sessionId]);
|
||||
|
||||
const prevWorkspaceRef = useRef(activeWorkspace);
|
||||
useEffect(() => {
|
||||
const previousWorkspace = prevWorkspaceRef.current;
|
||||
if (
|
||||
!sessionId ||
|
||||
!activeWorkspace ||
|
||||
!selectedProvider ||
|
||||
activeWorkspace === previousWorkspace
|
||||
) {
|
||||
return;
|
||||
}
|
||||
prevWorkspaceRef.current = activeWorkspace;
|
||||
if (previousWorkspace?.path === activeWorkspace.path) {
|
||||
return;
|
||||
}
|
||||
void prepareCurrentSession(selectedProvider).catch((error) => {
|
||||
console.error("Failed to prepare ACP session:", error);
|
||||
});
|
||||
}, [activeWorkspace, prepareCurrentSession, selectedProvider, sessionId]);
|
||||
|
||||
const {
|
||||
selectedAgentId,
|
||||
pickerAgents,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleProviderChange,
|
||||
handleModelChange,
|
||||
} = useAgentModelPickerState({
|
||||
providers,
|
||||
selectedProvider,
|
||||
onProviderSelected: (providerId) => {
|
||||
if (!sessionId) {
|
||||
setGlobalSelectedProvider(providerId);
|
||||
setPendingProviderId(providerId);
|
||||
setPendingModelSelection(null);
|
||||
return;
|
||||
}
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.switchSessionProvider(sessionId, providerId);
|
||||
setGlobalSelectedProvider(providerId);
|
||||
void prepareCurrentSession(providerId).catch((error) => {
|
||||
console.error("Failed to update ACP session provider:", error);
|
||||
});
|
||||
},
|
||||
onModelSelected: (model) => {
|
||||
const modelId = model.id;
|
||||
const modelName = model.displayName ?? model.name ?? model.id;
|
||||
const nextProviderId = model.providerId ?? selectedProvider;
|
||||
|
||||
if (!sessionId) {
|
||||
if (nextProviderId && nextProviderId !== selectedProvider) {
|
||||
setPendingProviderId(nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
}
|
||||
setPendingModelSelection({
|
||||
id: modelId,
|
||||
name: modelName,
|
||||
providerId: nextProviderId,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (
|
||||
!session ||
|
||||
(modelId === session.modelId &&
|
||||
(!nextProviderId || nextProviderId === session.providerId))
|
||||
) {
|
||||
return;
|
||||
}
|
||||
const previousProviderId = session.providerId;
|
||||
const previousModelId = session.modelId;
|
||||
const previousModelName = session.modelName;
|
||||
const providerChanged =
|
||||
Boolean(nextProviderId) && nextProviderId !== session.providerId;
|
||||
|
||||
if (providerChanged && nextProviderId) {
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.switchSessionProvider(sessionId, nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
}
|
||||
|
||||
useChatSessionStore.getState().updateSession(sessionId, {
|
||||
modelId,
|
||||
modelName,
|
||||
});
|
||||
|
||||
void (async () => {
|
||||
try {
|
||||
if (providerChanged && nextProviderId) {
|
||||
await prepareCurrentSession(nextProviderId);
|
||||
}
|
||||
await acpSetModel(sessionId, modelId);
|
||||
} catch (error) {
|
||||
console.error("Failed to set model:", error);
|
||||
if (providerChanged && previousProviderId) {
|
||||
setGlobalSelectedProvider(previousProviderId);
|
||||
}
|
||||
useChatSessionStore.getState().updateSession(sessionId, {
|
||||
providerId: previousProviderId,
|
||||
modelId: previousModelId,
|
||||
modelName: previousModelName,
|
||||
});
|
||||
void (async () => {
|
||||
try {
|
||||
if (providerChanged && previousProviderId) {
|
||||
await prepareCurrentSession(previousProviderId);
|
||||
}
|
||||
if (previousModelId) {
|
||||
await acpSetModel(sessionId, previousModelId);
|
||||
}
|
||||
} catch (rollbackError) {
|
||||
console.error(
|
||||
"Failed to restore previous provider/model after setModel failure:",
|
||||
rollbackError,
|
||||
);
|
||||
}
|
||||
})();
|
||||
}
|
||||
})();
|
||||
},
|
||||
});
|
||||
|
||||
const handleProjectChange = useCallback(
|
||||
(projectId: string | null) => {
|
||||
if (!sessionId) {
|
||||
setPendingProjectId(projectId);
|
||||
return;
|
||||
}
|
||||
const nextProject =
|
||||
projectId == null
|
||||
? null
|
||||
: (useProjectStore
|
||||
.getState()
|
||||
.projects.find((candidate) => candidate.id === projectId) ??
|
||||
null);
|
||||
|
||||
useChatSessionStore.getState().updateSession(sessionId, { projectId });
|
||||
if (!selectedProvider) {
|
||||
return;
|
||||
}
|
||||
void prepareCurrentSession(selectedProvider, nextProject).catch(
|
||||
(error) => {
|
||||
console.error(
|
||||
"Failed to update ACP session working directory:",
|
||||
error,
|
||||
);
|
||||
},
|
||||
);
|
||||
},
|
||||
[prepareCurrentSession, selectedProvider, sessionId],
|
||||
);
|
||||
|
||||
const handlePersonaChange = useCallback(
|
||||
(personaId: string | null) => {
|
||||
const persona = personas.find((candidate) => candidate.id === personaId);
|
||||
if (persona?.provider) {
|
||||
const matchingProvider = providers.find(
|
||||
(provider) =>
|
||||
provider.id === persona.provider ||
|
||||
provider.label.toLowerCase().includes(persona.provider ?? ""),
|
||||
);
|
||||
if (matchingProvider) {
|
||||
if (!sessionId) {
|
||||
setPendingProviderId(matchingProvider.id);
|
||||
setPendingModelSelection(null);
|
||||
setGlobalSelectedProvider(matchingProvider.id);
|
||||
} else {
|
||||
handleProviderChange(matchingProvider.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
const agentStore = useAgentStore.getState();
|
||||
const matchingAgent = agentStore.agents.find(
|
||||
(agent) => agent.personaId === personaId,
|
||||
);
|
||||
if (matchingAgent) {
|
||||
agentStore.setActiveAgent(matchingAgent.id);
|
||||
}
|
||||
if (!sessionId) {
|
||||
setPendingPersonaId(personaId);
|
||||
return;
|
||||
}
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.updateSession(sessionId, { personaId: personaId ?? undefined });
|
||||
},
|
||||
[
|
||||
handleProviderChange,
|
||||
personas,
|
||||
providers,
|
||||
sessionId,
|
||||
setGlobalSelectedProvider,
|
||||
],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
selectedPersonaId !== null &&
|
||||
personas.length > 0 &&
|
||||
!personas.find((persona) => persona.id === selectedPersonaId)
|
||||
) {
|
||||
if (sessionId) {
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.updateSession(sessionId, { personaId: undefined });
|
||||
} else {
|
||||
setPendingPersonaId(undefined);
|
||||
}
|
||||
}
|
||||
}, [personas, selectedPersonaId, sessionId]);
|
||||
|
||||
const personaInfo = selectedPersona
|
||||
? { id: selectedPersona.id, name: selectedPersona.displayName }
|
||||
: undefined;
|
||||
const {
|
||||
messages,
|
||||
chatState,
|
||||
tokenState,
|
||||
sendMessage,
|
||||
stopStreaming,
|
||||
streamingMessageId,
|
||||
} = useChat(
|
||||
stateSessionId,
|
||||
selectedProvider,
|
||||
effectiveSystemPrompt,
|
||||
personaInfo,
|
||||
{
|
||||
onMessageAccepted: sessionId ? onMessageAccepted : undefined,
|
||||
ensurePrepared: selectedProvider
|
||||
? () => prepareCurrentSession(selectedProvider)
|
||||
: undefined,
|
||||
},
|
||||
);
|
||||
const isLoadingHistory = useChatStore((s) =>
|
||||
sessionId
|
||||
? s.loadingSessionIds.has(sessionId) &&
|
||||
(s.messagesBySession[sessionId]?.length ?? 0) === 0
|
||||
: false,
|
||||
);
|
||||
const deferredSend = useRef<{
|
||||
text: string;
|
||||
attachments?: ChatAttachmentDraft[];
|
||||
} | null>(null);
|
||||
const queue = useMessageQueue(
|
||||
stateSessionId,
|
||||
sessionId ? chatState : "thinking",
|
||||
sendMessage,
|
||||
);
|
||||
const chatStore = useChatStore();
|
||||
|
||||
const handleSend = useCallback(
|
||||
(text: string, personaId?: string, attachments?: ChatAttachmentDraft[]) => {
|
||||
if (!sessionId) {
|
||||
if (!queue.queuedMessage) {
|
||||
queue.enqueue(text, personaId, attachments);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (personaId && personaId !== selectedPersonaId) {
|
||||
const nextPersona = personas.find(
|
||||
(persona) => persona.id === personaId,
|
||||
);
|
||||
if (nextPersona) {
|
||||
chatStore.addMessage(sessionId, {
|
||||
id: crypto.randomUUID(),
|
||||
role: "system",
|
||||
created: Date.now(),
|
||||
content: [
|
||||
{
|
||||
type: "systemNotification",
|
||||
notificationType: "info",
|
||||
text: `Switched to ${nextPersona.displayName}`,
|
||||
},
|
||||
],
|
||||
metadata: { userVisible: true, agentVisible: false },
|
||||
});
|
||||
}
|
||||
handlePersonaChange(personaId);
|
||||
deferredSend.current = { text, attachments };
|
||||
return;
|
||||
}
|
||||
|
||||
if (chatState !== "idle" && !queue.queuedMessage) {
|
||||
queue.enqueue(text, personaId, attachments);
|
||||
return;
|
||||
}
|
||||
|
||||
sendMessage(text, undefined, attachments);
|
||||
},
|
||||
[
|
||||
chatState,
|
||||
chatStore,
|
||||
handlePersonaChange,
|
||||
personas,
|
||||
queue,
|
||||
sessionId,
|
||||
selectedPersonaId,
|
||||
sendMessage,
|
||||
],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (deferredSend.current && selectedPersona) {
|
||||
const { text, attachments } = deferredSend.current;
|
||||
deferredSend.current = null;
|
||||
sendMessage(text, undefined, attachments);
|
||||
}
|
||||
}, [selectedPersona, sendMessage]);
|
||||
|
||||
const handleCreatePersona = useCallback(() => {
|
||||
useAgentStore.getState().openPersonaEditor();
|
||||
}, []);
|
||||
|
||||
const sessionDraftValue = useChatStore((s) =>
|
||||
sessionId ? (s.draftsBySession[sessionId] ?? "") : "",
|
||||
);
|
||||
const draftValue = sessionId ? sessionDraftValue : pendingDraftValue;
|
||||
const handleDraftChange = useCallback(
|
||||
(text: string) => {
|
||||
useChatStore.getState().setDraft(stateSessionId, text);
|
||||
},
|
||||
[stateSessionId],
|
||||
);
|
||||
const scrollTarget = useChatStore((s) =>
|
||||
sessionId ? (s.scrollTargetMessageBySession[sessionId] ?? null) : null,
|
||||
);
|
||||
const handleScrollTargetHandled = useCallback(() => {
|
||||
if (!sessionId) {
|
||||
return;
|
||||
}
|
||||
useChatStore.getState().clearScrollTargetMessage(sessionId);
|
||||
}, [sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!sessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
void pendingDraftValue;
|
||||
void pendingQueuedMessage;
|
||||
|
||||
const syncPendingHomeState = async () => {
|
||||
const chatState = useChatStore.getState();
|
||||
const pendingDraft =
|
||||
chatState.draftsBySession[PENDING_HOME_SESSION_ID] ?? "";
|
||||
|
||||
if (pendingDraft && !chatState.draftsBySession[sessionId]) {
|
||||
chatState.setDraft(sessionId, pendingDraft);
|
||||
}
|
||||
|
||||
const hasPendingProvider = pendingProviderId !== undefined;
|
||||
const hasPendingPersona = pendingPersonaId !== undefined;
|
||||
const hasPendingProject = pendingProjectId !== undefined;
|
||||
const hasPendingModel = pendingModelSelection !== undefined;
|
||||
|
||||
if (
|
||||
hasPendingProvider ||
|
||||
hasPendingPersona ||
|
||||
hasPendingProject ||
|
||||
hasPendingModel
|
||||
) {
|
||||
const nextProviderId = pendingProviderId ?? selectedProvider;
|
||||
const nextPersonaId =
|
||||
pendingPersonaId !== undefined
|
||||
? (pendingPersonaId ?? undefined)
|
||||
: session?.personaId;
|
||||
const nextProjectId =
|
||||
pendingProjectId !== undefined
|
||||
? pendingProjectId
|
||||
: session?.projectId;
|
||||
const nextProject =
|
||||
nextProjectId == null
|
||||
? null
|
||||
: (useProjectStore
|
||||
.getState()
|
||||
.projects.find((candidate) => candidate.id === nextProjectId) ??
|
||||
null);
|
||||
|
||||
const patch: {
|
||||
providerId?: string;
|
||||
personaId?: string | undefined;
|
||||
projectId?: string | null;
|
||||
modelId?: string | undefined;
|
||||
modelName?: string | undefined;
|
||||
} = {};
|
||||
|
||||
if (hasPendingProvider) {
|
||||
patch.providerId = nextProviderId;
|
||||
patch.modelId = undefined;
|
||||
patch.modelName = undefined;
|
||||
}
|
||||
if (hasPendingPersona) {
|
||||
patch.personaId = nextPersonaId;
|
||||
}
|
||||
if (hasPendingProject) {
|
||||
patch.projectId = nextProjectId ?? null;
|
||||
}
|
||||
if (hasPendingModel) {
|
||||
patch.modelId = pendingModelSelection?.id;
|
||||
patch.modelName = pendingModelSelection?.name;
|
||||
}
|
||||
|
||||
useChatSessionStore.getState().updateSession(sessionId, patch);
|
||||
|
||||
try {
|
||||
await prepareCurrentSession(
|
||||
nextProviderId,
|
||||
nextProject,
|
||||
activeWorkspace?.path,
|
||||
nextPersonaId,
|
||||
);
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
if (pendingModelSelection?.id) {
|
||||
await acpSetModel(sessionId, pendingModelSelection.id);
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to sync pending Home state:", error);
|
||||
return;
|
||||
}
|
||||
|
||||
setPendingProviderId(undefined);
|
||||
setPendingPersonaId(undefined);
|
||||
setPendingProjectId(undefined);
|
||||
setPendingModelSelection(undefined);
|
||||
}
|
||||
|
||||
const latestChatState = useChatStore.getState();
|
||||
const latestPendingQueue =
|
||||
latestChatState.queuedMessageBySession[PENDING_HOME_SESSION_ID] ?? null;
|
||||
if (
|
||||
latestPendingQueue &&
|
||||
!latestChatState.queuedMessageBySession[sessionId]
|
||||
) {
|
||||
latestChatState.enqueueMessage(sessionId, latestPendingQueue);
|
||||
}
|
||||
|
||||
useChatStore.getState().clearDraft(PENDING_HOME_SESSION_ID);
|
||||
useChatStore.getState().dismissQueuedMessage(PENDING_HOME_SESSION_ID);
|
||||
useChatStore.getState().cleanupSession(PENDING_HOME_SESSION_ID);
|
||||
};
|
||||
|
||||
void syncPendingHomeState();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [
|
||||
activeWorkspace?.path,
|
||||
pendingDraftValue,
|
||||
pendingModelSelection,
|
||||
pendingPersonaId,
|
||||
pendingProjectId,
|
||||
pendingProviderId,
|
||||
pendingQueuedMessage,
|
||||
prepareCurrentSession,
|
||||
selectedProvider,
|
||||
session?.personaId,
|
||||
session?.projectId,
|
||||
sessionId,
|
||||
]);
|
||||
|
||||
return {
|
||||
session,
|
||||
project,
|
||||
allowedArtifactRoots,
|
||||
messages,
|
||||
chatState,
|
||||
tokenState,
|
||||
stopStreaming,
|
||||
streamingMessageId,
|
||||
isLoadingHistory,
|
||||
queue,
|
||||
handleSend,
|
||||
draftValue,
|
||||
handleDraftChange,
|
||||
scrollTarget,
|
||||
handleScrollTargetHandled,
|
||||
projectMetadataPending,
|
||||
personas,
|
||||
selectedPersonaId,
|
||||
handlePersonaChange,
|
||||
handleCreatePersona,
|
||||
pickerAgents,
|
||||
providersLoading,
|
||||
selectedProvider: selectedAgentId,
|
||||
handleProviderChange,
|
||||
currentModelId:
|
||||
pendingModelSelection !== undefined
|
||||
? (pendingModelSelection?.id ?? null)
|
||||
: (session?.modelId ?? null),
|
||||
currentModelName:
|
||||
pendingModelSelection !== undefined
|
||||
? (pendingModelSelection?.name ?? null)
|
||||
: session?.modelName,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleModelChange,
|
||||
selectedProjectId: effectiveProjectId,
|
||||
availableProjects,
|
||||
handleProjectChange,
|
||||
};
|
||||
}
|
||||
|
|
@ -2,170 +2,78 @@ import { describe, expect, it } from "vitest";
|
|||
import { findExistingDraft } from "./newChat";
|
||||
import type { ChatSession } from "../stores/chatSessionStore";
|
||||
|
||||
function makeDraft(overrides: Partial<ChatSession> = {}): ChatSession {
|
||||
function makeSession(
|
||||
id: string,
|
||||
overrides: Partial<ChatSession> = {},
|
||||
): ChatSession {
|
||||
return {
|
||||
id: "session-1",
|
||||
id,
|
||||
title: "New Chat",
|
||||
createdAt: "2026-03-31T10:00:00.000Z",
|
||||
updatedAt: "2026-03-31T10:00:00.000Z",
|
||||
createdAt: "2026-04-01T00:00:00.000Z",
|
||||
updatedAt: "2026-04-01T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
draft: true,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("findExistingDraft", () => {
|
||||
it("reuses the active empty draft session", () => {
|
||||
const activeDraft = makeDraft({ id: "active-draft" });
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [activeDraft],
|
||||
activeSessionId: activeDraft.id,
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
it("reuses a matching project draft with content", () => {
|
||||
const draft = makeSession("alpha-draft", {
|
||||
projectId: "alpha",
|
||||
providerId: "goose",
|
||||
});
|
||||
|
||||
expect(result?.id).toBe(activeDraft.id);
|
||||
expect(
|
||||
findExistingDraft({
|
||||
sessions: [draft],
|
||||
activeSessionId: null,
|
||||
draftsBySession: { "alpha-draft": "alpha draft" },
|
||||
messagesBySession: {},
|
||||
request: {
|
||||
title: "New Chat",
|
||||
projectId: "alpha",
|
||||
},
|
||||
}),
|
||||
).toEqual(draft);
|
||||
});
|
||||
|
||||
it("does not reuse non-draft sessions", () => {
|
||||
const realSession = makeDraft({ id: "real", draft: undefined });
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [realSession],
|
||||
activeSessionId: realSession.id,
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
it("does not reuse a draft from a different project", () => {
|
||||
const draft = makeSession("alpha-draft", {
|
||||
projectId: "alpha",
|
||||
providerId: "goose",
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
expect(
|
||||
findExistingDraft({
|
||||
sessions: [draft],
|
||||
activeSessionId: null,
|
||||
draftsBySession: { "alpha-draft": "alpha draft" },
|
||||
messagesBySession: {},
|
||||
request: {
|
||||
title: "New Chat",
|
||||
projectId: "beta",
|
||||
},
|
||||
}),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
it("does not reuse drafts that already have messages", () => {
|
||||
const result = findExistingDraft({
|
||||
sessions: [makeDraft({ id: "used-draft", messageCount: 1 })],
|
||||
activeSessionId: "used-draft",
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
it("does not reuse an abandoned empty draft", () => {
|
||||
const draft = makeSession("alpha-draft", {
|
||||
projectId: "alpha",
|
||||
providerId: "goose",
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("does not reuse drafts with local in-memory messages", () => {
|
||||
const result = findExistingDraft({
|
||||
sessions: [makeDraft({ id: "streaming-draft" })],
|
||||
activeSessionId: "streaming-draft",
|
||||
draftsBySession: {},
|
||||
messagesBySession: {
|
||||
"streaming-draft": [
|
||||
{
|
||||
id: "msg-1",
|
||||
role: "user",
|
||||
created: Date.now(),
|
||||
content: [{ type: "text", text: "hello" }],
|
||||
},
|
||||
],
|
||||
},
|
||||
request: { title: "New Chat" },
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("only reuses drafts for the same chat context", () => {
|
||||
const projectDraft = makeDraft({
|
||||
id: "project-draft",
|
||||
projectId: "project-1",
|
||||
});
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [projectDraft],
|
||||
activeSessionId: projectDraft.id,
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat", projectId: "project-2" },
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("does not reuse drafts when creating a titled chat", () => {
|
||||
const result = findExistingDraft({
|
||||
sessions: [makeDraft()],
|
||||
activeSessionId: "session-1",
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: { title: "What day is it?" },
|
||||
});
|
||||
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it("finds a non-active draft with content", () => {
|
||||
const draftWithContent = makeDraft({ id: "background-draft" });
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [draftWithContent],
|
||||
activeSessionId: null,
|
||||
draftsBySession: { "background-draft": "some typed text" },
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
});
|
||||
|
||||
expect(result?.id).toBe("background-draft");
|
||||
});
|
||||
|
||||
it("prefers active draft with content over non-active", () => {
|
||||
const activeDraft = makeDraft({ id: "active-draft" });
|
||||
const backgroundDraft = makeDraft({ id: "background-draft" });
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [backgroundDraft, activeDraft],
|
||||
activeSessionId: "active-draft",
|
||||
draftsBySession: {
|
||||
"active-draft": "active text",
|
||||
"background-draft": "background text",
|
||||
},
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
});
|
||||
|
||||
expect(result?.id).toBe("active-draft");
|
||||
});
|
||||
|
||||
it("finds an inactive empty draft when no draft has content", () => {
|
||||
const inactiveDraft = makeDraft({
|
||||
id: "inactive-draft",
|
||||
updatedAt: "2026-03-31T12:00:00.000Z",
|
||||
});
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [inactiveDraft],
|
||||
activeSessionId: null,
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
});
|
||||
|
||||
expect(result?.id).toBe("inactive-draft");
|
||||
});
|
||||
|
||||
it("prefers drafts with content over empty drafts", () => {
|
||||
const emptyDraft = makeDraft({ id: "empty-draft" });
|
||||
const contentDraft = makeDraft({ id: "content-draft" });
|
||||
|
||||
const result = findExistingDraft({
|
||||
sessions: [emptyDraft, contentDraft],
|
||||
activeSessionId: "empty-draft",
|
||||
draftsBySession: { "content-draft": "has text" },
|
||||
messagesBySession: {},
|
||||
request: { title: "New Chat" },
|
||||
});
|
||||
|
||||
expect(result?.id).toBe("content-draft");
|
||||
expect(
|
||||
findExistingDraft({
|
||||
sessions: [draft],
|
||||
activeSessionId: null,
|
||||
draftsBySession: {},
|
||||
messagesBySession: {},
|
||||
request: {
|
||||
title: "New Chat",
|
||||
projectId: "alpha",
|
||||
},
|
||||
}),
|
||||
).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -5,9 +5,6 @@ import { DEFAULT_CHAT_TITLE } from "./sessionTitle";
|
|||
interface NewChatRequest {
|
||||
title: string;
|
||||
projectId?: string;
|
||||
agentId?: string;
|
||||
providerId?: string;
|
||||
personaId?: string;
|
||||
}
|
||||
|
||||
interface FindExistingDraftArgs {
|
||||
|
|
@ -22,24 +19,14 @@ function isMatchingContext(
|
|||
session: ChatSession,
|
||||
request: Omit<NewChatRequest, "title">,
|
||||
): boolean {
|
||||
return (
|
||||
session.projectId === request.projectId &&
|
||||
session.agentId === request.agentId &&
|
||||
session.providerId === request.providerId &&
|
||||
session.personaId === request.personaId
|
||||
);
|
||||
return session.projectId === request.projectId;
|
||||
}
|
||||
|
||||
function isReusableDraft(
|
||||
session: ChatSession,
|
||||
localMessages: Message[] | undefined,
|
||||
_localMessages: Message[] | undefined,
|
||||
): boolean {
|
||||
return (
|
||||
!!session.draft &&
|
||||
!session.archivedAt &&
|
||||
session.messageCount === 0 &&
|
||||
(localMessages?.length ?? 0) === 0
|
||||
);
|
||||
return !session.archivedAt && session.messageCount === 0;
|
||||
}
|
||||
|
||||
export function findExistingDraft({
|
||||
|
|
@ -64,18 +51,14 @@ export function findExistingDraft({
|
|||
}
|
||||
|
||||
const withContent = candidates.filter(
|
||||
(s) => (draftsBySession[s.id] ?? "").length > 0,
|
||||
(session) => (draftsBySession[session.id] ?? "").length > 0,
|
||||
);
|
||||
if (withContent.length > 0) {
|
||||
return withContent.find((s) => s.id === activeSessionId) ?? withContent[0];
|
||||
return (
|
||||
withContent.find((session) => session.id === activeSessionId) ??
|
||||
withContent[0]
|
||||
);
|
||||
}
|
||||
|
||||
const active = candidates.find((s) => s.id === activeSessionId);
|
||||
if (active) {
|
||||
return active;
|
||||
}
|
||||
|
||||
return candidates.sort(
|
||||
(a, b) => new Date(b.updatedAt).getTime() - new Date(a.updatedAt).getTime(),
|
||||
)[0];
|
||||
return candidates.find((session) => session.id === activeSessionId);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import type { ModelOption } from "../types";
|
||||
|
||||
const ACP_SESSION_METADATA_STORAGE_KEY = "goose:acp-session-metadata";
|
||||
const DRAFT_SESSION_STORAGE_KEY = "goose:chat-draft-sessions";
|
||||
const LEGACY_SESSION_CACHE_STORAGE_KEY = "goose:chat-sessions";
|
||||
const DRAFT_TEXT_STORAGE_KEY = "goose:chat-drafts";
|
||||
|
||||
export interface SessionMetadataOverlayRecord {
|
||||
sessionId: string;
|
||||
|
|
@ -22,7 +20,7 @@ export interface SessionMetadataOverlayRecord {
|
|||
updatedAt: string;
|
||||
}
|
||||
|
||||
export interface DraftSessionRecord {
|
||||
interface LegacySessionRecord {
|
||||
id: string;
|
||||
acpSessionId?: string;
|
||||
title: string;
|
||||
|
|
@ -36,7 +34,7 @@ export interface DraftSessionRecord {
|
|||
updatedAt: string;
|
||||
archivedAt?: string;
|
||||
messageCount: number;
|
||||
draft?: true;
|
||||
draft?: boolean;
|
||||
userSetName?: boolean;
|
||||
}
|
||||
|
||||
|
|
@ -65,32 +63,14 @@ function persistStorageArray<T>(storageKey: string, records: T[]): void {
|
|||
}
|
||||
}
|
||||
|
||||
function draftsWithText(): Set<string> {
|
||||
if (typeof window === "undefined") return new Set();
|
||||
try {
|
||||
const stored = window.localStorage.getItem(DRAFT_TEXT_STORAGE_KEY);
|
||||
if (!stored) return new Set();
|
||||
const parsed = JSON.parse(stored);
|
||||
return new Set(
|
||||
Object.entries(parsed)
|
||||
.filter(([, value]) => typeof value === "string" && value.length > 0)
|
||||
.map(([sessionId]) => sessionId),
|
||||
);
|
||||
} catch {
|
||||
return new Set();
|
||||
}
|
||||
}
|
||||
|
||||
function loadLegacySessions(): Array<
|
||||
DraftSessionRecord & {
|
||||
userSetName?: boolean;
|
||||
}
|
||||
> {
|
||||
return parseStorageArray(LEGACY_SESSION_CACHE_STORAGE_KEY);
|
||||
function loadLegacySessions(): LegacySessionRecord[] {
|
||||
return parseStorageArray<LegacySessionRecord>(
|
||||
LEGACY_SESSION_CACHE_STORAGE_KEY,
|
||||
);
|
||||
}
|
||||
|
||||
function recordFromLegacySession(
|
||||
session: DraftSessionRecord & { userSetName?: boolean },
|
||||
session: LegacySessionRecord,
|
||||
): SessionMetadataOverlayRecord {
|
||||
return {
|
||||
sessionId: session.acpSessionId ?? session.id,
|
||||
|
|
@ -138,49 +118,6 @@ export function persistSessionMetadataOverlay(
|
|||
);
|
||||
}
|
||||
|
||||
export function loadDraftSessionRecords(): DraftSessionRecord[] {
|
||||
const records = parseStorageArray<DraftSessionRecord>(
|
||||
DRAFT_SESSION_STORAGE_KEY,
|
||||
);
|
||||
const drafts = new Map(records.map((record) => [record.id, record]));
|
||||
|
||||
for (const session of loadLegacySessions()) {
|
||||
if (!session.draft) continue;
|
||||
if (drafts.has(session.id)) continue;
|
||||
drafts.set(session.id, session);
|
||||
}
|
||||
|
||||
return [...drafts.values()];
|
||||
}
|
||||
|
||||
export function persistDraftSessionRecords(
|
||||
records: DraftSessionRecord[],
|
||||
): void {
|
||||
const withText = draftsWithText();
|
||||
persistStorageArray(
|
||||
DRAFT_SESSION_STORAGE_KEY,
|
||||
records.filter((record) => withText.has(record.id)),
|
||||
);
|
||||
}
|
||||
|
||||
export function migrateSessionMetadataOverlayId(
|
||||
previousId: string,
|
||||
nextId: string,
|
||||
): void {
|
||||
if (!previousId || !nextId || previousId === nextId) return;
|
||||
const overlays = loadSessionMetadataOverlay();
|
||||
const previous = overlays.get(previousId);
|
||||
if (!previous) return;
|
||||
const existing = overlays.get(nextId);
|
||||
overlays.set(nextId, {
|
||||
...previous,
|
||||
...existing,
|
||||
sessionId: nextId,
|
||||
});
|
||||
overlays.delete(previousId);
|
||||
persistSessionMetadataOverlay(overlays.values());
|
||||
}
|
||||
|
||||
export function upsertSessionMetadataOverlayRecord(
|
||||
record: SessionMetadataOverlayRecord,
|
||||
): void {
|
||||
|
|
@ -195,21 +132,6 @@ export function removeSessionMetadataOverlayRecord(sessionId: string): void {
|
|||
persistSessionMetadataOverlay(overlays.values());
|
||||
}
|
||||
|
||||
export function persistDraftSessionRecord(record: DraftSessionRecord): void {
|
||||
const drafts = loadDraftSessionRecords();
|
||||
const nextDrafts = [
|
||||
...drafts.filter((draft) => draft.id !== record.id),
|
||||
record,
|
||||
];
|
||||
persistDraftSessionRecords(nextDrafts);
|
||||
}
|
||||
|
||||
export function removeDraftSessionRecord(sessionId: string): void {
|
||||
persistDraftSessionRecords(
|
||||
loadDraftSessionRecords().filter((record) => record.id !== sessionId),
|
||||
);
|
||||
}
|
||||
|
||||
export function modelIdsMatch(
|
||||
cached: ModelOption[] | undefined,
|
||||
next: ModelOption[],
|
||||
|
|
|
|||
|
|
@ -1,148 +1,97 @@
|
|||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { AcpSessionInfo } from "@/shared/api/acp";
|
||||
import { useChatSessionStore } from "../chatSessionStore";
|
||||
import { useChatSessionStore, type ChatSession } from "../chatSessionStore";
|
||||
|
||||
const mockAcpCreateSession = vi.fn();
|
||||
const mockAcpListSessions = vi.fn();
|
||||
|
||||
vi.mock("@/shared/api/acp", () => ({
|
||||
acpListSessions: vi.fn(),
|
||||
acpCreateSession: (...args: unknown[]) => mockAcpCreateSession(...args),
|
||||
acpListSessions: (...args: unknown[]) => mockAcpListSessions(...args),
|
||||
}));
|
||||
|
||||
import { acpListSessions } from "@/shared/api/acp";
|
||||
|
||||
const mockedAcpListSessions = vi.mocked(acpListSessions);
|
||||
|
||||
const LEGACY_SESSION_CACHE_KEY = "goose:chat-sessions";
|
||||
const OVERLAY_CACHE_KEY = "goose:acp-session-metadata";
|
||||
const DRAFT_SESSION_CACHE_KEY = "goose:chat-draft-sessions";
|
||||
|
||||
function resetStore() {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [],
|
||||
activeSessionId: null,
|
||||
isLoading: false,
|
||||
hasHydratedSessions: false,
|
||||
contextPanelOpenBySession: {},
|
||||
activeWorkspaceBySession: {},
|
||||
modelsBySession: {},
|
||||
modelCacheByProvider: {},
|
||||
});
|
||||
}
|
||||
|
||||
function makeSession(overrides: Partial<ChatSession> = {}): ChatSession {
|
||||
return {
|
||||
id: "session-1",
|
||||
acpSessionId: "session-1",
|
||||
title: "Test Session",
|
||||
createdAt: "2026-04-01T00:00:00.000Z",
|
||||
updatedAt: "2026-04-01T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function seedSession(overrides: Partial<ChatSession> = {}): ChatSession {
|
||||
const session = makeSession(overrides);
|
||||
useChatSessionStore.getState().addSession(session);
|
||||
return session;
|
||||
}
|
||||
|
||||
describe("chatSessionStore", () => {
|
||||
beforeEach(() => {
|
||||
resetStore();
|
||||
window.localStorage.removeItem(LEGACY_SESSION_CACHE_KEY);
|
||||
window.localStorage.removeItem(OVERLAY_CACHE_KEY);
|
||||
window.localStorage.removeItem(DRAFT_SESSION_CACHE_KEY);
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
window.localStorage.removeItem(LEGACY_SESSION_CACHE_KEY);
|
||||
window.localStorage.removeItem(OVERLAY_CACHE_KEY);
|
||||
window.localStorage.removeItem(DRAFT_SESSION_CACHE_KEY);
|
||||
});
|
||||
|
||||
describe("createDraftSession", () => {
|
||||
it("creates a draft session with default title", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
describe("createSession", () => {
|
||||
it("creates a real ACP-backed session", async () => {
|
||||
mockAcpCreateSession.mockResolvedValue({ sessionId: "acp-1" });
|
||||
|
||||
expect(session.title).toBe("New Chat");
|
||||
expect(session.draft).toBe(true);
|
||||
expect(session.messageCount).toBe(0);
|
||||
expect(useChatSessionStore.getState().sessions).toContainEqual(session);
|
||||
});
|
||||
|
||||
it("creates a draft session with custom options", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession({
|
||||
title: "My Custom Chat",
|
||||
projectId: "proj-1",
|
||||
const session = await useChatSessionStore.getState().createSession({
|
||||
title: "New Chat",
|
||||
providerId: "openai",
|
||||
personaId: "persona-1",
|
||||
modelId: "gpt-4.1",
|
||||
modelName: "GPT-4.1",
|
||||
workingDir: "/tmp/project",
|
||||
});
|
||||
|
||||
expect(session.title).toBe("My Custom Chat");
|
||||
expect(session.projectId).toBe("proj-1");
|
||||
expect(session.providerId).toBe("openai");
|
||||
expect(session.personaId).toBe("persona-1");
|
||||
expect(session.draft).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("promoteDraft", () => {
|
||||
it("removes draft flag from a draft session", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
expect(session.draft).toBe(true);
|
||||
|
||||
useChatSessionStore.getState().promoteDraft(session.id);
|
||||
|
||||
const updated = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.find((s) => s.id === session.id);
|
||||
expect(updated?.draft).toBeUndefined();
|
||||
});
|
||||
|
||||
it("does nothing for non-draft sessions", () => {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
{
|
||||
id: "non-draft",
|
||||
title: "Regular Session",
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
messageCount: 5,
|
||||
},
|
||||
],
|
||||
expect(mockAcpCreateSession).toHaveBeenCalledWith(
|
||||
"openai",
|
||||
"/tmp/project",
|
||||
{
|
||||
personaId: "persona-1",
|
||||
modelId: "gpt-4.1",
|
||||
},
|
||||
);
|
||||
expect(session).toMatchObject({
|
||||
id: "acp-1",
|
||||
acpSessionId: "acp-1",
|
||||
title: "New Chat",
|
||||
providerId: "openai",
|
||||
personaId: "persona-1",
|
||||
modelId: "gpt-4.1",
|
||||
modelName: "GPT-4.1",
|
||||
});
|
||||
|
||||
useChatSessionStore.getState().promoteDraft("non-draft");
|
||||
|
||||
const session = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.find((s) => s.id === "non-draft");
|
||||
expect(session?.draft).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("removeDraft", () => {
|
||||
it("removes a draft session", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
expect(useChatSessionStore.getState().sessions).toHaveLength(1);
|
||||
|
||||
useChatSessionStore.getState().removeDraft(session.id);
|
||||
|
||||
expect(useChatSessionStore.getState().sessions).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("clears activeSessionId if removing the active draft", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
useChatSessionStore.getState().setActiveSession(session.id);
|
||||
|
||||
useChatSessionStore.getState().removeDraft(session.id);
|
||||
|
||||
expect(useChatSessionStore.getState().activeSessionId).toBeNull();
|
||||
});
|
||||
|
||||
it("does not remove non-draft sessions", () => {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
{
|
||||
id: "non-draft",
|
||||
title: "Regular Session",
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
messageCount: 5,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
useChatSessionStore.getState().removeDraft("non-draft");
|
||||
|
||||
expect(useChatSessionStore.getState().sessions).toHaveLength(1);
|
||||
expect(useChatSessionStore.getState().sessions).toContainEqual(session);
|
||||
});
|
||||
});
|
||||
|
||||
describe("loadSessions", () => {
|
||||
it("loads sessions from ACP and maps them correctly", async () => {
|
||||
mockedAcpListSessions.mockResolvedValue([
|
||||
mockAcpListSessions.mockResolvedValue([
|
||||
{
|
||||
sessionId: "acp-1",
|
||||
title: "ACP Session 1",
|
||||
|
|
@ -161,64 +110,14 @@ describe("chatSessionStore", () => {
|
|||
|
||||
const sessions = useChatSessionStore.getState().sessions;
|
||||
expect(sessions).toHaveLength(2);
|
||||
expect(sessions[0].id).toBe("acp-2"); // Most recent first
|
||||
expect(sessions[0].title).toBe("Untitled"); // null title becomes "Untitled"
|
||||
expect(sessions[0].id).toBe("acp-2");
|
||||
expect(sessions[0].title).toBe("Untitled");
|
||||
expect(sessions[0].messageCount).toBe(7);
|
||||
expect(sessions[1].id).toBe("acp-1");
|
||||
expect(sessions[1].title).toBe("ACP Session 1");
|
||||
expect(sessions[1].messageCount).toBe(4);
|
||||
});
|
||||
|
||||
it("preserves local drafts alongside ACP sessions", async () => {
|
||||
const draft = useChatSessionStore.getState().createDraftSession({
|
||||
title: "My Draft",
|
||||
});
|
||||
|
||||
mockedAcpListSessions.mockResolvedValue([
|
||||
{
|
||||
sessionId: "acp-1",
|
||||
title: "ACP Session",
|
||||
updatedAt: "2026-04-01",
|
||||
messageCount: 3,
|
||||
},
|
||||
]);
|
||||
|
||||
await useChatSessionStore.getState().loadSessions();
|
||||
|
||||
const sessions = useChatSessionStore.getState().sessions;
|
||||
expect(sessions).toHaveLength(2);
|
||||
expect(sessions.find((s) => s.id === "acp-1")).toBeDefined();
|
||||
expect(sessions.find((s) => s.id === draft.id)).toBeDefined();
|
||||
});
|
||||
|
||||
it("migrates promoted draft metadata onto the resolved ACP session id", async () => {
|
||||
const draft = useChatSessionStore.getState().createDraftSession({
|
||||
title: "Project Draft",
|
||||
projectId: "project-123",
|
||||
providerId: "goose",
|
||||
});
|
||||
|
||||
useChatSessionStore.getState().promoteDraft(draft.id);
|
||||
useChatSessionStore.getState().setSessionAcpId(draft.id, "acp-1");
|
||||
|
||||
mockedAcpListSessions.mockResolvedValue([
|
||||
{
|
||||
sessionId: "acp-1",
|
||||
title: "ACP Session",
|
||||
updatedAt: "2026-04-02",
|
||||
messageCount: 3,
|
||||
},
|
||||
]);
|
||||
|
||||
await useChatSessionStore.getState().loadSessions();
|
||||
|
||||
const session = useChatSessionStore.getState().sessions[0];
|
||||
expect(session.id).toBe("acp-1");
|
||||
expect(session.acpSessionId).toBe("acp-1");
|
||||
expect(session.projectId).toBe("project-123");
|
||||
expect(session.providerId).toBe("goose");
|
||||
});
|
||||
|
||||
it("rehydrates cached project metadata for ACP sessions", async () => {
|
||||
window.localStorage.setItem(
|
||||
LEGACY_SESSION_CACHE_KEY,
|
||||
|
|
@ -237,7 +136,7 @@ describe("chatSessionStore", () => {
|
|||
]),
|
||||
);
|
||||
|
||||
mockedAcpListSessions.mockResolvedValue([
|
||||
mockAcpListSessions.mockResolvedValue([
|
||||
{
|
||||
sessionId: "acp-1",
|
||||
title: null,
|
||||
|
|
@ -259,21 +158,37 @@ describe("chatSessionStore", () => {
|
|||
expect(session.userSetName).toBe(true);
|
||||
});
|
||||
|
||||
it("drops stale non-draft sessions that are no longer in ACP", async () => {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
it("ignores legacy draft records while hydrating overlays", async () => {
|
||||
window.localStorage.setItem(
|
||||
LEGACY_SESSION_CACHE_KEY,
|
||||
JSON.stringify([
|
||||
{
|
||||
id: "stale-session",
|
||||
title: "Stale Session",
|
||||
id: "cached-draft",
|
||||
title: "Cached Draft",
|
||||
draft: true,
|
||||
createdAt: "2026-04-01",
|
||||
updatedAt: "2026-04-01",
|
||||
messageCount: 2,
|
||||
messageCount: 0,
|
||||
},
|
||||
]),
|
||||
);
|
||||
|
||||
mockAcpListSessions.mockResolvedValue([]);
|
||||
|
||||
await useChatSessionStore.getState().loadSessions();
|
||||
|
||||
expect(useChatSessionStore.getState().sessions).toEqual([]);
|
||||
});
|
||||
|
||||
it("drops stale sessions that are no longer in ACP", async () => {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
makeSession({ id: "stale-session", title: "Stale Session" }),
|
||||
],
|
||||
activeSessionId: "stale-session",
|
||||
});
|
||||
|
||||
mockedAcpListSessions.mockResolvedValue([
|
||||
mockAcpListSessions.mockResolvedValue([
|
||||
{
|
||||
sessionId: "acp-1",
|
||||
title: "ACP Session",
|
||||
|
|
@ -292,7 +207,7 @@ describe("chatSessionStore", () => {
|
|||
|
||||
it("sets isLoading during fetch", async () => {
|
||||
let resolvePromise: (value: AcpSessionInfo[]) => void = () => {};
|
||||
mockedAcpListSessions.mockReturnValue(
|
||||
mockAcpListSessions.mockReturnValue(
|
||||
new Promise((resolve) => {
|
||||
resolvePromise = resolve;
|
||||
}),
|
||||
|
|
@ -300,11 +215,13 @@ describe("chatSessionStore", () => {
|
|||
|
||||
const loadPromise = useChatSessionStore.getState().loadSessions();
|
||||
expect(useChatSessionStore.getState().isLoading).toBe(true);
|
||||
expect(useChatSessionStore.getState().hasHydratedSessions).toBe(false);
|
||||
|
||||
resolvePromise([]);
|
||||
await loadPromise;
|
||||
|
||||
expect(useChatSessionStore.getState().isLoading).toBe(false);
|
||||
expect(useChatSessionStore.getState().hasHydratedSessions).toBe(true);
|
||||
});
|
||||
|
||||
it("falls back to cached sessions on error", async () => {
|
||||
|
|
@ -315,56 +232,51 @@ describe("chatSessionStore", () => {
|
|||
id: "cached-session",
|
||||
title: "Cached Session",
|
||||
projectId: "project-123",
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
createdAt: "2026-04-01T00:00:00.000Z",
|
||||
updatedAt: "2026-04-01T00:00:00.000Z",
|
||||
messageCount: 8,
|
||||
},
|
||||
{
|
||||
id: "cached-draft",
|
||||
title: "Cached Draft",
|
||||
draft: true,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
createdAt: "2026-04-01T00:00:00.000Z",
|
||||
updatedAt: "2026-04-01T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
]),
|
||||
);
|
||||
|
||||
mockedAcpListSessions.mockRejectedValue(new Error("Network error"));
|
||||
mockAcpListSessions.mockRejectedValue(new Error("Network error"));
|
||||
|
||||
await useChatSessionStore.getState().loadSessions();
|
||||
|
||||
const sessions = useChatSessionStore.getState().sessions;
|
||||
expect(sessions).toHaveLength(2);
|
||||
expect(
|
||||
sessions.find((session) => session.id === "cached-session"),
|
||||
).toMatchObject({
|
||||
expect(sessions).toHaveLength(1);
|
||||
expect(sessions[0]).toMatchObject({
|
||||
id: "cached-session",
|
||||
projectId: "project-123",
|
||||
});
|
||||
expect(
|
||||
sessions.find((session) => session.id === "cached-draft")?.draft,
|
||||
).toBe(true);
|
||||
expect(useChatSessionStore.getState().hasHydratedSessions).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("updateSession", () => {
|
||||
it("updates session properties", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
const session = seedSession();
|
||||
|
||||
useChatSessionStore.getState().updateSession(session.id, {
|
||||
title: "Updated Title",
|
||||
projectId: "new-project",
|
||||
});
|
||||
|
||||
const updated = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.find((s) => s.id === session.id);
|
||||
const updated = useChatSessionStore.getState().getSession(session.id);
|
||||
expect(updated?.title).toBe("Updated Title");
|
||||
expect(updated?.projectId).toBe("new-project");
|
||||
});
|
||||
|
||||
it("preserves updatedAt when not explicitly provided in patch", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
const session = seedSession();
|
||||
const originalUpdatedAt = session.updatedAt;
|
||||
|
||||
vi.useFakeTimers();
|
||||
|
|
@ -376,14 +288,12 @@ describe("chatSessionStore", () => {
|
|||
|
||||
vi.useRealTimers();
|
||||
|
||||
const updated = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.find((s) => s.id === session.id);
|
||||
const updated = useChatSessionStore.getState().getSession(session.id);
|
||||
expect(updated?.updatedAt).toBe(originalUpdatedAt);
|
||||
});
|
||||
|
||||
it("updates updatedAt when explicitly provided in patch", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
const session = seedSession();
|
||||
const originalUpdatedAt = session.updatedAt;
|
||||
|
||||
vi.useFakeTimers();
|
||||
|
|
@ -397,76 +307,43 @@ describe("chatSessionStore", () => {
|
|||
|
||||
vi.useRealTimers();
|
||||
|
||||
const updated = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.find((s) => s.id === session.id);
|
||||
const updated = useChatSessionStore.getState().getSession(session.id);
|
||||
expect(updated?.updatedAt).not.toBe(originalUpdatedAt);
|
||||
expect(updated?.updatedAt).toBe(newTimestamp);
|
||||
});
|
||||
});
|
||||
|
||||
describe("session models", () => {
|
||||
it("stores models per session", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
describe("provider switching", () => {
|
||||
it("clears the selected model when switching providers", () => {
|
||||
const session = seedSession({
|
||||
providerId: "openai",
|
||||
modelId: "gpt-4o",
|
||||
modelName: "GPT-4o",
|
||||
});
|
||||
|
||||
useChatSessionStore.getState().setSessionModels(session.id, [
|
||||
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
|
||||
{ id: "gpt-4o", name: "GPT-4o" },
|
||||
]);
|
||||
|
||||
expect(
|
||||
useChatSessionStore.getState().getSessionModels(session.id),
|
||||
).toEqual([
|
||||
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
|
||||
{ id: "gpt-4o", name: "GPT-4o" },
|
||||
]);
|
||||
});
|
||||
|
||||
it("removes stored models when a draft session is removed", () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.setSessionModels(session.id, [
|
||||
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
|
||||
]);
|
||||
.switchSessionProvider(session.id, "anthropic");
|
||||
|
||||
useChatSessionStore.getState().removeDraft(session.id);
|
||||
|
||||
expect(
|
||||
useChatSessionStore.getState().getSessionModels(session.id),
|
||||
).toEqual([]);
|
||||
});
|
||||
|
||||
it("removes stored models when a session is archived", async () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.setSessionModels(session.id, [
|
||||
{ id: "claude-sonnet-4", name: "Claude Sonnet 4" },
|
||||
]);
|
||||
|
||||
await useChatSessionStore.getState().archiveSession(session.id);
|
||||
|
||||
expect(
|
||||
useChatSessionStore.getState().getSessionModels(session.id),
|
||||
).toEqual([]);
|
||||
const updated = useChatSessionStore.getState().getSession(session.id);
|
||||
expect(updated?.providerId).toBe("anthropic");
|
||||
expect(updated?.modelId).toBeUndefined();
|
||||
expect(updated?.modelName).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("archiveSession", () => {
|
||||
it("sets archivedAt on the session", async () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
const session = seedSession();
|
||||
|
||||
await useChatSessionStore.getState().archiveSession(session.id);
|
||||
|
||||
const archived = useChatSessionStore
|
||||
.getState()
|
||||
.sessions.find((s) => s.id === session.id);
|
||||
const archived = useChatSessionStore.getState().getSession(session.id);
|
||||
expect(archived?.archivedAt).toBeDefined();
|
||||
});
|
||||
|
||||
it("clears activeSessionId if archiving the active session", async () => {
|
||||
const session = useChatSessionStore.getState().createDraftSession();
|
||||
const session = seedSession();
|
||||
useChatSessionStore.getState().setActiveSession(session.id);
|
||||
|
||||
await useChatSessionStore.getState().archiveSession(session.id);
|
||||
|
|
@ -478,13 +355,14 @@ describe("chatSessionStore", () => {
|
|||
describe("addSession", () => {
|
||||
it("prepends a new session to the list", () => {
|
||||
const { addSession } = useChatSessionStore.getState();
|
||||
addSession({
|
||||
id: "imported-1",
|
||||
title: "Imported Session",
|
||||
createdAt: "2026-01-01T00:00:00Z",
|
||||
updatedAt: "2026-01-01T00:00:00Z",
|
||||
messageCount: 5,
|
||||
});
|
||||
addSession(
|
||||
makeSession({
|
||||
id: "imported-1",
|
||||
title: "Imported Session",
|
||||
messageCount: 5,
|
||||
}),
|
||||
);
|
||||
|
||||
const sessions = useChatSessionStore.getState().sessions;
|
||||
expect(sessions[0].id).toBe("imported-1");
|
||||
expect(sessions[0].title).toBe("Imported Session");
|
||||
|
|
@ -493,22 +371,13 @@ describe("chatSessionStore", () => {
|
|||
|
||||
it("does not create a duplicate if session ID already exists", () => {
|
||||
const { addSession } = useChatSessionStore.getState();
|
||||
addSession({
|
||||
id: "dup-1",
|
||||
title: "First",
|
||||
createdAt: "2026-01-01T00:00:00Z",
|
||||
updatedAt: "2026-01-01T00:00:00Z",
|
||||
messageCount: 1,
|
||||
});
|
||||
addSession({
|
||||
id: "dup-1",
|
||||
title: "Second",
|
||||
createdAt: "2026-01-01T00:00:00Z",
|
||||
updatedAt: "2026-01-01T00:00:00Z",
|
||||
messageCount: 2,
|
||||
});
|
||||
addSession(makeSession({ id: "dup-1", title: "First", messageCount: 1 }));
|
||||
addSession(
|
||||
makeSession({ id: "dup-1", title: "Second", messageCount: 2 }),
|
||||
);
|
||||
|
||||
const sessions = useChatSessionStore.getState().sessions;
|
||||
const matches = sessions.filter((s) => s.id === "dup-1");
|
||||
const matches = sessions.filter((session) => session.id === "dup-1");
|
||||
expect(matches).toHaveLength(1);
|
||||
expect(matches[0].title).toBe("Second");
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,23 +1,17 @@
|
|||
import { create } from "zustand";
|
||||
import { acpListSessions, type AcpSessionInfo } from "@/shared/api/acp";
|
||||
import {
|
||||
acpCreateSession,
|
||||
acpListSessions,
|
||||
type AcpSessionInfo,
|
||||
} from "@/shared/api/acp";
|
||||
import type { Session } from "@/shared/types/chat";
|
||||
import { DEFAULT_CHAT_TITLE } from "@/features/chat/lib/sessionTitle";
|
||||
import {
|
||||
loadDraftSessionRecords,
|
||||
loadSessionMetadataOverlay,
|
||||
migrateSessionMetadataOverlayId,
|
||||
modelIdsMatch,
|
||||
persistSessionMetadataOverlay,
|
||||
persistDraftSessionRecord,
|
||||
persistDraftSessionRecords,
|
||||
removeDraftSessionRecord,
|
||||
upsertSessionMetadataOverlayRecord,
|
||||
type DraftSessionRecord,
|
||||
type SessionMetadataOverlayRecord,
|
||||
} from "@/features/chat/lib/sessionMetadataOverlay";
|
||||
import type { ModelOption } from "../types";
|
||||
|
||||
const EMPTY_MODELS: ModelOption[] = [];
|
||||
|
||||
export interface ChatSession {
|
||||
id: string;
|
||||
|
|
@ -33,7 +27,6 @@ export interface ChatSession {
|
|||
updatedAt: string;
|
||||
archivedAt?: string;
|
||||
messageCount: number;
|
||||
draft?: boolean;
|
||||
userSetName?: boolean;
|
||||
}
|
||||
|
||||
|
|
@ -42,14 +35,31 @@ export interface ActiveWorkspace {
|
|||
branch: string | null;
|
||||
}
|
||||
|
||||
export function hasSessionStarted(
|
||||
session: Pick<ChatSession, "messageCount">,
|
||||
localMessages?: ArrayLike<unknown>,
|
||||
): boolean {
|
||||
return session.messageCount > 0 || (localMessages?.length ?? 0) > 0;
|
||||
}
|
||||
|
||||
export function getVisibleSessions<
|
||||
T extends Pick<ChatSession, "id" | "messageCount">,
|
||||
>(
|
||||
sessions: T[],
|
||||
messagesBySession: Record<string, ArrayLike<unknown> | undefined>,
|
||||
): T[] {
|
||||
return sessions.filter((session) =>
|
||||
hasSessionStarted(session, messagesBySession[session.id]),
|
||||
);
|
||||
}
|
||||
|
||||
interface ChatSessionStoreState {
|
||||
sessions: ChatSession[];
|
||||
activeSessionId: string | null;
|
||||
isLoading: boolean;
|
||||
hasHydratedSessions: boolean;
|
||||
contextPanelOpenBySession: Record<string, boolean>;
|
||||
activeWorkspaceBySession: Record<string, ActiveWorkspace>;
|
||||
modelsBySession: Record<string, ModelOption[]>;
|
||||
modelCacheByProvider: Record<string, ModelOption[]>;
|
||||
}
|
||||
|
||||
interface CreateSessionOpts {
|
||||
|
|
@ -58,18 +68,17 @@ interface CreateSessionOpts {
|
|||
agentId?: string;
|
||||
providerId?: string;
|
||||
personaId?: string;
|
||||
workingDir?: string;
|
||||
modelId?: string;
|
||||
modelName?: string;
|
||||
}
|
||||
|
||||
interface UpdateSessionOptions {
|
||||
localOnly?: boolean;
|
||||
persistOverlay?: boolean;
|
||||
}
|
||||
|
||||
interface ChatSessionStoreActions {
|
||||
createSession: (opts?: CreateSessionOpts) => Promise<ChatSession>;
|
||||
createDraftSession: (opts?: CreateSessionOpts) => ChatSession;
|
||||
promoteDraft: (id: string) => void;
|
||||
removeDraft: (id: string) => void;
|
||||
loadSessions: () => Promise<void>;
|
||||
updateSession: (
|
||||
id: string,
|
||||
|
|
@ -79,74 +88,20 @@ interface ChatSessionStoreActions {
|
|||
addSession: (session: ChatSession) => void;
|
||||
archiveSession: (id: string) => Promise<void>;
|
||||
unarchiveSession: (id: string) => Promise<void>;
|
||||
setSessionAcpId: (id: string, acpSessionId: string) => void;
|
||||
|
||||
setActiveSession: (sessionId: string | null) => void;
|
||||
setContextPanelOpen: (sessionId: string, open: boolean) => void;
|
||||
setActiveWorkspace: (sessionId: string, context: ActiveWorkspace) => void;
|
||||
clearActiveWorkspace: (sessionId: string) => void;
|
||||
setSessionModels: (sessionId: string, models: ModelOption[]) => void;
|
||||
switchSessionProvider: (
|
||||
sessionId: string,
|
||||
providerId: string,
|
||||
models: ModelOption[],
|
||||
) => void;
|
||||
cacheModelsForProvider: (providerId: string, models: ModelOption[]) => void;
|
||||
getCachedModels: (providerId: string) => ModelOption[];
|
||||
switchSessionProvider: (sessionId: string, providerId: string) => void;
|
||||
|
||||
getSession: (id: string) => ChatSession | undefined;
|
||||
getActiveSession: () => ChatSession | null;
|
||||
getArchivedSessions: () => ChatSession[];
|
||||
getSessionModels: (sessionId: string) => ModelOption[];
|
||||
}
|
||||
|
||||
export type ChatSessionStore = ChatSessionStoreState & ChatSessionStoreActions;
|
||||
|
||||
const MODEL_CACHE_STORAGE_KEY = "goose:model-cache";
|
||||
|
||||
function loadModelCache(): Record<string, ModelOption[]> {
|
||||
if (typeof window === "undefined") return {};
|
||||
try {
|
||||
const stored = window.localStorage.getItem(MODEL_CACHE_STORAGE_KEY);
|
||||
if (!stored) return {};
|
||||
const parsed = JSON.parse(stored);
|
||||
return typeof parsed === "object" && parsed !== null
|
||||
? (parsed as Record<string, ModelOption[]>)
|
||||
: {};
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function persistModelCache(cache: Record<string, ModelOption[]>): void {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
window.localStorage.setItem(MODEL_CACHE_STORAGE_KEY, JSON.stringify(cache));
|
||||
} catch {
|
||||
// localStorage may be unavailable
|
||||
}
|
||||
}
|
||||
|
||||
function draftSessionToRecord(session: ChatSession): DraftSessionRecord {
|
||||
return {
|
||||
id: session.id,
|
||||
acpSessionId: session.acpSessionId,
|
||||
title: session.title,
|
||||
projectId: session.projectId,
|
||||
agentId: session.agentId,
|
||||
providerId: session.providerId,
|
||||
personaId: session.personaId,
|
||||
modelId: session.modelId,
|
||||
modelName: session.modelName,
|
||||
createdAt: session.createdAt,
|
||||
updatedAt: session.updatedAt,
|
||||
archivedAt: session.archivedAt,
|
||||
messageCount: session.messageCount,
|
||||
draft: true,
|
||||
userSetName: session.userSetName,
|
||||
};
|
||||
}
|
||||
|
||||
function overlayKeyForSession(
|
||||
session: Pick<ChatSession, "id" | "acpSessionId">,
|
||||
) {
|
||||
|
|
@ -225,40 +180,12 @@ function mergeAcpSessionWithOverlay(
|
|||
};
|
||||
}
|
||||
|
||||
function mergeDraftSessions(
|
||||
currentDrafts: ChatSession[],
|
||||
persistedDrafts: DraftSessionRecord[],
|
||||
): ChatSession[] {
|
||||
const draftsById = new Map<string, ChatSession>(
|
||||
persistedDrafts.map((record) => [
|
||||
record.id,
|
||||
{
|
||||
...record,
|
||||
draft: true,
|
||||
} satisfies ChatSession,
|
||||
]),
|
||||
);
|
||||
|
||||
for (const draft of currentDrafts) {
|
||||
draftsById.set(draft.id, draft);
|
||||
}
|
||||
|
||||
return [...draftsById.values()];
|
||||
}
|
||||
|
||||
function persistDraftsFromSessions(sessions: ChatSession[]): void {
|
||||
persistDraftSessionRecords(
|
||||
sessions.filter((session) => session.draft).map(draftSessionToRecord),
|
||||
);
|
||||
}
|
||||
|
||||
function syncOverlaySnapshots(
|
||||
sessions: ChatSession[],
|
||||
existingOverlays = loadSessionMetadataOverlay(),
|
||||
): void {
|
||||
const overlays = new Map(existingOverlays);
|
||||
for (const session of sessions) {
|
||||
if (session.draft) continue;
|
||||
overlays.set(
|
||||
overlayKeyForSession(session),
|
||||
buildOverlayRecord(
|
||||
|
|
@ -299,81 +226,55 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
sessions: [],
|
||||
activeSessionId: null,
|
||||
isLoading: false,
|
||||
hasHydratedSessions: false,
|
||||
contextPanelOpenBySession: {},
|
||||
activeWorkspaceBySession: {},
|
||||
modelsBySession: {},
|
||||
modelCacheByProvider: loadModelCache(),
|
||||
|
||||
createSession: async (_opts) => {
|
||||
throw new Error(
|
||||
"createSession not yet wired to ACP — use createDraftSession",
|
||||
);
|
||||
},
|
||||
|
||||
createDraftSession: (opts) => {
|
||||
createSession: async (opts) => {
|
||||
if (!opts?.workingDir) {
|
||||
throw new Error("createSession requires a working directory");
|
||||
}
|
||||
const now = new Date().toISOString();
|
||||
const providerId = opts.providerId ?? "goose";
|
||||
const { sessionId } = await acpCreateSession(providerId, opts.workingDir, {
|
||||
personaId: opts.personaId,
|
||||
modelId: opts.modelId,
|
||||
});
|
||||
const chatSession: ChatSession = {
|
||||
id: crypto.randomUUID(),
|
||||
title: opts?.title ?? DEFAULT_CHAT_TITLE,
|
||||
projectId: opts?.projectId,
|
||||
agentId: opts?.agentId,
|
||||
providerId: opts?.providerId,
|
||||
personaId: opts?.personaId,
|
||||
id: sessionId,
|
||||
acpSessionId: sessionId,
|
||||
title: opts.title ?? DEFAULT_CHAT_TITLE,
|
||||
projectId: opts.projectId,
|
||||
agentId: opts.agentId,
|
||||
providerId,
|
||||
personaId: opts.personaId,
|
||||
modelId: opts.modelId,
|
||||
modelName: opts.modelName,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
messageCount: 0,
|
||||
draft: true,
|
||||
};
|
||||
set((state) => ({ sessions: [...state.sessions, chatSession] }));
|
||||
persistDraftSessionRecord(draftSessionToRecord(chatSession));
|
||||
set((state) => ({ sessions: [chatSession, ...state.sessions] }));
|
||||
const existing = loadSessionMetadataOverlay().get(
|
||||
overlayKeyForSession(chatSession),
|
||||
);
|
||||
upsertSessionMetadataOverlayRecord(
|
||||
buildOverlayRecord(chatSession, existing),
|
||||
);
|
||||
return chatSession;
|
||||
},
|
||||
|
||||
promoteDraft: (id) => {
|
||||
const session = get().sessions.find((candidate) => candidate.id === id);
|
||||
if (!session?.draft) return;
|
||||
set((state) => ({
|
||||
sessions: state.sessions.map((candidate) =>
|
||||
candidate.id === id ? { ...candidate, draft: undefined } : candidate,
|
||||
),
|
||||
}));
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
},
|
||||
|
||||
removeDraft: (id) => {
|
||||
const session = get().sessions.find((candidate) => candidate.id === id);
|
||||
if (!session?.draft) return;
|
||||
const { [id]: _ignoredPanelState, ...remainingPanelState } =
|
||||
get().contextPanelOpenBySession;
|
||||
const { [id]: _ignoredContext, ...remainingContextState } =
|
||||
get().activeWorkspaceBySession;
|
||||
const remainingModels = { ...get().modelsBySession };
|
||||
delete remainingModels[id];
|
||||
set((state) => ({
|
||||
sessions: state.sessions.filter((candidate) => candidate.id !== id),
|
||||
activeSessionId:
|
||||
state.activeSessionId === id ? null : state.activeSessionId,
|
||||
contextPanelOpenBySession: remainingPanelState,
|
||||
activeWorkspaceBySession: remainingContextState,
|
||||
modelsBySession: remainingModels,
|
||||
}));
|
||||
removeDraftSessionRecord(id);
|
||||
},
|
||||
|
||||
loadSessions: async () => {
|
||||
set({ isLoading: true });
|
||||
try {
|
||||
const overlays = loadSessionMetadataOverlay();
|
||||
const acpSessions = await acpListSessions();
|
||||
const persistedDrafts = loadDraftSessionRecords();
|
||||
const currentDrafts = get().sessions.filter((session) => session.draft);
|
||||
const drafts = mergeDraftSessions(currentDrafts, persistedDrafts);
|
||||
const mergedAcpSessions = sortByUpdatedAtDesc(
|
||||
acpSessions.map((session) =>
|
||||
mergeAcpSessionWithOverlay(session, overlays.get(session.sessionId)),
|
||||
),
|
||||
);
|
||||
const merged = [...mergedAcpSessions, ...drafts];
|
||||
const merged = mergedAcpSessions;
|
||||
const activeSessionId = get().activeSessionId;
|
||||
const activeSessionStillExists =
|
||||
activeSessionId == null ||
|
||||
|
|
@ -382,22 +283,16 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
sessions: merged,
|
||||
activeSessionId: activeSessionStillExists ? activeSessionId : null,
|
||||
});
|
||||
persistDraftsFromSessions(merged);
|
||||
syncOverlaySnapshots(mergedAcpSessions, overlays);
|
||||
} catch (error) {
|
||||
console.error("Failed to load sessions from ACP:", error);
|
||||
const overlays = loadSessionMetadataOverlay();
|
||||
const persistedDrafts = loadDraftSessionRecords();
|
||||
const currentDrafts = get().sessions.filter((session) => session.draft);
|
||||
const drafts = mergeDraftSessions(currentDrafts, persistedDrafts);
|
||||
const fallbackSessions = sortByUpdatedAtDesc([
|
||||
...[...overlays.values()].map(overlayToFallbackSession),
|
||||
...drafts,
|
||||
]);
|
||||
const fallbackSessions = sortByUpdatedAtDesc(
|
||||
[...overlays.values()].map(overlayToFallbackSession),
|
||||
);
|
||||
set({ sessions: fallbackSessions });
|
||||
persistDraftsFromSessions(fallbackSessions);
|
||||
} finally {
|
||||
set({ isLoading: false });
|
||||
set({ isLoading: false, hasHydratedSessions: true });
|
||||
}
|
||||
},
|
||||
|
||||
|
|
@ -415,23 +310,15 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
}));
|
||||
|
||||
const updatedSession = get().sessions.find((session) => session.id === id);
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
|
||||
if (
|
||||
updatedSession &&
|
||||
!updatedSession.draft &&
|
||||
opts?.persistOverlay !== false
|
||||
) {
|
||||
if (updatedSession && opts?.persistOverlay !== false) {
|
||||
const key = overlayKeyForSession(updatedSession);
|
||||
const existing = loadSessionMetadataOverlay().get(key);
|
||||
upsertSessionMetadataOverlayRecord(
|
||||
buildOverlayRecord(updatedSession, existing),
|
||||
);
|
||||
}
|
||||
|
||||
if (opts?.localOnly) return;
|
||||
if (updatedSession?.draft) return;
|
||||
// TODO: wire non-draft updates to ACP when supported
|
||||
// TODO: wire session updates to ACP when supported
|
||||
},
|
||||
|
||||
addSession: (session) => {
|
||||
|
|
@ -450,20 +337,15 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
}
|
||||
return { sessions: [normalizedSession, ...state.sessions] };
|
||||
});
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
if (!normalizedSession.draft) {
|
||||
const existing = loadSessionMetadataOverlay().get(
|
||||
overlayKeyForSession(normalizedSession),
|
||||
);
|
||||
upsertSessionMetadataOverlayRecord(
|
||||
buildOverlayRecord(normalizedSession, existing),
|
||||
);
|
||||
}
|
||||
const existing = loadSessionMetadataOverlay().get(
|
||||
overlayKeyForSession(normalizedSession),
|
||||
);
|
||||
upsertSessionMetadataOverlayRecord(
|
||||
buildOverlayRecord(normalizedSession, existing),
|
||||
);
|
||||
},
|
||||
|
||||
archiveSession: async (id) => {
|
||||
const remainingModels = { ...get().modelsBySession };
|
||||
delete remainingModels[id];
|
||||
set((state) => ({
|
||||
sessions: state.sessions.map((session) =>
|
||||
session.id === id
|
||||
|
|
@ -472,11 +354,9 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
),
|
||||
activeSessionId:
|
||||
state.activeSessionId === id ? null : state.activeSessionId,
|
||||
modelsBySession: remainingModels,
|
||||
}));
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
const session = get().sessions.find((candidate) => candidate.id === id);
|
||||
if (session && !session.draft) {
|
||||
if (session) {
|
||||
const existing = loadSessionMetadataOverlay().get(
|
||||
overlayKeyForSession(session),
|
||||
);
|
||||
|
|
@ -490,9 +370,8 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
session.id === id ? { ...session, archivedAt: undefined } : session,
|
||||
),
|
||||
}));
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
const session = get().sessions.find((candidate) => candidate.id === id);
|
||||
if (session && !session.draft) {
|
||||
if (session) {
|
||||
const existing = loadSessionMetadataOverlay().get(
|
||||
overlayKeyForSession(session),
|
||||
);
|
||||
|
|
@ -500,36 +379,6 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
}
|
||||
},
|
||||
|
||||
setSessionAcpId: (id, acpSessionId) => {
|
||||
if (!acpSessionId) return;
|
||||
const session = get().sessions.find((candidate) => candidate.id === id);
|
||||
if (!session || session.acpSessionId === acpSessionId) return;
|
||||
|
||||
set((state) => ({
|
||||
sessions: state.sessions.map((candidate) =>
|
||||
candidate.id === id ? { ...candidate, acpSessionId } : candidate,
|
||||
),
|
||||
}));
|
||||
|
||||
if (!session.draft) {
|
||||
migrateSessionMetadataOverlayId(
|
||||
overlayKeyForSession(session),
|
||||
acpSessionId,
|
||||
);
|
||||
const updatedSession = get().sessions.find(
|
||||
(candidate) => candidate.id === id,
|
||||
);
|
||||
if (updatedSession) {
|
||||
const existing = loadSessionMetadataOverlay().get(acpSessionId);
|
||||
upsertSessionMetadataOverlayRecord(
|
||||
buildOverlayRecord(updatedSession, existing),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
}
|
||||
},
|
||||
|
||||
setActiveSession: (sessionId) => {
|
||||
if (get().activeSessionId === sessionId) return;
|
||||
set({ activeSessionId: sessionId });
|
||||
|
|
@ -560,41 +409,24 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
});
|
||||
},
|
||||
|
||||
setSessionModels: (sessionId, models) => {
|
||||
switchSessionProvider: (sessionId, providerId) => {
|
||||
set((state) => ({
|
||||
modelsBySession: {
|
||||
...state.modelsBySession,
|
||||
[sessionId]: models,
|
||||
},
|
||||
}));
|
||||
},
|
||||
|
||||
switchSessionProvider: (sessionId, providerId, models) => {
|
||||
set((state) => ({
|
||||
modelsBySession: {
|
||||
...state.modelsBySession,
|
||||
[sessionId]: models,
|
||||
},
|
||||
sessions: state.sessions.map((session) =>
|
||||
session.id === sessionId
|
||||
? {
|
||||
...session,
|
||||
providerId,
|
||||
modelId: models.length > 0 ? models[0].id : undefined,
|
||||
modelName:
|
||||
models.length > 0
|
||||
? (models[0].displayName ?? models[0].name)
|
||||
: undefined,
|
||||
modelId: undefined,
|
||||
modelName: undefined,
|
||||
updatedAt: session.updatedAt,
|
||||
}
|
||||
: session,
|
||||
),
|
||||
}));
|
||||
persistDraftsFromSessions(get().sessions);
|
||||
const session = get().sessions.find(
|
||||
(candidate) => candidate.id === sessionId,
|
||||
);
|
||||
if (session && !session.draft) {
|
||||
if (session) {
|
||||
const existing = loadSessionMetadataOverlay().get(
|
||||
overlayKeyForSession(session),
|
||||
);
|
||||
|
|
@ -602,25 +434,6 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
}
|
||||
},
|
||||
|
||||
cacheModelsForProvider: (providerId, models) => {
|
||||
if (models.length === 0) return;
|
||||
const existing = get().modelCacheByProvider[providerId];
|
||||
if (modelIdsMatch(existing, models)) {
|
||||
return;
|
||||
}
|
||||
set((state) => {
|
||||
const updated = {
|
||||
...state.modelCacheByProvider,
|
||||
[providerId]: models,
|
||||
};
|
||||
persistModelCache(updated);
|
||||
return { modelCacheByProvider: updated };
|
||||
});
|
||||
},
|
||||
|
||||
getCachedModels: (providerId) =>
|
||||
get().modelCacheByProvider[providerId] ?? EMPTY_MODELS,
|
||||
|
||||
getSession: (id) => get().sessions.find((session) => session.id === id),
|
||||
|
||||
getActiveSession: () => {
|
||||
|
|
@ -631,7 +444,4 @@ export const useChatSessionStore = create<ChatSessionStore>((set, get) => ({
|
|||
|
||||
getArchivedSessions: () =>
|
||||
get().sessions.filter((session) => !!session.archivedAt),
|
||||
|
||||
getSessionModels: (sessionId) =>
|
||||
get().modelsBySession[sessionId] ?? EMPTY_MODELS,
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -1,6 +1,62 @@
|
|||
import type { AcpProvider } from "@/shared/api/acp";
|
||||
import type { Persona } from "@/shared/types/agents";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
|
||||
export interface ModelOption {
|
||||
id: string;
|
||||
name: string;
|
||||
displayName?: string;
|
||||
provider?: string;
|
||||
providerId?: string;
|
||||
providerName?: string;
|
||||
/** Whether this model should appear in the compact recommended picker. */
|
||||
recommended?: boolean;
|
||||
}
|
||||
|
||||
export interface ProjectOption {
|
||||
id: string;
|
||||
name: string;
|
||||
workingDirs: string[];
|
||||
color?: string | null;
|
||||
}
|
||||
|
||||
export interface ChatInputProps {
|
||||
onSend: (
|
||||
text: string,
|
||||
personaId?: string,
|
||||
attachments?: ChatAttachmentDraft[],
|
||||
) => void;
|
||||
onStop?: () => void;
|
||||
isStreaming?: boolean;
|
||||
disabled?: boolean;
|
||||
queuedMessage?: { text: string } | null;
|
||||
onDismissQueue?: () => void;
|
||||
initialValue?: string;
|
||||
onDraftChange?: (text: string) => void;
|
||||
className?: string;
|
||||
personas?: Persona[];
|
||||
selectedPersonaId?: string | null;
|
||||
onPersonaChange?: (personaId: string | null) => void;
|
||||
onCreatePersona?: () => void;
|
||||
providers?: AcpProvider[];
|
||||
providersLoading?: boolean;
|
||||
selectedProvider?: string;
|
||||
onProviderChange?: (providerId: string) => void;
|
||||
currentModelId?: string | null;
|
||||
currentModel?: string;
|
||||
availableModels?: ModelOption[];
|
||||
modelsLoading?: boolean;
|
||||
modelStatusMessage?: string | null;
|
||||
onModelChange?: (modelId: string) => void;
|
||||
selectedProjectId?: string | null;
|
||||
availableProjects?: ProjectOption[];
|
||||
onProjectChange?: (projectId: string | null) => void;
|
||||
onCreateProject?: (options?: {
|
||||
onCreated?: (projectId: string) => void;
|
||||
}) => void;
|
||||
contextTokens?: number;
|
||||
contextLimit?: number;
|
||||
onCompactContext?: () => void | Promise<void>;
|
||||
canCompactContext?: boolean;
|
||||
isCompactingContext?: boolean;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
import { useEffect, useMemo, useState, type ReactNode } from "react";
|
||||
import { useEffect, useMemo, useRef, useState, type ReactNode } from "react";
|
||||
import {
|
||||
IconCheck,
|
||||
IconChevronDown,
|
||||
IconChevronRight,
|
||||
IconChevronLeft,
|
||||
IconSearch,
|
||||
} from "@tabler/icons-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import type { AcpProvider } from "@/shared/api/acp";
|
||||
import { getProviderInventory } from "@/features/providers/api/inventory";
|
||||
import { useProviderInventoryStore } from "@/features/providers/stores/providerInventoryStore";
|
||||
import { cn } from "@/shared/lib/cn";
|
||||
import { Button } from "@/shared/ui/button";
|
||||
import { Popover, PopoverContent, PopoverTrigger } from "@/shared/ui/popover";
|
||||
import { ScrollArea } from "@/shared/ui/scroll-area";
|
||||
import { Spinner } from "@/shared/ui/spinner";
|
||||
import {
|
||||
formatProviderLabel,
|
||||
getProviderIcon,
|
||||
|
|
@ -23,66 +27,42 @@ interface AgentModelPickerProps {
|
|||
currentModelId?: string | null;
|
||||
currentModelName?: string | null;
|
||||
availableModels: ModelOption[];
|
||||
modelsLoading?: boolean;
|
||||
modelStatusMessage?: string | null;
|
||||
onModelChange?: (modelId: string) => void;
|
||||
loading?: boolean;
|
||||
isCompact?: boolean;
|
||||
}
|
||||
|
||||
interface ModelGroup {
|
||||
provider: string;
|
||||
models: ModelOption[];
|
||||
hasSelectedModel: boolean;
|
||||
}
|
||||
|
||||
const MODEL_PROVIDER_MATCHERS: Array<[string, RegExp]> = [
|
||||
["Anthropic", /claude|anthropic/i],
|
||||
["OpenAI", /(^|[\s-])(gpt|o1|o3|o4)([\s-]|$)|openai/i],
|
||||
["Google", /gemini|google/i],
|
||||
["Mistral", /mistral/i],
|
||||
["Meta", /llama|meta/i],
|
||||
["DeepSeek", /deepseek/i],
|
||||
["Qwen", /qwen/i],
|
||||
["Cohere", /cohere|command/i],
|
||||
];
|
||||
|
||||
function getModelDisplayName(model: ModelOption) {
|
||||
return model.displayName ?? model.name;
|
||||
}
|
||||
|
||||
function inferModelProvider(model: ModelOption) {
|
||||
if (model.provider) {
|
||||
return model.provider;
|
||||
function getGooseModelProviderLabel(model: ModelOption) {
|
||||
if (model.providerName) {
|
||||
return model.providerName;
|
||||
}
|
||||
|
||||
const candidate = `${model.id} ${model.name}`;
|
||||
for (const [provider, pattern] of MODEL_PROVIDER_MATCHERS) {
|
||||
if (pattern.test(candidate)) {
|
||||
return provider;
|
||||
}
|
||||
if (model.providerId) {
|
||||
return formatProviderLabel(model.providerId);
|
||||
}
|
||||
|
||||
return "Other";
|
||||
return null;
|
||||
}
|
||||
|
||||
function groupModels(models: ModelOption[], currentModelId: string | null) {
|
||||
const grouped = new Map<string, ModelOption[]>();
|
||||
function sortModels(models: ModelOption[], currentModelId: string | null) {
|
||||
return [...models].sort((left, right) => {
|
||||
if (left.id === currentModelId) return -1;
|
||||
if (right.id === currentModelId) return 1;
|
||||
|
||||
for (const model of models) {
|
||||
const provider = inferModelProvider(model);
|
||||
const existing = grouped.get(provider) ?? [];
|
||||
existing.push(model);
|
||||
grouped.set(provider, existing);
|
||||
}
|
||||
const leftProvider = getGooseModelProviderLabel(left) ?? "";
|
||||
const rightProvider = getGooseModelProviderLabel(right) ?? "";
|
||||
if (leftProvider !== rightProvider) {
|
||||
return leftProvider.localeCompare(rightProvider);
|
||||
}
|
||||
|
||||
const groups = Array.from(grouped.entries())
|
||||
.map(([provider, groupedModels]) => ({
|
||||
provider,
|
||||
models: groupedModels,
|
||||
hasSelectedModel: groupedModels.some((m) => m.id === currentModelId),
|
||||
}))
|
||||
.sort((left, right) => left.provider.localeCompare(right.provider));
|
||||
|
||||
return groups;
|
||||
return getModelDisplayName(left).localeCompare(getModelDisplayName(right));
|
||||
});
|
||||
}
|
||||
|
||||
function PickerItem({
|
||||
|
|
@ -116,6 +96,219 @@ function PickerItem({
|
|||
);
|
||||
}
|
||||
|
||||
// ── Model list views ────────────────────────────────────────────────
|
||||
|
||||
type ModelView = "recommended" | "all";
|
||||
|
||||
function RecommendedModelList({
|
||||
models,
|
||||
currentModelId,
|
||||
selectedAgentId,
|
||||
onModelSelect,
|
||||
onShowAll,
|
||||
t,
|
||||
}: {
|
||||
models: ModelOption[];
|
||||
currentModelId: string | null;
|
||||
selectedAgentId: string;
|
||||
onModelSelect: (id: string) => void;
|
||||
onShowAll: () => void;
|
||||
t: (key: string) => string;
|
||||
}) {
|
||||
const recommended = useMemo(() => {
|
||||
const rec = models.filter((m) => m.recommended);
|
||||
// If the current model isn't in the recommended list, prepend it
|
||||
// so the user can always see what's selected.
|
||||
if (
|
||||
currentModelId &&
|
||||
rec.length > 0 &&
|
||||
!rec.some((m) => m.id === currentModelId)
|
||||
) {
|
||||
const current = models.find((m) => m.id === currentModelId);
|
||||
if (current) {
|
||||
return [current, ...rec];
|
||||
}
|
||||
}
|
||||
// Fall back to full list if no recommendations exist (e.g. ACP agents).
|
||||
return rec.length > 0 ? rec : models;
|
||||
}, [models, currentModelId]);
|
||||
|
||||
const sorted = useMemo(
|
||||
() => sortModels(recommended, currentModelId),
|
||||
[recommended, currentModelId],
|
||||
);
|
||||
|
||||
const hasMore = models.length > recommended.length;
|
||||
|
||||
return (
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
|
||||
{t("toolbar.model")}
|
||||
</div>
|
||||
<ScrollArea className="min-h-0 min-w-0 flex-1">
|
||||
<div className="space-y-0.5 p-1">
|
||||
{sorted.map((model) => {
|
||||
const providerLabel = getGooseModelProviderLabel(model);
|
||||
return (
|
||||
<PickerItem
|
||||
key={`${model.providerId ?? "model"}:${model.id}`}
|
||||
onClick={() => onModelSelect(model.id)}
|
||||
selected={model.id === currentModelId}
|
||||
className="justify-between"
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
|
||||
{selectedAgentId === "goose" && model.providerId ? (
|
||||
<span
|
||||
className="shrink-0 text-muted-foreground"
|
||||
title={providerLabel ?? undefined}
|
||||
>
|
||||
{getProviderIcon(model.providerId, "size-3.5")}
|
||||
</span>
|
||||
) : null}
|
||||
<div className="min-w-0 flex-1 truncate">
|
||||
{getModelDisplayName(model)}
|
||||
</div>
|
||||
</div>
|
||||
{model.id === currentModelId ? (
|
||||
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
|
||||
) : null}
|
||||
</PickerItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
{hasMore ? (
|
||||
<div className="shrink-0 border-t px-1 py-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onShowAll}
|
||||
className="flex w-full items-center gap-1.5 rounded-sm px-2 py-1.5 text-xs text-muted-foreground transition-colors hover:bg-muted hover:text-foreground"
|
||||
>
|
||||
<IconSearch className="size-3.5" />
|
||||
<span>{t("toolbar.showAllModels")}</span>
|
||||
</button>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function AllModelsList({
|
||||
models,
|
||||
currentModelId,
|
||||
selectedAgentId,
|
||||
onModelSelect,
|
||||
onBack,
|
||||
t,
|
||||
}: {
|
||||
models: ModelOption[];
|
||||
currentModelId: string | null;
|
||||
selectedAgentId: string;
|
||||
onModelSelect: (id: string) => void;
|
||||
onBack: () => void;
|
||||
t: (key: string) => string;
|
||||
}) {
|
||||
const [query, setQuery] = useState("");
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
// Auto-focus search on mount.
|
||||
inputRef.current?.focus();
|
||||
}, []);
|
||||
|
||||
const filtered = useMemo(() => {
|
||||
if (!query.trim()) {
|
||||
return sortModels(models, currentModelId);
|
||||
}
|
||||
const q = query.toLowerCase();
|
||||
const matches = models.filter(
|
||||
(m) =>
|
||||
m.name.toLowerCase().includes(q) ||
|
||||
m.id.toLowerCase().includes(q) ||
|
||||
m.displayName?.toLowerCase().includes(q) ||
|
||||
m.providerName?.toLowerCase().includes(q) ||
|
||||
m.providerId?.toLowerCase().includes(q),
|
||||
);
|
||||
return sortModels(matches, currentModelId);
|
||||
}, [models, query, currentModelId]);
|
||||
|
||||
return (
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||
<div className="flex shrink-0 items-center gap-1 px-1 py-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onBack}
|
||||
className="flex shrink-0 items-center rounded-sm p-1 text-muted-foreground transition-colors hover:bg-muted hover:text-foreground"
|
||||
aria-label={t("toolbar.model")}
|
||||
>
|
||||
<IconChevronLeft className="size-4" />
|
||||
</button>
|
||||
<div className="relative min-w-0 flex-1">
|
||||
<IconSearch className="pointer-events-none absolute left-2 top-1/2 size-3.5 -translate-y-1/2 text-muted-foreground" />
|
||||
<input
|
||||
ref={inputRef}
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder={t("toolbar.searchModels")}
|
||||
className="h-7 w-full rounded-sm border bg-transparent pl-7 pr-2 text-sm outline-none placeholder:text-muted-foreground focus:ring-1 focus:ring-ring"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{filtered.length > 0 ? (
|
||||
<ScrollArea className="min-h-0 min-w-0 flex-1">
|
||||
<div className="space-y-0.5 p-1">
|
||||
{filtered.map((model) => {
|
||||
const providerLabel = getGooseModelProviderLabel(model);
|
||||
const displayName = getModelDisplayName(model);
|
||||
// Show the raw model_id as secondary text when it differs from name
|
||||
const showModelId =
|
||||
model.id !== model.name && model.id !== displayName;
|
||||
|
||||
return (
|
||||
<PickerItem
|
||||
key={`${model.providerId ?? "model"}:${model.id}`}
|
||||
onClick={() => onModelSelect(model.id)}
|
||||
selected={model.id === currentModelId}
|
||||
className="justify-between"
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-2 overflow-hidden">
|
||||
{selectedAgentId === "goose" && model.providerId ? (
|
||||
<span
|
||||
className="shrink-0 text-muted-foreground"
|
||||
title={providerLabel ?? undefined}
|
||||
>
|
||||
{getProviderIcon(model.providerId, "size-3.5")}
|
||||
</span>
|
||||
) : null}
|
||||
<div className="min-w-0 flex-1 overflow-hidden">
|
||||
<div className="truncate">{displayName}</div>
|
||||
{showModelId ? (
|
||||
<div className="truncate text-xs text-muted-foreground">
|
||||
{model.id}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
{model.id === currentModelId ? (
|
||||
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
|
||||
) : null}
|
||||
</PickerItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
) : (
|
||||
<div className="px-3 py-4 text-center text-sm text-muted-foreground">
|
||||
{t("toolbar.noSearchResults")}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ── Main component ──────────────────────────────────────────────────
|
||||
|
||||
export function AgentModelPicker({
|
||||
agents,
|
||||
selectedAgentId,
|
||||
|
|
@ -123,54 +316,31 @@ export function AgentModelPicker({
|
|||
currentModelId = null,
|
||||
currentModelName = null,
|
||||
availableModels,
|
||||
modelsLoading = false,
|
||||
modelStatusMessage = null,
|
||||
onModelChange,
|
||||
loading = false,
|
||||
isCompact = false,
|
||||
}: AgentModelPickerProps) {
|
||||
const { t } = useTranslation("chat");
|
||||
const [open, setOpen] = useState(false);
|
||||
const [expandedGroups, setExpandedGroups] = useState<Set<string>>(new Set());
|
||||
const [modelView, setModelView] = useState<ModelView>("recommended");
|
||||
const mergeInventoryEntries = useProviderInventoryStore(
|
||||
(s) => s.mergeEntries,
|
||||
);
|
||||
|
||||
const selectedAgentLabel =
|
||||
agents.find((agent) => agent.id === selectedAgentId)?.label ??
|
||||
formatProviderLabel(selectedAgentId);
|
||||
const groupedModels = useMemo<ModelGroup[]>(
|
||||
() => groupModels(availableModels, currentModelId),
|
||||
[availableModels, currentModelId],
|
||||
);
|
||||
const hasModelInfo = currentModelName !== null || availableModels.length > 0;
|
||||
const triggerModelLabel = hasModelInfo
|
||||
? (currentModelName ?? t("toolbar.loading"))
|
||||
const hasSelectedModel = currentModelName !== null || currentModelId !== null;
|
||||
const triggerModelLabel = hasSelectedModel
|
||||
? (currentModelName ?? currentModelId)
|
||||
: null;
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
const selected = groupedModels.find((g) => g.hasSelectedModel);
|
||||
if (selected) {
|
||||
setExpandedGroups(new Set([selected.provider]));
|
||||
}
|
||||
}
|
||||
}, [open, groupedModels]);
|
||||
|
||||
const toggleGroup = (provider: string) => {
|
||||
setExpandedGroups((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(provider)) {
|
||||
next.delete(provider);
|
||||
} else {
|
||||
next.add(provider);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const isGroupExpanded = (group: ModelGroup) => {
|
||||
return expandedGroups.has(group.provider);
|
||||
};
|
||||
|
||||
const handleAgentSelect = (agentId: string) => {
|
||||
if (agentId !== selectedAgentId) {
|
||||
onAgentChange(agentId);
|
||||
setModelView("recommended");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -179,6 +349,42 @@ export function AgentModelPicker({
|
|||
setOpen(false);
|
||||
};
|
||||
|
||||
// Reset to recommended view when popover closes.
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
setModelView("recommended");
|
||||
}
|
||||
}, [open]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
|
||||
const syncInventory = async () => {
|
||||
try {
|
||||
const entries = await getProviderInventory();
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
mergeInventoryEntries(entries);
|
||||
} catch (error) {
|
||||
console.error("Failed to sync provider inventory from picker:", error);
|
||||
}
|
||||
};
|
||||
|
||||
void syncInventory();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [open, mergeInventoryEntries]);
|
||||
|
||||
// When in "all" view, expand the popover to full width for the search experience.
|
||||
const isAllView = modelView === "all";
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
|
|
@ -251,133 +457,97 @@ export function AgentModelPicker({
|
|||
}
|
||||
}}
|
||||
>
|
||||
<div className="grid h-full grid-cols-[minmax(0,1fr)_minmax(0,1fr)] gap-1 overflow-hidden">
|
||||
<div
|
||||
data-col="agent"
|
||||
className="flex min-h-0 min-w-0 overflow-hidden p-1"
|
||||
>
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
|
||||
{t("toolbar.agent")}
|
||||
</div>
|
||||
<ScrollArea className="min-h-0 min-w-0 flex-1">
|
||||
<div className="p-1 space-y-0.5">
|
||||
{agents.map((agent) => {
|
||||
const isSelected = agent.id === selectedAgentId;
|
||||
|
||||
return (
|
||||
<PickerItem
|
||||
key={agent.id}
|
||||
onClick={() => handleAgentSelect(agent.id)}
|
||||
selected={isSelected}
|
||||
>
|
||||
<span className="shrink-0">
|
||||
{getProviderIcon(agent.id, "size-4")}
|
||||
</span>
|
||||
<span className="min-w-0 flex-1 truncate">
|
||||
{agent.label}
|
||||
</span>
|
||||
{isSelected ? (
|
||||
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
|
||||
) : null}
|
||||
</PickerItem>
|
||||
);
|
||||
})}
|
||||
<div
|
||||
className={cn(
|
||||
"grid h-full gap-1 overflow-hidden",
|
||||
isAllView
|
||||
? "grid-cols-1"
|
||||
: "grid-cols-[minmax(0,1fr)_minmax(0,1fr)]",
|
||||
)}
|
||||
>
|
||||
{/* Agent column — hidden in "all models" search view */}
|
||||
{!isAllView ? (
|
||||
<div
|
||||
data-col="agent"
|
||||
className="flex min-h-0 min-w-0 overflow-hidden p-1"
|
||||
>
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
|
||||
{t("toolbar.agent")}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
data-col="model"
|
||||
className="flex min-h-0 min-w-0 overflow-hidden p-1"
|
||||
>
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
|
||||
{t("toolbar.model")}
|
||||
</div>
|
||||
{groupedModels.length > 0 ? (
|
||||
<ScrollArea className="min-h-0 min-w-0 flex-1">
|
||||
<div className="p-1 space-y-0.5">
|
||||
{groupedModels.map((group) => {
|
||||
const expanded = isGroupExpanded(group);
|
||||
<div className="space-y-0.5 p-1">
|
||||
{agents.map((agent) => {
|
||||
const isSelected = agent.id === selectedAgentId;
|
||||
|
||||
return (
|
||||
<div key={group.provider}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleGroup(group.provider)}
|
||||
className={cn(
|
||||
"flex min-w-0 w-full items-center gap-1.5 rounded-sm px-2 py-1.5 text-left text-sm font-medium transition-colors",
|
||||
"hover:bg-muted focus:bg-muted focus:outline-none",
|
||||
)}
|
||||
>
|
||||
<IconChevronRight
|
||||
className={cn(
|
||||
"size-3.5 shrink-0 text-muted-foreground transition-transform",
|
||||
expanded && "rotate-90",
|
||||
)}
|
||||
/>
|
||||
<span className="min-w-0 flex-1 truncate">
|
||||
{group.provider}
|
||||
</span>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{group.models.length}
|
||||
</span>
|
||||
</button>
|
||||
{expanded ? (
|
||||
<div className="overflow-hidden pb-1">
|
||||
{group.models.map((model) => {
|
||||
const modelName = getModelDisplayName(model);
|
||||
|
||||
return (
|
||||
<PickerItem
|
||||
key={model.id}
|
||||
onClick={() => handleModelSelect(model.id)}
|
||||
selected={model.id === currentModelId}
|
||||
className="justify-between pl-6"
|
||||
>
|
||||
<div className="min-w-0 flex-1 truncate">
|
||||
{modelName}
|
||||
</div>
|
||||
{model.id === currentModelId ? (
|
||||
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
|
||||
) : null}
|
||||
</PickerItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : group.hasSelectedModel ? (
|
||||
<div className="overflow-hidden pb-1">
|
||||
{group.models
|
||||
.filter((m) => m.id === currentModelId)
|
||||
.map((model) => (
|
||||
<PickerItem
|
||||
key={model.id}
|
||||
onClick={() => handleModelSelect(model.id)}
|
||||
selected
|
||||
className="justify-between pl-6"
|
||||
>
|
||||
<div className="min-w-0 flex-1 truncate">
|
||||
{getModelDisplayName(model)}
|
||||
</div>
|
||||
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
|
||||
</PickerItem>
|
||||
))}
|
||||
</div>
|
||||
<PickerItem
|
||||
key={agent.id}
|
||||
onClick={() => handleAgentSelect(agent.id)}
|
||||
selected={isSelected}
|
||||
>
|
||||
<span className="shrink-0">
|
||||
{getProviderIcon(agent.id, "size-4")}
|
||||
</span>
|
||||
<span className="min-w-0 flex-1 truncate">
|
||||
{agent.label}
|
||||
</span>
|
||||
{isSelected ? (
|
||||
<IconCheck className="size-4 shrink-0 text-muted-foreground" />
|
||||
) : null}
|
||||
</div>
|
||||
</PickerItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{/* Model column */}
|
||||
<div
|
||||
data-col="model"
|
||||
className="flex min-h-0 min-w-0 overflow-hidden p-1"
|
||||
>
|
||||
{modelsLoading ? (
|
||||
<div className="flex min-h-0 flex-1 items-center gap-2 px-2 py-2 text-sm text-muted-foreground">
|
||||
<Spinner className="size-4" />
|
||||
<span>{t("toolbar.loadingModels")}</span>
|
||||
</div>
|
||||
) : availableModels.length > 0 ? (
|
||||
modelView === "recommended" ? (
|
||||
<RecommendedModelList
|
||||
models={availableModels}
|
||||
currentModelId={currentModelId}
|
||||
selectedAgentId={selectedAgentId}
|
||||
onModelSelect={handleModelSelect}
|
||||
onShowAll={() => setModelView("all")}
|
||||
t={t}
|
||||
/>
|
||||
) : (
|
||||
<AllModelsList
|
||||
models={availableModels}
|
||||
currentModelId={currentModelId}
|
||||
selectedAgentId={selectedAgentId}
|
||||
onModelSelect={handleModelSelect}
|
||||
onBack={() => setModelView("recommended")}
|
||||
t={t}
|
||||
/>
|
||||
)
|
||||
) : (
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col">
|
||||
<div className="shrink-0 px-2 py-1.5 text-sm font-semibold">
|
||||
{t("toolbar.model")}
|
||||
</div>
|
||||
<div className="px-2 py-2">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{currentModelName ? currentModelName : t("toolbar.loading")}
|
||||
{modelStatusMessage ??
|
||||
currentModelName ??
|
||||
t("toolbar.noModelsAvailable")}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</PopoverContent>
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import { useState, useRef, useCallback, useEffect, useMemo } from "react";
|
|||
import { open } from "@tauri-apps/plugin-dialog";
|
||||
import { X } from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import type { AcpProvider } from "@/shared/api/acp";
|
||||
import type { Persona } from "@/shared/types/agents";
|
||||
import { cn } from "@/shared/lib/cn";
|
||||
import { Badge } from "@/shared/ui/badge";
|
||||
import { Button } from "@/shared/ui/button";
|
||||
|
|
@ -14,61 +12,14 @@ import { ChatInputToolbar } from "./ChatInputToolbar";
|
|||
import { formatProviderLabel } from "@/shared/ui/icons/ProviderIcons";
|
||||
import { TooltipProvider } from "@/shared/ui/tooltip";
|
||||
import { PersonaAvatar } from "./PersonaPicker";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
import { useAttachmentDropTarget } from "../hooks/useAttachmentDropTarget";
|
||||
import {
|
||||
normalizeDialogSelection,
|
||||
useChatInputAttachments,
|
||||
} from "../hooks/useChatInputAttachments";
|
||||
import type { ModelOption } from "../types";
|
||||
import { ChatInputAttachments } from "./ChatInputAttachments";
|
||||
import { useVoiceDictation } from "../hooks/useVoiceDictation";
|
||||
|
||||
export interface ProjectOption {
|
||||
id: string;
|
||||
name: string;
|
||||
workingDirs: string[];
|
||||
color?: string | null;
|
||||
}
|
||||
|
||||
interface ChatInputProps {
|
||||
onSend: (
|
||||
text: string,
|
||||
personaId?: string,
|
||||
attachments?: ChatAttachmentDraft[],
|
||||
) => void;
|
||||
onStop?: () => void;
|
||||
isStreaming?: boolean;
|
||||
disabled?: boolean;
|
||||
queuedMessage?: { text: string } | null;
|
||||
onDismissQueue?: () => void;
|
||||
initialValue?: string;
|
||||
onDraftChange?: (text: string) => void;
|
||||
className?: string;
|
||||
personas?: Persona[];
|
||||
selectedPersonaId?: string | null;
|
||||
onPersonaChange?: (personaId: string | null) => void;
|
||||
onCreatePersona?: () => void;
|
||||
providers?: AcpProvider[];
|
||||
providersLoading?: boolean;
|
||||
selectedProvider?: string;
|
||||
onProviderChange?: (providerId: string) => void;
|
||||
currentModelId?: string | null;
|
||||
currentModel?: string;
|
||||
availableModels?: ModelOption[];
|
||||
onModelChange?: (modelId: string) => void;
|
||||
selectedProjectId?: string | null;
|
||||
availableProjects?: ProjectOption[];
|
||||
onProjectChange?: (projectId: string | null) => void;
|
||||
onCreateProject?: (options?: {
|
||||
onCreated?: (projectId: string) => void;
|
||||
}) => void;
|
||||
contextTokens?: number;
|
||||
contextLimit?: number;
|
||||
onCompactContext?: () => void | Promise<void>;
|
||||
canCompactContext?: boolean;
|
||||
isCompactingContext?: boolean;
|
||||
}
|
||||
import type { ChatInputProps } from "../types";
|
||||
|
||||
export function ChatInput({
|
||||
onSend,
|
||||
|
|
@ -91,6 +42,8 @@ export function ChatInput({
|
|||
currentModelId = null,
|
||||
currentModel,
|
||||
availableModels = [],
|
||||
modelsLoading = false,
|
||||
modelStatusMessage = null,
|
||||
onModelChange,
|
||||
selectedProjectId = null,
|
||||
availableProjects = [],
|
||||
|
|
@ -104,6 +57,9 @@ export function ChatInput({
|
|||
}: ChatInputProps) {
|
||||
const { t } = useTranslation("chat");
|
||||
const [text, setTextRaw] = useState(initialValue);
|
||||
useEffect(() => {
|
||||
setTextRaw(initialValue);
|
||||
}, [initialValue]);
|
||||
const setText = useCallback(
|
||||
(value: string) => {
|
||||
setTextRaw(value);
|
||||
|
|
@ -351,8 +307,18 @@ export function ChatInput({
|
|||
providers.find((provider) => provider.id === selectedProvider)?.label ??
|
||||
formatProviderLabel(selectedProvider);
|
||||
const agentDisplayName = activePersona?.displayName ?? providerDisplayName;
|
||||
const resolvedCurrentModel =
|
||||
currentModel ?? availableModels[0]?.displayName ?? availableModels[0]?.name;
|
||||
const resolvedCurrentModel = useMemo(() => {
|
||||
if (currentModel) {
|
||||
return currentModel;
|
||||
}
|
||||
if (!currentModelId) {
|
||||
return undefined;
|
||||
}
|
||||
const selectedModel = availableModels.find(
|
||||
(model) => model.id === currentModelId,
|
||||
);
|
||||
return selectedModel?.displayName ?? selectedModel?.name ?? currentModelId;
|
||||
}, [availableModels, currentModel, currentModelId]);
|
||||
const effectivePlaceholder = t("input.placeholder", {
|
||||
agent: agentDisplayName,
|
||||
});
|
||||
|
|
@ -472,6 +438,8 @@ export function ChatInput({
|
|||
currentModelId={currentModelId}
|
||||
currentModel={resolvedCurrentModel}
|
||||
availableModels={availableModels}
|
||||
modelsLoading={modelsLoading}
|
||||
modelStatusMessage={modelStatusMessage}
|
||||
onModelChange={onModelChange}
|
||||
selectedProjectId={selectedProjectId}
|
||||
availableProjects={availableProjects}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import { cn } from "@/shared/lib/cn";
|
|||
import { ChatInputSelector } from "./ChatInputSelector";
|
||||
import { ContextRing } from "./ContextRing";
|
||||
import { PersonaPicker } from "./PersonaPicker";
|
||||
import type { ProjectOption } from "./ChatInput";
|
||||
import type { ProjectOption } from "../types";
|
||||
import { Button } from "@/shared/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
|
|
@ -30,11 +30,7 @@ import { Tooltip, TooltipTrigger, TooltipContent } from "@/shared/ui/tooltip";
|
|||
import { AgentModelPicker } from "./AgentModelPicker";
|
||||
import type { ModelOption } from "../types";
|
||||
import { formatProviderLabel } from "@/shared/ui/icons/ProviderIcons";
|
||||
import { useAgentProviderStatus } from "@/features/providers/hooks/useAgentProviderStatus";
|
||||
import {
|
||||
getCatalogEntry,
|
||||
resolveAgentProviderCatalogIdStrict,
|
||||
} from "@/features/providers/providerCatalog";
|
||||
import { getCatalogEntry } from "@/features/providers/providerCatalog";
|
||||
|
||||
const NO_PROJECT_VALUE = "__no_project__";
|
||||
const CREATE_PROJECT_VALUE = "__create_project__";
|
||||
|
|
@ -67,6 +63,8 @@ interface ChatInputToolbarProps {
|
|||
currentModelId?: string | null;
|
||||
currentModel?: string;
|
||||
availableModels: ModelOption[];
|
||||
modelsLoading?: boolean;
|
||||
modelStatusMessage?: string | null;
|
||||
onModelChange?: (modelId: string) => void;
|
||||
// Project
|
||||
selectedProjectId: string | null;
|
||||
|
|
@ -111,6 +109,8 @@ export function ChatInputToolbar({
|
|||
currentModelId,
|
||||
currentModel,
|
||||
availableModels,
|
||||
modelsLoading = false,
|
||||
modelStatusMessage = null,
|
||||
onModelChange,
|
||||
selectedProjectId,
|
||||
availableProjects,
|
||||
|
|
@ -137,27 +137,22 @@ export function ChatInputToolbar({
|
|||
}: ChatInputToolbarProps) {
|
||||
const { t } = useTranslation("chat");
|
||||
const { formatNumber } = useLocaleFormatting();
|
||||
const { readyAgentIds } = useAgentProviderStatus();
|
||||
const [isContextPopoverOpen, setIsContextPopoverOpen] = useState(false);
|
||||
|
||||
const agentProviders = useMemo(() => {
|
||||
const seen = new Set<string>();
|
||||
const connected: AcpProvider[] = [];
|
||||
for (const p of providers) {
|
||||
const catalogId = resolveAgentProviderCatalogIdStrict(p.id);
|
||||
if (
|
||||
catalogId === null ||
|
||||
!readyAgentIds.has(catalogId) ||
|
||||
seen.has(catalogId)
|
||||
)
|
||||
const available: AcpProvider[] = [];
|
||||
for (const provider of providers) {
|
||||
if (seen.has(provider.id)) {
|
||||
continue;
|
||||
seen.add(catalogId);
|
||||
connected.push({
|
||||
id: p.id,
|
||||
label: getCatalogEntry(catalogId)?.displayName ?? p.label,
|
||||
}
|
||||
seen.add(provider.id);
|
||||
available.push({
|
||||
id: provider.id,
|
||||
label: getCatalogEntry(provider.id)?.displayName ?? provider.label,
|
||||
});
|
||||
}
|
||||
if (connected.length > 0) return connected;
|
||||
if (available.length > 0) return available;
|
||||
return [
|
||||
{
|
||||
id: selectedProvider,
|
||||
|
|
@ -166,7 +161,7 @@ export function ChatInputToolbar({
|
|||
formatProviderLabel(selectedProvider),
|
||||
},
|
||||
];
|
||||
}, [providers, readyAgentIds, selectedProvider]);
|
||||
}, [providers, selectedProvider]);
|
||||
const selectedProject = availableProjects.find(
|
||||
(project) => project.id === selectedProjectId,
|
||||
);
|
||||
|
|
@ -220,6 +215,8 @@ export function ChatInputToolbar({
|
|||
currentModelId={currentModelId}
|
||||
currentModelName={currentModel ?? null}
|
||||
availableModels={availableModels}
|
||||
modelsLoading={modelsLoading}
|
||||
modelStatusMessage={modelStatusMessage}
|
||||
onModelChange={onModelChange}
|
||||
loading={providersLoading}
|
||||
isCompact={isCompact}
|
||||
|
|
|
|||
|
|
@ -1,507 +1,97 @@
|
|||
import { useState, useEffect, useRef, useCallback, useMemo } from "react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { AnimatePresence } from "motion/react";
|
||||
import { MessageTimeline } from "./MessageTimeline";
|
||||
import { ChatInput } from "./ChatInput";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
import { LoadingGoose } from "./LoadingGoose";
|
||||
import { ChatLoadingSkeleton } from "./ChatLoadingSkeleton";
|
||||
import { useChat } from "../hooks/useChat";
|
||||
import { useMessageQueue } from "../hooks/useMessageQueue";
|
||||
import { useChatStore } from "../stores/chatStore";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
|
||||
import { useChatSessionStore } from "../stores/chatSessionStore";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import { acpPrepareSession, acpSetModel } from "@/shared/api/acp";
|
||||
import {
|
||||
buildProjectSystemPrompt,
|
||||
composeSystemPrompt,
|
||||
defaultGlobalArtifactRoot,
|
||||
getProjectArtifactRoots,
|
||||
resolveProjectDefaultArtifactRoot,
|
||||
} from "@/features/projects/lib/chatProjectContext";
|
||||
import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection";
|
||||
import { defaultGlobalArtifactRoot } from "@/features/projects/lib/chatProjectContext";
|
||||
import { ArtifactPolicyProvider } from "../hooks/ArtifactPolicyContext";
|
||||
import type { ModelOption } from "../types";
|
||||
import { ChatContextPanel } from "./ChatContextPanel";
|
||||
import { perfLog } from "@/shared/lib/perfLog";
|
||||
|
||||
const EMPTY_MODELS: ModelOption[] = [];
|
||||
import { useChatSessionController } from "../hooks/useChatSessionController";
|
||||
|
||||
interface ChatViewProps {
|
||||
sessionId: string;
|
||||
initialProvider?: string;
|
||||
initialPersonaId?: string;
|
||||
initialMessage?: string;
|
||||
initialAttachments?: ChatAttachmentDraft[];
|
||||
onInitialMessageConsumed?: () => void;
|
||||
onCreateProject?: (options?: {
|
||||
onCreated?: (projectId: string) => void;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
export function ChatView({
|
||||
sessionId,
|
||||
initialProvider,
|
||||
initialPersonaId,
|
||||
initialMessage,
|
||||
initialAttachments,
|
||||
onInitialMessageConsumed,
|
||||
onCreateProject,
|
||||
}: ChatViewProps) {
|
||||
export function ChatView({ sessionId, onCreateProject }: ChatViewProps) {
|
||||
const { t } = useTranslation("chat");
|
||||
const activeSessionId = sessionId;
|
||||
const mountStart = useRef(performance.now());
|
||||
const isContextPanelOpen = useChatSessionStore(
|
||||
(s) => s.contextPanelOpenBySession[sessionId] ?? false,
|
||||
);
|
||||
const setContextPanelOpen = useChatSessionStore((s) => s.setContextPanelOpen);
|
||||
const [globalArtifactRoot, setGlobalArtifactRoot] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
const controller = useChatSessionController({ sessionId });
|
||||
const contextPanelLabel = isContextPanelOpen
|
||||
? t("context.closePanel")
|
||||
: t("context.openPanel");
|
||||
const allowedArtifactRoots = [
|
||||
...controller.allowedArtifactRoots,
|
||||
...(globalArtifactRoot ? [globalArtifactRoot] : []),
|
||||
];
|
||||
|
||||
useEffect(() => {
|
||||
const ms = (performance.now() - mountStart.current).toFixed(1);
|
||||
perfLog(`[perf:chatview] ${sessionId.slice(0, 8)} mounted in ${ms}ms`);
|
||||
}, [sessionId]);
|
||||
const isContextPanelOpen = useChatSessionStore(
|
||||
(s) => s.contextPanelOpenBySession[activeSessionId] ?? false,
|
||||
);
|
||||
const setContextPanelOpen = useChatSessionStore((s) => s.setContextPanelOpen);
|
||||
const activeWorkspace = useChatSessionStore(
|
||||
(s) => s.activeWorkspaceBySession[activeSessionId],
|
||||
);
|
||||
const clearActiveWorkspace = useChatSessionStore(
|
||||
(s) => s.clearActiveWorkspace,
|
||||
);
|
||||
|
||||
const {
|
||||
providers,
|
||||
providersLoading,
|
||||
selectedProvider: globalSelectedProvider,
|
||||
setSelectedProvider: setGlobalSelectedProvider,
|
||||
} = useProviderSelection();
|
||||
const personas = useAgentStore((s) => s.personas);
|
||||
const [selectedPersonaId, setSelectedPersonaId] = useState<string | null>(
|
||||
initialPersonaId ?? null,
|
||||
);
|
||||
const session = useChatSessionStore((s) =>
|
||||
s.sessions.find((candidate) => candidate.id === activeSessionId),
|
||||
);
|
||||
const availableModels = useChatSessionStore(
|
||||
(s) => s.modelsBySession[activeSessionId] ?? EMPTY_MODELS,
|
||||
);
|
||||
const projects = useProjectStore((s) => s.projects);
|
||||
const projectsLoading = useProjectStore((s) => s.loading);
|
||||
const storedProject = useProjectStore((s) =>
|
||||
session?.projectId
|
||||
? s.projects.find((candidate) => candidate.id === session.projectId)
|
||||
: undefined,
|
||||
);
|
||||
const [globalArtifactRoot, setGlobalArtifactRoot] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
const project = storedProject ?? null;
|
||||
const contextPanelLabel = isContextPanelOpen
|
||||
? t("context.closePanel")
|
||||
: t("context.openPanel");
|
||||
const availableProjects = useMemo(
|
||||
() =>
|
||||
[...projects]
|
||||
.sort((a, b) => a.order - b.order || a.name.localeCompare(b.name))
|
||||
.map((projectInfo) => ({
|
||||
id: projectInfo.id,
|
||||
name: projectInfo.name,
|
||||
workingDirs: projectInfo.workingDirs,
|
||||
color: projectInfo.color,
|
||||
})),
|
||||
[projects],
|
||||
);
|
||||
const selectedProvider =
|
||||
session?.providerId ??
|
||||
initialProvider ??
|
||||
project?.preferredProvider ??
|
||||
globalSelectedProvider;
|
||||
|
||||
const selectedPersona = personas.find((p) => p.id === selectedPersonaId);
|
||||
const projectArtifactRoots = useMemo(
|
||||
() => getProjectArtifactRoots(project),
|
||||
[project],
|
||||
);
|
||||
const projectDefaultArtifactRoot = useMemo(
|
||||
() => resolveProjectDefaultArtifactRoot(project),
|
||||
[project],
|
||||
);
|
||||
const projectMetadataPending = Boolean(
|
||||
session?.projectId && !projectDefaultArtifactRoot && projectsLoading,
|
||||
);
|
||||
const allowedArtifactRoots = useMemo(() => {
|
||||
const roots = [
|
||||
...projectArtifactRoots.map((path) => path.trim()).filter(Boolean),
|
||||
];
|
||||
if (globalArtifactRoot) {
|
||||
roots.push(globalArtifactRoot);
|
||||
}
|
||||
return [...new Set(roots)];
|
||||
}, [globalArtifactRoot, projectArtifactRoots]);
|
||||
const projectSystemPrompt = useMemo(
|
||||
() => buildProjectSystemPrompt(project),
|
||||
[project],
|
||||
);
|
||||
const workingContextPrompt = useMemo(() => {
|
||||
if (!activeWorkspace?.branch) return undefined;
|
||||
return `<active-working-context>\nActive branch: ${activeWorkspace.branch}\nWorking directory: ${activeWorkspace.path}\n</active-working-context>`;
|
||||
}, [activeWorkspace?.branch, activeWorkspace?.path]);
|
||||
|
||||
const effectiveSystemPrompt = useMemo(
|
||||
() =>
|
||||
composeSystemPrompt(
|
||||
selectedPersona?.systemPrompt,
|
||||
projectSystemPrompt,
|
||||
workingContextPrompt,
|
||||
),
|
||||
[selectedPersona?.systemPrompt, projectSystemPrompt, workingContextPrompt],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
defaultGlobalArtifactRoot()
|
||||
.then((artifactRoot) => {
|
||||
if (cancelled) return;
|
||||
setGlobalArtifactRoot(artifactRoot);
|
||||
if (!cancelled) {
|
||||
setGlobalArtifactRoot(artifactRoot);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
if (cancelled) return;
|
||||
setGlobalArtifactRoot(null);
|
||||
if (!cancelled) {
|
||||
setGlobalArtifactRoot(null);
|
||||
}
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, []);
|
||||
|
||||
const prevProjectIdRef = useRef(session?.projectId);
|
||||
useEffect(() => {
|
||||
const prevProjectId = prevProjectIdRef.current;
|
||||
prevProjectIdRef.current = session?.projectId;
|
||||
if (prevProjectId !== undefined && prevProjectId !== session?.projectId) {
|
||||
clearActiveWorkspace(activeSessionId);
|
||||
}
|
||||
}, [session?.projectId, activeSessionId, clearActiveWorkspace]);
|
||||
|
||||
const prevWorkspaceRef = useRef(activeWorkspace);
|
||||
useEffect(() => {
|
||||
const prev = prevWorkspaceRef.current;
|
||||
if (
|
||||
!activeWorkspace ||
|
||||
!selectedProvider ||
|
||||
session?.draft ||
|
||||
activeWorkspace === prev
|
||||
) {
|
||||
return;
|
||||
}
|
||||
prevWorkspaceRef.current = activeWorkspace;
|
||||
if (prev && prev.path === activeWorkspace.path) return;
|
||||
|
||||
async function prepareWorkspaceSession() {
|
||||
const workingDir = await resolveSessionCwd(project, activeWorkspace.path);
|
||||
if (!workingDir) {
|
||||
return;
|
||||
}
|
||||
await acpPrepareSession(activeSessionId, selectedProvider, workingDir, {
|
||||
personaId: selectedPersonaId ?? undefined,
|
||||
});
|
||||
}
|
||||
|
||||
void prepareWorkspaceSession().catch((error) => {
|
||||
console.error("Failed to prepare ACP session:", error);
|
||||
});
|
||||
}, [
|
||||
activeWorkspace,
|
||||
activeSessionId,
|
||||
project,
|
||||
selectedProvider,
|
||||
selectedPersonaId,
|
||||
session?.draft,
|
||||
]);
|
||||
|
||||
const handleProviderChange = useCallback(
|
||||
(providerId: string) => {
|
||||
if (providerId === selectedProvider) {
|
||||
return;
|
||||
}
|
||||
const sessionStore = useChatSessionStore.getState();
|
||||
const cached = sessionStore.getCachedModels(providerId);
|
||||
sessionStore.switchSessionProvider(activeSessionId, providerId, cached);
|
||||
setGlobalSelectedProvider(providerId);
|
||||
},
|
||||
[activeSessionId, selectedProvider, setGlobalSelectedProvider],
|
||||
);
|
||||
|
||||
const handleProjectChange = useCallback(
|
||||
(projectId: string | null) => {
|
||||
const nextProject =
|
||||
projectId == null
|
||||
? null
|
||||
: (useProjectStore
|
||||
.getState()
|
||||
.projects.find((candidate) => candidate.id === projectId) ??
|
||||
null);
|
||||
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.updateSession(activeSessionId, { projectId });
|
||||
|
||||
if (!session?.draft && selectedProvider) {
|
||||
async function updateProjectSessionCwd() {
|
||||
const workingDir = await resolveSessionCwd(
|
||||
nextProject,
|
||||
activeWorkspace?.path,
|
||||
);
|
||||
if (!workingDir) {
|
||||
return;
|
||||
}
|
||||
|
||||
await acpPrepareSession(
|
||||
activeSessionId,
|
||||
selectedProvider,
|
||||
workingDir,
|
||||
{
|
||||
personaId: selectedPersonaId ?? undefined,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
void updateProjectSessionCwd().catch((error) => {
|
||||
console.error(
|
||||
"Failed to update ACP session working directory:",
|
||||
error,
|
||||
);
|
||||
});
|
||||
}
|
||||
},
|
||||
[
|
||||
activeSessionId,
|
||||
activeWorkspace?.path,
|
||||
selectedPersonaId,
|
||||
selectedProvider,
|
||||
session?.draft,
|
||||
],
|
||||
);
|
||||
const handleModelChange = useCallback(
|
||||
(modelId: string) => {
|
||||
if (!activeSessionId || modelId === session?.modelId) {
|
||||
return;
|
||||
}
|
||||
const previousModelId = session?.modelId;
|
||||
const previousModelName = session?.modelName;
|
||||
const models = useChatSessionStore
|
||||
.getState()
|
||||
.getSessionModels(activeSessionId);
|
||||
const selected = models.find((m) => m.id === modelId);
|
||||
useChatSessionStore.getState().updateSession(activeSessionId, {
|
||||
modelId,
|
||||
modelName: selected?.displayName ?? selected?.name ?? modelId,
|
||||
});
|
||||
if (session?.draft) {
|
||||
return;
|
||||
}
|
||||
acpSetModel(activeSessionId, modelId).catch((error) => {
|
||||
console.error("Failed to set model:", error);
|
||||
useChatSessionStore.getState().updateSession(activeSessionId, {
|
||||
modelId: previousModelId,
|
||||
modelName: previousModelName,
|
||||
});
|
||||
});
|
||||
},
|
||||
[activeSessionId, session?.draft, session?.modelId, session?.modelName],
|
||||
);
|
||||
|
||||
// When persona changes, update the provider to match persona's default
|
||||
const handlePersonaChange = useCallback(
|
||||
(personaId: string | null) => {
|
||||
setSelectedPersonaId(personaId);
|
||||
const persona = personas.find((p) => p.id === personaId);
|
||||
if (persona?.provider) {
|
||||
const matchingProvider = providers.find(
|
||||
(p) =>
|
||||
p.id === persona.provider ||
|
||||
p.label.toLowerCase().includes(persona.provider ?? ""),
|
||||
);
|
||||
if (matchingProvider) {
|
||||
handleProviderChange(matchingProvider.id);
|
||||
}
|
||||
}
|
||||
const agentStore = useAgentStore.getState();
|
||||
const matchingAgent = agentStore.agents.find(
|
||||
(a) => a.personaId === personaId,
|
||||
);
|
||||
if (matchingAgent) {
|
||||
agentStore.setActiveAgent(matchingAgent.id);
|
||||
}
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.updateSession(activeSessionId, { personaId: personaId ?? undefined });
|
||||
},
|
||||
[personas, providers, activeSessionId, handleProviderChange],
|
||||
);
|
||||
|
||||
// Validate persona still exists — fall back to default if deleted
|
||||
useEffect(() => {
|
||||
if (
|
||||
selectedPersonaId !== null &&
|
||||
personas.length > 0 &&
|
||||
!personas.find((p) => p.id === selectedPersonaId)
|
||||
) {
|
||||
// Selected persona was deleted — reset to no persona
|
||||
setSelectedPersonaId(null);
|
||||
}
|
||||
}, [personas, selectedPersonaId]);
|
||||
|
||||
const personaInfo = selectedPersona
|
||||
? { id: selectedPersona.id, name: selectedPersona.displayName }
|
||||
: undefined;
|
||||
const resolveCurrentSessionCwd = useCallback(
|
||||
() => resolveSessionCwd(project, activeWorkspace?.path),
|
||||
[project, activeWorkspace?.path],
|
||||
);
|
||||
const {
|
||||
messages,
|
||||
chatState,
|
||||
tokenState,
|
||||
sendMessage,
|
||||
compactConversation,
|
||||
stopStreaming,
|
||||
streamingMessageId,
|
||||
} = useChat(
|
||||
activeSessionId,
|
||||
selectedProvider,
|
||||
effectiveSystemPrompt,
|
||||
personaInfo,
|
||||
resolveCurrentSessionCwd,
|
||||
);
|
||||
const isLoadingHistory = useChatStore(
|
||||
(s) =>
|
||||
s.loadingSessionIds.has(activeSessionId) &&
|
||||
(s.messagesBySession[activeSessionId]?.length ?? 0) === 0,
|
||||
);
|
||||
|
||||
const deferredSend = useRef<{
|
||||
text: string;
|
||||
attachments?: ChatAttachmentDraft[];
|
||||
} | null>(null);
|
||||
const queue = useMessageQueue(activeSessionId, chatState, sendMessage);
|
||||
const chatStore = useChatStore();
|
||||
const handleSend = useCallback(
|
||||
(text: string, personaId?: string, attachments?: ChatAttachmentDraft[]) => {
|
||||
if (personaId && personaId !== selectedPersonaId) {
|
||||
const newPersona = personas.find((p) => p.id === personaId);
|
||||
if (newPersona) {
|
||||
// Inject a system notification about the persona switch
|
||||
chatStore.addMessage(activeSessionId, {
|
||||
id: crypto.randomUUID(),
|
||||
role: "system",
|
||||
created: Date.now(),
|
||||
content: [
|
||||
{
|
||||
type: "systemNotification",
|
||||
notificationType: "info",
|
||||
text: `Switched to ${newPersona.displayName}`,
|
||||
},
|
||||
],
|
||||
metadata: { userVisible: true, agentVisible: false },
|
||||
});
|
||||
}
|
||||
handlePersonaChange(personaId);
|
||||
// Defer the send until after persona state updates
|
||||
deferredSend.current = { text, attachments };
|
||||
return;
|
||||
}
|
||||
// Queue if agent is busy and no message already queued
|
||||
if (chatState !== "idle" && !queue.queuedMessage) {
|
||||
queue.enqueue(text, personaId, attachments);
|
||||
return;
|
||||
}
|
||||
|
||||
sendMessage(text, undefined, attachments);
|
||||
},
|
||||
[
|
||||
sendMessage,
|
||||
selectedPersonaId,
|
||||
handlePersonaChange,
|
||||
personas,
|
||||
chatStore,
|
||||
activeSessionId,
|
||||
chatState,
|
||||
queue,
|
||||
],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (deferredSend.current && selectedPersona) {
|
||||
const { text, attachments } = deferredSend.current;
|
||||
deferredSend.current = null;
|
||||
sendMessage(text, undefined, attachments);
|
||||
}
|
||||
}, [sendMessage, selectedPersona]);
|
||||
const initialMessageSent = useRef(false);
|
||||
useEffect(() => {
|
||||
if (
|
||||
(initialMessage || initialAttachments?.length) &&
|
||||
!initialMessageSent.current
|
||||
) {
|
||||
initialMessageSent.current = true;
|
||||
handleSend(initialMessage ?? "", undefined, initialAttachments);
|
||||
onInitialMessageConsumed?.();
|
||||
}
|
||||
}, [
|
||||
initialAttachments,
|
||||
initialMessage,
|
||||
handleSend,
|
||||
onInitialMessageConsumed,
|
||||
]);
|
||||
const isStreaming = chatState === "streaming";
|
||||
const isCompacting = chatState === "compacting";
|
||||
const showIndicator =
|
||||
chatState === "thinking" ||
|
||||
chatState === "streaming" ||
|
||||
chatState === "waiting" ||
|
||||
chatState === "compacting";
|
||||
const handleCreatePersona = useCallback(() => {
|
||||
useAgentStore.getState().openPersonaEditor();
|
||||
}, []);
|
||||
const draftValue = useChatStore(
|
||||
(s) => s.draftsBySession[activeSessionId] ?? "",
|
||||
);
|
||||
const scrollTarget = useChatStore(
|
||||
(s) => s.scrollTargetMessageBySession[activeSessionId] ?? null,
|
||||
);
|
||||
const handleDraftChange = useCallback(
|
||||
(text: string) => {
|
||||
useChatStore.getState().setDraft(activeSessionId, text);
|
||||
},
|
||||
[activeSessionId],
|
||||
);
|
||||
const handleScrollTargetHandled = useCallback(() => {
|
||||
useChatStore.getState().clearScrollTargetMessage(activeSessionId);
|
||||
}, [activeSessionId]);
|
||||
controller.chatState === "thinking" ||
|
||||
controller.chatState === "streaming" ||
|
||||
controller.chatState === "waiting" ||
|
||||
controller.chatState === "compacting";
|
||||
|
||||
return (
|
||||
<ArtifactPolicyProvider
|
||||
messages={messages}
|
||||
messages={controller.messages}
|
||||
allowedRoots={allowedArtifactRoots}
|
||||
>
|
||||
<div className="relative flex h-full min-w-0">
|
||||
<div className="flex min-w-0 flex-1 flex-col pr-1">
|
||||
{isLoadingHistory ? (
|
||||
{controller.isLoadingHistory ? (
|
||||
<ChatLoadingSkeleton />
|
||||
) : (
|
||||
<MessageTimeline
|
||||
messages={messages}
|
||||
streamingMessageId={streamingMessageId}
|
||||
scrollTargetMessageId={scrollTarget?.messageId ?? null}
|
||||
scrollTargetQuery={scrollTarget?.query ?? null}
|
||||
onScrollTargetHandled={handleScrollTargetHandled}
|
||||
messages={controller.messages}
|
||||
streamingMessageId={controller.streamingMessageId}
|
||||
scrollTargetMessageId={controller.scrollTarget?.messageId ?? null}
|
||||
scrollTargetQuery={controller.scrollTarget?.query ?? null}
|
||||
onScrollTargetHandled={controller.handleScrollTargetHandled}
|
||||
/>
|
||||
)}
|
||||
|
||||
<AnimatePresence initial={false}>
|
||||
{showIndicator && !isLoadingHistory ? (
|
||||
{showIndicator && !controller.isLoadingHistory ? (
|
||||
<LoadingGoose
|
||||
key="loading-indicator"
|
||||
chatState={
|
||||
chatState as
|
||||
controller.chatState as
|
||||
| "thinking"
|
||||
| "streaming"
|
||||
| "waiting"
|
||||
|
|
@ -512,54 +102,52 @@ export function ChatView({
|
|||
</AnimatePresence>
|
||||
|
||||
<ChatInput
|
||||
onSend={handleSend}
|
||||
disabled={projectMetadataPending}
|
||||
queuedMessage={queue.queuedMessage}
|
||||
onDismissQueue={queue.dismiss}
|
||||
initialValue={draftValue}
|
||||
onDraftChange={handleDraftChange}
|
||||
onStop={stopStreaming}
|
||||
isStreaming={isStreaming || chatState === "thinking"}
|
||||
personas={personas}
|
||||
selectedPersonaId={selectedPersonaId}
|
||||
onPersonaChange={handlePersonaChange}
|
||||
onCreatePersona={handleCreatePersona}
|
||||
providers={providers}
|
||||
providersLoading={providersLoading}
|
||||
selectedProvider={selectedProvider}
|
||||
onProviderChange={handleProviderChange}
|
||||
currentModelId={session?.modelId ?? null}
|
||||
currentModel={session?.modelName}
|
||||
availableModels={availableModels}
|
||||
onModelChange={handleModelChange}
|
||||
selectedProjectId={session?.projectId ?? null}
|
||||
availableProjects={availableProjects}
|
||||
onProjectChange={handleProjectChange}
|
||||
onSend={controller.handleSend}
|
||||
disabled={controller.projectMetadataPending}
|
||||
queuedMessage={controller.queue.queuedMessage}
|
||||
onDismissQueue={controller.queue.dismiss}
|
||||
initialValue={controller.draftValue}
|
||||
onDraftChange={controller.handleDraftChange}
|
||||
onStop={controller.stopStreaming}
|
||||
isStreaming={
|
||||
controller.chatState === "streaming" ||
|
||||
controller.chatState === "thinking"
|
||||
}
|
||||
personas={controller.personas}
|
||||
selectedPersonaId={controller.selectedPersonaId}
|
||||
onPersonaChange={controller.handlePersonaChange}
|
||||
onCreatePersona={controller.handleCreatePersona}
|
||||
providers={controller.pickerAgents}
|
||||
providersLoading={controller.providersLoading}
|
||||
selectedProvider={controller.selectedProvider}
|
||||
onProviderChange={controller.handleProviderChange}
|
||||
currentModelId={controller.currentModelId}
|
||||
currentModel={controller.currentModelName ?? undefined}
|
||||
availableModels={controller.availableModels}
|
||||
modelsLoading={controller.modelsLoading}
|
||||
modelStatusMessage={controller.modelStatusMessage}
|
||||
onModelChange={controller.handleModelChange}
|
||||
selectedProjectId={controller.selectedProjectId}
|
||||
availableProjects={controller.availableProjects}
|
||||
onProjectChange={controller.handleProjectChange}
|
||||
onCreateProject={(options) =>
|
||||
onCreateProject?.({
|
||||
onCreated: (projectId) => {
|
||||
handleProjectChange(projectId);
|
||||
controller.handleProjectChange(projectId);
|
||||
options?.onCreated?.(projectId);
|
||||
},
|
||||
})
|
||||
}
|
||||
contextTokens={tokenState.accumulatedTotal}
|
||||
contextLimit={tokenState.contextLimit}
|
||||
onCompactContext={compactConversation}
|
||||
canCompactContext={
|
||||
chatState === "idle" &&
|
||||
tokenState.accumulatedTotal > 0 &&
|
||||
!projectMetadataPending
|
||||
}
|
||||
isCompactingContext={isCompacting}
|
||||
contextTokens={controller.tokenState.accumulatedTotal}
|
||||
contextLimit={controller.tokenState.contextLimit}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<ChatContextPanel
|
||||
activeSessionId={activeSessionId}
|
||||
activeSessionId={sessionId}
|
||||
isOpen={isContextPanelOpen}
|
||||
label={contextPanelLabel}
|
||||
project={project}
|
||||
project={controller.project}
|
||||
setOpen={setContextPanelOpen}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ describe("AgentModelPicker", () => {
|
|||
screen.getByRole("button", { name: /choose agent and model/i }),
|
||||
);
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /OpenAI/i }));
|
||||
await user.click(screen.getByRole("button", { name: "GPT-4o" }));
|
||||
|
||||
expect(onModelChange).toHaveBeenCalledWith("gpt-4o");
|
||||
|
|
@ -149,4 +148,49 @@ describe("AgentModelPicker", () => {
|
|||
expect(trigger).toHaveTextContent("Goose");
|
||||
expect(trigger).not.toHaveTextContent("·");
|
||||
});
|
||||
|
||||
it("shows a loading state while models are refreshing", async () => {
|
||||
const user = userEvent.setup();
|
||||
|
||||
render(
|
||||
<AgentModelPicker
|
||||
agents={AGENTS}
|
||||
selectedAgentId="goose"
|
||||
onAgentChange={vi.fn()}
|
||||
currentModelId={null}
|
||||
currentModelName={null}
|
||||
availableModels={[]}
|
||||
modelsLoading
|
||||
onModelChange={vi.fn()}
|
||||
/>,
|
||||
);
|
||||
|
||||
await user.click(
|
||||
screen.getByRole("button", { name: /choose agent and model/i }),
|
||||
);
|
||||
|
||||
expect(screen.getByText("Loading models...")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("shows an empty-state message when no inventory models are available", async () => {
|
||||
const user = userEvent.setup();
|
||||
|
||||
render(
|
||||
<AgentModelPicker
|
||||
agents={AGENTS}
|
||||
selectedAgentId="goose"
|
||||
onAgentChange={vi.fn()}
|
||||
currentModelId={null}
|
||||
currentModelName={null}
|
||||
availableModels={[]}
|
||||
onModelChange={vi.fn()}
|
||||
/>,
|
||||
);
|
||||
|
||||
await user.click(
|
||||
screen.getByRole("button", { name: /choose agent and model/i }),
|
||||
);
|
||||
|
||||
expect(screen.getByText("No models available")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ describe("ChatInput", () => {
|
|||
).toHaveTextContent("GPT-4o");
|
||||
});
|
||||
|
||||
it("shows default model name in model picker", () => {
|
||||
it("shows provider label when no current model is selected", () => {
|
||||
render(
|
||||
<ChatInput
|
||||
onSend={vi.fn()}
|
||||
|
|
@ -159,7 +159,7 @@ describe("ChatInput", () => {
|
|||
);
|
||||
expect(
|
||||
screen.getByRole("button", { name: /choose agent and model/i }),
|
||||
).toHaveTextContent("Claude Sonnet 4");
|
||||
).toHaveTextContent("Goose");
|
||||
});
|
||||
|
||||
it("shows default provider label", () => {
|
||||
|
|
@ -176,6 +176,18 @@ describe("ChatInput", () => {
|
|||
expect(providerButton).toHaveTextContent("Goose");
|
||||
});
|
||||
|
||||
it("resets the textarea when initialValue changes", () => {
|
||||
const { rerender } = render(
|
||||
<ChatInput onSend={vi.fn()} initialValue="alpha draft" />,
|
||||
);
|
||||
|
||||
expect(screen.getByRole("textbox")).toHaveValue("alpha draft");
|
||||
|
||||
rerender(<ChatInput onSend={vi.fn()} initialValue="" />);
|
||||
|
||||
expect(screen.getByRole("textbox")).toHaveValue("");
|
||||
});
|
||||
|
||||
it("opens the agent and model picker", async () => {
|
||||
const user = userEvent.setup();
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,67 @@ import { HomeScreen } from "./HomeScreen";
|
|||
|
||||
const setSelectedProvider = vi.fn();
|
||||
const setSelectedProviderWithoutPersist = vi.fn();
|
||||
const mockController = {
|
||||
handleSend: vi.fn(),
|
||||
projectMetadataPending: false,
|
||||
queue: { queuedMessage: null, dismiss: vi.fn() },
|
||||
stopStreaming: vi.fn(),
|
||||
chatState: "idle" as const,
|
||||
personas: [
|
||||
{
|
||||
id: "builtin-solo",
|
||||
displayName: "Solo",
|
||||
systemPrompt: "You are Solo.",
|
||||
provider: "openai",
|
||||
description: null,
|
||||
avatar: null,
|
||||
createdBy: null,
|
||||
source: "custom",
|
||||
extensions: [],
|
||||
metadata: null,
|
||||
sortOrder: 0,
|
||||
isDefault: false,
|
||||
},
|
||||
{
|
||||
id: "builtin-goose",
|
||||
displayName: "Goosey",
|
||||
systemPrompt: "You are Goosey.",
|
||||
isBuiltin: true,
|
||||
description: null,
|
||||
avatar: null,
|
||||
createdBy: null,
|
||||
source: "custom",
|
||||
extensions: [],
|
||||
metadata: null,
|
||||
sortOrder: 1,
|
||||
isDefault: false,
|
||||
createdAt: "",
|
||||
updatedAt: "",
|
||||
},
|
||||
],
|
||||
draftValue: "",
|
||||
handleDraftChange: vi.fn(),
|
||||
selectedPersonaId: null,
|
||||
handlePersonaChange: vi.fn(),
|
||||
handleCreatePersona: vi.fn(),
|
||||
pickerAgents: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "claude-acp", label: "Claude Code" },
|
||||
],
|
||||
providersLoading: false,
|
||||
selectedProvider: "goose",
|
||||
handleProviderChange: setSelectedProvider,
|
||||
currentModelId: null,
|
||||
currentModelName: null,
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null,
|
||||
handleModelChange: vi.fn(),
|
||||
selectedProjectId: null,
|
||||
availableProjects: [],
|
||||
handleProjectChange: vi.fn(),
|
||||
tokenState: { accumulatedTotal: 0, contextLimit: 0 },
|
||||
};
|
||||
|
||||
vi.mock("@/shared/api/acp", () => ({
|
||||
discoverAcpProviders: vi.fn().mockResolvedValue([
|
||||
|
|
@ -21,6 +82,10 @@ vi.mock("@/features/providers/hooks/useAgentProviderStatus", () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
vi.mock("@/features/chat/hooks/useChatSessionController", () => ({
|
||||
useChatSessionController: () => mockController,
|
||||
}));
|
||||
|
||||
vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
|
||||
useProviderSelection: () => ({
|
||||
providers: [
|
||||
|
|
@ -34,7 +99,6 @@ vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
|
|||
}),
|
||||
}));
|
||||
|
||||
// HomeScreen now reads personas from the agent store, not from ACP providers
|
||||
vi.mock("@/features/agents/stores/agentStore", async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<
|
||||
|
|
@ -88,6 +152,9 @@ vi.mock("@/features/agents/stores/agentStore", async (importOriginal) => {
|
|||
});
|
||||
|
||||
describe("HomeScreen", () => {
|
||||
const renderHome = () =>
|
||||
render(<HomeScreen sessionId="home-session" onActivateSession={vi.fn()} />);
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date(2026, 2, 29, 14, 30, 0)); // 2:30 PM
|
||||
|
|
@ -98,32 +165,32 @@ describe("HomeScreen", () => {
|
|||
});
|
||||
|
||||
it("renders the clock", () => {
|
||||
render(<HomeScreen />);
|
||||
renderHome();
|
||||
expect(screen.getByText("2:30")).toBeInTheDocument();
|
||||
expect(screen.getByText("PM")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("shows afternoon greeting at 2:30 PM", () => {
|
||||
render(<HomeScreen />);
|
||||
renderHome();
|
||||
expect(screen.getByText("Good afternoon")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the chat input placeholder with default agent name when no persona selected", () => {
|
||||
render(<HomeScreen />);
|
||||
renderHome();
|
||||
expect(
|
||||
screen.getByPlaceholderText("Message Goose, @ to mention personas"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the assistant chooser affordance", () => {
|
||||
render(<HomeScreen />);
|
||||
renderHome();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /choose assistant/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders the provider and project controls on the home screen", () => {
|
||||
render(<HomeScreen />);
|
||||
renderHome();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /choose agent and model/i }),
|
||||
).toBeInTheDocument();
|
||||
|
|
@ -132,23 +199,17 @@ describe("HomeScreen", () => {
|
|||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("reverts to the stored provider when a persona override is cleared", async () => {
|
||||
it("forwards persona selection through the shared session controller", async () => {
|
||||
vi.useRealTimers();
|
||||
const user = userEvent.setup();
|
||||
|
||||
render(<HomeScreen />);
|
||||
renderHome();
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /choose assistant/i }));
|
||||
await user.click(screen.getByRole("menuitem", { name: /solo/i }));
|
||||
|
||||
expect(setSelectedProviderWithoutPersist).toHaveBeenLastCalledWith(
|
||||
"openai",
|
||||
expect(mockController.handlePersonaChange).toHaveBeenLastCalledWith(
|
||||
"builtin-solo",
|
||||
);
|
||||
|
||||
await user.click(
|
||||
screen.getByRole("button", { name: /clear active assistant/i }),
|
||||
);
|
||||
|
||||
expect(setSelectedProviderWithoutPersist).toHaveBeenLastCalledWith("goose");
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,17 +1,8 @@
|
|||
import { useState, useEffect, useCallback } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import {
|
||||
getStoredProvider,
|
||||
useAgentStore,
|
||||
} from "@/features/agents/stores/agentStore";
|
||||
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
|
||||
import { ChatInput } from "@/features/chat/ui/ChatInput";
|
||||
import { useChatStore } from "@/features/chat/stores/chatStore";
|
||||
import type { ChatAttachmentDraft } from "@/shared/types/messages";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import { useLocaleFormatting } from "@/shared/i18n";
|
||||
|
||||
const HOME_DRAFT_KEY = "home";
|
||||
import { useChatSessionController } from "@/features/chat/hooks/useChatSessionController";
|
||||
|
||||
function HomeClock() {
|
||||
const [time, setTime] = useState(new Date());
|
||||
|
|
@ -48,124 +39,94 @@ function getGreetingKey(hour: number): "morning" | "afternoon" | "evening" {
|
|||
}
|
||||
|
||||
interface HomeScreenProps {
|
||||
onStartChat?: (
|
||||
initialMessage?: string,
|
||||
providerId?: string,
|
||||
personaId?: string,
|
||||
projectId?: string | null,
|
||||
attachments?: ChatAttachmentDraft[],
|
||||
) => void;
|
||||
sessionId: string | null;
|
||||
onActivateSession: (sessionId: string) => void;
|
||||
onCreateProject?: (options?: {
|
||||
onCreated?: (projectId: string) => void;
|
||||
}) => void;
|
||||
}
|
||||
|
||||
export function HomeScreen({ onStartChat, onCreateProject }: HomeScreenProps) {
|
||||
function HomeComposer({
|
||||
sessionId,
|
||||
onActivateSession,
|
||||
onCreateProject,
|
||||
}: {
|
||||
sessionId: string | null;
|
||||
onActivateSession: (sessionId: string) => void;
|
||||
onCreateProject?: HomeScreenProps["onCreateProject"];
|
||||
}) {
|
||||
const controller = useChatSessionController({
|
||||
sessionId,
|
||||
onMessageAccepted: onActivateSession,
|
||||
});
|
||||
|
||||
return (
|
||||
<ChatInput
|
||||
onSend={controller.handleSend}
|
||||
disabled={controller.projectMetadataPending}
|
||||
queuedMessage={controller.queue.queuedMessage}
|
||||
onDismissQueue={controller.queue.dismiss}
|
||||
initialValue={controller.draftValue}
|
||||
onDraftChange={controller.handleDraftChange}
|
||||
onStop={controller.stopStreaming}
|
||||
isStreaming={
|
||||
controller.chatState === "streaming" ||
|
||||
controller.chatState === "thinking"
|
||||
}
|
||||
personas={controller.personas}
|
||||
selectedPersonaId={controller.selectedPersonaId}
|
||||
onPersonaChange={controller.handlePersonaChange}
|
||||
onCreatePersona={controller.handleCreatePersona}
|
||||
providers={controller.pickerAgents}
|
||||
providersLoading={controller.providersLoading}
|
||||
selectedProvider={controller.selectedProvider}
|
||||
onProviderChange={controller.handleProviderChange}
|
||||
currentModelId={controller.currentModelId}
|
||||
currentModel={controller.currentModelName ?? undefined}
|
||||
availableModels={controller.availableModels}
|
||||
modelsLoading={controller.modelsLoading}
|
||||
modelStatusMessage={controller.modelStatusMessage}
|
||||
onModelChange={controller.handleModelChange}
|
||||
selectedProjectId={controller.selectedProjectId}
|
||||
availableProjects={controller.availableProjects}
|
||||
onProjectChange={controller.handleProjectChange}
|
||||
onCreateProject={(options) =>
|
||||
onCreateProject?.({
|
||||
onCreated: (projectId) => {
|
||||
controller.handleProjectChange(projectId);
|
||||
options?.onCreated?.(projectId);
|
||||
},
|
||||
})
|
||||
}
|
||||
contextTokens={controller.tokenState.accumulatedTotal}
|
||||
contextLimit={controller.tokenState.contextLimit}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function HomeScreen({
|
||||
sessionId,
|
||||
onActivateSession,
|
||||
onCreateProject,
|
||||
}: HomeScreenProps) {
|
||||
const { t } = useTranslation("home");
|
||||
const [hour] = useState(() => new Date().getHours());
|
||||
const greeting = t(`greeting.${getGreetingKey(hour)}`);
|
||||
|
||||
const personas = useAgentStore((s) => s.personas);
|
||||
const {
|
||||
providers,
|
||||
providersLoading,
|
||||
selectedProvider,
|
||||
setSelectedProvider,
|
||||
setSelectedProviderWithoutPersist,
|
||||
} = useProviderSelection();
|
||||
const projects = useProjectStore((s) => s.projects);
|
||||
const [selectedPersonaId, setSelectedPersonaId] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
const [selectedProjectId, setSelectedProjectId] = useState<string | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
const handlePersonaChange = useCallback(
|
||||
(personaId: string | null) => {
|
||||
setSelectedPersonaId(personaId);
|
||||
const persona = personaId
|
||||
? personas.find((candidate) => candidate.id === personaId)
|
||||
: null;
|
||||
const nextProvider = persona?.provider ?? getStoredProvider(providers);
|
||||
|
||||
setSelectedProviderWithoutPersist(nextProvider);
|
||||
},
|
||||
[personas, providers, setSelectedProviderWithoutPersist],
|
||||
);
|
||||
|
||||
const handleCreatePersona = useCallback(() => {
|
||||
useAgentStore.getState().openPersonaEditor();
|
||||
}, []);
|
||||
|
||||
const homeDraft = useChatStore(
|
||||
(s) => s.draftsBySession[HOME_DRAFT_KEY] ?? "",
|
||||
);
|
||||
const handleDraftChange = useCallback((text: string) => {
|
||||
useChatStore.getState().setDraft(HOME_DRAFT_KEY, text);
|
||||
}, []);
|
||||
|
||||
const handleSend = useCallback(
|
||||
(
|
||||
message: string,
|
||||
personaId?: string,
|
||||
attachments?: ChatAttachmentDraft[],
|
||||
) => {
|
||||
const effectivePersonaId = personaId ?? selectedPersonaId ?? undefined;
|
||||
|
||||
useChatStore.getState().clearDraft(HOME_DRAFT_KEY);
|
||||
onStartChat?.(
|
||||
message,
|
||||
selectedProvider,
|
||||
effectivePersonaId,
|
||||
selectedProjectId,
|
||||
attachments,
|
||||
);
|
||||
},
|
||||
[onStartChat, selectedPersonaId, selectedProjectId, selectedProvider],
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="h-full w-full overflow-y-auto">
|
||||
<div className="relative flex min-h-full flex-col items-center justify-center px-6 pb-4">
|
||||
<div className="flex w-full max-w-[600px] flex-col antialiased">
|
||||
{/* Clock */}
|
||||
<HomeClock />
|
||||
|
||||
{/* Greeting */}
|
||||
<p className="mb-6 pl-4 text-xl font-normal font-display text-muted-foreground">
|
||||
{greeting}
|
||||
</p>
|
||||
|
||||
{/* Chat input */}
|
||||
<ChatInput
|
||||
onSend={handleSend}
|
||||
initialValue={homeDraft}
|
||||
onDraftChange={handleDraftChange}
|
||||
personas={personas}
|
||||
selectedPersonaId={selectedPersonaId}
|
||||
onPersonaChange={handlePersonaChange}
|
||||
onCreatePersona={handleCreatePersona}
|
||||
providers={providers}
|
||||
providersLoading={providersLoading}
|
||||
selectedProvider={selectedProvider}
|
||||
onProviderChange={setSelectedProvider}
|
||||
selectedProjectId={selectedProjectId}
|
||||
availableProjects={projects.map((project) => ({
|
||||
id: project.id,
|
||||
name: project.name,
|
||||
workingDirs: project.workingDirs,
|
||||
color: project.color,
|
||||
}))}
|
||||
onProjectChange={setSelectedProjectId}
|
||||
onCreateProject={(options) =>
|
||||
onCreateProject?.({
|
||||
onCreated: (projectId) => {
|
||||
setSelectedProjectId(projectId);
|
||||
options?.onCreated?.(projectId);
|
||||
},
|
||||
})
|
||||
}
|
||||
<HomeComposer
|
||||
sessionId={sessionId}
|
||||
onActivateSession={onActivateSession}
|
||||
onCreateProject={onCreateProject}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
32
ui/goose2/src/features/providers/api/inventory.ts
Normal file
32
ui/goose2/src/features/providers/api/inventory.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import type {
|
||||
ProviderInventoryEntryDto,
|
||||
RefreshProviderInventoryResponse,
|
||||
} from "@aaif/goose-sdk";
|
||||
import { getClient } from "@/shared/api/acpConnection";
|
||||
import { perfLog } from "@/shared/lib/perfLog";
|
||||
|
||||
export async function getProviderInventory(
|
||||
providerIds: string[] = [],
|
||||
): Promise<ProviderInventoryEntryDto[]> {
|
||||
const client = await getClient();
|
||||
const t0 = performance.now();
|
||||
const response = await client.goose.GooseProvidersInventory({ providerIds });
|
||||
perfLog(
|
||||
`[perf:inventory] getProviderInventory done in ${(performance.now() - t0).toFixed(1)}ms (n=${response.entries.length})`,
|
||||
);
|
||||
return response.entries;
|
||||
}
|
||||
|
||||
export async function refreshProviderInventory(
|
||||
providerIds: string[] = [],
|
||||
): Promise<RefreshProviderInventoryResponse> {
|
||||
const client = await getClient();
|
||||
const t0 = performance.now();
|
||||
const response = await client.goose.GooseProvidersInventoryRefresh({
|
||||
providerIds,
|
||||
});
|
||||
perfLog(
|
||||
`[perf:inventory] refreshProviderInventory done in ${(performance.now() - t0).toFixed(1)}ms started=[${response.started.join(",")}]`,
|
||||
);
|
||||
return response;
|
||||
}
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
import { useCallback, useMemo } from "react";
|
||||
import { useProviderInventoryStore } from "../stores/providerInventoryStore";
|
||||
import type { ModelOption } from "@/features/chat/types";
|
||||
import type {
|
||||
ProviderInventoryEntryDto,
|
||||
ProviderInventoryModelDto,
|
||||
} from "@aaif/goose-sdk";
|
||||
import { getModelProviders } from "../providerCatalog";
|
||||
|
||||
const MODEL_PROVIDER_IDS = new Set(getModelProviders().map((p) => p.id));
|
||||
|
||||
function inventoryModelToOption(
|
||||
model: ProviderInventoryModelDto,
|
||||
provider?: Pick<ProviderInventoryEntryDto, "providerId" | "providerName">,
|
||||
): ModelOption {
|
||||
return {
|
||||
id: model.id,
|
||||
name: model.name,
|
||||
displayName: model.name !== model.id ? model.name : undefined,
|
||||
provider: model.family ?? undefined,
|
||||
providerId: provider?.providerId,
|
||||
providerName: provider?.providerName,
|
||||
recommended: model.recommended ?? false,
|
||||
};
|
||||
}
|
||||
|
||||
export function useProviderInventory() {
|
||||
const entries = useProviderInventoryStore((s) => s.entries);
|
||||
const loading = useProviderInventoryStore((s) => s.loading);
|
||||
|
||||
const getEntry = useCallback(
|
||||
(providerId: string) => entries.get(providerId),
|
||||
[entries],
|
||||
);
|
||||
|
||||
const getModelsForProvider = useCallback(
|
||||
(providerId: string): ModelOption[] => {
|
||||
const entry = entries.get(providerId);
|
||||
if (!entry) return [];
|
||||
return entry.models.map((model) => inventoryModelToOption(model, entry));
|
||||
},
|
||||
[entries],
|
||||
);
|
||||
|
||||
const configuredModelProviderEntries = useMemo(
|
||||
() =>
|
||||
[...entries.values()].filter(
|
||||
(entry) => entry.configured && MODEL_PROVIDER_IDS.has(entry.providerId),
|
||||
),
|
||||
[entries],
|
||||
);
|
||||
|
||||
const getModelsForAgent = useCallback(
|
||||
(agentId: string): ModelOption[] => {
|
||||
if (agentId !== "goose") {
|
||||
return getModelsForProvider(agentId);
|
||||
}
|
||||
|
||||
return configuredModelProviderEntries.flatMap((entry) =>
|
||||
entry.models.map((model) => inventoryModelToOption(model, entry)),
|
||||
);
|
||||
},
|
||||
[configuredModelProviderEntries, getModelsForProvider],
|
||||
);
|
||||
|
||||
const configuredProviderIds = useMemo(
|
||||
() =>
|
||||
[...entries.values()]
|
||||
.filter((e) => e.configured)
|
||||
.map((e) => e.providerId),
|
||||
[entries],
|
||||
);
|
||||
|
||||
return {
|
||||
entries,
|
||||
loading,
|
||||
getEntry,
|
||||
configuredModelProviderEntries,
|
||||
getModelsForAgent,
|
||||
getModelsForProvider,
|
||||
configuredProviderIds,
|
||||
};
|
||||
}
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
import { create } from "zustand";
|
||||
import type { ProviderInventoryEntryDto } from "@aaif/goose-sdk";
|
||||
import { perfLog } from "@/shared/lib/perfLog";
|
||||
|
||||
export interface ProviderInventoryState {
|
||||
entries: Map<string, ProviderInventoryEntryDto>;
|
||||
loading: boolean;
|
||||
}
|
||||
|
||||
interface ProviderInventoryActions {
|
||||
setEntries: (entries: ProviderInventoryEntryDto[]) => void;
|
||||
mergeEntries: (entries: ProviderInventoryEntryDto[]) => void;
|
||||
setLoading: (loading: boolean) => void;
|
||||
}
|
||||
|
||||
export type ProviderInventoryStore = ProviderInventoryState &
|
||||
ProviderInventoryActions;
|
||||
|
||||
export const useProviderInventoryStore = create<ProviderInventoryStore>(
|
||||
(set) => ({
|
||||
entries: new Map(),
|
||||
loading: false,
|
||||
|
||||
setEntries: (entries) => {
|
||||
const map = new Map<string, ProviderInventoryEntryDto>();
|
||||
for (const entry of entries) {
|
||||
map.set(entry.providerId, entry);
|
||||
}
|
||||
set({ entries: map });
|
||||
perfLog(
|
||||
`[perf:inventory] setEntries n=${entries.length} providers=[${entries.map((e) => e.providerId).join(",")}]`,
|
||||
);
|
||||
},
|
||||
|
||||
mergeEntries: (entries) => {
|
||||
set((state) => {
|
||||
const map = new Map(state.entries);
|
||||
for (const entry of entries) {
|
||||
map.set(entry.providerId, entry);
|
||||
}
|
||||
return { entries: map };
|
||||
});
|
||||
},
|
||||
|
||||
setLoading: (loading) => set({ loading }),
|
||||
}),
|
||||
);
|
||||
|
|
@ -8,7 +8,11 @@ import { Button } from "@/shared/ui/button";
|
|||
import { SessionCard } from "./SessionCard";
|
||||
import { groupSessionsByDate } from "../lib/groupSessionsByDate";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
|
||||
import {
|
||||
getVisibleSessions,
|
||||
useChatSessionStore,
|
||||
} from "@/features/chat/stores/chatSessionStore";
|
||||
import { useChatStore } from "@/features/chat/stores/chatStore";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import {
|
||||
acpDuplicateSession,
|
||||
|
|
@ -38,10 +42,14 @@ export function SessionHistoryView({
|
|||
}: SessionHistoryViewProps) {
|
||||
const { t, i18n } = useTranslation(["sessions", "common"]);
|
||||
const sessions = useChatSessionStore((s) => s.sessions);
|
||||
const messagesBySession = useChatStore((s) => s.messagesBySession);
|
||||
const loadSessions = useChatSessionStore((s) => s.loadSessions);
|
||||
const activeSessions = useMemo(
|
||||
() => sessions.filter((session) => !session.draft && !session.archivedAt),
|
||||
[sessions],
|
||||
() =>
|
||||
getVisibleSessions(sessions, messagesBySession).filter(
|
||||
(session) => !session.archivedAt,
|
||||
),
|
||||
[messagesBySession, sessions],
|
||||
);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,10 @@ import { cn } from "@/shared/lib/cn";
|
|||
import type { AppView } from "@/app/AppShell";
|
||||
import type { ProjectInfo } from "@/features/projects/api/projects";
|
||||
import { useChatStore } from "@/features/chat/stores/chatStore";
|
||||
import { useChatSessionStore } from "@/features/chat/stores/chatSessionStore";
|
||||
import {
|
||||
getVisibleSessions,
|
||||
useChatSessionStore,
|
||||
} from "@/features/chat/stores/chatSessionStore";
|
||||
import { isSessionRunning } from "@/features/chat/lib/sessionActivity";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
|
|
@ -92,8 +95,12 @@ export function Sidebar({
|
|||
|
||||
const chatStore = useChatStore();
|
||||
const { sessions } = useChatSessionStore();
|
||||
const activeSessions = sessions.filter(
|
||||
(session) => !session.draft && !session.archivedAt,
|
||||
const visibleSessions = getVisibleSessions(
|
||||
sessions,
|
||||
chatStore.messagesBySession,
|
||||
);
|
||||
const activeSessions = visibleSessions.filter(
|
||||
(session) => !session.archivedAt,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -137,8 +144,8 @@ export function Sidebar({
|
|||
};
|
||||
const byProject: Record<string, SessionItem[]> = {};
|
||||
const standalone: SessionItem[] = [];
|
||||
for (const session of sessions) {
|
||||
if (session.draft || session.archivedAt) continue;
|
||||
for (const session of visibleSessions) {
|
||||
if (session.archivedAt) continue;
|
||||
const runtime = chatStore.getSessionRuntime(session.id);
|
||||
const item: SessionItem = {
|
||||
id: session.id,
|
||||
|
|
@ -191,7 +198,7 @@ export function Sidebar({
|
|||
|
||||
useEffect(() => {
|
||||
if (!activeSessionId) return;
|
||||
const activeSession = sessions.find((s) => s.id === activeSessionId);
|
||||
const activeSession = visibleSessions.find((s) => s.id === activeSessionId);
|
||||
const projectId = activeSession?.projectId;
|
||||
if (projectId) {
|
||||
setExpandedProjects((prev) => {
|
||||
|
|
@ -199,7 +206,7 @@ export function Sidebar({
|
|||
return { ...prev, [projectId]: true };
|
||||
});
|
||||
}
|
||||
}, [activeSessionId, sessions]);
|
||||
}, [activeSessionId, visibleSessions]);
|
||||
|
||||
useEffect(() => {
|
||||
try {
|
||||
|
|
@ -254,17 +261,13 @@ export function Sidebar({
|
|||
updateActiveRect,
|
||||
} = useSidebarHighlight(navRef);
|
||||
|
||||
const activeDraft = activeSessionId
|
||||
? sessions.find((s) => s.id === activeSessionId && s.draft)
|
||||
: undefined;
|
||||
const activeProjectId = activeDraft?.projectId ?? null;
|
||||
const activeProjectId =
|
||||
activeSessionId && activeView === "chat"
|
||||
? (sessions.find((s) => s.id === activeSessionId)?.projectId ?? null)
|
||||
: null;
|
||||
|
||||
useEffect(() => {
|
||||
if (activeDraft) {
|
||||
if (!activeProjectId) updateActiveRect(null);
|
||||
return;
|
||||
}
|
||||
if (activeSessionId) return;
|
||||
if (activeSessionId && activeView === "chat") return;
|
||||
if (activeView === "home") {
|
||||
updateActiveRect(homeRef.current);
|
||||
} else if (activeView && navItemRefs.current[activeView]) {
|
||||
|
|
@ -272,13 +275,7 @@ export function Sidebar({
|
|||
} else {
|
||||
updateActiveRect(null);
|
||||
}
|
||||
}, [
|
||||
activeSessionId,
|
||||
activeDraft,
|
||||
activeProjectId,
|
||||
activeView,
|
||||
updateActiveRect,
|
||||
]);
|
||||
}, [activeSessionId, activeView, updateActiveRect]);
|
||||
|
||||
const activeSessionRefCallback = useCallback(
|
||||
(el: HTMLElement | null) => {
|
||||
|
|
|
|||
|
|
@ -9,12 +9,12 @@ const mockSessions: Array<{
|
|||
updatedAt: string;
|
||||
messageCount: number;
|
||||
projectId?: string;
|
||||
draft?: boolean;
|
||||
archivedAt?: string;
|
||||
}> = [];
|
||||
|
||||
vi.mock("@/features/chat/stores/chatStore", () => ({
|
||||
useChatStore: () => ({
|
||||
messagesBySession: {},
|
||||
getSessionRuntime: () => ({
|
||||
chatState: "idle",
|
||||
hasUnread: false,
|
||||
|
|
@ -23,6 +23,8 @@ vi.mock("@/features/chat/stores/chatStore", () => ({
|
|||
}));
|
||||
|
||||
vi.mock("@/features/chat/stores/chatSessionStore", () => ({
|
||||
getVisibleSessions: (sessions: typeof mockSessions) =>
|
||||
sessions.filter((session) => session.messageCount > 0),
|
||||
useChatSessionStore: () => ({
|
||||
sessions: mockSessions,
|
||||
}),
|
||||
|
|
@ -65,6 +67,40 @@ describe("Sidebar", () => {
|
|||
mockSessions.splice(0, mockSessions.length);
|
||||
});
|
||||
|
||||
it("hides zero-message sessions from recents", () => {
|
||||
mockSessions.splice(
|
||||
0,
|
||||
mockSessions.length,
|
||||
{
|
||||
id: "home-session",
|
||||
title: "New Chat",
|
||||
updatedAt: "2026-04-09T12:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
{
|
||||
id: "session-1",
|
||||
title: "Recovered Session",
|
||||
updatedAt: "2026-04-09T12:01:00.000Z",
|
||||
messageCount: 3,
|
||||
},
|
||||
);
|
||||
|
||||
render(
|
||||
<Sidebar
|
||||
collapsed={false}
|
||||
onCollapse={vi.fn()}
|
||||
onNavigate={vi.fn()}
|
||||
onSelectSession={vi.fn()}
|
||||
projects={[]}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByText("New Chat")).not.toBeInTheDocument();
|
||||
expect(screen.getByText("Recovered Session")).toBeInTheDocument();
|
||||
|
||||
mockSessions.splice(0, mockSessions.length);
|
||||
});
|
||||
|
||||
it("renders a home button in the sidebar header and navigates home", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onNavigate = vi.fn();
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ vi.mock("../acpConnection", () => ({
|
|||
}));
|
||||
|
||||
describe("dictation SDK wiring", () => {
|
||||
let client: any;
|
||||
let client: { goose: Record<string, ReturnType<typeof vi.fn>> };
|
||||
beforeEach(() => {
|
||||
client = {
|
||||
goose: {
|
||||
|
|
@ -33,7 +33,9 @@ describe("dictation SDK wiring", () => {
|
|||
GooseDictationTranscribe: vi.fn().mockResolvedValue({ text: "hello" }),
|
||||
},
|
||||
};
|
||||
vi.mocked(getClient).mockResolvedValue(client);
|
||||
vi.mocked(getClient).mockResolvedValue(
|
||||
client as unknown as Awaited<ReturnType<typeof getClient>>,
|
||||
);
|
||||
});
|
||||
|
||||
it("getDictationConfig calls GooseDictationConfig and returns providers map", async () => {
|
||||
|
|
@ -46,7 +48,7 @@ describe("dictation SDK wiring", () => {
|
|||
const result = await transcribeDictation({
|
||||
audio: "base64==",
|
||||
mimeType: "audio/webm",
|
||||
provider: "openai" as any,
|
||||
provider: "openai",
|
||||
});
|
||||
expect(client.goose.GooseDictationTranscribe).toHaveBeenCalledWith({
|
||||
audio: "base64==",
|
||||
|
|
@ -58,7 +60,7 @@ describe("dictation SDK wiring", () => {
|
|||
|
||||
it("saveDictationModelSelection calls GooseDictationModelSelect", async () => {
|
||||
client.goose.GooseDictationModelSelect = vi.fn().mockResolvedValue({});
|
||||
await saveDictationModelSelection("local" as any, "tiny");
|
||||
await saveDictationModelSelection("local", "tiny");
|
||||
expect(client.goose.GooseDictationModelSelect).toHaveBeenCalledWith({
|
||||
provider: "local",
|
||||
modelId: "tiny",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
import type { ContentBlock } from "@agentclientprotocol/sdk";
|
||||
import * as directAcp from "./acpApi";
|
||||
import * as sessionTracker from "./acpSessionTracker";
|
||||
import {
|
||||
getCatalogEntry,
|
||||
resolveAgentProviderCatalogId,
|
||||
} from "@/features/providers/providerCatalog";
|
||||
import {
|
||||
setActiveMessageId,
|
||||
clearActiveMessageId,
|
||||
|
|
@ -25,9 +29,31 @@ export interface AcpPrepareSessionOptions {
|
|||
personaId?: string;
|
||||
}
|
||||
|
||||
export interface AcpCreateSessionOptions extends AcpPrepareSessionOptions {
|
||||
modelId?: string | null;
|
||||
}
|
||||
|
||||
/** Discover ACP providers installed on the system. */
|
||||
export async function discoverAcpProviders(): Promise<AcpProvider[]> {
|
||||
return directAcp.listProviders();
|
||||
const providers = await directAcp.listProviders();
|
||||
const seen = new Set<string>();
|
||||
|
||||
return providers
|
||||
.map((provider) => {
|
||||
const catalogId = resolveAgentProviderCatalogId(
|
||||
provider.id,
|
||||
provider.label,
|
||||
);
|
||||
if (!catalogId || seen.has(catalogId)) {
|
||||
return null;
|
||||
}
|
||||
seen.add(catalogId);
|
||||
return {
|
||||
id: catalogId,
|
||||
label: getCatalogEntry(catalogId)?.displayName ?? provider.label,
|
||||
};
|
||||
})
|
||||
.filter((provider): provider is AcpProvider => provider !== null);
|
||||
}
|
||||
|
||||
/** Send a message to an ACP agent. Response streams via Tauri events. */
|
||||
|
|
@ -79,13 +105,13 @@ export async function acpPrepareSession(
|
|||
providerId: string,
|
||||
workingDir: string,
|
||||
options: AcpPrepareSessionOptions = {},
|
||||
): Promise<void> {
|
||||
): Promise<string> {
|
||||
const sid = sessionId.slice(0, 8);
|
||||
const t0 = performance.now();
|
||||
perfLog(
|
||||
`[perf:prepare] ${sid} acpPrepareSession start (provider=${providerId})`,
|
||||
);
|
||||
await sessionTracker.prepareSession(
|
||||
const gooseSessionId = await sessionTracker.prepareSession(
|
||||
sessionId,
|
||||
providerId,
|
||||
workingDir,
|
||||
|
|
@ -94,6 +120,31 @@ export async function acpPrepareSession(
|
|||
perfLog(
|
||||
`[perf:prepare] ${sid} acpPrepareSession done in ${(performance.now() - t0).toFixed(1)}ms`,
|
||||
);
|
||||
return gooseSessionId;
|
||||
}
|
||||
|
||||
export async function acpCreateSession(
|
||||
providerId: string,
|
||||
workingDir: string,
|
||||
options: AcpCreateSessionOptions = {},
|
||||
): Promise<{ sessionId: string }> {
|
||||
const localSessionId = crypto.randomUUID();
|
||||
const gooseSessionId = await acpPrepareSession(
|
||||
localSessionId,
|
||||
providerId,
|
||||
workingDir,
|
||||
options,
|
||||
);
|
||||
sessionTracker.registerSession(
|
||||
gooseSessionId,
|
||||
gooseSessionId,
|
||||
providerId,
|
||||
workingDir,
|
||||
);
|
||||
if (options.modelId) {
|
||||
await directAcp.setModel(gooseSessionId, options.modelId);
|
||||
}
|
||||
return { sessionId: gooseSessionId };
|
||||
}
|
||||
|
||||
export async function acpSetModel(
|
||||
|
|
|
|||
|
|
@ -455,7 +455,6 @@ function handleShared(sessionId: string, update: SessionUpdate): void {
|
|||
currentModelId;
|
||||
|
||||
const sessionStore = useChatSessionStore.getState();
|
||||
sessionStore.setSessionModels(sessionId, availableModels);
|
||||
sessionStore.updateSession(
|
||||
sessionId,
|
||||
{ modelId: currentModelId, modelName: currentModelName },
|
||||
|
|
|
|||
|
|
@ -115,8 +115,10 @@ export async function prepareSession(
|
|||
`[perf:prepare] ${sid} tracker setProvider(${providerId}) in ${(performance.now() - tProv).toFixed(1)}ms (goose_sid=${gooseSid})`,
|
||||
);
|
||||
|
||||
prepared.set(key, { gooseSessionId, providerId, workingDir });
|
||||
prepared.set(sessionId, { gooseSessionId, providerId, workingDir });
|
||||
const entry = { gooseSessionId, providerId, workingDir };
|
||||
prepared.set(key, entry);
|
||||
prepared.set(sessionId, entry);
|
||||
prepared.set(gooseSessionId, entry);
|
||||
gooseToLocal.set(gooseSessionId, sessionId);
|
||||
notifySessionRegistered(sessionId, gooseSessionId);
|
||||
|
||||
|
|
@ -161,6 +163,7 @@ export function registerSession(
|
|||
}
|
||||
|
||||
prepared.set(sessionId, entry);
|
||||
prepared.set(gooseSessionId, entry);
|
||||
gooseToLocal.set(gooseSessionId, sessionId);
|
||||
notifySessionRegistered(sessionId, gooseSessionId);
|
||||
|
||||
|
|
|
|||
|
|
@ -163,7 +163,14 @@
|
|||
"createProject": "Create project",
|
||||
"generalChatWithoutProject": "General chat without project context",
|
||||
"loading": "Loading...",
|
||||
"loadingModels": "Loading models...",
|
||||
"model": "Model",
|
||||
"allModels": "All models",
|
||||
"searchModels": "Search models...",
|
||||
"recommended": "Recommended",
|
||||
"noModelsAvailable": "No models available",
|
||||
"noSearchResults": "No matching models",
|
||||
"showAllModels": "Browse all models",
|
||||
"noProject": "No project",
|
||||
"selectModel": "Select model",
|
||||
"selectProject": "Select project",
|
||||
|
|
|
|||
|
|
@ -163,7 +163,14 @@
|
|||
"createProject": "Crear proyecto",
|
||||
"generalChatWithoutProject": "Chat general sin contexto de proyecto",
|
||||
"loading": "Cargando...",
|
||||
"loadingModels": "Cargando modelos...",
|
||||
"model": "Modelo",
|
||||
"allModels": "Todos los modelos",
|
||||
"searchModels": "Buscar modelos...",
|
||||
"recommended": "Recomendados",
|
||||
"noModelsAvailable": "No hay modelos disponibles",
|
||||
"noSearchResults": "Sin resultados",
|
||||
"showAllModels": "Ver todos los modelos",
|
||||
"noProject": "Sin proyecto",
|
||||
"selectModel": "Seleccionar modelo",
|
||||
"selectProject": "Seleccionar proyecto",
|
||||
|
|
|
|||
|
|
@ -30,6 +30,137 @@ export function buildInitScript(options?: {
|
|||
const PERSONAS = ${personas};
|
||||
const SKILLS = ${skills};
|
||||
const PROJECTS = ${projects};
|
||||
const ACP_SESSIONS = [];
|
||||
|
||||
function nowIso() {
|
||||
return new Date().toISOString();
|
||||
}
|
||||
|
||||
function buildSession(sessionId, providerId = "goose") {
|
||||
return {
|
||||
sessionId,
|
||||
title: "New Chat",
|
||||
updatedAt: nowIso(),
|
||||
messageCount: 0,
|
||||
providerId,
|
||||
modelId: null,
|
||||
};
|
||||
}
|
||||
|
||||
function findSession(sessionId) {
|
||||
return ACP_SESSIONS.find((session) => session.sessionId === sessionId) ?? null;
|
||||
}
|
||||
|
||||
function jsonRpcResult(id, result) {
|
||||
return { jsonrpc: "2.0", id, result };
|
||||
}
|
||||
|
||||
function handleAcpRequest(message) {
|
||||
switch (message.method) {
|
||||
case "initialize":
|
||||
return jsonRpcResult(message.id, {
|
||||
protocolVersion: "0.1.0",
|
||||
agentCapabilities: {
|
||||
loadSession: {},
|
||||
listSessions: {},
|
||||
},
|
||||
agentInfo: {
|
||||
name: "mock-goose",
|
||||
version: "0.0.0",
|
||||
},
|
||||
authMethods: [],
|
||||
});
|
||||
case "session/list":
|
||||
return jsonRpcResult(message.id, {
|
||||
sessions: ACP_SESSIONS.map((session) => ({
|
||||
sessionId: session.sessionId,
|
||||
title: session.title,
|
||||
updatedAt: session.updatedAt,
|
||||
_meta: {
|
||||
messageCount: session.messageCount,
|
||||
},
|
||||
})),
|
||||
});
|
||||
case "session/new": {
|
||||
const providerId = message.params?.meta?.provider ?? "goose";
|
||||
const sessionId = "session-" + Math.random().toString(36).slice(2, 10);
|
||||
ACP_SESSIONS.unshift(buildSession(sessionId, providerId));
|
||||
return jsonRpcResult(message.id, { sessionId });
|
||||
}
|
||||
case "session/load":
|
||||
return jsonRpcResult(message.id, {});
|
||||
case "session/set_config_option": {
|
||||
const session = findSession(message.params?.sessionId);
|
||||
if (session) {
|
||||
if (message.params?.configId === "provider") {
|
||||
session.providerId = message.params?.value ?? session.providerId;
|
||||
session.modelId = null;
|
||||
}
|
||||
if (message.params?.configId === "model") {
|
||||
session.modelId = message.params?.value ?? null;
|
||||
}
|
||||
session.updatedAt = nowIso();
|
||||
}
|
||||
return jsonRpcResult(message.id, {});
|
||||
}
|
||||
case "session/prompt": {
|
||||
const session = findSession(message.params?.sessionId);
|
||||
if (session) {
|
||||
session.messageCount += 1;
|
||||
session.updatedAt = nowIso();
|
||||
}
|
||||
return jsonRpcResult(message.id, { stopReason: "end_turn" });
|
||||
}
|
||||
case "_goose/providers/list":
|
||||
return jsonRpcResult(message.id, { providers: [] });
|
||||
case "_goose/providers/inventory":
|
||||
return jsonRpcResult(message.id, { entries: [] });
|
||||
case "_goose/providers/inventory/refresh":
|
||||
return jsonRpcResult(message.id, { started: [], skipped: [] });
|
||||
case "_goose/working_dir/update":
|
||||
case "goose/working_dir/update":
|
||||
return jsonRpcResult(message.id, {});
|
||||
default:
|
||||
return jsonRpcResult(message.id, {});
|
||||
}
|
||||
}
|
||||
|
||||
class MockWebSocket extends EventTarget {
|
||||
constructor(url) {
|
||||
super();
|
||||
this.url = url;
|
||||
this.readyState = 0;
|
||||
queueMicrotask(() => {
|
||||
this.readyState = 1;
|
||||
this.dispatchEvent(new Event("open"));
|
||||
});
|
||||
}
|
||||
|
||||
send(raw) {
|
||||
const message = JSON.parse(raw);
|
||||
const response =
|
||||
message && typeof message === "object" && "id" in message
|
||||
? handleAcpRequest(message)
|
||||
: null;
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
queueMicrotask(() => {
|
||||
this.dispatchEvent(
|
||||
new MessageEvent("message", {
|
||||
data: JSON.stringify(response),
|
||||
}),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
close() {
|
||||
this.readyState = 3;
|
||||
this.dispatchEvent(new CloseEvent("close"));
|
||||
}
|
||||
}
|
||||
|
||||
window.WebSocket = MockWebSocket;
|
||||
|
||||
window.__TAURI_INTERNALS__ = {
|
||||
invoke(cmd, args) {
|
||||
|
|
@ -95,7 +226,16 @@ export function buildInitScript(options?: {
|
|||
|
||||
// ---- Sessions / Misc ----
|
||||
case "list_sessions":
|
||||
return Promise.resolve([]);
|
||||
return Promise.resolve(
|
||||
ACP_SESSIONS.map((session) => ({
|
||||
sessionId: session.sessionId,
|
||||
title: session.title,
|
||||
updatedAt: session.updatedAt,
|
||||
messageCount: session.messageCount,
|
||||
})),
|
||||
);
|
||||
case "get_goose_serve_url":
|
||||
return Promise.resolve("ws://mock-goose");
|
||||
case "create_session":
|
||||
return Promise.resolve({
|
||||
id: "session-" + Math.random().toString(36).slice(2, 10),
|
||||
|
|
@ -130,6 +270,16 @@ export function buildInitScript(options?: {
|
|||
return Promise.resolve("/tmp/home");
|
||||
case "path_exists":
|
||||
return Promise.resolve(false);
|
||||
case "resolve_path": {
|
||||
const parts = args?.request?.parts ?? [];
|
||||
const path = parts
|
||||
.filter((part) => typeof part === "string" && part.length > 0)
|
||||
.join("/");
|
||||
const normalizedPath = path.startsWith("~/")
|
||||
? "/tmp/home/" + path.slice(2)
|
||||
: path;
|
||||
return Promise.resolve({ path: normalizedPath });
|
||||
}
|
||||
|
||||
// ---- Fallback ----
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -31,8 +31,8 @@ import type {
|
|||
GetExtensionsResponse,
|
||||
GetProviderDetailsRequest,
|
||||
GetProviderDetailsResponse,
|
||||
GetProviderModelsRequest,
|
||||
GetProviderModelsResponse,
|
||||
GetProviderInventoryRequest,
|
||||
GetProviderInventoryResponse,
|
||||
GetSessionExtensionsRequest,
|
||||
GetSessionExtensionsResponse,
|
||||
GetToolsRequest,
|
||||
|
|
@ -45,6 +45,8 @@ import type {
|
|||
ReadConfigResponse,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResponse,
|
||||
RefreshProviderInventoryRequest,
|
||||
RefreshProviderInventoryResponse,
|
||||
RemoveConfigRequest,
|
||||
RemoveExtensionRequest,
|
||||
RemoveSecretRequest,
|
||||
|
|
@ -62,13 +64,14 @@ import {
|
|||
zExportSessionResponse,
|
||||
zGetExtensionsResponse,
|
||||
zGetProviderDetailsResponse,
|
||||
zGetProviderModelsResponse,
|
||||
zGetProviderInventoryResponse,
|
||||
zGetSessionExtensionsResponse,
|
||||
zGetToolsResponse,
|
||||
zImportSessionResponse,
|
||||
zListProvidersResponse,
|
||||
zReadConfigResponse,
|
||||
zReadResourceResponse,
|
||||
zRefreshProviderInventoryResponse,
|
||||
} from './zod.gen.js';
|
||||
|
||||
export class GooseExtClient {
|
||||
|
|
@ -132,11 +135,25 @@ export class GooseExtClient {
|
|||
return zGetProviderDetailsResponse.parse(raw) as GetProviderDetailsResponse;
|
||||
}
|
||||
|
||||
async GooseProvidersModels(
|
||||
params: GetProviderModelsRequest,
|
||||
): Promise<GetProviderModelsResponse> {
|
||||
const raw = await this.conn.extMethod("_goose/providers/models", params);
|
||||
return zGetProviderModelsResponse.parse(raw) as GetProviderModelsResponse;
|
||||
async GooseProvidersInventory(
|
||||
params: GetProviderInventoryRequest,
|
||||
): Promise<GetProviderInventoryResponse> {
|
||||
const raw = await this.conn.extMethod("_goose/providers/inventory", params);
|
||||
return zGetProviderInventoryResponse.parse(
|
||||
raw,
|
||||
) as GetProviderInventoryResponse;
|
||||
}
|
||||
|
||||
async GooseProvidersInventoryRefresh(
|
||||
params: RefreshProviderInventoryRequest,
|
||||
): Promise<RefreshProviderInventoryResponse> {
|
||||
const raw = await this.conn.extMethod(
|
||||
"_goose/providers/inventory/refresh",
|
||||
params,
|
||||
);
|
||||
return zRefreshProviderInventoryResponse.parse(
|
||||
raw,
|
||||
) as RefreshProviderInventoryResponse;
|
||||
}
|
||||
|
||||
async GooseConfigRead(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// This file is auto-generated by @hey-api/openapi-ts
|
||||
|
||||
export type { AddExtensionRequest, ArchiveSessionRequest, CheckSecretRequest, CheckSecretResponse, DeleteSessionRequest, DictationConfigRequest, DictationConfigResponse, DictationDownloadProgress, DictationLocalModelStatus, DictationModelCancelRequest, DictationModelDeleteRequest, DictationModelDownloadProgressRequest, DictationModelDownloadProgressResponse, DictationModelDownloadRequest, DictationModelOption, DictationModelSelectRequest, DictationModelsListRequest, DictationModelsListResponse, DictationProviderStatusEntry, DictationTranscribeRequest, DictationTranscribeResponse, EmptyResponse, ExportSessionRequest, ExportSessionResponse, ExtRequest, ExtResponse, GetExtensionsRequest, GetExtensionsResponse, GetProviderDetailsRequest, GetProviderDetailsResponse, GetProviderModelsRequest, GetProviderModelsResponse, GetSessionExtensionsRequest, GetSessionExtensionsResponse, GetToolsRequest, GetToolsResponse, ImportSessionRequest, ImportSessionResponse, ListProvidersRequest, ListProvidersResponse, ModelEntry, ProviderConfigKey, ProviderDetailEntry, ProviderListEntry, ReadConfigRequest, ReadConfigResponse, ReadResourceRequest, ReadResourceResponse, RemoveConfigRequest, RemoveExtensionRequest, RemoveSecretRequest, UnarchiveSessionRequest, UpdateWorkingDirRequest, UpsertConfigRequest, UpsertSecretRequest } from './types.gen.js';
|
||||
export type { AddExtensionRequest, ArchiveSessionRequest, CheckSecretRequest, CheckSecretResponse, DeleteSessionRequest, DictationConfigRequest, DictationConfigResponse, DictationDownloadProgress, DictationLocalModelStatus, DictationModelCancelRequest, DictationModelDeleteRequest, DictationModelDownloadProgressRequest, DictationModelDownloadProgressResponse, DictationModelDownloadRequest, DictationModelOption, DictationModelSelectRequest, DictationModelsListRequest, DictationModelsListResponse, DictationProviderStatusEntry, DictationTranscribeRequest, DictationTranscribeResponse, EmptyResponse, ExportSessionRequest, ExportSessionResponse, ExtRequest, ExtResponse, GetExtensionsRequest, GetExtensionsResponse, GetProviderDetailsRequest, GetProviderDetailsResponse, GetProviderInventoryRequest, GetProviderInventoryResponse, GetSessionExtensionsRequest, GetSessionExtensionsResponse, GetToolsRequest, GetToolsResponse, ImportSessionRequest, ImportSessionResponse, ListProvidersRequest, ListProvidersResponse, ModelEntry, ProviderConfigKey, ProviderDetailEntry, ProviderInventoryEntryDto, ProviderInventoryModelDto, ProviderListEntry, ReadConfigRequest, ReadConfigResponse, ReadResourceRequest, ReadResourceResponse, RefreshProviderInventoryRequest, RefreshProviderInventoryResponse, RefreshProviderInventorySkipDto, RefreshProviderInventorySkipReasonDto, RemoveConfigRequest, RemoveExtensionRequest, RemoveSecretRequest, UnarchiveSessionRequest, UpdateWorkingDirRequest, UpsertConfigRequest, UpsertSecretRequest } from './types.gen.js';
|
||||
|
||||
export const GOOSE_EXT_METHODS = [
|
||||
{
|
||||
|
|
@ -54,9 +54,14 @@ export const GOOSE_EXT_METHODS = [
|
|||
responseType: "GetProviderDetailsResponse",
|
||||
},
|
||||
{
|
||||
method: "_goose/providers/models",
|
||||
requestType: "GetProviderModelsRequest",
|
||||
responseType: "GetProviderModelsResponse",
|
||||
method: "_goose/providers/inventory",
|
||||
requestType: "GetProviderInventoryRequest",
|
||||
responseType: "GetProviderInventoryResponse",
|
||||
},
|
||||
{
|
||||
method: "_goose/providers/inventory/refresh",
|
||||
requestType: "RefreshProviderInventoryRequest",
|
||||
responseType: "RefreshProviderInventoryResponse",
|
||||
},
|
||||
{
|
||||
method: "_goose/config/read",
|
||||
|
|
|
|||
|
|
@ -165,19 +165,133 @@ export type ModelEntry = {
|
|||
};
|
||||
|
||||
/**
|
||||
* Fetch the full list of models available for a specific provider.
|
||||
* Read per-provider inventory. Always returns immediately from stored state.
|
||||
*/
|
||||
export type GetProviderModelsRequest = {
|
||||
providerName: string;
|
||||
export type GetProviderInventoryRequest = {
|
||||
/**
|
||||
* Only return entries for these providers. Empty means all.
|
||||
*/
|
||||
providerIds?: Array<string>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider models response.
|
||||
* Provider inventory response.
|
||||
*/
|
||||
export type GetProviderModelsResponse = {
|
||||
models: Array<string>;
|
||||
export type GetProviderInventoryResponse = {
|
||||
entries: Array<ProviderInventoryEntryDto>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider inventory entry.
|
||||
*/
|
||||
export type ProviderInventoryEntryDto = {
|
||||
/**
|
||||
* Provider identifier.
|
||||
*/
|
||||
providerId: string;
|
||||
/**
|
||||
* Human-readable provider name.
|
||||
*/
|
||||
providerName: string;
|
||||
/**
|
||||
* Whether Goose has enough configuration to use this provider.
|
||||
*/
|
||||
configured: boolean;
|
||||
/**
|
||||
* Whether this provider supports background inventory refresh.
|
||||
*/
|
||||
supportsRefresh: boolean;
|
||||
/**
|
||||
* Whether a refresh is currently in flight.
|
||||
*/
|
||||
refreshing: boolean;
|
||||
/**
|
||||
* The list of available models.
|
||||
*/
|
||||
models: Array<ProviderInventoryModelDto>;
|
||||
/**
|
||||
* When this entry was last successfully refreshed (ISO 8601).
|
||||
*/
|
||||
lastUpdatedAt?: string | null;
|
||||
/**
|
||||
* When a refresh was most recently attempted (ISO 8601).
|
||||
*/
|
||||
lastRefreshAttemptAt?: string | null;
|
||||
/**
|
||||
* The last refresh failure message, if any.
|
||||
*/
|
||||
lastRefreshError?: string | null;
|
||||
/**
|
||||
* Whether we believe this data may be outdated.
|
||||
*/
|
||||
stale: boolean;
|
||||
/**
|
||||
* Guidance message shown when this provider manages its own model selection externally.
|
||||
*/
|
||||
modelSelectionHint?: string | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* A single model in provider inventory.
|
||||
*/
|
||||
export type ProviderInventoryModelDto = {
|
||||
/**
|
||||
* Model identifier as the provider knows it.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* Human-readable display name.
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* Model family for grouping in UI.
|
||||
*/
|
||||
family?: string | null;
|
||||
/**
|
||||
* Context window size in tokens.
|
||||
*/
|
||||
contextLimit?: number | null;
|
||||
/**
|
||||
* Whether the model supports reasoning/extended thinking.
|
||||
*/
|
||||
reasoning?: boolean | null;
|
||||
/**
|
||||
* Whether this model should appear in the compact recommended picker.
|
||||
*/
|
||||
recommended?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Trigger a background refresh of provider inventories.
|
||||
*/
|
||||
export type RefreshProviderInventoryRequest = {
|
||||
/**
|
||||
* Which providers to refresh. Empty means all known providers.
|
||||
*/
|
||||
providerIds?: Array<string>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Refresh acknowledgement.
|
||||
*/
|
||||
export type RefreshProviderInventoryResponse = {
|
||||
/**
|
||||
* Which providers will be refreshed.
|
||||
*/
|
||||
started: Array<string>;
|
||||
/**
|
||||
* Which providers were skipped and why.
|
||||
*/
|
||||
skipped?: Array<RefreshProviderInventorySkipDto>;
|
||||
};
|
||||
|
||||
export type RefreshProviderInventorySkipDto = {
|
||||
providerId: string;
|
||||
reason: RefreshProviderInventorySkipReasonDto;
|
||||
};
|
||||
|
||||
export type RefreshProviderInventorySkipReasonDto = 'unknown_provider' | 'not_configured' | 'does_not_support_refresh' | 'already_refreshing';
|
||||
|
||||
/**
|
||||
* Read a single non-secret config value.
|
||||
*/
|
||||
|
|
@ -421,14 +535,14 @@ export type DictationModelSelectRequest = {
|
|||
export type ExtRequest = {
|
||||
id: string;
|
||||
method: string;
|
||||
params?: AddExtensionRequest | RemoveExtensionRequest | GetToolsRequest | ReadResourceRequest | UpdateWorkingDirRequest | DeleteSessionRequest | GetExtensionsRequest | GetSessionExtensionsRequest | ListProvidersRequest | GetProviderDetailsRequest | GetProviderModelsRequest | ReadConfigRequest | UpsertConfigRequest | RemoveConfigRequest | CheckSecretRequest | UpsertSecretRequest | RemoveSecretRequest | ExportSessionRequest | ImportSessionRequest | ArchiveSessionRequest | UnarchiveSessionRequest | DictationTranscribeRequest | DictationConfigRequest | DictationModelsListRequest | DictationModelDownloadRequest | DictationModelDownloadProgressRequest | DictationModelCancelRequest | DictationModelDeleteRequest | DictationModelSelectRequest | {
|
||||
params?: AddExtensionRequest | RemoveExtensionRequest | GetToolsRequest | ReadResourceRequest | UpdateWorkingDirRequest | DeleteSessionRequest | GetExtensionsRequest | GetSessionExtensionsRequest | ListProvidersRequest | GetProviderDetailsRequest | GetProviderInventoryRequest | RefreshProviderInventoryRequest | ReadConfigRequest | UpsertConfigRequest | RemoveConfigRequest | CheckSecretRequest | UpsertSecretRequest | RemoveSecretRequest | ExportSessionRequest | ImportSessionRequest | ArchiveSessionRequest | UnarchiveSessionRequest | DictationTranscribeRequest | DictationConfigRequest | DictationModelsListRequest | DictationModelDownloadRequest | DictationModelDownloadProgressRequest | DictationModelCancelRequest | DictationModelDeleteRequest | DictationModelSelectRequest | {
|
||||
[key: string]: unknown;
|
||||
} | null;
|
||||
};
|
||||
|
||||
export type ExtResponse = {
|
||||
id: string;
|
||||
result?: EmptyResponse | GetToolsResponse | ReadResourceResponse | GetExtensionsResponse | GetSessionExtensionsResponse | ListProvidersResponse | GetProviderDetailsResponse | GetProviderModelsResponse | ReadConfigResponse | CheckSecretResponse | ExportSessionResponse | ImportSessionResponse | DictationTranscribeResponse | DictationConfigResponse | DictationModelsListResponse | DictationModelDownloadProgressResponse | unknown;
|
||||
result?: EmptyResponse | GetToolsResponse | ReadResourceResponse | GetExtensionsResponse | GetSessionExtensionsResponse | ListProvidersResponse | GetProviderDetailsResponse | GetProviderInventoryResponse | RefreshProviderInventoryResponse | ReadConfigResponse | CheckSecretResponse | ExportSessionResponse | ImportSessionResponse | DictationTranscribeResponse | DictationConfigResponse | DictationModelsListResponse | DictationModelDownloadProgressResponse | unknown;
|
||||
} | {
|
||||
error: {
|
||||
code: number;
|
||||
|
|
|
|||
|
|
@ -149,17 +149,94 @@ export const zGetProviderDetailsResponse = z.object({
|
|||
});
|
||||
|
||||
/**
|
||||
* Fetch the full list of models available for a specific provider.
|
||||
* Read per-provider inventory. Always returns immediately from stored state.
|
||||
*/
|
||||
export const zGetProviderModelsRequest = z.object({
|
||||
providerName: z.string()
|
||||
export const zGetProviderInventoryRequest = z.object({
|
||||
providerIds: z.array(z.string()).optional().default([])
|
||||
});
|
||||
|
||||
/**
|
||||
* Provider models response.
|
||||
* A single model in provider inventory.
|
||||
*/
|
||||
export const zGetProviderModelsResponse = z.object({
|
||||
models: z.array(z.string())
|
||||
export const zProviderInventoryModelDto = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
family: z.union([
|
||||
z.string(),
|
||||
z.null()
|
||||
]).optional(),
|
||||
contextLimit: z.union([
|
||||
z.number().int().gte(0),
|
||||
z.null()
|
||||
]).optional(),
|
||||
reasoning: z.union([
|
||||
z.boolean(),
|
||||
z.null()
|
||||
]).optional(),
|
||||
recommended: z.boolean().optional().default(false)
|
||||
});
|
||||
|
||||
/**
|
||||
* Provider inventory entry.
|
||||
*/
|
||||
export const zProviderInventoryEntryDto = z.object({
|
||||
providerId: z.string(),
|
||||
providerName: z.string(),
|
||||
configured: z.boolean(),
|
||||
supportsRefresh: z.boolean(),
|
||||
refreshing: z.boolean(),
|
||||
models: z.array(zProviderInventoryModelDto),
|
||||
lastUpdatedAt: z.union([
|
||||
z.string(),
|
||||
z.null()
|
||||
]).optional(),
|
||||
lastRefreshAttemptAt: z.union([
|
||||
z.string(),
|
||||
z.null()
|
||||
]).optional(),
|
||||
lastRefreshError: z.union([
|
||||
z.string(),
|
||||
z.null()
|
||||
]).optional(),
|
||||
stale: z.boolean(),
|
||||
modelSelectionHint: z.union([
|
||||
z.string(),
|
||||
z.null()
|
||||
]).optional()
|
||||
});
|
||||
|
||||
/**
|
||||
* Provider inventory response.
|
||||
*/
|
||||
export const zGetProviderInventoryResponse = z.object({
|
||||
entries: z.array(zProviderInventoryEntryDto)
|
||||
});
|
||||
|
||||
/**
|
||||
* Trigger a background refresh of provider inventories.
|
||||
*/
|
||||
export const zRefreshProviderInventoryRequest = z.object({
|
||||
providerIds: z.array(z.string()).optional().default([])
|
||||
});
|
||||
|
||||
export const zRefreshProviderInventorySkipReasonDto = z.enum([
|
||||
'unknown_provider',
|
||||
'not_configured',
|
||||
'does_not_support_refresh',
|
||||
'already_refreshing'
|
||||
]);
|
||||
|
||||
export const zRefreshProviderInventorySkipDto = z.object({
|
||||
providerId: z.string(),
|
||||
reason: zRefreshProviderInventorySkipReasonDto
|
||||
});
|
||||
|
||||
/**
|
||||
* Refresh acknowledgement.
|
||||
*/
|
||||
export const zRefreshProviderInventoryResponse = z.object({
|
||||
started: z.array(z.string()),
|
||||
skipped: z.array(zRefreshProviderInventorySkipDto).optional().default([])
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
@ -426,7 +503,8 @@ export const zExtRequest = z.object({
|
|||
zGetSessionExtensionsRequest,
|
||||
zListProvidersRequest,
|
||||
zGetProviderDetailsRequest,
|
||||
zGetProviderModelsRequest,
|
||||
zGetProviderInventoryRequest,
|
||||
zRefreshProviderInventoryRequest,
|
||||
zReadConfigRequest,
|
||||
zUpsertConfigRequest,
|
||||
zRemoveConfigRequest,
|
||||
|
|
@ -465,7 +543,8 @@ export const zExtResponse = z.union([
|
|||
zGetSessionExtensionsResponse,
|
||||
zListProvidersResponse,
|
||||
zGetProviderDetailsResponse,
|
||||
zGetProviderModelsResponse,
|
||||
zGetProviderInventoryResponse,
|
||||
zRefreshProviderInventoryResponse,
|
||||
zReadConfigResponse,
|
||||
zCheckSecretResponse,
|
||||
zExportSessionResponse,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue