fix(server-runtime): enforce auth & routing safety, fix lifecycle leaks (#1686)

This commit is contained in:
Iro 2026-04-22 13:03:16 +07:00 committed by GitHub
parent e165829b19
commit 3f829421d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 455 additions and 83 deletions

View file

@ -1 +1 @@
sha256-69tCpJaxRUnyR9CrHlmWJWEWLDpYzyzska7ui9++QoY=
sha256-+ruQJso6gF5n4vdu9xTMiYoTP6q3xRrCVsGrEEggQqc=

View file

@ -16,6 +16,7 @@ import { Buffer } from 'node:buffer'
import { timingSafeEqual } from 'node:crypto'
import { availableLogLevelStrings, Format, LogLevelString, logLevelStringToLogLevelMap, useLogg } from '@guiiai/logg'
import { errorMessageFrom } from '@moeru/std'
import {
createInvalidJsonServerErrorMessage,
ServerErrorMessages,
@ -122,17 +123,61 @@ export interface ConsumerSelectionCandidate {
healthy?: boolean
}
function isConsumerDeliveryMode(mode: unknown): mode is 'consumer' | 'consumer-group' {
return mode === 'consumer' || mode === 'consumer-group'
}
function normalizeConsumerMode(mode: unknown, group?: string): 'consumer' | 'consumer-group' {
if (isConsumerDeliveryMode(mode)) {
return mode
}
return group ? 'consumer-group' : 'consumer'
}
function normalizeConsumerPriority(priority: unknown) {
return typeof priority === 'number' && Number.isFinite(priority)
? priority
: 0
}
// helper send function
function send(peer: Peer, event: WebSocketEvent<Record<string, unknown>> | string) {
peer.send(typeof event === 'string' ? event : stringify(event))
}
/**
* Detects raw websocket heartbeat control frames surfaced as text payloads.
*
* Use when:
* - A websocket runtime forwards ping/pong frames through the normal message callback
* - The runtime should ignore transport heartbeats instead of treating them as protocol JSON
*
* Expects:
* - Raw text payloads such as `ping` and `pong`
*
* Returns:
* - The heartbeat kind when the text is a control frame, otherwise `undefined`
*/
export function detectHeartbeatControlFrame(text: string): MessageHeartbeatKind | undefined {
if (text === MessageHeartbeatKind.Ping || text === MessageHeartbeatKind.Pong) {
return text
}
}
/**
* Resolves the effective delivery configuration for an event.
*
* Use when:
* - Protocol defaults should be merged with route-level overrides
* - Delivery mode selection needs to happen before routing or consumer dispatch
*
* Expects:
* - Route delivery to override protocol metadata field-by-field
*
* Returns:
* - The merged delivery config or `undefined` when the event has no delivery rules
*/
export function resolveDeliveryConfig(event: WebSocketEvent): DeliveryConfig | undefined {
const eventMetadata = getProtocolEventMetadata(event.type)
const defaultDelivery = eventMetadata?.delivery
@ -170,6 +215,19 @@ function sortConsumers(entries: Array<Pick<ConsumerSelectionCandidate, 'peerId'
})
}
/**
* Selects a concrete consumer peer for consumer-style delivery modes.
*
* Use when:
* - An event should be sent to exactly one registered consumer
* - Sticky or round-robin routing needs to be resolved against the live peer registry
*
* Expects:
* - Candidates to already describe the authenticated and health state of each peer
*
* Returns:
* - The selected peer id, or `undefined` when no eligible consumer is available
*/
export function selectConsumerPeerId(options: {
eventType: string
fromPeerId: string
@ -241,6 +299,19 @@ export interface AppOptions {
}
}
/**
* Normalizes logger settings from explicit options and environment variables.
*
* Use when:
* - The runtime should support config-driven and env-driven logging
* - App and websocket logger settings need consistent defaults
*
* Expects:
* - Explicit websocket settings to override app-level defaults
*
* Returns:
* - The resolved app and websocket logger configuration
*/
export function normalizeLoggerConfig(options?: AppOptions) {
const appLogLevel = optionOrEnv(options?.logger?.app?.level, 'LOG_LEVEL', LogLevelString.Log, { validator: (value): value is LogLevelString => availableLogLevelStrings.includes(value as LogLevelString) })
const appLogFormat = optionOrEnv(options?.logger?.app?.format, 'LOG_FORMAT', Format.Pretty, { validator: (value): value is Format => Object.values(Format).includes(value as Format) })
@ -255,7 +326,20 @@ export function normalizeLoggerConfig(options?: AppOptions) {
}
}
export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () => void } {
/**
* Creates the H3 websocket application and its in-memory peer registry.
*
* Use when:
* - Embedding the AIRI websocket runtime inside a server process
* - Spinning up a testable application instance before binding a socket listener
*
* Expects:
* - Caller lifecycle management to invoke `dispose` when the app is no longer needed
*
* Returns:
* - The H3 app plus cleanup helpers for peer shutdown and timer disposal
*/
export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () => void, dispose: () => void } {
const instanceId = options?.instanceId || optionOrEnv(undefined, 'SERVER_INSTANCE_ID', nanoid())
const authToken = optionOrEnv(options?.auth?.token, 'AUTHENTICATION_TOKEN', '')
@ -284,8 +368,41 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
const HEALTH_CHECK_MISSES_UNHEALTHY = 5
const HEALTH_CHECK_MISSES_DEAD = HEALTH_CHECK_MISSES_UNHEALTHY * 2
const healthCheckIntervalMs = Math.max(5_000, Math.floor(heartbeatTtlMs / HEALTH_CHECK_MISSES_UNHEALTHY))
let disposed = false
setInterval(() => {
function broadcastPeerHealthy(peerInfo: AuthenticatedPeer, parentId?: string) {
if (!peerInfo.name || !peerInfo.identity) {
return
}
broadcastToAuthenticated({
type: 'registry:modules:health:healthy',
data: { name: peerInfo.name, index: peerInfo.index, identity: peerInfo.identity },
metadata: createServerEventMetadata(instanceId, parentId),
})
}
function markPeerAlive(peerInfo: AuthenticatedPeer, options?: { parentId?: string, logMessage?: string }) {
peerInfo.lastHeartbeatAt = Date.now()
peerInfo.missedHeartbeats = 0
if (peerInfo.healthy === false && peerInfo.authenticated) {
peerInfo.healthy = true
logger.withFields({ peer: peerInfo.peer.id, peerName: peerInfo.name }).debug(options?.logMessage ?? 'peer activity recovered, marking healthy')
broadcastPeerHealthy(peerInfo, options?.parentId)
}
}
function resetRoutingState() {
peers.clear()
peersByModule.clear()
consumerRegistry.clear()
consumerKeysByPeer.clear()
deliveryRoundRobinCursor.clear()
stickyAssignments.clear()
}
const healthCheckInterval = setInterval(() => {
const now = Date.now()
for (const [id, peerInfo] of peers.entries()) {
if (!peerInfo.lastHeartbeatAt) {
@ -325,6 +442,9 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
}
}
}, healthCheckIntervalMs)
if (typeof healthCheckInterval === 'object') {
healthCheckInterval.unref?.()
}
function registerModulePeer(p: AuthenticatedPeer, name: string, index?: number) {
if (!peersByModule.has(name)) {
@ -361,7 +481,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
event,
group: normalizedGroup,
peerId,
priority: priority ?? 0,
priority: normalizeConsumerPriority(priority),
registeredAt: Date.now(),
})
@ -466,33 +586,45 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
return peers.get(selectedPeerId)
}
function unregisterModulePeer(p: AuthenticatedPeer, reason?: string) {
unregisterPeerConsumers(p.peer.id)
function unregisterModuleRegistration(
peerInfo: AuthenticatedPeer,
options?: { reason?: string, unregisterConsumers?: boolean },
) {
if (options?.unregisterConsumers !== false) {
unregisterPeerConsumers(peerInfo.peer.id)
}
if (!p.name)
if (!peerInfo.name)
return
const group = peersByModule.get(p.name)
const group = peersByModule.get(peerInfo.name)
if (group) {
group.delete(p.index)
group.delete(peerInfo.index)
if (group.size === 0) {
peersByModule.delete(p.name)
peersByModule.delete(peerInfo.name)
}
}
// broadcast module:de-announced to all authenticated peers
if (p.identity) {
if (peerInfo.identity) {
broadcastToAuthenticated({
type: 'module:de-announced',
data: { name: p.name, index: p.index, identity: p.identity, reason },
data: { name: peerInfo.name, index: peerInfo.index, identity: peerInfo.identity, reason: options?.reason },
metadata: createServerEventMetadata(instanceId),
})
}
peerInfo.name = ''
peerInfo.index = undefined
broadcastRegistrySync()
}
function unregisterModulePeer(peerInfo: AuthenticatedPeer, reason?: string) {
unregisterModuleRegistration(peerInfo, { reason })
}
function listKnownModules() {
return Array.from(peers.values())
.filter(peerInfo => peerInfo.name && peerInfo.identity)
@ -553,19 +685,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
// liveness only so they do not leak into the application event protocol.
if (controlFrame) {
if (authenticatedPeer) {
authenticatedPeer.lastHeartbeatAt = Date.now()
authenticatedPeer.missedHeartbeats = 0
if (authenticatedPeer.healthy === false && authenticatedPeer.name && authenticatedPeer.identity) {
authenticatedPeer.healthy = true
logger.withFields({ peer: peer.id, peerName: authenticatedPeer.name })
.debug('ping/pong recovered, marking healthy')
broadcastToAuthenticated({
type: 'registry:modules:health:healthy',
data: { name: authenticatedPeer.name, index: authenticatedPeer.index, identity: authenticatedPeer.identity },
metadata: createServerEventMetadata(instanceId),
})
}
markPeerAlive(authenticatedPeer, { logMessage: 'ping/pong recovered, marking healthy' })
}
return
@ -592,7 +712,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
event = potentialEvent as WebSocketEvent
}
catch (err) {
const errorMessage = err instanceof Error ? err.message : String(err)
const errorMessage = errorMessageFrom(err) ?? 'Unknown JSON parsing error'
send(peer, RESPONSES.error(createInvalidJsonServerErrorMessage(errorMessage), instanceId))
return
@ -606,8 +726,9 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
}).debug('received event')
if (authenticatedPeer) {
authenticatedPeer.lastHeartbeatAt = Date.now()
if (event.metadata?.source) {
markPeerAlive(authenticatedPeer, { parentId: event.metadata?.event.id })
if (authenticatedPeer.authenticated && event.metadata?.source) {
authenticatedPeer.identity = event.metadata.source
}
}
@ -616,19 +737,12 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
case 'transport:connection:heartbeat': {
const p = peers.get(peer.id)
if (p) {
p.lastHeartbeatAt = Date.now()
p.missedHeartbeats = 0
markPeerAlive(p, {
parentId: event.metadata?.event.id,
logMessage: 'heartbeat recovered, marking healthy',
})
// recover from unhealthy → healthy
if (p.healthy === false && p.name && p.identity) {
p.healthy = true
logger.withFields({ peer: peer.id, peerName: p.name }).debug('heartbeat recovered, marking healthy')
broadcastToAuthenticated({
type: 'registry:modules:health:healthy',
data: { name: p.name, index: p.index, identity: p.identity },
metadata: createServerEventMetadata(instanceId, event.metadata?.event.id),
})
}
}
if (event.data.kind === MessageHeartbeatKind.Ping) {
@ -664,9 +778,6 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
return
}
unregisterModulePeer(p, 're-announcing')
// verify
const { name, index, identity } = event.data as { name: string, index?: number, identity?: MetadataEventSource }
if (!name || typeof name !== 'string') {
send(peer, RESPONSES.error(ServerErrorMessages.moduleAnnounceNameInvalid, instanceId))
@ -691,11 +802,14 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
return
}
unregisterModuleRegistration(p, {
reason: 're-announcing',
unregisterConsumers: false,
})
p.name = name
p.index = index
if (identity) {
p.identity = identity
}
p.identity = identity
registerModulePeer(p, name, index)
@ -775,7 +889,13 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
return
}
registerConsumer(peer.id, data.event, data.mode ?? (data.group ? 'consumer-group' : 'consumer'), data.group, data.priority)
registerConsumer(
peer.id,
data.event,
normalizeConsumerMode(data.mode, data.group),
data.group,
normalizeConsumerPriority(data.priority),
)
return
}
@ -797,7 +917,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
return
}
unregisterConsumer(peer.id, data.event, data.mode ?? (data.group ? 'consumer-group' : 'consumer'), data.group)
unregisterConsumer(peer.id, data.event, normalizeConsumerMode(data.mode, data.group), data.group)
return
}
}
@ -816,6 +936,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
const shouldBypass = Boolean(event.route?.bypass && allowBypass && isDevtoolsPeer(p))
const destinations = shouldBypass ? undefined : collectDestinations(event)
const delivery = shouldBypass ? undefined : resolveDeliveryConfig(event)
const effectiveRoutingMiddleware = shouldBypass ? [] : routingMiddleware
const routingContext: RouteContext = {
event,
fromPeer: p,
@ -824,7 +945,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
}
let decision: RouteDecision | undefined
for (const middleware of routingMiddleware) {
for (const middleware of effectiveRoutingMiddleware) {
const result = middleware(routingContext)
if (result) {
decision = result
@ -886,11 +1007,16 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
continue
}
if (!other.authenticated) {
logger.withFields({ fromPeer: peer.id, toPeer: other.peer.id, toPeerName: other.name, event }).debug('not sending event to unauthenticated peer')
continue
}
if (!shouldBroadcast && targetIds && !targetIds.has(id)) {
continue
}
if (shouldBroadcast && destinations && destinations.length > 0 && !matchesDestinations(destinations, other)) {
if (shouldBroadcast && destinations !== undefined && !matchesDestinations(destinations, other)) {
continue
}
@ -913,6 +1039,10 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
close: (peer, details) => {
const p = peers.get(peer.id)
const now = Date.now()
const peerName = p?.name
const peerIndex = p?.index
const peerHealthy = p?.healthy
const peerMissedHeartbeats = p?.missedHeartbeats
const safeDetails = details ?? {}
const closeCode = typeof safeDetails.code === 'number' ? safeDetails.code : undefined
const closeReason = typeof safeDetails.reason === 'string' ? safeDetails.reason : undefined
@ -928,8 +1058,10 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
)
const likelySilentNetworkClose = closeCode === 1005
if (p)
if (p) {
peers.delete(peer.id)
unregisterModulePeer(p, 'connection closed')
}
logger.withFields({
peer: peer.id,
@ -940,10 +1072,10 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
closeWasClean,
activePeers: peers.size,
peerAuthenticated: p?.authenticated,
peerName: p?.name,
peerIndex: p?.index,
peerHealthy: p?.healthy,
peerMissedHeartbeats: p?.missedHeartbeats,
peerName,
peerIndex,
peerHealthy,
peerMissedHeartbeats,
heartbeatLastSeenAt,
heartbeatSilentForMs,
heartbeatTtlMs,
@ -951,20 +1083,36 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
likelyHeartbeatExpiry,
likelySilentNetworkClose,
}).log('closed')
peers.delete(peer.id)
},
}))
function closeAllPeers() {
logger.withFields({ totalPeers: peers.size }).log('closing all peers')
for (const peer of peers.values()) {
for (const peer of Array.from(peers.values())) {
logger.withFields({ peer: peer.peer.id, peerName: peer.name }).debug('closing peer')
peer.peer.close?.()
try {
peer.peer.close?.()
}
catch (error) {
logger.withFields({ peer: peer.peer.id, peerName: peer.name }).withError(error as Error).debug('failed to close peer during shutdown')
}
}
}
function dispose() {
if (disposed) {
return
}
disposed = true
clearInterval(healthCheckInterval)
closeAllPeers()
resetRoutingState()
}
return {
app,
closeAllPeers,
dispose,
}
}

View file

@ -13,6 +13,7 @@ function createPeer(options: {
plugin?: string
instanceId?: string
labels?: Record<string, string>
authenticated?: boolean
}): AuthenticatedPeer {
return {
peer: {
@ -21,7 +22,7 @@ function createPeer(options: {
request: { url: 'http://localhost', headers: new Headers() },
remoteAddress: '127.0.0.1',
},
authenticated: true,
authenticated: options.authenticated ?? true,
name: options.name,
identity: options.plugin && options.instanceId
? { kind: 'plugin', plugin: { id: options.plugin }, id: options.instanceId, labels: options.labels }
@ -57,6 +58,7 @@ describe('match-expression', () => {
expect(matchesLabelSelector('env=prod', { env: 'dev' })).toBe(false)
expect(matchesLabelSelector('feature', { feature: 'on' })).toBe(true)
expect(matchesLabelSelector('missing', { env: 'prod' })).toBe(false)
expect(matchesLabelSelector(' env = prod ', { env: 'prod' })).toBe(true)
})
it('matches label selector list', () => {
@ -167,6 +169,48 @@ describe('route middleware', () => {
expect([...decision!.targetIds]).toEqual(['peer-1'])
})
it('policy middleware excludes unauthenticated peers', () => {
const peers = new Map<string, AuthenticatedPeer>([
['peer-1', createPeer({ id: 'peer-1', name: 'telegram', plugin: 'telegram-bot', instanceId: 'telegram-1', labels: { env: 'prod' } })],
['peer-2', createPeer({ id: 'peer-2', name: 'stage-ui', plugin: 'stage-ui', instanceId: 'stage-ui-1', labels: { env: 'prod' }, authenticated: false })],
])
const policy = createPolicyMiddleware({ allowLabels: ['env=prod'] })
const decision = policy({
event: createSparkNotifyEvent(),
fromPeer: peers.get('peer-1')!,
peers,
destinations: undefined,
})
expect(decision).toBeDefined()
if (!decision || decision.type !== 'targets')
return
expect([...decision.targetIds]).toEqual(['peer-1'])
})
it('policy middleware does not authorize bypass by itself', () => {
const peers = new Map<string, AuthenticatedPeer>([
['peer-1', createPeer({ id: 'peer-1', name: 'telegram', plugin: 'telegram-bot', instanceId: 'telegram-1', labels: { env: 'prod' } })],
['peer-2', createPeer({ id: 'peer-2', name: 'stage-ui', plugin: 'stage-ui', instanceId: 'stage-ui-1', labels: { env: 'dev' } })],
])
const policy = createPolicyMiddleware({ allowLabels: ['env=prod'] })
const decision = policy({
event: createSparkNotifyEvent({ route: { bypass: true } }),
fromPeer: peers.get('peer-1')!,
peers,
destinations: undefined,
})
expect(decision).toBeDefined()
if (!decision || decision.type !== 'targets')
return
expect([...decision.targetIds]).toEqual(['peer-1'])
})
it('devtools peer detection uses label', () => {
const peer = createPeer({
id: 'peer-3',

View file

@ -32,13 +32,43 @@ function getPeerLabels(peer: AuthenticatedPeer) {
}
}
/**
* Detects whether a peer should be treated as a trusted devtools sender.
*
* Use when:
* - Checking whether route bypass is allowed for a peer
* - Applying devtools-only routing affordances
*
* Expects:
* - Peer labels to be sourced from authenticated identity metadata
*
* Returns:
* - `true` when the peer declares a devtools label or uses a devtools module name
*/
export function isDevtoolsPeer(peer: AuthenticatedPeer) {
const devtoolsLabel = getPeerLabels(peer).devtools
const isDevtoolsLabel = devtoolsLabel === 'true' || devtoolsLabel === '1'
return Boolean(isDevtoolsLabel || peer.name.includes('devtools'))
}
/**
* Evaluates whether a peer is allowed by the active routing policy.
*
* Use when:
* - Building a target list from the connected peer registry
* - Enforcing allow/deny lists before broadcasting an event
*
* Expects:
* - Unauthenticated peers must never be considered routable targets
*
* Returns:
* - `true` when the peer is authenticated and satisfies all policy constraints
*/
export function peerMatchesPolicy(peer: AuthenticatedPeer, policy: RoutingPolicy) {
if (!peer.authenticated) {
return false
}
const pluginId = peer.identity?.plugin?.id ?? ''
if (policy.allowPlugins?.length && !policy.allowPlugins.includes(pluginId)) {
@ -61,12 +91,21 @@ export function peerMatchesPolicy(peer: AuthenticatedPeer, policy: RoutingPolicy
return true
}
/**
* Creates a routing middleware from a static allow/deny policy.
*
* Use when:
* - Server-wide routing rules should be applied consistently
* - Destination filtering should be derived from peer metadata instead of event payloads
*
* Expects:
* - Route bypass authorization to be handled by the caller, not by the policy itself
*
* Returns:
* - A middleware that narrows delivery to the peers allowed by the policy
*/
export function createPolicyMiddleware(policy: RoutingPolicy): RouteMiddleware {
return ({ event, peers }) => {
if (event.route?.bypass) {
return
}
return ({ peers }) => {
const targetIds = new Set<string>()
for (const [id, peer] of peers.entries()) {
if (peerMatchesPolicy(peer, policy)) {
@ -78,6 +117,19 @@ export function createPolicyMiddleware(policy: RoutingPolicy): RouteMiddleware {
}
}
/**
* Resolves the destinations attached to an event.
*
* Use when:
* - Route-level destinations should override payload-level destinations
* - Delivery logic needs to distinguish between "broadcast" and "explicitly send nowhere"
*
* Expects:
* - An explicit empty `route.destinations` array is a meaningful override
*
* Returns:
* - The route destinations, payload destinations, or `undefined` when the event is unrestricted
*/
export function collectDestinations(event: WebSocketEvent | (Omit<WebSocketEvent, 'metadata'> & Partial<Pick<WebSocketEvent, 'metadata'>>)) {
if (event.route && 'destinations' in event.route) {
return event.route.destinations

View file

@ -18,7 +18,10 @@ function matchesGlob(glob: string, value?: string) {
}
export function matchesLabelSelector(selector: string, labels: Record<string, string>) {
const [key, value] = selector.split('=', 2)
const [rawKey, rawValue] = selector.split('=', 2)
const key = rawKey?.trim()
const value = rawValue?.trim()
if (!key) {
return false
}

View file

@ -1,3 +1,4 @@
import { Format, LogLevelString } from '@guiiai/logg'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const serveMocks = vi.hoisted(() => {
@ -10,15 +11,18 @@ const serveMocks = vi.hoisted(() => {
}))
const closeCall = vi.fn(async () => {})
const disposeCall = vi.fn(() => {})
const setupAppCall = vi.fn(() => ({
app: {
fetch: vi.fn(async () => ({ crossws: {} })),
},
closeAllPeers: vi.fn(),
dispose: disposeCall,
}))
return {
closeCall,
disposeCall,
rejectServe: (error: Error) => rejectServe?.(error),
resolveServe: () => resolveServe?.(),
serveCall,
@ -77,6 +81,7 @@ describe('createServer', async () => {
serveMocks.rejectServe(new Error('bind failed'))
await expect(firstStart).rejects.toThrow('bind failed')
expect(serveMocks.disposeCall).toHaveBeenCalledTimes(1)
const retryStart = server.start()
expect(serveMocks.serveCall).toHaveBeenCalledTimes(2)
@ -84,4 +89,37 @@ describe('createServer', async () => {
serveMocks.resolveServe()
await retryStart
})
it('merges nested config updates instead of replacing sibling logger settings', async () => {
const server = createServer({
hostname: '127.0.0.1',
port: 6121,
logger: {
app: { level: LogLevelString.Log },
websocket: { format: Format.Pretty },
},
})
server.updateConfig({
logger: {
app: { format: Format.Pretty },
},
})
const startTask = server.start()
serveMocks.resolveServe()
await startTask
expect(serveMocks.setupAppCall).toHaveBeenCalledWith(expect.objectContaining({
logger: {
app: {
level: LogLevelString.Log,
format: Format.Pretty,
},
websocket: {
format: Format.Pretty,
},
},
}))
})
})

View file

@ -32,9 +32,22 @@ export interface Server {
updateConfig: (newOptions: ServerOptions) => void
}
/**
* Collects local IP addresses that can be used to reach the server from the LAN.
*
* Use when:
* - Building connection hints for `0.0.0.0` listeners
* - Showing reachable addresses in logs or UI
*
* Expects:
* - Virtual interfaces should be ignored to reduce noisy or misleading addresses
*
* Returns:
* - A de-duplicated list of valid IP addresses discovered from the host network interfaces
*/
export function getLocalIPs(): string[] {
const interfaces = networkInterfaces()
const addresses: string[] = []
const addresses = new Set<string>()
const VIRTUAL_INTERFACE_PREFIXES = [
'vboxnet',
@ -63,13 +76,26 @@ export function getLocalIPs(): string[] {
const address = rawAddress.includes('%') ? rawAddress.split('%')[0] : rawAddress
if (isIP(address))
addresses.push(address)
addresses.add(address)
}
}
return addresses
return [...addresses]
}
/**
* Creates the websocket server controller for the AIRI runtime.
*
* Use when:
* - Starting, stopping, or restarting the standalone runtime server
* - Updating bind options between restarts
*
* Expects:
* - The returned controller to manage a single active server instance at a time
*
* Returns:
* - Lifecycle helpers for starting, stopping, restarting, and updating server options
*/
export function createServer(opts?: ServerOptions): Server {
let options = merge<ServerOptions>({ port: 6121, hostname: '127.0.0.1' }, opts)
@ -140,8 +166,7 @@ export function createServer(opts?: ServerOptions): Server {
try {
serverInstance = {
close: async (closeActiveConnections = false) => {
log.log('closing all peers')
h3App.closeAllPeers()
h3App.dispose()
log.log('closing server instance')
await instance.close(closeActiveConnections)
log.log('server instance closed')
@ -162,7 +187,7 @@ export function createServer(opts?: ServerOptions): Server {
}
catch (error) {
serverInstance = null
h3App.closeAllPeers()
h3App.dispose()
await instance.close(true).catch(() => {})
log.withError(error).error('failed to start WebSocket server')
throw error
@ -183,12 +208,16 @@ export function createServer(opts?: ServerOptions): Server {
await start()
}
async function updateConfig(newOptions: ServerOptions) {
options = { ...options, ...newOptions }
function updateConfig(newOptions: ServerOptions) {
options = merge<ServerOptions>(options, newOptions)
}
return {
getConnectionHost: () => {
if (options.hostname && options.hostname !== '0.0.0.0' && options.hostname !== '::') {
return [options.hostname]
}
return getLocalIPs()
},
start,

View file

@ -1,7 +1,7 @@
<script setup lang="ts">
import type { SpeechProvider } from '@xsai-ext/providers/utils'
import { getCachedWebGPUCapabilities } from '@proj-airi/stage-shared/webgpu'
import { getCachedWebGPUCapabilities, hasNavigatorWebGPU } from '@proj-airi/stage-shared/webgpu'
import {
SpeechPlayground,
SpeechProviderSettings,
@ -104,7 +104,7 @@ onMounted(async () => {
// NOTICE: Uses synchronous check for initial render. The cached result from
// detectWebGPU() is populated by the providers store during initialization.
const capabilities = getCachedWebGPUCapabilities()
hasWebGPU.value = capabilities?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
hasWebGPU.value = capabilities?.supported ?? hasNavigatorWebGPU()
fp16Supported.value = capabilities?.fp16Supported ?? false
try {

View file

@ -7,6 +7,14 @@
import { check as gpuuCheck, isWebGPUSupported as gpuuIsSupported } from 'gpuu/webgpu'
interface NavigatorWebGPU {
requestAdapter: () => Promise<unknown>
}
interface NavigatorWithOptionalWebGPU extends Navigator {
gpu?: NavigatorWebGPU
}
export interface WebGPUCapabilities {
/** Whether WebGPU is available in this environment */
supported: boolean
@ -21,6 +29,50 @@ export interface WebGPUCapabilities {
let cachedResult: WebGPUCapabilities | null = null
let pendingDetection: Promise<WebGPUCapabilities> | null = null
/**
* Returns the WebGPU navigator entry point when the current runtime exposes it.
*
* Use when:
* - browser or worker code needs guarded access to WebGPU
* - a caller only needs the `navigator.gpu` entry point, not a full capability probe
*
* Expects:
* - browser-like runtimes where `navigator` may be unavailable during SSR or tests
*
* Returns:
* - the WebGPU navigator entry point when available, otherwise `null`
*/
export function getNavigatorWebGPU(): NavigatorWebGPU | null {
// NOTICE:
// TypeScript's default DOM libs in this repo do not declare `Navigator.gpu`.
// Direct `navigator.gpu` access fails in consumers like `@proj-airi/ui-server-auth`
// and in `@proj-airi/stage-ui` worker compilation even though the runtime guard is valid.
// We centralize the structural cast here until the repo opts into WebGPU ambient types.
// Removal condition: remove this helper once the workspace TypeScript config includes
// WebGPU navigator typings everywhere that imports these modules.
if (typeof navigator === 'undefined' || !('gpu' in navigator))
return null
return (navigator as NavigatorWithOptionalWebGPU).gpu ?? null
}
/**
* Returns whether the current runtime exposes a WebGPU entry point on `navigator`.
*
* Use when:
* - synchronous code only needs a fast boolean feature probe
* - cached capability detection has not completed yet
*
* Expects:
* - browser-like runtimes where `navigator` may be unavailable
*
* Returns:
* - `true` when `navigator.gpu` is available, otherwise `false`
*/
export function hasNavigatorWebGPU(): boolean {
return getNavigatorWebGPU() != null
}
/**
* Detect WebGPU capabilities. The result is cached as a singleton
* after the first successful call -- safe to call repeatedly.

View file

@ -1,6 +1,8 @@
export {
detectWebGPU,
getCachedWebGPUCapabilities,
getNavigatorWebGPU,
hasNavigatorWebGPU,
isWebGPUSupported,
resetWebGPUCache,
} from './detect'

View file

@ -31,6 +31,7 @@ import {
TextStreamer,
WhisperForConditionalGeneration,
} from '@huggingface/transformers'
import { getNavigatorWebGPU } from '@proj-airi/stage-shared/webgpu'
import { MODEL_IDS, MODEL_NAMES } from '../inference/constants'
import { classifyError, isRecoverable } from '../inference/protocol'
@ -75,9 +76,10 @@ const MODEL_ID = MODEL_IDS.WHISPER
*/
async function detectWebGPUInWorker(): Promise<boolean> {
try {
if (typeof navigator === 'undefined' || !navigator.gpu)
const webgpu = getNavigatorWebGPU()
if (!webgpu)
return false
const adapter = await navigator.gpu.requestAdapter()
const adapter = await webgpu.requestAdapter()
return adapter != null
}
catch {

View file

@ -21,7 +21,7 @@ import type {
import type { AliyunRealtimeSpeechExtraOptions } from './providers/aliyun/stream-transcription'
import { isStageTamagotchi, isUrl } from '@proj-airi/stage-shared'
import { getCachedWebGPUCapabilities, isWebGPUSupported } from '@proj-airi/stage-shared/webgpu'
import { getCachedWebGPUCapabilities, hasNavigatorWebGPU, isWebGPUSupported } from '@proj-airi/stage-shared/webgpu'
import { computedAsync, useIntervalFn, useLocalStorage } from '@vueuse/core'
import {
createOpenAI,
@ -1715,7 +1715,7 @@ export const useProvidersStore = defineStore('providers', () => {
defaultOptions: () => {
const capabilities = getCachedWebGPUCapabilities()
const hasWebGPU = capabilities?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
const hasWebGPU = capabilities?.supported ?? hasNavigatorWebGPU()
const fp16Supported = capabilities?.fp16Supported ?? false
const model = getDefaultKokoroModel(hasWebGPU, fp16Supported)
return {
@ -1772,7 +1772,7 @@ export const useProvidersStore = defineStore('providers', () => {
capabilities: {
listModels: async (_config: Record<string, unknown>) => {
const caps = getCachedWebGPUCapabilities()
const hasWebGPU = caps?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
const hasWebGPU = caps?.supported ?? hasNavigatorWebGPU()
const fp16Supported = caps?.fp16Supported ?? false
return kokoroModelsToModelInfo(hasWebGPU, t, fp16Supported)
},
@ -1791,7 +1791,7 @@ export const useProvidersStore = defineStore('providers', () => {
// Validate platform requirements
if (modelDef.platform === 'webgpu') {
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? hasNavigatorWebGPU()
if (!hasWebGPU) {
throw new Error('WebGPU is required for this model but is not available in your browser')
}
@ -1835,7 +1835,7 @@ export const useProvidersStore = defineStore('providers', () => {
const modelDef = KOKORO_MODELS.find(m => m.id === modelId)
if (modelDef) {
if (modelDef.platform === 'webgpu') {
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? hasNavigatorWebGPU()
if (!hasWebGPU) {
throw new Error('WebGPU is required for this model but is not available in your browser')
}

View file

@ -18,6 +18,7 @@ import type {
} from '../../libs/inference/protocol'
import { AutoModel, AutoProcessor, env, RawImage } from '@huggingface/transformers'
import { getNavigatorWebGPU } from '@proj-airi/stage-shared/webgpu'
import { MODEL_IDS, MODEL_NAMES } from '../../libs/inference/constants'
import { classifyError, isRecoverable } from '../../libs/inference/protocol'
@ -80,9 +81,10 @@ function sendError(requestId: string, error: unknown, phase?: 'load' | 'inferenc
*/
async function detectWebGPUInWorker(): Promise<boolean> {
try {
if (typeof navigator === 'undefined' || !navigator.gpu)
const webgpu = getNavigatorWebGPU()
if (!webgpu)
return false
const adapter = await navigator.gpu.requestAdapter()
const adapter = await webgpu.requestAdapter()
return adapter != null
}
catch {