mirror of
https://github.com/block/goose.git
synced 2026-04-26 10:40:45 +00:00
persist and reliably apply chat model selection (#8734)
Signed-off-by: morgmart <98432065+morgmart@users.noreply.github.com>
This commit is contained in:
parent
38941b1d26
commit
7325fbdae3
12 changed files with 1734 additions and 210 deletions
|
|
@ -10,7 +10,6 @@ import { TopBar } from "./ui/TopBar";
|
|||
import { useChatStore } from "@/features/chat/stores/chatStore";
|
||||
import {
|
||||
type ChatSession,
|
||||
hasSessionStarted,
|
||||
useChatSessionStore,
|
||||
} from "@/features/chat/stores/chatSessionStore";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
|
|
@ -18,14 +17,18 @@ import { useProjectStore } from "@/features/projects/stores/projectStore";
|
|||
import { findExistingDraft } from "@/features/chat/lib/newChat";
|
||||
import { DEFAULT_CHAT_TITLE } from "@/features/chat/lib/sessionTitle";
|
||||
import { useAppStartup } from "./hooks/useAppStartup";
|
||||
import { useHomeSessionStateSync } from "./hooks/useHomeSessionStateSync";
|
||||
import { loadStoredHomeSessionId } from "./lib/homeSessionStorage";
|
||||
import { resolveSupportedSessionModelPreference } from "./lib/resolveSupportedSessionModelPreference";
|
||||
import { AppShellContent } from "./ui/AppShellContent";
|
||||
import { acpPrepareSession } from "@/shared/api/acp";
|
||||
import { acpPrepareSession, acpSetModel } from "@/shared/api/acp";
|
||||
import {
|
||||
clearReplayBuffer,
|
||||
getAndDeleteReplayBuffer,
|
||||
} from "@/features/chat/hooks/replayBuffer";
|
||||
import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection";
|
||||
import { perfLog } from "@/shared/lib/perfLog";
|
||||
import { useProviderInventoryStore } from "@/features/providers/stores/providerInventoryStore";
|
||||
|
||||
export type AppView =
|
||||
| "home"
|
||||
|
|
@ -40,34 +43,6 @@ const SIDEBAR_MIN_WIDTH = 180;
|
|||
const SIDEBAR_MAX_WIDTH = 380;
|
||||
const SIDEBAR_SNAP_COLLAPSE_THRESHOLD = 100;
|
||||
const SIDEBAR_COLLAPSED_WIDTH = 48;
|
||||
const HOME_SESSION_STORAGE_KEY = "goose:home-session-id";
|
||||
|
||||
function loadStoredHomeSessionId(): string | null {
|
||||
if (typeof window === "undefined") {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return window.localStorage.getItem(HOME_SESSION_STORAGE_KEY);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function persistHomeSessionId(sessionId: string | null): void {
|
||||
if (typeof window === "undefined") {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (sessionId) {
|
||||
window.localStorage.setItem(HOME_SESSION_STORAGE_KEY, sessionId);
|
||||
return;
|
||||
}
|
||||
window.localStorage.removeItem(HOME_SESSION_STORAGE_KEY);
|
||||
} catch {
|
||||
// localStorage may be unavailable
|
||||
}
|
||||
}
|
||||
|
||||
export function AppShell({ children }: { children?: React.ReactNode }) {
|
||||
const [sidebarCollapsed, setSidebarCollapsed] = useState(false);
|
||||
const [sidebarWidth, setSidebarWidth] = useState(SIDEBAR_DEFAULT_WIDTH);
|
||||
|
|
@ -90,6 +65,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
const sessionStore = useChatSessionStore();
|
||||
const agentStore = useAgentStore();
|
||||
const projectStore = useProjectStore();
|
||||
const providerInventoryEntries = useProviderInventoryStore((s) => s.entries);
|
||||
|
||||
const pendingProjectCreatedRef = useRef<((projectId: string) => void) | null>(
|
||||
null,
|
||||
|
|
@ -173,37 +149,14 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
? sessionStore.getSession(homeSessionId)
|
||||
: undefined;
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!homeSessionId ||
|
||||
!sessionStore.hasHydratedSessions ||
|
||||
sessionStore.isLoading
|
||||
) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
!homeSession ||
|
||||
homeSession.archivedAt ||
|
||||
hasSessionStarted(
|
||||
homeSession,
|
||||
chatStore.messagesBySession[homeSession.id],
|
||||
)
|
||||
) {
|
||||
setHomeSessionId(null);
|
||||
}
|
||||
}, [
|
||||
chatStore.messagesBySession,
|
||||
homeSession,
|
||||
homeSession?.archivedAt,
|
||||
homeSession?.messageCount,
|
||||
useHomeSessionStateSync({
|
||||
homeSessionId,
|
||||
sessionStore.hasHydratedSessions,
|
||||
sessionStore.isLoading,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
persistHomeSessionId(homeSessionId);
|
||||
}, [homeSessionId]);
|
||||
homeSession,
|
||||
messagesBySession: chatStore.messagesBySession,
|
||||
hasHydratedSessions: sessionStore.hasHydratedSessions,
|
||||
isLoading: sessionStore.isLoading,
|
||||
setHomeSessionId,
|
||||
});
|
||||
|
||||
const ensureHomeSession = useCallback(async () => {
|
||||
if (!sessionStore.hasHydratedSessions || sessionStore.isLoading) {
|
||||
|
|
@ -220,6 +173,11 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
!homeSession.archivedAt &&
|
||||
homeSession.messageCount === 0
|
||||
) {
|
||||
const sessionModelPreference =
|
||||
await resolveSupportedSessionModelPreference(
|
||||
agentStore.selectedProvider ?? "goose",
|
||||
providerInventoryEntries,
|
||||
);
|
||||
const project = homeSession.projectId
|
||||
? (projectStore.projects.find(
|
||||
(candidate) => candidate.id === homeSession.projectId,
|
||||
|
|
@ -228,20 +186,42 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
const workingDir = await resolveSessionCwd(project);
|
||||
await acpPrepareSession(
|
||||
homeSession.id,
|
||||
homeSession.providerId ?? agentStore.selectedProvider ?? "goose",
|
||||
sessionModelPreference.providerId,
|
||||
workingDir,
|
||||
{
|
||||
personaId: homeSession.personaId,
|
||||
},
|
||||
);
|
||||
const shouldClearHomeModel =
|
||||
sessionModelPreference.providerId !== homeSession.providerId ||
|
||||
!sessionModelPreference.modelId;
|
||||
sessionStore.updateSession(homeSession.id, {
|
||||
providerId: sessionModelPreference.providerId,
|
||||
modelId: shouldClearHomeModel ? undefined : homeSession.modelId,
|
||||
modelName: shouldClearHomeModel ? undefined : homeSession.modelName,
|
||||
});
|
||||
if (sessionModelPreference.modelId) {
|
||||
await acpSetModel(homeSession.id, sessionModelPreference.modelId);
|
||||
sessionStore.updateSession(homeSession.id, {
|
||||
modelId: sessionModelPreference.modelId,
|
||||
modelName: sessionModelPreference.modelName,
|
||||
});
|
||||
}
|
||||
return homeSession;
|
||||
}
|
||||
|
||||
const workingDir = await resolveSessionCwd(null);
|
||||
const sessionModelPreference =
|
||||
await resolveSupportedSessionModelPreference(
|
||||
agentStore.selectedProvider ?? "goose",
|
||||
providerInventoryEntries,
|
||||
);
|
||||
const session = await sessionStore.createSession({
|
||||
title: DEFAULT_CHAT_TITLE,
|
||||
providerId: agentStore.selectedProvider ?? "goose",
|
||||
providerId: sessionModelPreference.providerId,
|
||||
workingDir,
|
||||
modelId: sessionModelPreference.modelId,
|
||||
modelName: sessionModelPreference.modelName,
|
||||
});
|
||||
setHomeSessionId(session.id);
|
||||
return session;
|
||||
|
|
@ -258,6 +238,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
}, [
|
||||
agentStore.selectedProvider,
|
||||
homeSession,
|
||||
providerInventoryEntries,
|
||||
projectStore.projects,
|
||||
sessionStore.hasHydratedSessions,
|
||||
sessionStore,
|
||||
|
|
@ -282,7 +263,12 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
const agentId = agentStore.activeAgentId ?? undefined;
|
||||
const providerId =
|
||||
project?.preferredProvider ?? agentStore.selectedProvider ?? "goose";
|
||||
const modelId = project?.preferredModel ?? undefined;
|
||||
const sessionModelPreference =
|
||||
await resolveSupportedSessionModelPreference(
|
||||
providerId,
|
||||
providerInventoryEntries,
|
||||
project?.preferredModel ?? undefined,
|
||||
);
|
||||
const sessionState = useChatSessionStore.getState();
|
||||
const chatState = useChatStore.getState();
|
||||
const existingDraft = findExistingDraft({
|
||||
|
|
@ -311,10 +297,10 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
title,
|
||||
projectId: project?.id,
|
||||
agentId,
|
||||
providerId,
|
||||
providerId: sessionModelPreference.providerId,
|
||||
workingDir,
|
||||
modelId,
|
||||
modelName: modelId,
|
||||
modelId: sessionModelPreference.modelId,
|
||||
modelName: sessionModelPreference.modelName,
|
||||
});
|
||||
sessionStore.setActiveSession(session.id);
|
||||
setActiveView("chat");
|
||||
|
|
@ -328,6 +314,7 @@ export function AppShell({ children }: { children?: React.ReactNode }) {
|
|||
agentStore.activeAgentId,
|
||||
agentStore.selectedProvider,
|
||||
chatStore,
|
||||
providerInventoryEntries,
|
||||
sessionStore,
|
||||
],
|
||||
);
|
||||
|
|
|
|||
51
ui/goose2/src/app/hooks/useHomeSessionStateSync.ts
Normal file
51
ui/goose2/src/app/hooks/useHomeSessionStateSync.ts
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
import { useEffect } from "react";
|
||||
import {
|
||||
hasSessionStarted,
|
||||
type ChatSession,
|
||||
} from "@/features/chat/stores/chatSessionStore";
|
||||
import { persistHomeSessionId } from "../lib/homeSessionStorage";
|
||||
|
||||
interface UseHomeSessionStateSyncOptions {
|
||||
homeSessionId: string | null;
|
||||
homeSession?: ChatSession;
|
||||
messagesBySession: Record<string, ArrayLike<unknown> | undefined>;
|
||||
hasHydratedSessions: boolean;
|
||||
isLoading: boolean;
|
||||
setHomeSessionId: (sessionId: string | null) => void;
|
||||
}
|
||||
|
||||
export function useHomeSessionStateSync({
|
||||
homeSessionId,
|
||||
homeSession,
|
||||
messagesBySession,
|
||||
hasHydratedSessions,
|
||||
isLoading,
|
||||
setHomeSessionId,
|
||||
}: UseHomeSessionStateSyncOptions): void {
|
||||
useEffect(() => {
|
||||
if (!homeSessionId || !hasHydratedSessions || isLoading) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
!homeSession ||
|
||||
homeSession.archivedAt ||
|
||||
hasSessionStarted(homeSession, messagesBySession[homeSession.id])
|
||||
) {
|
||||
setHomeSessionId(null);
|
||||
}
|
||||
}, [
|
||||
hasHydratedSessions,
|
||||
homeSession,
|
||||
homeSession?.archivedAt,
|
||||
homeSession?.messageCount,
|
||||
homeSessionId,
|
||||
isLoading,
|
||||
messagesBySession,
|
||||
setHomeSessionId,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
persistHomeSessionId(homeSessionId);
|
||||
}, [homeSessionId]);
|
||||
}
|
||||
27
ui/goose2/src/app/lib/homeSessionStorage.ts
Normal file
27
ui/goose2/src/app/lib/homeSessionStorage.ts
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
const HOME_SESSION_STORAGE_KEY = "goose:home-session-id";
|
||||
|
||||
export function loadStoredHomeSessionId(): string | null {
|
||||
if (typeof window === "undefined") {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return window.localStorage.getItem(HOME_SESSION_STORAGE_KEY);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function persistHomeSessionId(sessionId: string | null): void {
|
||||
if (typeof window === "undefined") {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (sessionId) {
|
||||
window.localStorage.setItem(HOME_SESSION_STORAGE_KEY, sessionId);
|
||||
return;
|
||||
}
|
||||
window.localStorage.removeItem(HOME_SESSION_STORAGE_KEY);
|
||||
} catch {
|
||||
// localStorage may be unavailable
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { resolveSupportedSessionModelPreference } from "./resolveSupportedSessionModelPreference";
|
||||
|
||||
const mockGetProviderInventory = vi.fn();
|
||||
|
||||
vi.mock("@/features/providers/api/inventory", () => ({
|
||||
getProviderInventory: (...args: unknown[]) =>
|
||||
mockGetProviderInventory(...args),
|
||||
}));
|
||||
|
||||
describe("resolveSupportedSessionModelPreference", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
window.localStorage.clear();
|
||||
});
|
||||
|
||||
it("drops the model when provider inventory lookup fails", async () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "gpt-5.4",
|
||||
modelName: "GPT-5.4",
|
||||
providerId: "openai",
|
||||
},
|
||||
}),
|
||||
);
|
||||
mockGetProviderInventory.mockRejectedValue(
|
||||
new Error("inventory unavailable"),
|
||||
);
|
||||
|
||||
await expect(
|
||||
resolveSupportedSessionModelPreference("goose", new Map()),
|
||||
).resolves.toEqual({
|
||||
providerId: "openai",
|
||||
});
|
||||
});
|
||||
|
||||
it("drops the model when provider inventory has no matching entry", async () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "gpt-5.4",
|
||||
modelName: "GPT-5.4",
|
||||
providerId: "openai",
|
||||
},
|
||||
}),
|
||||
);
|
||||
mockGetProviderInventory.mockResolvedValue([]);
|
||||
|
||||
await expect(
|
||||
resolveSupportedSessionModelPreference("goose", new Map()),
|
||||
).resolves.toEqual({
|
||||
providerId: "openai",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
import type { ProviderInventoryEntryDto } from "@aaif/goose-sdk";
|
||||
import { getProviderInventory } from "@/features/providers/api/inventory";
|
||||
import {
|
||||
resolveSessionModelPreference,
|
||||
sanitizeSessionModelPreference,
|
||||
type SessionModelPreference,
|
||||
} from "@/features/chat/lib/sessionModelPreference";
|
||||
|
||||
export async function resolveSupportedSessionModelPreference(
|
||||
providerId: string,
|
||||
inventoryEntries: Map<string, ProviderInventoryEntryDto>,
|
||||
preferredModel?: string,
|
||||
): Promise<SessionModelPreference> {
|
||||
const sessionModelPreference = resolveSessionModelPreference({
|
||||
providerId,
|
||||
preferredModel,
|
||||
});
|
||||
|
||||
if (!sessionModelPreference.modelId) {
|
||||
return sessionModelPreference;
|
||||
}
|
||||
|
||||
const inventoryEntry =
|
||||
inventoryEntries.get(sessionModelPreference.providerId) ??
|
||||
(await getProviderInventory([sessionModelPreference.providerId])
|
||||
.then(([entry]) => entry)
|
||||
.catch(() => undefined));
|
||||
|
||||
if (!inventoryEntry) {
|
||||
return {
|
||||
providerId: sessionModelPreference.providerId,
|
||||
};
|
||||
}
|
||||
|
||||
return sanitizeSessionModelPreference(sessionModelPreference, inventoryEntry);
|
||||
}
|
||||
|
|
@ -9,12 +9,37 @@ const mockAcpPrepareSession = vi.fn();
|
|||
const mockAcpSetModel = vi.fn();
|
||||
const mockSetSelectedProvider = vi.fn();
|
||||
const mockResolveSessionCwd = vi.fn();
|
||||
const mockGooseConfigRead = vi.fn();
|
||||
const mockUseProviderInventory = vi.fn();
|
||||
const mockPickerState = {
|
||||
pickerAgents: [{ id: "goose", label: "Goose" }],
|
||||
availableModels: [] as Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
displayName?: string;
|
||||
providerId?: string;
|
||||
}>,
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null as string | null,
|
||||
};
|
||||
|
||||
vi.mock("@/shared/api/acp", () => ({
|
||||
acpPrepareSession: (...args: unknown[]) => mockAcpPrepareSession(...args),
|
||||
acpSetModel: (...args: unknown[]) => mockAcpSetModel(...args),
|
||||
}));
|
||||
|
||||
vi.mock("@/shared/api/acpConnection", () => ({
|
||||
getClient: async () => ({
|
||||
goose: {
|
||||
GooseConfigRead: (...args: unknown[]) => mockGooseConfigRead(...args),
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("@/features/providers/hooks/useProviderInventory", () => ({
|
||||
useProviderInventory: () => mockUseProviderInventory(),
|
||||
}));
|
||||
|
||||
vi.mock("../useChat", () => ({
|
||||
useChat: () => ({
|
||||
messages: [],
|
||||
|
|
@ -41,7 +66,7 @@ vi.mock("@/features/agents/hooks/useProviderSelection", () => ({
|
|||
{ id: "anthropic", label: "Anthropic" },
|
||||
],
|
||||
providersLoading: false,
|
||||
selectedProvider: "openai",
|
||||
selectedProvider: useAgentStore.getState().selectedProvider ?? "openai",
|
||||
setSelectedProvider: (...args: unknown[]) =>
|
||||
mockSetSelectedProvider(...args),
|
||||
}),
|
||||
|
|
@ -63,10 +88,10 @@ vi.mock("../useAgentModelPickerState", () => ({
|
|||
}) => void;
|
||||
}) => ({
|
||||
selectedAgentId: "goose",
|
||||
pickerAgents: [{ id: "goose", label: "Goose" }],
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null,
|
||||
pickerAgents: mockPickerState.pickerAgents,
|
||||
availableModels: mockPickerState.availableModels,
|
||||
modelsLoading: mockPickerState.modelsLoading,
|
||||
modelStatusMessage: mockPickerState.modelStatusMessage,
|
||||
handleProviderChange: vi.fn(),
|
||||
handleModelChange: (modelId: string) => {
|
||||
if (modelId === "claude-sonnet-4") {
|
||||
|
|
@ -86,9 +111,18 @@ import { useChatSessionController } from "../useChatSessionController";
|
|||
describe("useChatSessionController", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
window.localStorage.clear();
|
||||
mockAcpPrepareSession.mockResolvedValue(undefined);
|
||||
mockAcpSetModel.mockResolvedValue(undefined);
|
||||
mockResolveSessionCwd.mockResolvedValue("/tmp/project");
|
||||
mockGooseConfigRead.mockResolvedValue({ value: null });
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
getEntry: () => undefined,
|
||||
});
|
||||
mockPickerState.pickerAgents = [{ id: "goose", label: "Goose" }];
|
||||
mockPickerState.availableModels = [];
|
||||
mockPickerState.modelsLoading = false;
|
||||
mockPickerState.modelStatusMessage = null;
|
||||
|
||||
useAgentStore.setState({
|
||||
personas: [],
|
||||
|
|
@ -178,4 +212,217 @@ describe("useChatSessionController", () => {
|
|||
modelName: "Claude Sonnet 4",
|
||||
});
|
||||
});
|
||||
|
||||
it("restores the previous stored model preference when setting a model fails", async () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "gpt-4o",
|
||||
modelName: "GPT-4o",
|
||||
providerId: "openai",
|
||||
},
|
||||
}),
|
||||
);
|
||||
mockAcpSetModel.mockRejectedValueOnce(new Error("set model failed"));
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChatSessionController({ sessionId: "session-1" }),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChange("claude-sonnet-4");
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
useChatSessionStore.getState().getSession("session-1"),
|
||||
).toMatchObject({
|
||||
providerId: "openai",
|
||||
modelId: "gpt-4o",
|
||||
modelName: "GPT-4o",
|
||||
});
|
||||
});
|
||||
|
||||
expect(
|
||||
JSON.parse(
|
||||
window.localStorage.getItem("goose:preferredModelsByAgent") ?? "{}",
|
||||
),
|
||||
).toEqual({
|
||||
goose: {
|
||||
modelId: "gpt-4o",
|
||||
modelName: "GPT-4o",
|
||||
providerId: "openai",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("shows the stored explicit model for new chats", async () => {
|
||||
useAgentStore.setState({ selectedProvider: "goose" });
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChatSessionController({ sessionId: null }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.currentModelId).toBe("claude-sonnet-4");
|
||||
});
|
||||
expect(result.current.currentModelName).toBe("Claude Sonnet 4");
|
||||
});
|
||||
|
||||
it("falls back to the configured goose default model when no explicit model is stored", async () => {
|
||||
useAgentStore.setState({ selectedProvider: "goose" });
|
||||
mockGooseConfigRead.mockImplementation(
|
||||
async ({ key }: { key: string }): Promise<{ value: string | null }> => {
|
||||
if (key === "GOOSE_PROVIDER") {
|
||||
return { value: "databricks" };
|
||||
}
|
||||
if (key === "GOOSE_MODEL") {
|
||||
return { value: "goose-claude-4-6-opus" };
|
||||
}
|
||||
return { value: null };
|
||||
},
|
||||
);
|
||||
mockPickerState.availableModels = [
|
||||
{
|
||||
id: "goose-claude-4-6-opus",
|
||||
name: "Claude 4.6 Opus",
|
||||
providerId: "databricks",
|
||||
},
|
||||
];
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useChatSessionController({ sessionId: null }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.currentModelId).toBe("goose-claude-4-6-opus");
|
||||
});
|
||||
expect(result.current.currentModelName).toBe("Claude 4.6 Opus");
|
||||
});
|
||||
|
||||
it("applies the pending Home model to ACP when a real session becomes active", async () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ sessionId }: { sessionId: string | null }) =>
|
||||
useChatSessionController({ sessionId }),
|
||||
{
|
||||
initialProps: { sessionId: null as string | null },
|
||||
},
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChange("claude-sonnet-4");
|
||||
});
|
||||
|
||||
useChatSessionStore.setState((state) => ({
|
||||
sessions: [
|
||||
{
|
||||
id: "session-2",
|
||||
title: "Chat",
|
||||
providerId: "openai",
|
||||
createdAt: "2026-04-21T00:00:00.000Z",
|
||||
updatedAt: "2026-04-21T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
...state.sessions,
|
||||
],
|
||||
}));
|
||||
|
||||
rerender({ sessionId: "session-2" });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockAcpPrepareSession).toHaveBeenCalledWith(
|
||||
"session-2",
|
||||
"anthropic",
|
||||
"/tmp/project",
|
||||
{ personaId: undefined },
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockAcpSetModel).toHaveBeenCalledWith(
|
||||
"session-2",
|
||||
"claude-sonnet-4",
|
||||
);
|
||||
});
|
||||
|
||||
expect(
|
||||
useChatSessionStore.getState().getSession("session-2"),
|
||||
).toMatchObject({
|
||||
providerId: "anthropic",
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
});
|
||||
});
|
||||
|
||||
it("does not persist or record a pending Home model when ACP rejects it", async () => {
|
||||
mockAcpSetModel.mockRejectedValueOnce(new Error("set model failed"));
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ sessionId }: { sessionId: string | null }) =>
|
||||
useChatSessionController({ sessionId }),
|
||||
{
|
||||
initialProps: { sessionId: null as string | null },
|
||||
},
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleModelChange("claude-sonnet-4");
|
||||
});
|
||||
|
||||
expect(
|
||||
window.localStorage.getItem("goose:preferredModelsByAgent"),
|
||||
).toBeNull();
|
||||
|
||||
useChatSessionStore.setState((state) => ({
|
||||
sessions: [
|
||||
{
|
||||
id: "session-3",
|
||||
title: "Chat",
|
||||
providerId: "openai",
|
||||
createdAt: "2026-04-21T00:00:00.000Z",
|
||||
updatedAt: "2026-04-21T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
...state.sessions,
|
||||
],
|
||||
}));
|
||||
|
||||
rerender({ sessionId: "session-3" });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockAcpSetModel).toHaveBeenCalledWith(
|
||||
"session-3",
|
||||
"claude-sonnet-4",
|
||||
);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
useChatSessionStore.getState().getSession("session-3"),
|
||||
).toMatchObject({
|
||||
providerId: "anthropic",
|
||||
});
|
||||
});
|
||||
|
||||
expect(
|
||||
useChatSessionStore.getState().getSession("session-3"),
|
||||
).not.toMatchObject({
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
});
|
||||
expect(
|
||||
window.localStorage.getItem("goose:preferredModelsByAgent"),
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,396 @@
|
|||
import { act, renderHook } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { useResolvedAgentModelPicker } from "../useResolvedAgentModelPicker";
|
||||
|
||||
const mockUseProviderInventory = vi.fn();
|
||||
const mockUseAgentModelPickerState = vi.fn();
|
||||
const mockGetClient = vi.fn();
|
||||
|
||||
vi.mock("@/features/providers/hooks/useProviderInventory", () => ({
|
||||
useProviderInventory: () => mockUseProviderInventory(),
|
||||
}));
|
||||
|
||||
vi.mock("../useAgentModelPickerState", () => ({
|
||||
useAgentModelPickerState: (args: unknown) =>
|
||||
mockUseAgentModelPickerState(args),
|
||||
}));
|
||||
|
||||
vi.mock("@/shared/api/acpConnection", () => ({
|
||||
getClient: (...args: unknown[]) => mockGetClient(...args),
|
||||
}));
|
||||
|
||||
describe("useResolvedAgentModelPicker", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
window.localStorage.clear();
|
||||
|
||||
mockGetClient.mockResolvedValue({
|
||||
goose: {
|
||||
GooseConfigRead: vi.fn().mockResolvedValue({ value: null }),
|
||||
},
|
||||
});
|
||||
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
getEntry: (providerId: string) =>
|
||||
providerId === "codex-acp"
|
||||
? {
|
||||
providerId: "codex-acp",
|
||||
defaultModel: "gpt-5.4",
|
||||
models: [
|
||||
{
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
recommended: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
|
||||
mockUseAgentModelPickerState.mockImplementation(
|
||||
({
|
||||
onProviderSelected,
|
||||
}: {
|
||||
onProviderSelected: (providerId: string) => void;
|
||||
}) => ({
|
||||
pickerAgents: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "codex-acp", label: "Codex" },
|
||||
],
|
||||
availableModels: [],
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null,
|
||||
handleProviderChange: (providerId: string) =>
|
||||
onProviderSelected(providerId),
|
||||
handleModelChange: vi.fn(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("selects the agent default model when switching to a provider without a saved model", () => {
|
||||
const setPendingProviderId = vi.fn();
|
||||
const setPendingModelSelection = vi.fn();
|
||||
const setGlobalSelectedProvider = vi.fn();
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useResolvedAgentModelPicker({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "codex-acp", label: "Codex" },
|
||||
],
|
||||
selectedProvider: "goose",
|
||||
sessionId: null,
|
||||
session: undefined,
|
||||
pendingModelSelection: undefined,
|
||||
setPendingProviderId,
|
||||
setPendingModelSelection,
|
||||
setGlobalSelectedProvider,
|
||||
prepareSelectedProvider: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProviderChange("codex-acp");
|
||||
});
|
||||
|
||||
expect(setGlobalSelectedProvider).toHaveBeenCalledWith("codex-acp");
|
||||
expect(setPendingProviderId).toHaveBeenCalledWith("codex-acp");
|
||||
expect(setPendingModelSelection).toHaveBeenCalledWith({
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
providerId: "codex-acp",
|
||||
source: "default",
|
||||
});
|
||||
});
|
||||
|
||||
it("selects the saved model when switching back to an agent", () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
"codex-acp": {
|
||||
modelId: "gpt-5.4-mini",
|
||||
modelName: "GPT-5.4 mini",
|
||||
providerId: "codex-acp",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const setPendingProviderId = vi.fn();
|
||||
const setPendingModelSelection = vi.fn();
|
||||
const setGlobalSelectedProvider = vi.fn();
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useResolvedAgentModelPicker({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "codex-acp", label: "Codex" },
|
||||
],
|
||||
selectedProvider: "goose",
|
||||
sessionId: null,
|
||||
session: undefined,
|
||||
pendingModelSelection: undefined,
|
||||
setPendingProviderId,
|
||||
setPendingModelSelection,
|
||||
setGlobalSelectedProvider,
|
||||
prepareSelectedProvider: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProviderChange("codex-acp");
|
||||
});
|
||||
|
||||
expect(setGlobalSelectedProvider).toHaveBeenCalledWith("codex-acp");
|
||||
expect(setPendingProviderId).toHaveBeenCalledWith("codex-acp");
|
||||
expect(setPendingModelSelection).toHaveBeenCalledWith({
|
||||
id: "gpt-5.4-mini",
|
||||
name: "GPT-5.4 mini",
|
||||
providerId: "codex-acp",
|
||||
source: "explicit",
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps explicit concrete provider requests authoritative", () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
const setPendingProviderId = vi.fn();
|
||||
const setPendingModelSelection = vi.fn();
|
||||
const setGlobalSelectedProvider = vi.fn();
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useResolvedAgentModelPicker({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "openai", label: "OpenAI" },
|
||||
],
|
||||
selectedProvider: "anthropic",
|
||||
sessionId: null,
|
||||
session: undefined,
|
||||
pendingModelSelection: undefined,
|
||||
setPendingProviderId,
|
||||
setPendingModelSelection,
|
||||
setGlobalSelectedProvider,
|
||||
prepareSelectedProvider: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProviderChange("openai");
|
||||
});
|
||||
|
||||
expect(setGlobalSelectedProvider).toHaveBeenCalledWith("openai");
|
||||
expect(setPendingProviderId).toHaveBeenCalledWith("openai");
|
||||
expect(setPendingModelSelection).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
|
||||
it("resolves ACP alias defaults to a concrete model when switching agents", () => {
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
getEntry: (providerId: string) =>
|
||||
providerId === "claude-acp"
|
||||
? {
|
||||
providerId: "claude-acp",
|
||||
defaultModel: "current",
|
||||
models: [
|
||||
{
|
||||
id: "sonnet",
|
||||
name: "Claude Sonnet",
|
||||
recommended: true,
|
||||
},
|
||||
{
|
||||
id: "opus",
|
||||
name: "Claude Opus",
|
||||
recommended: false,
|
||||
},
|
||||
],
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
|
||||
const setPendingProviderId = vi.fn();
|
||||
const setPendingModelSelection = vi.fn();
|
||||
const setGlobalSelectedProvider = vi.fn();
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useResolvedAgentModelPicker({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "claude-acp", label: "Claude Code" },
|
||||
],
|
||||
selectedProvider: "goose",
|
||||
sessionId: null,
|
||||
session: undefined,
|
||||
pendingModelSelection: undefined,
|
||||
setPendingProviderId,
|
||||
setPendingModelSelection,
|
||||
setGlobalSelectedProvider,
|
||||
prepareSelectedProvider: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProviderChange("claude-acp");
|
||||
});
|
||||
|
||||
expect(setPendingModelSelection).toHaveBeenCalledWith({
|
||||
id: "sonnet",
|
||||
name: "Claude Sonnet",
|
||||
providerId: "claude-acp",
|
||||
source: "default",
|
||||
});
|
||||
});
|
||||
|
||||
it("prefers a concrete default model over a session alias like current", () => {
|
||||
mockUseProviderInventory.mockReturnValue({
|
||||
getEntry: (providerId: string) =>
|
||||
providerId === "claude-acp"
|
||||
? {
|
||||
providerId: "claude-acp",
|
||||
defaultModel: "current",
|
||||
models: [
|
||||
{
|
||||
id: "sonnet",
|
||||
name: "Claude Sonnet",
|
||||
recommended: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
|
||||
mockUseAgentModelPickerState.mockImplementation(
|
||||
({
|
||||
onProviderSelected,
|
||||
}: {
|
||||
onProviderSelected: (providerId: string) => void;
|
||||
}) => ({
|
||||
pickerAgents: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "claude-acp", label: "Claude Code" },
|
||||
],
|
||||
availableModels: [
|
||||
{
|
||||
id: "sonnet",
|
||||
name: "Claude Sonnet",
|
||||
recommended: true,
|
||||
},
|
||||
],
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null,
|
||||
handleProviderChange: (providerId: string) =>
|
||||
onProviderSelected(providerId),
|
||||
handleModelChange: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useResolvedAgentModelPicker({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "claude-acp", label: "Claude Code" },
|
||||
],
|
||||
selectedProvider: "claude-acp",
|
||||
sessionId: "session-1",
|
||||
session: {
|
||||
id: "session-1",
|
||||
title: "Chat",
|
||||
providerId: "claude-acp",
|
||||
modelId: "current",
|
||||
modelName: "current",
|
||||
createdAt: "2026-04-21T00:00:00.000Z",
|
||||
updatedAt: "2026-04-21T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
pendingModelSelection: undefined,
|
||||
setPendingProviderId: vi.fn(),
|
||||
setPendingModelSelection: vi.fn(),
|
||||
setGlobalSelectedProvider: vi.fn(),
|
||||
prepareSelectedProvider: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.current.effectiveModelSelection).toEqual({
|
||||
id: "sonnet",
|
||||
name: "Claude Sonnet",
|
||||
providerId: "claude-acp",
|
||||
source: "default",
|
||||
});
|
||||
});
|
||||
|
||||
it("drops Goose fallback models that are incompatible with a concrete provider", () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
mockUseAgentModelPickerState.mockImplementation(
|
||||
({
|
||||
onProviderSelected,
|
||||
}: {
|
||||
onProviderSelected: (providerId: string) => void;
|
||||
}) => ({
|
||||
pickerAgents: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "openai", label: "OpenAI" },
|
||||
],
|
||||
availableModels: [
|
||||
{
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
providerId: "openai",
|
||||
},
|
||||
{
|
||||
id: "claude-sonnet-4",
|
||||
name: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
},
|
||||
],
|
||||
modelsLoading: false,
|
||||
modelStatusMessage: null,
|
||||
handleProviderChange: (providerId: string) =>
|
||||
onProviderSelected(providerId),
|
||||
handleModelChange: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useResolvedAgentModelPicker({
|
||||
providers: [
|
||||
{ id: "goose", label: "Goose" },
|
||||
{ id: "openai", label: "OpenAI" },
|
||||
],
|
||||
selectedProvider: "openai",
|
||||
sessionId: "session-1",
|
||||
session: {
|
||||
id: "session-1",
|
||||
title: "Chat",
|
||||
providerId: "openai",
|
||||
createdAt: "2026-04-21T00:00:00.000Z",
|
||||
updatedAt: "2026-04-21T00:00:00.000Z",
|
||||
messageCount: 0,
|
||||
},
|
||||
pendingModelSelection: undefined,
|
||||
setPendingProviderId: vi.fn(),
|
||||
setPendingModelSelection: vi.fn(),
|
||||
setGlobalSelectedProvider: vi.fn(),
|
||||
prepareSelectedProvider: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.current.effectiveModelSelection).toBeNull();
|
||||
});
|
||||
});
|
||||
|
|
@ -7,15 +7,20 @@ import { useChatSessionStore } from "../stores/chatSessionStore";
|
|||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import { useProviderSelection } from "@/features/agents/hooks/useProviderSelection";
|
||||
import { useProjectStore } from "@/features/projects/stores/projectStore";
|
||||
import { useAgentModelPickerState } from "./useAgentModelPickerState";
|
||||
import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog";
|
||||
import {
|
||||
buildProjectSystemPrompt,
|
||||
composeSystemPrompt,
|
||||
getProjectArtifactRoots,
|
||||
resolveProjectDefaultArtifactRoot,
|
||||
} from "@/features/projects/lib/chatProjectContext";
|
||||
import { setStoredModelPreference } from "../lib/modelPreferences";
|
||||
import { resolveSessionCwd } from "@/features/projects/lib/sessionCwdSelection";
|
||||
import { acpPrepareSession, acpSetModel } from "@/shared/api/acp";
|
||||
import {
|
||||
useResolvedAgentModelPicker,
|
||||
type PreferredModelSelection,
|
||||
} from "./useResolvedAgentModelPicker";
|
||||
|
||||
interface UseChatSessionControllerOptions {
|
||||
sessionId: string | null;
|
||||
|
|
@ -52,11 +57,8 @@ export function useChatSessionController({
|
|||
const [pendingPersonaId, setPendingPersonaId] = useState<string | null>();
|
||||
const [pendingProjectId, setPendingProjectId] = useState<string | null>();
|
||||
const [pendingProviderId, setPendingProviderId] = useState<string>();
|
||||
const [pendingModelSelection, setPendingModelSelection] = useState<{
|
||||
id: string;
|
||||
name: string;
|
||||
providerId?: string;
|
||||
} | null>();
|
||||
const [pendingModelSelection, setPendingModelSelection] =
|
||||
useState<PreferredModelSelection | null>();
|
||||
const pendingDraftValue = useChatStore(
|
||||
(s) => s.draftsBySession[PENDING_HOME_SESSION_ID] ?? "",
|
||||
);
|
||||
|
|
@ -140,6 +142,7 @@ export function useChatSessionController({
|
|||
nextProject = project,
|
||||
nextWorkspacePath = activeWorkspace?.path,
|
||||
personaId = selectedPersonaId ?? undefined,
|
||||
modelSelection?: PreferredModelSelection | null,
|
||||
) => {
|
||||
if (!sessionId) {
|
||||
return;
|
||||
|
|
@ -149,9 +152,39 @@ export function useChatSessionController({
|
|||
nextWorkspacePath,
|
||||
);
|
||||
await acpPrepareSession(sessionId, providerId, workingDir, { personaId });
|
||||
if (!modelSelection?.id) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sessionStore = useChatSessionStore.getState();
|
||||
const liveSession = sessionStore.getSession(sessionId);
|
||||
const modelAlreadyApplied =
|
||||
liveSession?.modelId === modelSelection.id &&
|
||||
liveSession?.modelName === modelSelection.name;
|
||||
|
||||
if (modelAlreadyApplied) {
|
||||
return;
|
||||
}
|
||||
|
||||
await acpSetModel(sessionId, modelSelection.id);
|
||||
sessionStore.updateSession(sessionId, {
|
||||
modelId: modelSelection.id,
|
||||
modelName: modelSelection.name,
|
||||
});
|
||||
},
|
||||
[activeWorkspace?.path, project, selectedPersonaId, sessionId],
|
||||
);
|
||||
const prepareSelectedProvider = useCallback(
|
||||
(providerId: string, modelSelection?: PreferredModelSelection | null) =>
|
||||
prepareCurrentSession(
|
||||
providerId,
|
||||
project,
|
||||
activeWorkspace?.path,
|
||||
selectedPersonaId ?? undefined,
|
||||
modelSelection,
|
||||
),
|
||||
[activeWorkspace?.path, prepareCurrentSession, project, selectedPersonaId],
|
||||
);
|
||||
|
||||
const prevProjectIdRef = useRef(session?.projectId);
|
||||
useEffect(() => {
|
||||
|
|
@ -168,6 +201,27 @@ export function useChatSessionController({
|
|||
}
|
||||
}, [clearActiveWorkspace, session?.projectId, sessionId]);
|
||||
|
||||
const {
|
||||
selectedAgentId,
|
||||
pickerAgents,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleProviderChange,
|
||||
handleModelChange,
|
||||
effectiveModelSelection,
|
||||
} = useResolvedAgentModelPicker({
|
||||
providers,
|
||||
selectedProvider,
|
||||
sessionId,
|
||||
session,
|
||||
pendingModelSelection,
|
||||
setPendingProviderId,
|
||||
setPendingModelSelection,
|
||||
setGlobalSelectedProvider,
|
||||
prepareSelectedProvider,
|
||||
});
|
||||
|
||||
const prevWorkspaceRef = useRef(activeWorkspace);
|
||||
useEffect(() => {
|
||||
const previousWorkspace = prevWorkspaceRef.current;
|
||||
|
|
@ -183,114 +237,19 @@ export function useChatSessionController({
|
|||
if (previousWorkspace?.path === activeWorkspace.path) {
|
||||
return;
|
||||
}
|
||||
void prepareCurrentSession(selectedProvider).catch((error) => {
|
||||
void prepareSelectedProvider(
|
||||
selectedProvider,
|
||||
effectiveModelSelection,
|
||||
).catch((error) => {
|
||||
console.error("Failed to prepare ACP session:", error);
|
||||
});
|
||||
}, [activeWorkspace, prepareCurrentSession, selectedProvider, sessionId]);
|
||||
|
||||
const {
|
||||
selectedAgentId,
|
||||
pickerAgents,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleProviderChange,
|
||||
handleModelChange,
|
||||
} = useAgentModelPickerState({
|
||||
providers,
|
||||
}, [
|
||||
activeWorkspace,
|
||||
effectiveModelSelection,
|
||||
prepareSelectedProvider,
|
||||
selectedProvider,
|
||||
onProviderSelected: (providerId) => {
|
||||
if (!sessionId) {
|
||||
setGlobalSelectedProvider(providerId);
|
||||
setPendingProviderId(providerId);
|
||||
setPendingModelSelection(null);
|
||||
return;
|
||||
}
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.switchSessionProvider(sessionId, providerId);
|
||||
setGlobalSelectedProvider(providerId);
|
||||
void prepareCurrentSession(providerId).catch((error) => {
|
||||
console.error("Failed to update ACP session provider:", error);
|
||||
});
|
||||
},
|
||||
onModelSelected: (model) => {
|
||||
const modelId = model.id;
|
||||
const modelName = model.displayName ?? model.name ?? model.id;
|
||||
const nextProviderId = model.providerId ?? selectedProvider;
|
||||
|
||||
if (!sessionId) {
|
||||
if (nextProviderId && nextProviderId !== selectedProvider) {
|
||||
setPendingProviderId(nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
}
|
||||
setPendingModelSelection({
|
||||
id: modelId,
|
||||
name: modelName,
|
||||
providerId: nextProviderId,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (
|
||||
!session ||
|
||||
(modelId === session.modelId &&
|
||||
(!nextProviderId || nextProviderId === session.providerId))
|
||||
) {
|
||||
return;
|
||||
}
|
||||
const previousProviderId = session.providerId;
|
||||
const previousModelId = session.modelId;
|
||||
const previousModelName = session.modelName;
|
||||
const providerChanged =
|
||||
Boolean(nextProviderId) && nextProviderId !== session.providerId;
|
||||
|
||||
if (providerChanged && nextProviderId) {
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.switchSessionProvider(sessionId, nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
}
|
||||
|
||||
useChatSessionStore.getState().updateSession(sessionId, {
|
||||
modelId,
|
||||
modelName,
|
||||
});
|
||||
|
||||
void (async () => {
|
||||
try {
|
||||
if (providerChanged && nextProviderId) {
|
||||
await prepareCurrentSession(nextProviderId);
|
||||
}
|
||||
await acpSetModel(sessionId, modelId);
|
||||
} catch (error) {
|
||||
console.error("Failed to set model:", error);
|
||||
if (providerChanged && previousProviderId) {
|
||||
setGlobalSelectedProvider(previousProviderId);
|
||||
}
|
||||
useChatSessionStore.getState().updateSession(sessionId, {
|
||||
providerId: previousProviderId,
|
||||
modelId: previousModelId,
|
||||
modelName: previousModelName,
|
||||
});
|
||||
void (async () => {
|
||||
try {
|
||||
if (providerChanged && previousProviderId) {
|
||||
await prepareCurrentSession(previousProviderId);
|
||||
}
|
||||
if (previousModelId) {
|
||||
await acpSetModel(sessionId, previousModelId);
|
||||
}
|
||||
} catch (rollbackError) {
|
||||
console.error(
|
||||
"Failed to restore previous provider/model after setModel failure:",
|
||||
rollbackError,
|
||||
);
|
||||
}
|
||||
})();
|
||||
}
|
||||
})();
|
||||
},
|
||||
});
|
||||
sessionId,
|
||||
]);
|
||||
|
||||
const handleProjectChange = useCallback(
|
||||
(projectId: string | null) => {
|
||||
|
|
@ -310,16 +269,24 @@ export function useChatSessionController({
|
|||
if (!selectedProvider) {
|
||||
return;
|
||||
}
|
||||
void prepareCurrentSession(selectedProvider, nextProject).catch(
|
||||
(error) => {
|
||||
console.error(
|
||||
"Failed to update ACP session working directory:",
|
||||
error,
|
||||
);
|
||||
},
|
||||
);
|
||||
void prepareCurrentSession(
|
||||
selectedProvider,
|
||||
nextProject,
|
||||
activeWorkspace?.path,
|
||||
selectedPersonaId ?? undefined,
|
||||
effectiveModelSelection,
|
||||
).catch((error) => {
|
||||
console.error("Failed to update ACP session working directory:", error);
|
||||
});
|
||||
},
|
||||
[prepareCurrentSession, selectedProvider, sessionId],
|
||||
[
|
||||
activeWorkspace?.path,
|
||||
effectiveModelSelection,
|
||||
prepareCurrentSession,
|
||||
selectedPersonaId,
|
||||
selectedProvider,
|
||||
sessionId,
|
||||
],
|
||||
);
|
||||
|
||||
const handlePersonaChange = useCallback(
|
||||
|
|
@ -334,7 +301,7 @@ export function useChatSessionController({
|
|||
if (matchingProvider) {
|
||||
if (!sessionId) {
|
||||
setPendingProviderId(matchingProvider.id);
|
||||
setPendingModelSelection(null);
|
||||
setPendingModelSelection(undefined);
|
||||
setGlobalSelectedProvider(matchingProvider.id);
|
||||
} else {
|
||||
handleProviderChange(matchingProvider.id);
|
||||
|
|
@ -399,7 +366,8 @@ export function useChatSessionController({
|
|||
{
|
||||
onMessageAccepted: sessionId ? onMessageAccepted : undefined,
|
||||
ensurePrepared: selectedProvider
|
||||
? () => prepareCurrentSession(selectedProvider)
|
||||
? () =>
|
||||
prepareSelectedProvider(selectedProvider, effectiveModelSelection)
|
||||
: undefined,
|
||||
},
|
||||
);
|
||||
|
|
@ -569,10 +537,6 @@ export function useChatSessionController({
|
|||
if (hasPendingProject) {
|
||||
patch.projectId = nextProjectId ?? null;
|
||||
}
|
||||
if (hasPendingModel) {
|
||||
patch.modelId = pendingModelSelection?.id;
|
||||
patch.modelName = pendingModelSelection?.name;
|
||||
}
|
||||
|
||||
useChatSessionStore.getState().updateSession(sessionId, patch);
|
||||
|
||||
|
|
@ -582,15 +546,21 @@ export function useChatSessionController({
|
|||
nextProject,
|
||||
activeWorkspace?.path,
|
||||
nextPersonaId,
|
||||
pendingModelSelection,
|
||||
);
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
if (pendingModelSelection?.id) {
|
||||
await acpSetModel(sessionId, pendingModelSelection.id);
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
if (pendingModelSelection?.source === "explicit") {
|
||||
const agentId =
|
||||
resolveAgentProviderCatalogIdStrict(
|
||||
pendingModelSelection.providerId ?? nextProviderId,
|
||||
) ?? "goose";
|
||||
setStoredModelPreference(agentId, {
|
||||
modelId: pendingModelSelection.id,
|
||||
modelName: pendingModelSelection.name,
|
||||
providerId: pendingModelSelection.providerId ?? nextProviderId,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to sync pending Home state:", error);
|
||||
|
|
@ -663,14 +633,8 @@ export function useChatSessionController({
|
|||
providersLoading,
|
||||
selectedProvider: selectedAgentId,
|
||||
handleProviderChange,
|
||||
currentModelId:
|
||||
pendingModelSelection !== undefined
|
||||
? (pendingModelSelection?.id ?? null)
|
||||
: (session?.modelId ?? null),
|
||||
currentModelName:
|
||||
pendingModelSelection !== undefined
|
||||
? (pendingModelSelection?.name ?? null)
|
||||
: session?.modelName,
|
||||
currentModelId: effectiveModelSelection?.id ?? null,
|
||||
currentModelName: effectiveModelSelection?.name ?? null,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
|
|
|
|||
483
ui/goose2/src/features/chat/hooks/useResolvedAgentModelPicker.ts
Normal file
483
ui/goose2/src/features/chat/hooks/useResolvedAgentModelPicker.ts
Normal file
|
|
@ -0,0 +1,483 @@
|
|||
import { useEffect, useMemo, useState } from "react";
|
||||
import type { AcpProvider } from "@/shared/api/acp";
|
||||
import { useProviderInventory } from "@/features/providers/hooks/useProviderInventory";
|
||||
import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog";
|
||||
import { getClient } from "@/shared/api/acpConnection";
|
||||
import { acpSetModel } from "@/shared/api/acp";
|
||||
import {
|
||||
useChatSessionStore,
|
||||
type ChatSession,
|
||||
} from "../stores/chatSessionStore";
|
||||
import { useAgentModelPickerState } from "./useAgentModelPickerState";
|
||||
import {
|
||||
clearStoredModelPreference,
|
||||
getStoredModelPreference,
|
||||
setStoredModelPreference,
|
||||
} from "../lib/modelPreferences";
|
||||
|
||||
const GOOSE_PROVIDER_CONFIG_KEY = "GOOSE_PROVIDER";
|
||||
const GOOSE_MODEL_CONFIG_KEY = "GOOSE_MODEL";
|
||||
const MODEL_ALIAS_IDS = new Set(["current", "default"]);
|
||||
|
||||
export type PreferredModelSelection = {
|
||||
id: string;
|
||||
name: string;
|
||||
providerId?: string;
|
||||
source: "default" | "explicit";
|
||||
};
|
||||
|
||||
interface UseResolvedAgentModelPickerOptions {
|
||||
providers: AcpProvider[];
|
||||
selectedProvider: string;
|
||||
sessionId: string | null;
|
||||
session?: ChatSession;
|
||||
pendingModelSelection: PreferredModelSelection | null | undefined;
|
||||
setPendingProviderId: (providerId: string | undefined) => void;
|
||||
setPendingModelSelection: (
|
||||
selection: PreferredModelSelection | null | undefined,
|
||||
) => void;
|
||||
setGlobalSelectedProvider: (providerId: string) => void;
|
||||
prepareSelectedProvider: (
|
||||
providerId: string,
|
||||
modelSelection?: PreferredModelSelection | null,
|
||||
) => Promise<void>;
|
||||
}
|
||||
|
||||
function isModelAlias(modelId?: string | null): boolean {
|
||||
return modelId != null && MODEL_ALIAS_IDS.has(modelId);
|
||||
}
|
||||
|
||||
export function useResolvedAgentModelPicker({
|
||||
providers,
|
||||
selectedProvider,
|
||||
sessionId,
|
||||
session,
|
||||
pendingModelSelection,
|
||||
setPendingProviderId,
|
||||
setPendingModelSelection,
|
||||
setGlobalSelectedProvider,
|
||||
prepareSelectedProvider,
|
||||
}: UseResolvedAgentModelPickerOptions) {
|
||||
const { getEntry: getProviderInventoryEntry } = useProviderInventory();
|
||||
const [gooseDefaultSelection, setGooseDefaultSelection] =
|
||||
useState<PreferredModelSelection | null>(null);
|
||||
|
||||
const selectedAgentId =
|
||||
resolveAgentProviderCatalogIdStrict(selectedProvider) ?? "goose";
|
||||
const concreteSelectedProviderId =
|
||||
resolveAgentProviderCatalogIdStrict(selectedProvider) == null
|
||||
? selectedProvider
|
||||
: null;
|
||||
const storedModelPreference = useMemo(
|
||||
() => getStoredModelPreference(selectedAgentId),
|
||||
[selectedAgentId],
|
||||
);
|
||||
|
||||
const getPreferredSelectionForAgent = useMemo(
|
||||
() => (agentId: string, fallbackProviderId?: string) => {
|
||||
const preferredModel = getStoredModelPreference(agentId);
|
||||
if (preferredModel) {
|
||||
return {
|
||||
id: preferredModel.modelId,
|
||||
name: preferredModel.modelName,
|
||||
providerId: preferredModel.providerId ?? fallbackProviderId,
|
||||
source: "explicit" as const,
|
||||
};
|
||||
}
|
||||
|
||||
if (agentId === "goose") {
|
||||
if (!gooseDefaultSelection) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
...gooseDefaultSelection,
|
||||
providerId: gooseDefaultSelection.providerId ?? fallbackProviderId,
|
||||
};
|
||||
}
|
||||
|
||||
const inventoryEntry = getProviderInventoryEntry(agentId);
|
||||
if (!inventoryEntry?.defaultModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const resolvedInventoryModel =
|
||||
inventoryEntry.models.find(
|
||||
(model) =>
|
||||
model.id === inventoryEntry.defaultModel && !isModelAlias(model.id),
|
||||
) ??
|
||||
inventoryEntry.models.find((model) => model.recommended) ??
|
||||
inventoryEntry.models.find((model) => !isModelAlias(model.id)) ??
|
||||
inventoryEntry.models.find(
|
||||
(model) => model.id === inventoryEntry.defaultModel,
|
||||
) ??
|
||||
inventoryEntry.models[0];
|
||||
|
||||
if (resolvedInventoryModel) {
|
||||
return {
|
||||
id: resolvedInventoryModel.id,
|
||||
name: resolvedInventoryModel.name,
|
||||
providerId:
|
||||
inventoryEntry.providerId === agentId
|
||||
? inventoryEntry.providerId
|
||||
: fallbackProviderId,
|
||||
source: "default" as const,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id: inventoryEntry.defaultModel,
|
||||
name: inventoryEntry.defaultModel,
|
||||
providerId:
|
||||
inventoryEntry.providerId === agentId
|
||||
? inventoryEntry.providerId
|
||||
: fallbackProviderId,
|
||||
source: "default" as const,
|
||||
};
|
||||
},
|
||||
[getProviderInventoryEntry, gooseDefaultSelection],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedAgentId !== "goose") {
|
||||
setGooseDefaultSelection(null);
|
||||
return;
|
||||
}
|
||||
|
||||
let cancelled = false;
|
||||
|
||||
const loadGooseDefaultSelection = async () => {
|
||||
try {
|
||||
const client = await getClient();
|
||||
const [providerResponse, modelResponse] = await Promise.all([
|
||||
client.goose.GooseConfigRead({ key: GOOSE_PROVIDER_CONFIG_KEY }),
|
||||
client.goose.GooseConfigRead({ key: GOOSE_MODEL_CONFIG_KEY }),
|
||||
]);
|
||||
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
const providerId =
|
||||
typeof providerResponse.value === "string"
|
||||
? providerResponse.value
|
||||
: undefined;
|
||||
const modelId =
|
||||
typeof modelResponse.value === "string"
|
||||
? modelResponse.value
|
||||
: undefined;
|
||||
|
||||
if (!modelId) {
|
||||
setGooseDefaultSelection(null);
|
||||
return;
|
||||
}
|
||||
|
||||
setGooseDefaultSelection({
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
providerId,
|
||||
source: "default",
|
||||
});
|
||||
} catch {
|
||||
if (!cancelled) {
|
||||
setGooseDefaultSelection(null);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void loadGooseDefaultSelection();
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [selectedAgentId]);
|
||||
|
||||
const {
|
||||
pickerAgents,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleProviderChange,
|
||||
handleModelChange,
|
||||
} = useAgentModelPickerState({
|
||||
providers,
|
||||
selectedProvider,
|
||||
onProviderSelected: (providerId) => {
|
||||
const requestedAgentId = resolveAgentProviderCatalogIdStrict(providerId);
|
||||
const preferredModelSelection = getPreferredSelectionForAgent(
|
||||
requestedAgentId ?? "goose",
|
||||
providerId,
|
||||
);
|
||||
const nextProviderId = requestedAgentId
|
||||
? (preferredModelSelection?.providerId ?? providerId)
|
||||
: providerId;
|
||||
const nextModelSelection =
|
||||
!requestedAgentId &&
|
||||
preferredModelSelection?.providerId &&
|
||||
preferredModelSelection.providerId !== providerId
|
||||
? undefined
|
||||
: preferredModelSelection
|
||||
? {
|
||||
...preferredModelSelection,
|
||||
providerId:
|
||||
requestedAgentId == null
|
||||
? providerId
|
||||
: preferredModelSelection.providerId,
|
||||
}
|
||||
: undefined;
|
||||
|
||||
if (!sessionId) {
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
setPendingProviderId(nextProviderId);
|
||||
setPendingModelSelection(nextModelSelection);
|
||||
return;
|
||||
}
|
||||
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.switchSessionProvider(sessionId, nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
void prepareSelectedProvider(nextProviderId, nextModelSelection).catch(
|
||||
(error) => {
|
||||
console.error("Failed to update ACP session provider:", error);
|
||||
},
|
||||
);
|
||||
},
|
||||
onModelSelected: (model) => {
|
||||
const modelId = model.id;
|
||||
const modelName = model.displayName ?? model.name ?? model.id;
|
||||
const nextProviderId = model.providerId ?? selectedProvider;
|
||||
const nextStoredModelPreference = {
|
||||
modelId,
|
||||
modelName,
|
||||
providerId: nextProviderId,
|
||||
};
|
||||
|
||||
if (!sessionId) {
|
||||
if (nextProviderId && nextProviderId !== selectedProvider) {
|
||||
setPendingProviderId(nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
}
|
||||
setPendingModelSelection({
|
||||
id: modelId,
|
||||
name: modelName,
|
||||
providerId: nextProviderId,
|
||||
source: "explicit",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
!session ||
|
||||
(modelId === session.modelId &&
|
||||
(!nextProviderId || nextProviderId === session.providerId))
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const previousStoredModelPreference =
|
||||
getStoredModelPreference(selectedAgentId);
|
||||
const previousProviderId = session.providerId;
|
||||
const previousModelId = session.modelId;
|
||||
const previousModelName = session.modelName;
|
||||
const providerChanged =
|
||||
Boolean(nextProviderId) && nextProviderId !== session.providerId;
|
||||
|
||||
if (providerChanged && nextProviderId) {
|
||||
useChatSessionStore
|
||||
.getState()
|
||||
.switchSessionProvider(sessionId, nextProviderId);
|
||||
setGlobalSelectedProvider(nextProviderId);
|
||||
}
|
||||
|
||||
useChatSessionStore.getState().updateSession(sessionId, {
|
||||
modelId,
|
||||
modelName,
|
||||
});
|
||||
|
||||
void (async () => {
|
||||
try {
|
||||
if (providerChanged && nextProviderId) {
|
||||
await prepareSelectedProvider(nextProviderId);
|
||||
}
|
||||
await acpSetModel(sessionId, modelId);
|
||||
setStoredModelPreference(selectedAgentId, nextStoredModelPreference);
|
||||
} catch (error) {
|
||||
console.error("Failed to set model:", error);
|
||||
if (providerChanged && previousProviderId) {
|
||||
setGlobalSelectedProvider(previousProviderId);
|
||||
}
|
||||
if (previousStoredModelPreference) {
|
||||
setStoredModelPreference(
|
||||
selectedAgentId,
|
||||
previousStoredModelPreference,
|
||||
);
|
||||
} else {
|
||||
clearStoredModelPreference(selectedAgentId);
|
||||
}
|
||||
useChatSessionStore.getState().updateSession(sessionId, {
|
||||
providerId: previousProviderId,
|
||||
modelId: previousModelId,
|
||||
modelName: previousModelName,
|
||||
});
|
||||
void (async () => {
|
||||
try {
|
||||
if (providerChanged && previousProviderId) {
|
||||
await prepareSelectedProvider(previousProviderId);
|
||||
}
|
||||
if (previousModelId) {
|
||||
await acpSetModel(sessionId, previousModelId);
|
||||
}
|
||||
} catch (rollbackError) {
|
||||
console.error(
|
||||
"Failed to restore previous provider/model after setModel failure:",
|
||||
rollbackError,
|
||||
);
|
||||
}
|
||||
})();
|
||||
}
|
||||
})();
|
||||
},
|
||||
});
|
||||
|
||||
const preferredModelSelection =
|
||||
useMemo<PreferredModelSelection | null>(() => {
|
||||
if (storedModelPreference) {
|
||||
const matchingStoredModel =
|
||||
availableModels.find(
|
||||
(model) =>
|
||||
model.id === storedModelPreference.modelId &&
|
||||
(!storedModelPreference.providerId ||
|
||||
!model.providerId ||
|
||||
model.providerId === storedModelPreference.providerId) &&
|
||||
(!concreteSelectedProviderId ||
|
||||
!model.providerId ||
|
||||
model.providerId === concreteSelectedProviderId),
|
||||
) ?? null;
|
||||
const storedSelectionCompatible =
|
||||
!concreteSelectedProviderId ||
|
||||
storedModelPreference.providerId === concreteSelectedProviderId;
|
||||
|
||||
if (
|
||||
matchingStoredModel ||
|
||||
((availableModels.length === 0 || modelsLoading) &&
|
||||
storedSelectionCompatible)
|
||||
) {
|
||||
return {
|
||||
id: storedModelPreference.modelId,
|
||||
name:
|
||||
matchingStoredModel?.displayName ??
|
||||
matchingStoredModel?.name ??
|
||||
storedModelPreference.modelName,
|
||||
providerId:
|
||||
matchingStoredModel?.providerId ??
|
||||
storedModelPreference.providerId,
|
||||
source: "explicit",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const inventoryDefaultSelection = getPreferredSelectionForAgent(
|
||||
selectedAgentId,
|
||||
selectedProvider,
|
||||
);
|
||||
|
||||
if (!inventoryDefaultSelection) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const matchingDefaultModel =
|
||||
availableModels.find(
|
||||
(model) =>
|
||||
model.id === inventoryDefaultSelection.id &&
|
||||
(!inventoryDefaultSelection.providerId ||
|
||||
!model.providerId ||
|
||||
model.providerId === inventoryDefaultSelection.providerId) &&
|
||||
(!concreteSelectedProviderId ||
|
||||
!model.providerId ||
|
||||
model.providerId === concreteSelectedProviderId),
|
||||
) ?? null;
|
||||
const defaultSelectionCompatible =
|
||||
!concreteSelectedProviderId ||
|
||||
inventoryDefaultSelection.providerId === concreteSelectedProviderId;
|
||||
|
||||
if (!matchingDefaultModel && !defaultSelectionCompatible) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
id: inventoryDefaultSelection.id,
|
||||
name:
|
||||
matchingDefaultModel?.displayName ??
|
||||
matchingDefaultModel?.name ??
|
||||
inventoryDefaultSelection.name,
|
||||
providerId:
|
||||
matchingDefaultModel?.providerId ??
|
||||
inventoryDefaultSelection.providerId,
|
||||
source: "default",
|
||||
};
|
||||
}, [
|
||||
availableModels,
|
||||
getPreferredSelectionForAgent,
|
||||
modelsLoading,
|
||||
concreteSelectedProviderId,
|
||||
selectedProvider,
|
||||
selectedAgentId,
|
||||
storedModelPreference,
|
||||
]);
|
||||
|
||||
const sessionModelSelection = useMemo<PreferredModelSelection | null>(() => {
|
||||
if (!session?.modelId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const matchingSessionModel =
|
||||
availableModels.find(
|
||||
(model) =>
|
||||
model.id === session.modelId &&
|
||||
(!session.providerId ||
|
||||
!model.providerId ||
|
||||
model.providerId === session.providerId),
|
||||
) ?? null;
|
||||
|
||||
if (matchingSessionModel) {
|
||||
return {
|
||||
id: matchingSessionModel.id,
|
||||
name:
|
||||
matchingSessionModel.displayName ??
|
||||
matchingSessionModel.name ??
|
||||
session.modelName ??
|
||||
session.modelId,
|
||||
providerId: matchingSessionModel.providerId ?? session.providerId,
|
||||
source: "explicit",
|
||||
};
|
||||
}
|
||||
|
||||
if (isModelAlias(session.modelId)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
id: session.modelId,
|
||||
name: session.modelName ?? session.modelId,
|
||||
providerId: session.providerId,
|
||||
source: "explicit",
|
||||
};
|
||||
}, [availableModels, session]);
|
||||
|
||||
const effectiveModelSelection =
|
||||
pendingModelSelection !== undefined
|
||||
? pendingModelSelection
|
||||
: (sessionModelSelection ?? preferredModelSelection);
|
||||
|
||||
return {
|
||||
selectedAgentId,
|
||||
pickerAgents,
|
||||
availableModels,
|
||||
modelsLoading,
|
||||
modelStatusMessage,
|
||||
handleProviderChange,
|
||||
handleModelChange,
|
||||
effectiveModelSelection,
|
||||
};
|
||||
}
|
||||
86
ui/goose2/src/features/chat/lib/modelPreferences.ts
Normal file
86
ui/goose2/src/features/chat/lib/modelPreferences.ts
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog";
|
||||
|
||||
const MODEL_PREFERENCES_STORAGE_KEY = "goose:preferredModelsByAgent";
|
||||
|
||||
export interface StoredModelPreference {
|
||||
modelId: string;
|
||||
modelName: string;
|
||||
providerId?: string;
|
||||
}
|
||||
|
||||
type StoredModelPreferences = Record<string, StoredModelPreference>;
|
||||
|
||||
function readStoredModelPreferences(): StoredModelPreferences {
|
||||
if (typeof window === "undefined") {
|
||||
return {};
|
||||
}
|
||||
|
||||
try {
|
||||
const stored = window.localStorage.getItem(MODEL_PREFERENCES_STORAGE_KEY);
|
||||
if (!stored) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const parsed = JSON.parse(stored);
|
||||
if (!parsed || typeof parsed !== "object") {
|
||||
return {};
|
||||
}
|
||||
|
||||
return parsed as StoredModelPreferences;
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function persistStoredModelPreferences(
|
||||
preferences: StoredModelPreferences,
|
||||
): void {
|
||||
if (typeof window === "undefined") {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
if (Object.keys(preferences).length === 0) {
|
||||
window.localStorage.removeItem(MODEL_PREFERENCES_STORAGE_KEY);
|
||||
return;
|
||||
}
|
||||
|
||||
window.localStorage.setItem(
|
||||
MODEL_PREFERENCES_STORAGE_KEY,
|
||||
JSON.stringify(preferences),
|
||||
);
|
||||
} catch {
|
||||
// localStorage may be unavailable
|
||||
}
|
||||
}
|
||||
|
||||
export function getStoredModelPreference(
|
||||
agentId: string,
|
||||
): StoredModelPreference | null {
|
||||
return readStoredModelPreferences()[agentId] ?? null;
|
||||
}
|
||||
|
||||
export function getStoredModelPreferenceForProvider(
|
||||
providerId: string,
|
||||
): StoredModelPreference | null {
|
||||
const agentId = resolveAgentProviderCatalogIdStrict(providerId) ?? "goose";
|
||||
return getStoredModelPreference(agentId);
|
||||
}
|
||||
|
||||
export function setStoredModelPreference(
|
||||
agentId: string,
|
||||
preference: StoredModelPreference,
|
||||
): void {
|
||||
const next = readStoredModelPreferences();
|
||||
next[agentId] = preference;
|
||||
persistStoredModelPreferences(next);
|
||||
}
|
||||
|
||||
export function clearStoredModelPreference(agentId: string): void {
|
||||
const next = readStoredModelPreferences();
|
||||
if (!(agentId in next)) {
|
||||
return;
|
||||
}
|
||||
delete next[agentId];
|
||||
persistStoredModelPreferences(next);
|
||||
}
|
||||
114
ui/goose2/src/features/chat/lib/sessionModelPreference.test.ts
Normal file
114
ui/goose2/src/features/chat/lib/sessionModelPreference.test.ts
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
import { beforeEach, describe, expect, it } from "vitest";
|
||||
import {
|
||||
resolveSessionModelPreference,
|
||||
sanitizeSessionModelPreference,
|
||||
} from "./sessionModelPreference";
|
||||
|
||||
describe("resolveSessionModelPreference", () => {
|
||||
beforeEach(() => {
|
||||
window.localStorage.clear();
|
||||
});
|
||||
|
||||
it("keeps a requested concrete provider when the stored preference uses a different provider", () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(
|
||||
resolveSessionModelPreference({
|
||||
providerId: "openai",
|
||||
}),
|
||||
).toEqual({
|
||||
providerId: "openai",
|
||||
});
|
||||
});
|
||||
|
||||
it("reuses a stored model when it matches the requested concrete provider", () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "gpt-5.4",
|
||||
modelName: "GPT-5.4",
|
||||
providerId: "openai",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(
|
||||
resolveSessionModelPreference({
|
||||
providerId: "openai",
|
||||
}),
|
||||
).toEqual({
|
||||
providerId: "openai",
|
||||
modelId: "gpt-5.4",
|
||||
modelName: "GPT-5.4",
|
||||
});
|
||||
});
|
||||
|
||||
it("resolves an agent provider to the stored concrete provider and model", () => {
|
||||
window.localStorage.setItem(
|
||||
"goose:preferredModelsByAgent",
|
||||
JSON.stringify({
|
||||
goose: {
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
providerId: "anthropic",
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(
|
||||
resolveSessionModelPreference({
|
||||
providerId: "goose",
|
||||
}),
|
||||
).toEqual({
|
||||
providerId: "anthropic",
|
||||
modelId: "claude-sonnet-4",
|
||||
modelName: "Claude Sonnet 4",
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps a stored model when the provider inventory still contains it", () => {
|
||||
expect(
|
||||
sanitizeSessionModelPreference(
|
||||
{
|
||||
providerId: "openai",
|
||||
modelId: "gpt-5.4",
|
||||
modelName: "GPT-5.4",
|
||||
},
|
||||
{
|
||||
models: [{ id: "gpt-5.4" }, { id: "gpt-5.4-mini" }],
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
providerId: "openai",
|
||||
modelId: "gpt-5.4",
|
||||
modelName: "GPT-5.4",
|
||||
});
|
||||
});
|
||||
|
||||
it("drops a stored model when the provider inventory no longer contains it", () => {
|
||||
expect(
|
||||
sanitizeSessionModelPreference(
|
||||
{
|
||||
providerId: "openai",
|
||||
modelId: "gpt-4.1",
|
||||
modelName: "GPT-4.1",
|
||||
},
|
||||
{
|
||||
models: [{ id: "gpt-5.4" }, { id: "gpt-5.4-mini" }],
|
||||
},
|
||||
),
|
||||
).toEqual({
|
||||
providerId: "openai",
|
||||
});
|
||||
});
|
||||
});
|
||||
75
ui/goose2/src/features/chat/lib/sessionModelPreference.ts
Normal file
75
ui/goose2/src/features/chat/lib/sessionModelPreference.ts
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import { resolveAgentProviderCatalogIdStrict } from "@/features/providers/providerCatalog";
|
||||
import { getStoredModelPreferenceForProvider } from "./modelPreferences";
|
||||
|
||||
interface SessionModelPreferenceOptions {
|
||||
providerId: string;
|
||||
preferredModel?: string;
|
||||
}
|
||||
|
||||
export interface SessionModelPreference {
|
||||
providerId: string;
|
||||
modelId?: string;
|
||||
modelName?: string;
|
||||
}
|
||||
|
||||
interface ProviderInventoryEntryLike {
|
||||
models: Array<{
|
||||
id: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
export function resolveSessionModelPreference({
|
||||
providerId,
|
||||
preferredModel,
|
||||
}: SessionModelPreferenceOptions): SessionModelPreference {
|
||||
if (preferredModel) {
|
||||
return {
|
||||
providerId,
|
||||
modelId: preferredModel,
|
||||
modelName: preferredModel,
|
||||
};
|
||||
}
|
||||
|
||||
const storedModelPreference = getStoredModelPreferenceForProvider(providerId);
|
||||
if (!storedModelPreference) {
|
||||
return { providerId };
|
||||
}
|
||||
|
||||
if (resolveAgentProviderCatalogIdStrict(providerId)) {
|
||||
return {
|
||||
providerId: storedModelPreference.providerId ?? providerId,
|
||||
modelId: storedModelPreference.modelId,
|
||||
modelName: storedModelPreference.modelName,
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
storedModelPreference.providerId &&
|
||||
storedModelPreference.providerId !== providerId
|
||||
) {
|
||||
return { providerId };
|
||||
}
|
||||
|
||||
return {
|
||||
providerId,
|
||||
modelId: storedModelPreference.modelId,
|
||||
modelName: storedModelPreference.modelName,
|
||||
};
|
||||
}
|
||||
|
||||
export function sanitizeSessionModelPreference(
|
||||
preference: SessionModelPreference,
|
||||
inventoryEntry?: ProviderInventoryEntryLike | null,
|
||||
): SessionModelPreference {
|
||||
if (!preference.modelId || !inventoryEntry) {
|
||||
return preference;
|
||||
}
|
||||
|
||||
if (inventoryEntry.models.some((model) => model.id === preference.modelId)) {
|
||||
return preference;
|
||||
}
|
||||
|
||||
return {
|
||||
providerId: preference.providerId,
|
||||
};
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue