diff --git a/ui/goose2/src/app/AppShell.tsx b/ui/goose2/src/app/AppShell.tsx index 3ba9c0f792..ce33962fd5 100644 --- a/ui/goose2/src/app/AppShell.tsx +++ b/ui/goose2/src/app/AppShell.tsx @@ -10,7 +10,6 @@ import { TopBar } from "./ui/TopBar"; import { useChatStore } from "@/features/chat/stores/chatStore"; import { type ChatSession, - hasSessionStarted, useChatSessionStore, } from "@/features/chat/stores/chatSessionStore"; import { useAgentStore } from "@/features/agents/stores/agentStore"; @@ -18,14 +17,18 @@ import { useProjectStore } from "@/features/projects/stores/projectStore"; import { findExistingDraft } from "@/features/chat/lib/newChat"; import { DEFAULT_CHAT_TITLE } from "@/features/chat/lib/sessionTitle"; import { useAppStartup } from "./hooks/useAppStartup"; +import { useHomeSessionStateSync } from "./hooks/useHomeSessionStateSync"; +import { loadStoredHomeSessionId } from "./lib/homeSessionStorage"; +import { resolveSupportedSessionModelPreference } from "./lib/resolveSupportedSessionModelPreference"; import { AppShellContent } from "./ui/AppShellContent"; -import { acpPrepareSession } from "@/shared/api/acp"; +import { acpPrepareSession, acpSetModel } from "@/shared/api/acp"; import { clearReplayBuffer, getAndDeleteReplayBuffer, } from "@/features/chat/hooks/replayBuffer"; import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection"; import { perfLog } from "@/shared/lib/perfLog"; +import { useProviderInventoryStore } from "@/features/providers/stores/providerInventoryStore"; export type AppView = | "home" @@ -40,34 +43,6 @@ const SIDEBAR_MIN_WIDTH = 180; const SIDEBAR_MAX_WIDTH = 380; const SIDEBAR_SNAP_COLLAPSE_THRESHOLD = 100; const SIDEBAR_COLLAPSED_WIDTH = 48; -const HOME_SESSION_STORAGE_KEY = "goose:home-session-id"; - -function loadStoredHomeSessionId(): string | null { - if (typeof window === "undefined") { - return null; - } - try { - return window.localStorage.getItem(HOME_SESSION_STORAGE_KEY); - } catch { - return null; - } -} - -function persistHomeSessionId(sessionId: string | null): void { - if (typeof window === "undefined") { - return; - } - try { - if (sessionId) { - window.localStorage.setItem(HOME_SESSION_STORAGE_KEY, sessionId); - return; - } - window.localStorage.removeItem(HOME_SESSION_STORAGE_KEY); - } catch { - // localStorage may be unavailable - } -} - export function AppShell({ children }: { children?: React.ReactNode }) { const [sidebarCollapsed, setSidebarCollapsed] = useState(false); const [sidebarWidth, setSidebarWidth] = useState(SIDEBAR_DEFAULT_WIDTH); @@ -90,6 +65,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) { const sessionStore = useChatSessionStore(); const agentStore = useAgentStore(); const projectStore = useProjectStore(); + const providerInventoryEntries = useProviderInventoryStore((s) => s.entries); const pendingProjectCreatedRef = useRef<((projectId: string) => void) | null>( null, @@ -173,37 +149,14 @@ export function AppShell({ children }: { children?: React.ReactNode }) { ? sessionStore.getSession(homeSessionId) : undefined; - useEffect(() => { - if ( - !homeSessionId || - !sessionStore.hasHydratedSessions || - sessionStore.isLoading - ) { - return; - } - if ( - !homeSession || - homeSession.archivedAt || - hasSessionStarted( - homeSession, - chatStore.messagesBySession[homeSession.id], - ) - ) { - setHomeSessionId(null); - } - }, [ - chatStore.messagesBySession, - homeSession, - homeSession?.archivedAt, - homeSession?.messageCount, + useHomeSessionStateSync({ homeSessionId, - sessionStore.hasHydratedSessions, - sessionStore.isLoading, - ]); - - useEffect(() => { - persistHomeSessionId(homeSessionId); - }, [homeSessionId]); + homeSession, + messagesBySession: chatStore.messagesBySession, + hasHydratedSessions: sessionStore.hasHydratedSessions, + isLoading: sessionStore.isLoading, + setHomeSessionId, + }); const ensureHomeSession = useCallback(async () => { if (!sessionStore.hasHydratedSessions || sessionStore.isLoading) { @@ -220,6 +173,11 @@ export function AppShell({ children }: { children?: React.ReactNode }) { !homeSession.archivedAt && homeSession.messageCount === 0 ) { + const sessionModelPreference = + await resolveSupportedSessionModelPreference( + agentStore.selectedProvider ?? "goose", + providerInventoryEntries, + ); const project = homeSession.projectId ? (projectStore.projects.find( (candidate) => candidate.id === homeSession.projectId, @@ -228,20 +186,42 @@ export function AppShell({ children }: { children?: React.ReactNode }) { const workingDir = await resolveSessionCwd(project); await acpPrepareSession( homeSession.id, - homeSession.providerId ?? agentStore.selectedProvider ?? "goose", + sessionModelPreference.providerId, workingDir, { personaId: homeSession.personaId, }, ); + const shouldClearHomeModel = + sessionModelPreference.providerId !== homeSession.providerId || + !sessionModelPreference.modelId; + sessionStore.updateSession(homeSession.id, { + providerId: sessionModelPreference.providerId, + modelId: shouldClearHomeModel ? undefined : homeSession.modelId, + modelName: shouldClearHomeModel ? undefined : homeSession.modelName, + }); + if (sessionModelPreference.modelId) { + await acpSetModel(homeSession.id, sessionModelPreference.modelId); + sessionStore.updateSession(homeSession.id, { + modelId: sessionModelPreference.modelId, + modelName: sessionModelPreference.modelName, + }); + } return homeSession; } const workingDir = await resolveSessionCwd(null); + const sessionModelPreference = + await resolveSupportedSessionModelPreference( + agentStore.selectedProvider ?? "goose", + providerInventoryEntries, + ); const session = await sessionStore.createSession({ title: DEFAULT_CHAT_TITLE, - providerId: agentStore.selectedProvider ?? "goose", + providerId: sessionModelPreference.providerId, workingDir, + modelId: sessionModelPreference.modelId, + modelName: sessionModelPreference.modelName, }); setHomeSessionId(session.id); return session; @@ -258,6 +238,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) { }, [ agentStore.selectedProvider, homeSession, + providerInventoryEntries, projectStore.projects, sessionStore.hasHydratedSessions, sessionStore, @@ -282,7 +263,12 @@ export function AppShell({ children }: { children?: React.ReactNode }) { const agentId = agentStore.activeAgentId ?? undefined; const providerId = project?.preferredProvider ?? agentStore.selectedProvider ?? "goose"; - const modelId = project?.preferredModel ?? undefined; + const sessionModelPreference = + await resolveSupportedSessionModelPreference( + providerId, + providerInventoryEntries, + project?.preferredModel ?? undefined, + ); const sessionState = useChatSessionStore.getState(); const chatState = useChatStore.getState(); const existingDraft = findExistingDraft({ @@ -311,10 +297,10 @@ export function AppShell({ children }: { children?: React.ReactNode }) { title, projectId: project?.id, agentId, - providerId, + providerId: sessionModelPreference.providerId, workingDir, - modelId, - modelName: modelId, + modelId: sessionModelPreference.modelId, + modelName: sessionModelPreference.modelName, }); sessionStore.setActiveSession(session.id); setActiveView("chat"); @@ -328,6 +314,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) { agentStore.activeAgentId, agentStore.selectedProvider, chatStore, + providerInventoryEntries, sessionStore, ], ); diff --git a/ui/goose2/src/app/hooks/useHomeSessionStateSync.ts b/ui/goose2/src/app/hooks/useHomeSessionStateSync.ts new file mode 100644 index 0000000000..f55e8a9f01 --- /dev/null +++ b/ui/goose2/src/app/hooks/useHomeSessionStateSync.ts @@ -0,0 +1,51 @@ +import { useEffect } from "react"; +import { + hasSessionStarted, + type ChatSession, +} from "@/features/chat/stores/chatSessionStore"; +import { persistHomeSessionId } from "../lib/homeSessionStorage"; + +interface UseHomeSessionStateSyncOptions { + homeSessionId: string | null; + homeSession?: ChatSession; + messagesBySession: Record | undefined>; + hasHydratedSessions: boolean; + isLoading: boolean; + setHomeSessionId: (sessionId: string | null) => void; +} + +export function useHomeSessionStateSync({ + homeSessionId, + homeSession, + messagesBySession, + hasHydratedSessions, + isLoading, + setHomeSessionId, +}: UseHomeSessionStateSyncOptions): void { + useEffect(() => { + if (!homeSessionId || !hasHydratedSessions || isLoading) { + return; + } + + if ( + !homeSession || + homeSession.archivedAt || + hasSessionStarted(homeSession, messagesBySession[homeSession.id]) + ) { + setHomeSessionId(null); + } + }, [ + hasHydratedSessions, + homeSession, + homeSession?.archivedAt, + homeSession?.messageCount, + homeSessionId, + isLoading, + messagesBySession, + setHomeSessionId, + ]); + + useEffect(() => { + persistHomeSessionId(homeSessionId); + }, [homeSessionId]); +} diff --git a/ui/goose2/src/app/lib/homeSessionStorage.ts b/ui/goose2/src/app/lib/homeSessionStorage.ts new file mode 100644 index 0000000000..f4c0475434 --- /dev/null +++ b/ui/goose2/src/app/lib/homeSessionStorage.ts @@ -0,0 +1,27 @@ +const HOME_SESSION_STORAGE_KEY = "goose:home-session-id"; + +export function loadStoredHomeSessionId(): string | null { + if (typeof window === "undefined") { + return null; + } + try { + return window.localStorage.getItem(HOME_SESSION_STORAGE_KEY); + } catch { + return null; + } +} + +export function persistHomeSessionId(sessionId: string | null): void { + if (typeof window === "undefined") { + return; + } + try { + if (sessionId) { + window.localStorage.setItem(HOME_SESSION_STORAGE_KEY, sessionId); + return; + } + window.localStorage.removeItem(HOME_SESSION_STORAGE_KEY); + } catch { + // localStorage may be unavailable + } +} diff --git a/ui/goose2/src/app/lib/resolveSupportedSessionModelPreference.test.ts b/ui/goose2/src/app/lib/resolveSupportedSessionModelPreference.test.ts new file mode 100644 index 0000000000..98583b5582 --- /dev/null +++ b/ui/goose2/src/app/lib/resolveSupportedSessionModelPreference.test.ts @@ -0,0 +1,58 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { resolveSupportedSessionModelPreference } from "./resolveSupportedSessionModelPreference"; + +const mockGetProviderInventory = vi.fn(); + +vi.mock("@/features/providers/api/inventory", () => ({ + getProviderInventory: (...args: unknown[]) => + mockGetProviderInventory(...args), +})); + +describe("resolveSupportedSessionModelPreference", () => { + beforeEach(() => { + vi.clearAllMocks(); + window.localStorage.clear(); + }); + + it("drops the model when provider inventory lookup fails", async () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "gpt-5.4", + modelName: "GPT-5.4", + providerId: "openai", + }, + }), + ); + mockGetProviderInventory.mockRejectedValue( + new Error("inventory unavailable"), + ); + + await expect( + resolveSupportedSessionModelPreference("goose", new Map()), + ).resolves.toEqual({ + providerId: "openai", + }); + }); + + it("drops the model when provider inventory has no matching entry", async () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "gpt-5.4", + modelName: "GPT-5.4", + providerId: "openai", + }, + }), + ); + mockGetProviderInventory.mockResolvedValue([]); + + await expect( + resolveSupportedSessionModelPreference("goose", new Map()), + ).resolves.toEqual({ + providerId: "openai", + }); + }); +}); diff --git a/ui/goose2/src/app/lib/resolveSupportedSessionModelPreference.ts b/ui/goose2/src/app/lib/resolveSupportedSessionModelPreference.ts new file mode 100644 index 0000000000..dedbd53ffb --- /dev/null +++ b/ui/goose2/src/app/lib/resolveSupportedSessionModelPreference.ts @@ -0,0 +1,36 @@ +import type { ProviderInventoryEntryDto } from "@aaif/goose-sdk"; +import { getProviderInventory } from "@/features/providers/api/inventory"; +import { + resolveSessionModelPreference, + sanitizeSessionModelPreference, + type SessionModelPreference, +} from "@/features/chat/lib/sessionModelPreference"; + +export async function resolveSupportedSessionModelPreference( + providerId: string, + inventoryEntries: Map, + preferredModel?: string, +): Promise { + const sessionModelPreference = resolveSessionModelPreference({ + providerId, + preferredModel, + }); + + if (!sessionModelPreference.modelId) { + return sessionModelPreference; + } + + const inventoryEntry = + inventoryEntries.get(sessionModelPreference.providerId) ?? + (await getProviderInventory([sessionModelPreference.providerId]) + .then(([entry]) => entry) + .catch(() => undefined)); + + if (!inventoryEntry) { + return { + providerId: sessionModelPreference.providerId, + }; + } + + return sanitizeSessionModelPreference(sessionModelPreference, inventoryEntry); +} diff --git a/ui/goose2/src/features/chat/hooks/__tests__/useChatSessionController.test.ts b/ui/goose2/src/features/chat/hooks/__tests__/useChatSessionController.test.ts index 9c7c4a6dd8..cdc3153929 100644 --- a/ui/goose2/src/features/chat/hooks/__tests__/useChatSessionController.test.ts +++ b/ui/goose2/src/features/chat/hooks/__tests__/useChatSessionController.test.ts @@ -9,12 +9,37 @@ const mockAcpPrepareSession = vi.fn(); const mockAcpSetModel = vi.fn(); const mockSetSelectedProvider = vi.fn(); const mockResolveSessionCwd = vi.fn(); +const mockGooseConfigRead = vi.fn(); +const mockUseProviderInventory = vi.fn(); +const mockPickerState = { + pickerAgents: [{ id: "goose", label: "Goose" }], + availableModels: [] as Array<{ + id: string; + name: string; + displayName?: string; + providerId?: string; + }>, + modelsLoading: false, + modelStatusMessage: null as string | null, +}; vi.mock("@/shared/api/acp", () => ({ acpPrepareSession: (...args: unknown[]) => mockAcpPrepareSession(...args), acpSetModel: (...args: unknown[]) => mockAcpSetModel(...args), })); +vi.mock("@/shared/api/acpConnection", () => ({ + getClient: async () => ({ + goose: { + GooseConfigRead: (...args: unknown[]) => mockGooseConfigRead(...args), + }, + }), +})); + +vi.mock("@/features/providers/hooks/useProviderInventory", () => ({ + useProviderInventory: () => mockUseProviderInventory(), +})); + vi.mock("../useChat", () => ({ useChat: () => ({ messages: [], @@ -41,7 +66,7 @@ vi.mock("@/features/agents/hooks/useProviderSelection", () => ({ { id: "anthropic", label: "Anthropic" }, ], providersLoading: false, - selectedProvider: "openai", + selectedProvider: useAgentStore.getState().selectedProvider ?? "openai", setSelectedProvider: (...args: unknown[]) => mockSetSelectedProvider(...args), }), @@ -63,10 +88,10 @@ vi.mock("../useAgentModelPickerState", () => ({ }) => void; }) => ({ selectedAgentId: "goose", - pickerAgents: [{ id: "goose", label: "Goose" }], - availableModels: [], - modelsLoading: false, - modelStatusMessage: null, + pickerAgents: mockPickerState.pickerAgents, + availableModels: mockPickerState.availableModels, + modelsLoading: mockPickerState.modelsLoading, + modelStatusMessage: mockPickerState.modelStatusMessage, handleProviderChange: vi.fn(), handleModelChange: (modelId: string) => { if (modelId === "claude-sonnet-4") { @@ -86,9 +111,18 @@ import { useChatSessionController } from "../useChatSessionController"; describe("useChatSessionController", () => { beforeEach(() => { vi.clearAllMocks(); + window.localStorage.clear(); mockAcpPrepareSession.mockResolvedValue(undefined); mockAcpSetModel.mockResolvedValue(undefined); mockResolveSessionCwd.mockResolvedValue("/tmp/project"); + mockGooseConfigRead.mockResolvedValue({ value: null }); + mockUseProviderInventory.mockReturnValue({ + getEntry: () => undefined, + }); + mockPickerState.pickerAgents = [{ id: "goose", label: "Goose" }]; + mockPickerState.availableModels = []; + mockPickerState.modelsLoading = false; + mockPickerState.modelStatusMessage = null; useAgentStore.setState({ personas: [], @@ -178,4 +212,217 @@ describe("useChatSessionController", () => { modelName: "Claude Sonnet 4", }); }); + + it("restores the previous stored model preference when setting a model fails", async () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "gpt-4o", + modelName: "GPT-4o", + providerId: "openai", + }, + }), + ); + mockAcpSetModel.mockRejectedValueOnce(new Error("set model failed")); + + const { result } = renderHook(() => + useChatSessionController({ sessionId: "session-1" }), + ); + + act(() => { + result.current.handleModelChange("claude-sonnet-4"); + }); + + await waitFor(() => { + expect( + useChatSessionStore.getState().getSession("session-1"), + ).toMatchObject({ + providerId: "openai", + modelId: "gpt-4o", + modelName: "GPT-4o", + }); + }); + + expect( + JSON.parse( + window.localStorage.getItem("goose:preferredModelsByAgent") ?? "{}", + ), + ).toEqual({ + goose: { + modelId: "gpt-4o", + modelName: "GPT-4o", + providerId: "openai", + }, + }); + }); + + it("shows the stored explicit model for new chats", async () => { + useAgentStore.setState({ selectedProvider: "goose" }); + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + providerId: "anthropic", + }, + }), + ); + + const { result } = renderHook(() => + useChatSessionController({ sessionId: null }), + ); + + await waitFor(() => { + expect(result.current.currentModelId).toBe("claude-sonnet-4"); + }); + expect(result.current.currentModelName).toBe("Claude Sonnet 4"); + }); + + it("falls back to the configured goose default model when no explicit model is stored", async () => { + useAgentStore.setState({ selectedProvider: "goose" }); + mockGooseConfigRead.mockImplementation( + async ({ key }: { key: string }): Promise<{ value: string | null }> => { + if (key === "GOOSE_PROVIDER") { + return { value: "databricks" }; + } + if (key === "GOOSE_MODEL") { + return { value: "goose-claude-4-6-opus" }; + } + return { value: null }; + }, + ); + mockPickerState.availableModels = [ + { + id: "goose-claude-4-6-opus", + name: "Claude 4.6 Opus", + providerId: "databricks", + }, + ]; + + const { result } = renderHook(() => + useChatSessionController({ sessionId: null }), + ); + + await waitFor(() => { + expect(result.current.currentModelId).toBe("goose-claude-4-6-opus"); + }); + expect(result.current.currentModelName).toBe("Claude 4.6 Opus"); + }); + + it("applies the pending Home model to ACP when a real session becomes active", async () => { + const { result, rerender } = renderHook( + ({ sessionId }: { sessionId: string | null }) => + useChatSessionController({ sessionId }), + { + initialProps: { sessionId: null as string | null }, + }, + ); + + act(() => { + result.current.handleModelChange("claude-sonnet-4"); + }); + + useChatSessionStore.setState((state) => ({ + sessions: [ + { + id: "session-2", + title: "Chat", + providerId: "openai", + createdAt: "2026-04-21T00:00:00.000Z", + updatedAt: "2026-04-21T00:00:00.000Z", + messageCount: 0, + }, + ...state.sessions, + ], + })); + + rerender({ sessionId: "session-2" }); + + await waitFor(() => { + expect(mockAcpPrepareSession).toHaveBeenCalledWith( + "session-2", + "anthropic", + "/tmp/project", + { personaId: undefined }, + ); + }); + + await waitFor(() => { + expect(mockAcpSetModel).toHaveBeenCalledWith( + "session-2", + "claude-sonnet-4", + ); + }); + + expect( + useChatSessionStore.getState().getSession("session-2"), + ).toMatchObject({ + providerId: "anthropic", + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + }); + }); + + it("does not persist or record a pending Home model when ACP rejects it", async () => { + mockAcpSetModel.mockRejectedValueOnce(new Error("set model failed")); + + const { result, rerender } = renderHook( + ({ sessionId }: { sessionId: string | null }) => + useChatSessionController({ sessionId }), + { + initialProps: { sessionId: null as string | null }, + }, + ); + + act(() => { + result.current.handleModelChange("claude-sonnet-4"); + }); + + expect( + window.localStorage.getItem("goose:preferredModelsByAgent"), + ).toBeNull(); + + useChatSessionStore.setState((state) => ({ + sessions: [ + { + id: "session-3", + title: "Chat", + providerId: "openai", + createdAt: "2026-04-21T00:00:00.000Z", + updatedAt: "2026-04-21T00:00:00.000Z", + messageCount: 0, + }, + ...state.sessions, + ], + })); + + rerender({ sessionId: "session-3" }); + + await waitFor(() => { + expect(mockAcpSetModel).toHaveBeenCalledWith( + "session-3", + "claude-sonnet-4", + ); + }); + + await waitFor(() => { + expect( + useChatSessionStore.getState().getSession("session-3"), + ).toMatchObject({ + providerId: "anthropic", + }); + }); + + expect( + useChatSessionStore.getState().getSession("session-3"), + ).not.toMatchObject({ + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + }); + expect( + window.localStorage.getItem("goose:preferredModelsByAgent"), + ).toBeNull(); + }); }); diff --git a/ui/goose2/src/features/chat/hooks/__tests__/useResolvedAgentModelPicker.test.ts b/ui/goose2/src/features/chat/hooks/__tests__/useResolvedAgentModelPicker.test.ts new file mode 100644 index 0000000000..4e71dd022a --- /dev/null +++ b/ui/goose2/src/features/chat/hooks/__tests__/useResolvedAgentModelPicker.test.ts @@ -0,0 +1,396 @@ +import { act, renderHook } from "@testing-library/react"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { useResolvedAgentModelPicker } from "../useResolvedAgentModelPicker"; + +const mockUseProviderInventory = vi.fn(); +const mockUseAgentModelPickerState = vi.fn(); +const mockGetClient = vi.fn(); + +vi.mock("@/features/providers/hooks/useProviderInventory", () => ({ + useProviderInventory: () => mockUseProviderInventory(), +})); + +vi.mock("../useAgentModelPickerState", () => ({ + useAgentModelPickerState: (args: unknown) => + mockUseAgentModelPickerState(args), +})); + +vi.mock("@/shared/api/acpConnection", () => ({ + getClient: (...args: unknown[]) => mockGetClient(...args), +})); + +describe("useResolvedAgentModelPicker", () => { + beforeEach(() => { + vi.clearAllMocks(); + window.localStorage.clear(); + + mockGetClient.mockResolvedValue({ + goose: { + GooseConfigRead: vi.fn().mockResolvedValue({ value: null }), + }, + }); + + mockUseProviderInventory.mockReturnValue({ + getEntry: (providerId: string) => + providerId === "codex-acp" + ? { + providerId: "codex-acp", + defaultModel: "gpt-5.4", + models: [ + { + id: "gpt-5.4", + name: "GPT-5.4", + recommended: true, + }, + ], + } + : undefined, + }); + + mockUseAgentModelPickerState.mockImplementation( + ({ + onProviderSelected, + }: { + onProviderSelected: (providerId: string) => void; + }) => ({ + pickerAgents: [ + { id: "goose", label: "Goose" }, + { id: "codex-acp", label: "Codex" }, + ], + availableModels: [], + modelsLoading: false, + modelStatusMessage: null, + handleProviderChange: (providerId: string) => + onProviderSelected(providerId), + handleModelChange: vi.fn(), + }), + ); + }); + + it("selects the agent default model when switching to a provider without a saved model", () => { + const setPendingProviderId = vi.fn(); + const setPendingModelSelection = vi.fn(); + const setGlobalSelectedProvider = vi.fn(); + + const { result } = renderHook(() => + useResolvedAgentModelPicker({ + providers: [ + { id: "goose", label: "Goose" }, + { id: "codex-acp", label: "Codex" }, + ], + selectedProvider: "goose", + sessionId: null, + session: undefined, + pendingModelSelection: undefined, + setPendingProviderId, + setPendingModelSelection, + setGlobalSelectedProvider, + prepareSelectedProvider: vi.fn(), + }), + ); + + act(() => { + result.current.handleProviderChange("codex-acp"); + }); + + expect(setGlobalSelectedProvider).toHaveBeenCalledWith("codex-acp"); + expect(setPendingProviderId).toHaveBeenCalledWith("codex-acp"); + expect(setPendingModelSelection).toHaveBeenCalledWith({ + id: "gpt-5.4", + name: "GPT-5.4", + providerId: "codex-acp", + source: "default", + }); + }); + + it("selects the saved model when switching back to an agent", () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + "codex-acp": { + modelId: "gpt-5.4-mini", + modelName: "GPT-5.4 mini", + providerId: "codex-acp", + }, + }), + ); + + const setPendingProviderId = vi.fn(); + const setPendingModelSelection = vi.fn(); + const setGlobalSelectedProvider = vi.fn(); + + const { result } = renderHook(() => + useResolvedAgentModelPicker({ + providers: [ + { id: "goose", label: "Goose" }, + { id: "codex-acp", label: "Codex" }, + ], + selectedProvider: "goose", + sessionId: null, + session: undefined, + pendingModelSelection: undefined, + setPendingProviderId, + setPendingModelSelection, + setGlobalSelectedProvider, + prepareSelectedProvider: vi.fn(), + }), + ); + + act(() => { + result.current.handleProviderChange("codex-acp"); + }); + + expect(setGlobalSelectedProvider).toHaveBeenCalledWith("codex-acp"); + expect(setPendingProviderId).toHaveBeenCalledWith("codex-acp"); + expect(setPendingModelSelection).toHaveBeenCalledWith({ + id: "gpt-5.4-mini", + name: "GPT-5.4 mini", + providerId: "codex-acp", + source: "explicit", + }); + }); + + it("keeps explicit concrete provider requests authoritative", () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + providerId: "anthropic", + }, + }), + ); + + const setPendingProviderId = vi.fn(); + const setPendingModelSelection = vi.fn(); + const setGlobalSelectedProvider = vi.fn(); + + const { result } = renderHook(() => + useResolvedAgentModelPicker({ + providers: [ + { id: "goose", label: "Goose" }, + { id: "openai", label: "OpenAI" }, + ], + selectedProvider: "anthropic", + sessionId: null, + session: undefined, + pendingModelSelection: undefined, + setPendingProviderId, + setPendingModelSelection, + setGlobalSelectedProvider, + prepareSelectedProvider: vi.fn(), + }), + ); + + act(() => { + result.current.handleProviderChange("openai"); + }); + + expect(setGlobalSelectedProvider).toHaveBeenCalledWith("openai"); + expect(setPendingProviderId).toHaveBeenCalledWith("openai"); + expect(setPendingModelSelection).toHaveBeenCalledWith(undefined); + }); + + it("resolves ACP alias defaults to a concrete model when switching agents", () => { + mockUseProviderInventory.mockReturnValue({ + getEntry: (providerId: string) => + providerId === "claude-acp" + ? { + providerId: "claude-acp", + defaultModel: "current", + models: [ + { + id: "sonnet", + name: "Claude Sonnet", + recommended: true, + }, + { + id: "opus", + name: "Claude Opus", + recommended: false, + }, + ], + } + : undefined, + }); + + const setPendingProviderId = vi.fn(); + const setPendingModelSelection = vi.fn(); + const setGlobalSelectedProvider = vi.fn(); + + const { result } = renderHook(() => + useResolvedAgentModelPicker({ + providers: [ + { id: "goose", label: "Goose" }, + { id: "claude-acp", label: "Claude Code" }, + ], + selectedProvider: "goose", + sessionId: null, + session: undefined, + pendingModelSelection: undefined, + setPendingProviderId, + setPendingModelSelection, + setGlobalSelectedProvider, + prepareSelectedProvider: vi.fn(), + }), + ); + + act(() => { + result.current.handleProviderChange("claude-acp"); + }); + + expect(setPendingModelSelection).toHaveBeenCalledWith({ + id: "sonnet", + name: "Claude Sonnet", + providerId: "claude-acp", + source: "default", + }); + }); + + it("prefers a concrete default model over a session alias like current", () => { + mockUseProviderInventory.mockReturnValue({ + getEntry: (providerId: string) => + providerId === "claude-acp" + ? { + providerId: "claude-acp", + defaultModel: "current", + models: [ + { + id: "sonnet", + name: "Claude Sonnet", + recommended: true, + }, + ], + } + : undefined, + }); + + mockUseAgentModelPickerState.mockImplementation( + ({ + onProviderSelected, + }: { + onProviderSelected: (providerId: string) => void; + }) => ({ + pickerAgents: [ + { id: "goose", label: "Goose" }, + { id: "claude-acp", label: "Claude Code" }, + ], + availableModels: [ + { + id: "sonnet", + name: "Claude Sonnet", + recommended: true, + }, + ], + modelsLoading: false, + modelStatusMessage: null, + handleProviderChange: (providerId: string) => + onProviderSelected(providerId), + handleModelChange: vi.fn(), + }), + ); + + const { result } = renderHook(() => + useResolvedAgentModelPicker({ + providers: [ + { id: "goose", label: "Goose" }, + { id: "claude-acp", label: "Claude Code" }, + ], + selectedProvider: "claude-acp", + sessionId: "session-1", + session: { + id: "session-1", + title: "Chat", + providerId: "claude-acp", + modelId: "current", + modelName: "current", + createdAt: "2026-04-21T00:00:00.000Z", + updatedAt: "2026-04-21T00:00:00.000Z", + messageCount: 0, + }, + pendingModelSelection: undefined, + setPendingProviderId: vi.fn(), + setPendingModelSelection: vi.fn(), + setGlobalSelectedProvider: vi.fn(), + prepareSelectedProvider: vi.fn(), + }), + ); + + expect(result.current.effectiveModelSelection).toEqual({ + id: "sonnet", + name: "Claude Sonnet", + providerId: "claude-acp", + source: "default", + }); + }); + + it("drops Goose fallback models that are incompatible with a concrete provider", () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + providerId: "anthropic", + }, + }), + ); + + mockUseAgentModelPickerState.mockImplementation( + ({ + onProviderSelected, + }: { + onProviderSelected: (providerId: string) => void; + }) => ({ + pickerAgents: [ + { id: "goose", label: "Goose" }, + { id: "openai", label: "OpenAI" }, + ], + availableModels: [ + { + id: "gpt-5.4", + name: "GPT-5.4", + providerId: "openai", + }, + { + id: "claude-sonnet-4", + name: "Claude Sonnet 4", + providerId: "anthropic", + }, + ], + modelsLoading: false, + modelStatusMessage: null, + handleProviderChange: (providerId: string) => + onProviderSelected(providerId), + handleModelChange: vi.fn(), + }), + ); + + const { result } = renderHook(() => + useResolvedAgentModelPicker({ + providers: [ + { id: "goose", label: "Goose" }, + { id: "openai", label: "OpenAI" }, + ], + selectedProvider: "openai", + sessionId: "session-1", + session: { + id: "session-1", + title: "Chat", + providerId: "openai", + createdAt: "2026-04-21T00:00:00.000Z", + updatedAt: "2026-04-21T00:00:00.000Z", + messageCount: 0, + }, + pendingModelSelection: undefined, + setPendingProviderId: vi.fn(), + setPendingModelSelection: vi.fn(), + setGlobalSelectedProvider: vi.fn(), + prepareSelectedProvider: vi.fn(), + }), + ); + + expect(result.current.effectiveModelSelection).toBeNull(); + }); +}); diff --git a/ui/goose2/src/features/chat/hooks/useChatSessionController.ts b/ui/goose2/src/features/chat/hooks/useChatSessionController.ts index 2b321fed90..56938a86ef 100644 --- a/ui/goose2/src/features/chat/hooks/useChatSessionController.ts +++ b/ui/goose2/src/features/chat/hooks/useChatSessionController.ts @@ -7,15 +7,20 @@ import { useChatSessionStore } from "../stores/chatSessionStore"; import { useAgentStore } from "@/features/agents/stores/agentStore"; import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection"; import { useProjectStore } from "@/features/projects/stores/projectStore"; -import { useAgentModelPickerState } from "./useAgentModelPickerState"; +import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog"; import { buildProjectSystemPrompt, composeSystemPrompt, getProjectArtifactRoots, resolveProjectDefaultArtifactRoot, } from "@/features/projects/lib/chatProjectContext"; +import { setStoredModelPreference } from "../lib/modelPreferences"; import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection"; import { acpPrepareSession, acpSetModel } from "@/shared/api/acp"; +import { + useResolvedAgentModelPicker, + type PreferredModelSelection, +} from "./useResolvedAgentModelPicker"; interface UseChatSessionControllerOptions { sessionId: string | null; @@ -52,11 +57,8 @@ export function useChatSessionController({ const [pendingPersonaId, setPendingPersonaId] = useState(); const [pendingProjectId, setPendingProjectId] = useState(); const [pendingProviderId, setPendingProviderId] = useState(); - const [pendingModelSelection, setPendingModelSelection] = useState<{ - id: string; - name: string; - providerId?: string; - } | null>(); + const [pendingModelSelection, setPendingModelSelection] = + useState(); const pendingDraftValue = useChatStore( (s) => s.draftsBySession[PENDING_HOME_SESSION_ID] ?? "", ); @@ -140,6 +142,7 @@ export function useChatSessionController({ nextProject = project, nextWorkspacePath = activeWorkspace?.path, personaId = selectedPersonaId ?? undefined, + modelSelection?: PreferredModelSelection | null, ) => { if (!sessionId) { return; @@ -149,9 +152,39 @@ export function useChatSessionController({ nextWorkspacePath, ); await acpPrepareSession(sessionId, providerId, workingDir, { personaId }); + if (!modelSelection?.id) { + return; + } + + const sessionStore = useChatSessionStore.getState(); + const liveSession = sessionStore.getSession(sessionId); + const modelAlreadyApplied = + liveSession?.modelId === modelSelection.id && + liveSession?.modelName === modelSelection.name; + + if (modelAlreadyApplied) { + return; + } + + await acpSetModel(sessionId, modelSelection.id); + sessionStore.updateSession(sessionId, { + modelId: modelSelection.id, + modelName: modelSelection.name, + }); }, [activeWorkspace?.path, project, selectedPersonaId, sessionId], ); + const prepareSelectedProvider = useCallback( + (providerId: string, modelSelection?: PreferredModelSelection | null) => + prepareCurrentSession( + providerId, + project, + activeWorkspace?.path, + selectedPersonaId ?? undefined, + modelSelection, + ), + [activeWorkspace?.path, prepareCurrentSession, project, selectedPersonaId], + ); const prevProjectIdRef = useRef(session?.projectId); useEffect(() => { @@ -168,6 +201,27 @@ export function useChatSessionController({ } }, [clearActiveWorkspace, session?.projectId, sessionId]); + const { + selectedAgentId, + pickerAgents, + availableModels, + modelsLoading, + modelStatusMessage, + handleProviderChange, + handleModelChange, + effectiveModelSelection, + } = useResolvedAgentModelPicker({ + providers, + selectedProvider, + sessionId, + session, + pendingModelSelection, + setPendingProviderId, + setPendingModelSelection, + setGlobalSelectedProvider, + prepareSelectedProvider, + }); + const prevWorkspaceRef = useRef(activeWorkspace); useEffect(() => { const previousWorkspace = prevWorkspaceRef.current; @@ -183,114 +237,19 @@ export function useChatSessionController({ if (previousWorkspace?.path === activeWorkspace.path) { return; } - void prepareCurrentSession(selectedProvider).catch((error) => { + void prepareSelectedProvider( + selectedProvider, + effectiveModelSelection, + ).catch((error) => { console.error("Failed to prepare ACP session:", error); }); - }, [activeWorkspace, prepareCurrentSession, selectedProvider, sessionId]); - - const { - selectedAgentId, - pickerAgents, - availableModels, - modelsLoading, - modelStatusMessage, - handleProviderChange, - handleModelChange, - } = useAgentModelPickerState({ - providers, + }, [ + activeWorkspace, + effectiveModelSelection, + prepareSelectedProvider, selectedProvider, - onProviderSelected: (providerId) => { - if (!sessionId) { - setGlobalSelectedProvider(providerId); - setPendingProviderId(providerId); - setPendingModelSelection(null); - return; - } - useChatSessionStore - .getState() - .switchSessionProvider(sessionId, providerId); - setGlobalSelectedProvider(providerId); - void prepareCurrentSession(providerId).catch((error) => { - console.error("Failed to update ACP session provider:", error); - }); - }, - onModelSelected: (model) => { - const modelId = model.id; - const modelName = model.displayName ?? model.name ?? model.id; - const nextProviderId = model.providerId ?? selectedProvider; - - if (!sessionId) { - if (nextProviderId && nextProviderId !== selectedProvider) { - setPendingProviderId(nextProviderId); - setGlobalSelectedProvider(nextProviderId); - } - setPendingModelSelection({ - id: modelId, - name: modelName, - providerId: nextProviderId, - }); - return; - } - if ( - !session || - (modelId === session.modelId && - (!nextProviderId || nextProviderId === session.providerId)) - ) { - return; - } - const previousProviderId = session.providerId; - const previousModelId = session.modelId; - const previousModelName = session.modelName; - const providerChanged = - Boolean(nextProviderId) && nextProviderId !== session.providerId; - - if (providerChanged && nextProviderId) { - useChatSessionStore - .getState() - .switchSessionProvider(sessionId, nextProviderId); - setGlobalSelectedProvider(nextProviderId); - } - - useChatSessionStore.getState().updateSession(sessionId, { - modelId, - modelName, - }); - - void (async () => { - try { - if (providerChanged && nextProviderId) { - await prepareCurrentSession(nextProviderId); - } - await acpSetModel(sessionId, modelId); - } catch (error) { - console.error("Failed to set model:", error); - if (providerChanged && previousProviderId) { - setGlobalSelectedProvider(previousProviderId); - } - useChatSessionStore.getState().updateSession(sessionId, { - providerId: previousProviderId, - modelId: previousModelId, - modelName: previousModelName, - }); - void (async () => { - try { - if (providerChanged && previousProviderId) { - await prepareCurrentSession(previousProviderId); - } - if (previousModelId) { - await acpSetModel(sessionId, previousModelId); - } - } catch (rollbackError) { - console.error( - "Failed to restore previous provider/model after setModel failure:", - rollbackError, - ); - } - })(); - } - })(); - }, - }); + sessionId, + ]); const handleProjectChange = useCallback( (projectId: string | null) => { @@ -310,16 +269,24 @@ export function useChatSessionController({ if (!selectedProvider) { return; } - void prepareCurrentSession(selectedProvider, nextProject).catch( - (error) => { - console.error( - "Failed to update ACP session working directory:", - error, - ); - }, - ); + void prepareCurrentSession( + selectedProvider, + nextProject, + activeWorkspace?.path, + selectedPersonaId ?? undefined, + effectiveModelSelection, + ).catch((error) => { + console.error("Failed to update ACP session working directory:", error); + }); }, - [prepareCurrentSession, selectedProvider, sessionId], + [ + activeWorkspace?.path, + effectiveModelSelection, + prepareCurrentSession, + selectedPersonaId, + selectedProvider, + sessionId, + ], ); const handlePersonaChange = useCallback( @@ -334,7 +301,7 @@ export function useChatSessionController({ if (matchingProvider) { if (!sessionId) { setPendingProviderId(matchingProvider.id); - setPendingModelSelection(null); + setPendingModelSelection(undefined); setGlobalSelectedProvider(matchingProvider.id); } else { handleProviderChange(matchingProvider.id); @@ -399,7 +366,8 @@ export function useChatSessionController({ { onMessageAccepted: sessionId ? onMessageAccepted : undefined, ensurePrepared: selectedProvider - ? () => prepareCurrentSession(selectedProvider) + ? () => + prepareSelectedProvider(selectedProvider, effectiveModelSelection) : undefined, }, ); @@ -569,10 +537,6 @@ export function useChatSessionController({ if (hasPendingProject) { patch.projectId = nextProjectId ?? null; } - if (hasPendingModel) { - patch.modelId = pendingModelSelection?.id; - patch.modelName = pendingModelSelection?.name; - } useChatSessionStore.getState().updateSession(sessionId, patch); @@ -582,15 +546,21 @@ export function useChatSessionController({ nextProject, activeWorkspace?.path, nextPersonaId, + pendingModelSelection, ); if (cancelled) { return; } - if (pendingModelSelection?.id) { - await acpSetModel(sessionId, pendingModelSelection.id); - if (cancelled) { - return; - } + if (pendingModelSelection?.source === "explicit") { + const agentId = + resolveAgentProviderCatalogIdStrict( + pendingModelSelection.providerId ?? nextProviderId, + ) ?? "goose"; + setStoredModelPreference(agentId, { + modelId: pendingModelSelection.id, + modelName: pendingModelSelection.name, + providerId: pendingModelSelection.providerId ?? nextProviderId, + }); } } catch (error) { console.error("Failed to sync pending Home state:", error); @@ -663,14 +633,8 @@ export function useChatSessionController({ providersLoading, selectedProvider: selectedAgentId, handleProviderChange, - currentModelId: - pendingModelSelection !== undefined - ? (pendingModelSelection?.id ?? null) - : (session?.modelId ?? null), - currentModelName: - pendingModelSelection !== undefined - ? (pendingModelSelection?.name ?? null) - : session?.modelName, + currentModelId: effectiveModelSelection?.id ?? null, + currentModelName: effectiveModelSelection?.name ?? null, availableModels, modelsLoading, modelStatusMessage, diff --git a/ui/goose2/src/features/chat/hooks/useResolvedAgentModelPicker.ts b/ui/goose2/src/features/chat/hooks/useResolvedAgentModelPicker.ts new file mode 100644 index 0000000000..de65b34399 --- /dev/null +++ b/ui/goose2/src/features/chat/hooks/useResolvedAgentModelPicker.ts @@ -0,0 +1,483 @@ +import { useEffect, useMemo, useState } from "react"; +import type { AcpProvider } from "@/shared/api/acp"; +import { useProviderInventory } from "@/features/providers/hooks/useProviderInventory"; +import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog"; +import { getClient } from "@/shared/api/acpConnection"; +import { acpSetModel } from "@/shared/api/acp"; +import { + useChatSessionStore, + type ChatSession, +} from "../stores/chatSessionStore"; +import { useAgentModelPickerState } from "./useAgentModelPickerState"; +import { + clearStoredModelPreference, + getStoredModelPreference, + setStoredModelPreference, +} from "../lib/modelPreferences"; + +const GOOSE_PROVIDER_CONFIG_KEY = "GOOSE_PROVIDER"; +const GOOSE_MODEL_CONFIG_KEY = "GOOSE_MODEL"; +const MODEL_ALIAS_IDS = new Set(["current", "default"]); + +export type PreferredModelSelection = { + id: string; + name: string; + providerId?: string; + source: "default" | "explicit"; +}; + +interface UseResolvedAgentModelPickerOptions { + providers: AcpProvider[]; + selectedProvider: string; + sessionId: string | null; + session?: ChatSession; + pendingModelSelection: PreferredModelSelection | null | undefined; + setPendingProviderId: (providerId: string | undefined) => void; + setPendingModelSelection: ( + selection: PreferredModelSelection | null | undefined, + ) => void; + setGlobalSelectedProvider: (providerId: string) => void; + prepareSelectedProvider: ( + providerId: string, + modelSelection?: PreferredModelSelection | null, + ) => Promise; +} + +function isModelAlias(modelId?: string | null): boolean { + return modelId != null && MODEL_ALIAS_IDS.has(modelId); +} + +export function useResolvedAgentModelPicker({ + providers, + selectedProvider, + sessionId, + session, + pendingModelSelection, + setPendingProviderId, + setPendingModelSelection, + setGlobalSelectedProvider, + prepareSelectedProvider, +}: UseResolvedAgentModelPickerOptions) { + const { getEntry: getProviderInventoryEntry } = useProviderInventory(); + const [gooseDefaultSelection, setGooseDefaultSelection] = + useState(null); + + const selectedAgentId = + resolveAgentProviderCatalogIdStrict(selectedProvider) ?? "goose"; + const concreteSelectedProviderId = + resolveAgentProviderCatalogIdStrict(selectedProvider) == null + ? selectedProvider + : null; + const storedModelPreference = useMemo( + () => getStoredModelPreference(selectedAgentId), + [selectedAgentId], + ); + + const getPreferredSelectionForAgent = useMemo( + () => (agentId: string, fallbackProviderId?: string) => { + const preferredModel = getStoredModelPreference(agentId); + if (preferredModel) { + return { + id: preferredModel.modelId, + name: preferredModel.modelName, + providerId: preferredModel.providerId ?? fallbackProviderId, + source: "explicit" as const, + }; + } + + if (agentId === "goose") { + if (!gooseDefaultSelection) { + return null; + } + + return { + ...gooseDefaultSelection, + providerId: gooseDefaultSelection.providerId ?? fallbackProviderId, + }; + } + + const inventoryEntry = getProviderInventoryEntry(agentId); + if (!inventoryEntry?.defaultModel) { + return null; + } + + const resolvedInventoryModel = + inventoryEntry.models.find( + (model) => + model.id === inventoryEntry.defaultModel && !isModelAlias(model.id), + ) ?? + inventoryEntry.models.find((model) => model.recommended) ?? + inventoryEntry.models.find((model) => !isModelAlias(model.id)) ?? + inventoryEntry.models.find( + (model) => model.id === inventoryEntry.defaultModel, + ) ?? + inventoryEntry.models[0]; + + if (resolvedInventoryModel) { + return { + id: resolvedInventoryModel.id, + name: resolvedInventoryModel.name, + providerId: + inventoryEntry.providerId === agentId + ? inventoryEntry.providerId + : fallbackProviderId, + source: "default" as const, + }; + } + + return { + id: inventoryEntry.defaultModel, + name: inventoryEntry.defaultModel, + providerId: + inventoryEntry.providerId === agentId + ? inventoryEntry.providerId + : fallbackProviderId, + source: "default" as const, + }; + }, + [getProviderInventoryEntry, gooseDefaultSelection], + ); + + useEffect(() => { + if (selectedAgentId !== "goose") { + setGooseDefaultSelection(null); + return; + } + + let cancelled = false; + + const loadGooseDefaultSelection = async () => { + try { + const client = await getClient(); + const [providerResponse, modelResponse] = await Promise.all([ + client.goose.GooseConfigRead({ key: GOOSE_PROVIDER_CONFIG_KEY }), + client.goose.GooseConfigRead({ key: GOOSE_MODEL_CONFIG_KEY }), + ]); + + if (cancelled) { + return; + } + + const providerId = + typeof providerResponse.value === "string" + ? providerResponse.value + : undefined; + const modelId = + typeof modelResponse.value === "string" + ? modelResponse.value + : undefined; + + if (!modelId) { + setGooseDefaultSelection(null); + return; + } + + setGooseDefaultSelection({ + id: modelId, + name: modelId, + providerId, + source: "default", + }); + } catch { + if (!cancelled) { + setGooseDefaultSelection(null); + } + } + }; + + void loadGooseDefaultSelection(); + + return () => { + cancelled = true; + }; + }, [selectedAgentId]); + + const { + pickerAgents, + availableModels, + modelsLoading, + modelStatusMessage, + handleProviderChange, + handleModelChange, + } = useAgentModelPickerState({ + providers, + selectedProvider, + onProviderSelected: (providerId) => { + const requestedAgentId = resolveAgentProviderCatalogIdStrict(providerId); + const preferredModelSelection = getPreferredSelectionForAgent( + requestedAgentId ?? "goose", + providerId, + ); + const nextProviderId = requestedAgentId + ? (preferredModelSelection?.providerId ?? providerId) + : providerId; + const nextModelSelection = + !requestedAgentId && + preferredModelSelection?.providerId && + preferredModelSelection.providerId !== providerId + ? undefined + : preferredModelSelection + ? { + ...preferredModelSelection, + providerId: + requestedAgentId == null + ? providerId + : preferredModelSelection.providerId, + } + : undefined; + + if (!sessionId) { + setGlobalSelectedProvider(nextProviderId); + setPendingProviderId(nextProviderId); + setPendingModelSelection(nextModelSelection); + return; + } + + useChatSessionStore + .getState() + .switchSessionProvider(sessionId, nextProviderId); + setGlobalSelectedProvider(nextProviderId); + void prepareSelectedProvider(nextProviderId, nextModelSelection).catch( + (error) => { + console.error("Failed to update ACP session provider:", error); + }, + ); + }, + onModelSelected: (model) => { + const modelId = model.id; + const modelName = model.displayName ?? model.name ?? model.id; + const nextProviderId = model.providerId ?? selectedProvider; + const nextStoredModelPreference = { + modelId, + modelName, + providerId: nextProviderId, + }; + + if (!sessionId) { + if (nextProviderId && nextProviderId !== selectedProvider) { + setPendingProviderId(nextProviderId); + setGlobalSelectedProvider(nextProviderId); + } + setPendingModelSelection({ + id: modelId, + name: modelName, + providerId: nextProviderId, + source: "explicit", + }); + return; + } + + if ( + !session || + (modelId === session.modelId && + (!nextProviderId || nextProviderId === session.providerId)) + ) { + return; + } + + const previousStoredModelPreference = + getStoredModelPreference(selectedAgentId); + const previousProviderId = session.providerId; + const previousModelId = session.modelId; + const previousModelName = session.modelName; + const providerChanged = + Boolean(nextProviderId) && nextProviderId !== session.providerId; + + if (providerChanged && nextProviderId) { + useChatSessionStore + .getState() + .switchSessionProvider(sessionId, nextProviderId); + setGlobalSelectedProvider(nextProviderId); + } + + useChatSessionStore.getState().updateSession(sessionId, { + modelId, + modelName, + }); + + void (async () => { + try { + if (providerChanged && nextProviderId) { + await prepareSelectedProvider(nextProviderId); + } + await acpSetModel(sessionId, modelId); + setStoredModelPreference(selectedAgentId, nextStoredModelPreference); + } catch (error) { + console.error("Failed to set model:", error); + if (providerChanged && previousProviderId) { + setGlobalSelectedProvider(previousProviderId); + } + if (previousStoredModelPreference) { + setStoredModelPreference( + selectedAgentId, + previousStoredModelPreference, + ); + } else { + clearStoredModelPreference(selectedAgentId); + } + useChatSessionStore.getState().updateSession(sessionId, { + providerId: previousProviderId, + modelId: previousModelId, + modelName: previousModelName, + }); + void (async () => { + try { + if (providerChanged && previousProviderId) { + await prepareSelectedProvider(previousProviderId); + } + if (previousModelId) { + await acpSetModel(sessionId, previousModelId); + } + } catch (rollbackError) { + console.error( + "Failed to restore previous provider/model after setModel failure:", + rollbackError, + ); + } + })(); + } + })(); + }, + }); + + const preferredModelSelection = + useMemo(() => { + if (storedModelPreference) { + const matchingStoredModel = + availableModels.find( + (model) => + model.id === storedModelPreference.modelId && + (!storedModelPreference.providerId || + !model.providerId || + model.providerId === storedModelPreference.providerId) && + (!concreteSelectedProviderId || + !model.providerId || + model.providerId === concreteSelectedProviderId), + ) ?? null; + const storedSelectionCompatible = + !concreteSelectedProviderId || + storedModelPreference.providerId === concreteSelectedProviderId; + + if ( + matchingStoredModel || + ((availableModels.length === 0 || modelsLoading) && + storedSelectionCompatible) + ) { + return { + id: storedModelPreference.modelId, + name: + matchingStoredModel?.displayName ?? + matchingStoredModel?.name ?? + storedModelPreference.modelName, + providerId: + matchingStoredModel?.providerId ?? + storedModelPreference.providerId, + source: "explicit", + }; + } + } + + const inventoryDefaultSelection = getPreferredSelectionForAgent( + selectedAgentId, + selectedProvider, + ); + + if (!inventoryDefaultSelection) { + return null; + } + + const matchingDefaultModel = + availableModels.find( + (model) => + model.id === inventoryDefaultSelection.id && + (!inventoryDefaultSelection.providerId || + !model.providerId || + model.providerId === inventoryDefaultSelection.providerId) && + (!concreteSelectedProviderId || + !model.providerId || + model.providerId === concreteSelectedProviderId), + ) ?? null; + const defaultSelectionCompatible = + !concreteSelectedProviderId || + inventoryDefaultSelection.providerId === concreteSelectedProviderId; + + if (!matchingDefaultModel && !defaultSelectionCompatible) { + return null; + } + + return { + id: inventoryDefaultSelection.id, + name: + matchingDefaultModel?.displayName ?? + matchingDefaultModel?.name ?? + inventoryDefaultSelection.name, + providerId: + matchingDefaultModel?.providerId ?? + inventoryDefaultSelection.providerId, + source: "default", + }; + }, [ + availableModels, + getPreferredSelectionForAgent, + modelsLoading, + concreteSelectedProviderId, + selectedProvider, + selectedAgentId, + storedModelPreference, + ]); + + const sessionModelSelection = useMemo(() => { + if (!session?.modelId) { + return null; + } + + const matchingSessionModel = + availableModels.find( + (model) => + model.id === session.modelId && + (!session.providerId || + !model.providerId || + model.providerId === session.providerId), + ) ?? null; + + if (matchingSessionModel) { + return { + id: matchingSessionModel.id, + name: + matchingSessionModel.displayName ?? + matchingSessionModel.name ?? + session.modelName ?? + session.modelId, + providerId: matchingSessionModel.providerId ?? session.providerId, + source: "explicit", + }; + } + + if (isModelAlias(session.modelId)) { + return null; + } + + return { + id: session.modelId, + name: session.modelName ?? session.modelId, + providerId: session.providerId, + source: "explicit", + }; + }, [availableModels, session]); + + const effectiveModelSelection = + pendingModelSelection !== undefined + ? pendingModelSelection + : (sessionModelSelection ?? preferredModelSelection); + + return { + selectedAgentId, + pickerAgents, + availableModels, + modelsLoading, + modelStatusMessage, + handleProviderChange, + handleModelChange, + effectiveModelSelection, + }; +} diff --git a/ui/goose2/src/features/chat/lib/modelPreferences.ts b/ui/goose2/src/features/chat/lib/modelPreferences.ts new file mode 100644 index 0000000000..7055deb894 --- /dev/null +++ b/ui/goose2/src/features/chat/lib/modelPreferences.ts @@ -0,0 +1,86 @@ +import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog"; + +const MODEL_PREFERENCES_STORAGE_KEY = "goose:preferredModelsByAgent"; + +export interface StoredModelPreference { + modelId: string; + modelName: string; + providerId?: string; +} + +type StoredModelPreferences = Record; + +function readStoredModelPreferences(): StoredModelPreferences { + if (typeof window === "undefined") { + return {}; + } + + try { + const stored = window.localStorage.getItem(MODEL_PREFERENCES_STORAGE_KEY); + if (!stored) { + return {}; + } + + const parsed = JSON.parse(stored); + if (!parsed || typeof parsed !== "object") { + return {}; + } + + return parsed as StoredModelPreferences; + } catch { + return {}; + } +} + +function persistStoredModelPreferences( + preferences: StoredModelPreferences, +): void { + if (typeof window === "undefined") { + return; + } + + try { + if (Object.keys(preferences).length === 0) { + window.localStorage.removeItem(MODEL_PREFERENCES_STORAGE_KEY); + return; + } + + window.localStorage.setItem( + MODEL_PREFERENCES_STORAGE_KEY, + JSON.stringify(preferences), + ); + } catch { + // localStorage may be unavailable + } +} + +export function getStoredModelPreference( + agentId: string, +): StoredModelPreference | null { + return readStoredModelPreferences()[agentId] ?? null; +} + +export function getStoredModelPreferenceForProvider( + providerId: string, +): StoredModelPreference | null { + const agentId = resolveAgentProviderCatalogIdStrict(providerId) ?? "goose"; + return getStoredModelPreference(agentId); +} + +export function setStoredModelPreference( + agentId: string, + preference: StoredModelPreference, +): void { + const next = readStoredModelPreferences(); + next[agentId] = preference; + persistStoredModelPreferences(next); +} + +export function clearStoredModelPreference(agentId: string): void { + const next = readStoredModelPreferences(); + if (!(agentId in next)) { + return; + } + delete next[agentId]; + persistStoredModelPreferences(next); +} diff --git a/ui/goose2/src/features/chat/lib/sessionModelPreference.test.ts b/ui/goose2/src/features/chat/lib/sessionModelPreference.test.ts new file mode 100644 index 0000000000..69c2ddd94e --- /dev/null +++ b/ui/goose2/src/features/chat/lib/sessionModelPreference.test.ts @@ -0,0 +1,114 @@ +import { beforeEach, describe, expect, it } from "vitest"; +import { + resolveSessionModelPreference, + sanitizeSessionModelPreference, +} from "./sessionModelPreference"; + +describe("resolveSessionModelPreference", () => { + beforeEach(() => { + window.localStorage.clear(); + }); + + it("keeps a requested concrete provider when the stored preference uses a different provider", () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + providerId: "anthropic", + }, + }), + ); + + expect( + resolveSessionModelPreference({ + providerId: "openai", + }), + ).toEqual({ + providerId: "openai", + }); + }); + + it("reuses a stored model when it matches the requested concrete provider", () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "gpt-5.4", + modelName: "GPT-5.4", + providerId: "openai", + }, + }), + ); + + expect( + resolveSessionModelPreference({ + providerId: "openai", + }), + ).toEqual({ + providerId: "openai", + modelId: "gpt-5.4", + modelName: "GPT-5.4", + }); + }); + + it("resolves an agent provider to the stored concrete provider and model", () => { + window.localStorage.setItem( + "goose:preferredModelsByAgent", + JSON.stringify({ + goose: { + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + providerId: "anthropic", + }, + }), + ); + + expect( + resolveSessionModelPreference({ + providerId: "goose", + }), + ).toEqual({ + providerId: "anthropic", + modelId: "claude-sonnet-4", + modelName: "Claude Sonnet 4", + }); + }); + + it("keeps a stored model when the provider inventory still contains it", () => { + expect( + sanitizeSessionModelPreference( + { + providerId: "openai", + modelId: "gpt-5.4", + modelName: "GPT-5.4", + }, + { + models: [{ id: "gpt-5.4" }, { id: "gpt-5.4-mini" }], + }, + ), + ).toEqual({ + providerId: "openai", + modelId: "gpt-5.4", + modelName: "GPT-5.4", + }); + }); + + it("drops a stored model when the provider inventory no longer contains it", () => { + expect( + sanitizeSessionModelPreference( + { + providerId: "openai", + modelId: "gpt-4.1", + modelName: "GPT-4.1", + }, + { + models: [{ id: "gpt-5.4" }, { id: "gpt-5.4-mini" }], + }, + ), + ).toEqual({ + providerId: "openai", + }); + }); +}); diff --git a/ui/goose2/src/features/chat/lib/sessionModelPreference.ts b/ui/goose2/src/features/chat/lib/sessionModelPreference.ts new file mode 100644 index 0000000000..fc60b4bbc5 --- /dev/null +++ b/ui/goose2/src/features/chat/lib/sessionModelPreference.ts @@ -0,0 +1,75 @@ +import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog"; +import { getStoredModelPreferenceForProvider } from "./modelPreferences"; + +interface SessionModelPreferenceOptions { + providerId: string; + preferredModel?: string; +} + +export interface SessionModelPreference { + providerId: string; + modelId?: string; + modelName?: string; +} + +interface ProviderInventoryEntryLike { + models: Array<{ + id: string; + }>; +} + +export function resolveSessionModelPreference({ + providerId, + preferredModel, +}: SessionModelPreferenceOptions): SessionModelPreference { + if (preferredModel) { + return { + providerId, + modelId: preferredModel, + modelName: preferredModel, + }; + } + + const storedModelPreference = getStoredModelPreferenceForProvider(providerId); + if (!storedModelPreference) { + return { providerId }; + } + + if (resolveAgentProviderCatalogIdStrict(providerId)) { + return { + providerId: storedModelPreference.providerId ?? providerId, + modelId: storedModelPreference.modelId, + modelName: storedModelPreference.modelName, + }; + } + + if ( + storedModelPreference.providerId && + storedModelPreference.providerId !== providerId + ) { + return { providerId }; + } + + return { + providerId, + modelId: storedModelPreference.modelId, + modelName: storedModelPreference.modelName, + }; +} + +export function sanitizeSessionModelPreference( + preference: SessionModelPreference, + inventoryEntry?: ProviderInventoryEntryLike | null, +): SessionModelPreference { + if (!preference.modelId || !inventoryEntry) { + return preference; + } + + if (inventoryEntry.models.some((model) => model.id === preference.modelId)) { + return preference; + } + + return { + providerId: preference.providerId, + }; +}