From 9f157158770c27b08b3e15d7a9a6dd421a0ba025 Mon Sep 17 00:00:00 2001 From: Frank Date: Thu, 28 May 2026 18:38:07 -0400 Subject: [PATCH] zen: sync --- .../app/src/routes/zen/util/handler.ts | 47 ++++++++++++------- .../src/routes/zen/util/modelTpsLimiter.ts | 9 +--- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/packages/console/app/src/routes/zen/util/handler.ts b/packages/console/app/src/routes/zen/util/handler.ts index 6ddc967097..ed76c16c1a 100644 --- a/packages/console/app/src/routes/zen/util/handler.ts +++ b/packages/console/app/src/routes/zen/util/handler.ts @@ -478,29 +478,24 @@ export async function handler( stickyId: string, trialProviders: string[] | undefined, retry: RetryOptions, - stickyProvider: string | undefined, + stickyProviderId: string | undefined, modelTpmLimits: Record | undefined, - modelTpsLimits: Record | undefined, + modelTpsLimits: Record | undefined, ) { const modelProvider = (() => { - const allProviders = modelInfo.providers.filter((provider) => !provider.disabled) - // Byok is top priority b/c if user set their own API key, we should use it // instead of using the sticky provider for the same session if (authInfo?.provider?.credentials) { - return allProviders.find((provider) => provider.id === modelInfo.byokProvider) - } - - // Always use the same provider for the same session - if (stickyProvider) { - const provider = allProviders.find((provider) => provider.id === stickyProvider) - if (provider) return provider + return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider) } + // Prioritize trial providers + let allProviders = modelInfo.providers.filter((provider) => !provider.disabled) if (trialProviders) { - const trialProvider = trialProviders[Math.floor(Math.random() * trialProviders.length)] - const provider = allProviders.find((provider) => provider.id === trialProvider) - if (provider) return provider + allProviders = allProviders.map((provider) => ({ + ...provider, + priority: trialProviders.includes(provider.id) ? 0 : provider.priority, + })) } if (retry.retryCount !== MAX_FAILOVER_RETRIES) { @@ -515,7 +510,11 @@ export async function handler( }) .filter((provider) => { if (!provider.tpsGoal) return true - const isLowTps = modelTpsLimits?.[`${provider.id}/${provider.model}/${provider.tpsGoal}`] ?? false + const tps = modelTpsLimits?.[`${provider.id}/${provider.model}/${provider.tpsGoal}`] ?? { + qualify: 0, + unqualify: 0, + } + const isLowTps = tps.qualify + tps.unqualify > 10 && tps.qualify < tps.unqualify return !isLowTps }) .map((provider) => { @@ -533,7 +532,23 @@ export async function handler( } const index = (h >>> 0) % providers.length // make unsigned + range 0..length-1 const provider = providers[index || 0] - if (provider) return provider + + // sticky provider does not exist => use selected provider + if (!stickyProviderId) return provider + const stickProvider = allProviders.find((provider) => provider.id === stickyProviderId) + if (!stickProvider) return provider + + // stick provider exists + selected provider is API type => use sticky provider + if (!provider.tpsGoal) return stickProvider + + // stick provier exists + selected provider is GPU type + GPU not idle => use selected provider + const tps = modelTpsLimits?.[`${provider.id}/${provider.model}/${provider.tpsGoal}`] ?? { + qualify: 0, + unqualify: 0, + } + if (tps.qualify <= tps.unqualify * 3) return stickProvider + + return provider } // fallback provider diff --git a/packages/console/app/src/routes/zen/util/modelTpsLimiter.ts b/packages/console/app/src/routes/zen/util/modelTpsLimiter.ts index 477d08ce68..3ff63f7d49 100644 --- a/packages/console/app/src/routes/zen/util/modelTpsLimiter.ts +++ b/packages/console/app/src/routes/zen/util/modelTpsLimiter.ts @@ -37,7 +37,7 @@ export function createModelTpsLimiter(providers: { id: string; model: string; tp ) // convert to map of model to summed count across current and previous intervals - const result = data.reduce( + return data.reduce( (acc, curr) => { const existing = acc[curr.id] ?? { qualify: 0, unqualify: 0 } acc[curr.id] = { @@ -48,13 +48,6 @@ export function createModelTpsLimiter(providers: { id: string; model: string; tp }, {} as Record, ) - - return Object.fromEntries( - Object.entries(result).map(([id, { qualify, unqualify }]) => { - const isLowTps = qualify + unqualify > 10 && qualify < unqualify - return [id, isLowTps] - }), - ) }, track: async ( provider: string,