zen: tpm routing

This commit is contained in:
Frank 2026-04-20 22:21:06 -04:00
parent 3e8abac625
commit f74a255ca9
5 changed files with 2746 additions and 29 deletions

View file

@ -448,31 +448,40 @@ export async function handler(
return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
}
// Filter out TPM limited providers
const allProviders = modelInfo.providers.filter((provider) => {
if (!provider.tpmLimit) return true
const usage = modelTpmLimits?.[`${provider.id}/${provider.model}`] ?? 0
return usage < provider.tpmLimit * 1_000_000
})
// Always use the same provider for the same session
if (stickyProvider) {
const provider = modelInfo.providers.find((provider) => provider.id === stickyProvider)
const provider = allProviders.find((provider) => provider.id === stickyProvider)
if (provider) return provider
}
if (trialProviders) {
const trialProvider = trialProviders[Math.floor(Math.random() * trialProviders.length)]
const provider = modelInfo.providers.find((provider) => provider.id === trialProvider)
const provider = allProviders.find((provider) => provider.id === trialProvider)
if (provider) return provider
}
if (retry.retryCount !== MAX_FAILOVER_RETRIES) {
const allProviders = modelInfo.providers
let topPriority = Infinity
const providers = allProviders
.filter((provider) => !provider.disabled)
.filter((provider) => provider.weight !== 0)
.filter((provider) => !retry.excludeProviders.includes(provider.id))
.filter((provider) => {
if (!provider.tpmLimit) return true
const usage = modelTpmLimits?.[`${provider.id}/${provider.model}`] ?? 0
return usage < provider.tpmLimit * 1_000_000
return usage < provider.tpmLimit * 1_000_000 * 0.8
})
.map((provider) => {
topPriority = Math.min(topPriority, provider.priority)
return provider
})
const topPriority = Math.min(...allProviders.map((p) => p.priority))
const providers = allProviders
.filter((p) => p.priority <= topPriority)
.flatMap((provider) => Array<typeof provider>(provider.weight).fill(provider))

View file

@ -1,28 +1,25 @@
import { and, Database, eq, inArray, sql } from "@opencode-ai/console-core/drizzle/index.js"
import { ModelTpmLimitTable } from "@opencode-ai/console-core/schema/ip.sql.js"
import { ModelTpmRateLimitTable } from "@opencode-ai/console-core/schema/ip.sql.js"
import { UsageInfo } from "./provider/provider"
export function createModelTpmLimiter(providers: { id: string; model: string; tpmLimit?: number }[]) {
const ids = providers.filter((p) => p.tpmLimit).map((p) => `${p.id}/${p.model}`)
if (ids.length === 0) return
const yyyyMMddHHmm = new Date(Date.now())
.toISOString()
.replace(/[^0-9]/g, "")
.substring(0, 12)
const yyyyMMddHHmm = parseInt(
new Date(Date.now())
.toISOString()
.replace(/[^0-9]/g, "")
.substring(0, 12),
)
return {
check: async () => {
const data = await Database.use((tx) =>
tx
.select()
.from(ModelTpmLimitTable)
.where(
inArray(
ModelTpmLimitTable.id,
ids.map((id) => formatId(id, yyyyMMddHHmm)),
),
),
.from(ModelTpmRateLimitTable)
.where(and(inArray(ModelTpmRateLimitTable.id, ids), eq(ModelTpmRateLimitTable.interval, yyyyMMddHHmm))),
)
// convert to map of model to count
@ -41,14 +38,10 @@ export function createModelTpmLimiter(providers: { id: string; model: string; tp
if (usage <= 0) return
await Database.use((tx) =>
tx
.insert(ModelTpmLimitTable)
.values({ id: formatId(id, yyyyMMddHHmm), count: usage })
.onDuplicateKeyUpdate({ set: { count: sql`${ModelTpmLimitTable.count} + ${usage}` } }),
.insert(ModelTpmRateLimitTable)
.values({ id, interval: yyyyMMddHHmm, count: usage })
.onDuplicateKeyUpdate({ set: { count: sql`${ModelTpmRateLimitTable.count} + ${usage}` } }),
)
},
}
function formatId(id: string, yyyyMMddHHmm: string) {
return `${id.substring(0, 200)}/${yyyyMMddHHmm}`
}
}

View file

@ -0,0 +1,5 @@
CREATE TABLE `model_tpm_rate_limit` (
`id` varchar(255) PRIMARY KEY,
`interval` bigint NOT NULL,
`count` int NOT NULL
);

View file

@ -1,4 +1,4 @@
import { mysqlTable, int, primaryKey, varchar } from "drizzle-orm/mysql-core"
import { mysqlTable, int, primaryKey, varchar, bigint } from "drizzle-orm/mysql-core"
import { timestamps } from "../drizzle/types"
export const IpTable = mysqlTable(
@ -31,10 +31,11 @@ export const KeyRateLimitTable = mysqlTable(
(table) => [primaryKey({ columns: [table.key, table.interval] })],
)
export const ModelTpmLimitTable = mysqlTable(
"model_tpm_limit",
export const ModelTpmRateLimitTable = mysqlTable(
"model_tpm_rate_limit",
{
id: varchar("id", { length: 255 }).notNull(),
interval: bigint("interval", { mode: "number" }).notNull(),
count: int("count").notNull(),
},
(table) => [primaryKey({ columns: [table.id] })],