mirror of
https://github.com/moeru-ai/airi.git
synced 2026-05-19 08:10:45 +00:00
fix(server-runtime): enforce auth & routing safety, fix lifecycle leaks (#1686)
This commit is contained in:
parent
e165829b19
commit
3f829421d3
13 changed files with 455 additions and 83 deletions
|
|
@ -1 +1 @@
|
|||
sha256-69tCpJaxRUnyR9CrHlmWJWEWLDpYzyzska7ui9++QoY=
|
||||
sha256-+ruQJso6gF5n4vdu9xTMiYoTP6q3xRrCVsGrEEggQqc=
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import { Buffer } from 'node:buffer'
|
|||
import { timingSafeEqual } from 'node:crypto'
|
||||
|
||||
import { availableLogLevelStrings, Format, LogLevelString, logLevelStringToLogLevelMap, useLogg } from '@guiiai/logg'
|
||||
import { errorMessageFrom } from '@moeru/std'
|
||||
import {
|
||||
createInvalidJsonServerErrorMessage,
|
||||
ServerErrorMessages,
|
||||
|
|
@ -122,17 +123,61 @@ export interface ConsumerSelectionCandidate {
|
|||
healthy?: boolean
|
||||
}
|
||||
|
||||
function isConsumerDeliveryMode(mode: unknown): mode is 'consumer' | 'consumer-group' {
|
||||
return mode === 'consumer' || mode === 'consumer-group'
|
||||
}
|
||||
|
||||
function normalizeConsumerMode(mode: unknown, group?: string): 'consumer' | 'consumer-group' {
|
||||
if (isConsumerDeliveryMode(mode)) {
|
||||
return mode
|
||||
}
|
||||
|
||||
return group ? 'consumer-group' : 'consumer'
|
||||
}
|
||||
|
||||
function normalizeConsumerPriority(priority: unknown) {
|
||||
return typeof priority === 'number' && Number.isFinite(priority)
|
||||
? priority
|
||||
: 0
|
||||
}
|
||||
|
||||
// helper send function
|
||||
function send(peer: Peer, event: WebSocketEvent<Record<string, unknown>> | string) {
|
||||
peer.send(typeof event === 'string' ? event : stringify(event))
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects raw websocket heartbeat control frames surfaced as text payloads.
|
||||
*
|
||||
* Use when:
|
||||
* - A websocket runtime forwards ping/pong frames through the normal message callback
|
||||
* - The runtime should ignore transport heartbeats instead of treating them as protocol JSON
|
||||
*
|
||||
* Expects:
|
||||
* - Raw text payloads such as `ping` and `pong`
|
||||
*
|
||||
* Returns:
|
||||
* - The heartbeat kind when the text is a control frame, otherwise `undefined`
|
||||
*/
|
||||
export function detectHeartbeatControlFrame(text: string): MessageHeartbeatKind | undefined {
|
||||
if (text === MessageHeartbeatKind.Ping || text === MessageHeartbeatKind.Pong) {
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves the effective delivery configuration for an event.
|
||||
*
|
||||
* Use when:
|
||||
* - Protocol defaults should be merged with route-level overrides
|
||||
* - Delivery mode selection needs to happen before routing or consumer dispatch
|
||||
*
|
||||
* Expects:
|
||||
* - Route delivery to override protocol metadata field-by-field
|
||||
*
|
||||
* Returns:
|
||||
* - The merged delivery config or `undefined` when the event has no delivery rules
|
||||
*/
|
||||
export function resolveDeliveryConfig(event: WebSocketEvent): DeliveryConfig | undefined {
|
||||
const eventMetadata = getProtocolEventMetadata(event.type)
|
||||
const defaultDelivery = eventMetadata?.delivery
|
||||
|
|
@ -170,6 +215,19 @@ function sortConsumers(entries: Array<Pick<ConsumerSelectionCandidate, 'peerId'
|
|||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects a concrete consumer peer for consumer-style delivery modes.
|
||||
*
|
||||
* Use when:
|
||||
* - An event should be sent to exactly one registered consumer
|
||||
* - Sticky or round-robin routing needs to be resolved against the live peer registry
|
||||
*
|
||||
* Expects:
|
||||
* - Candidates to already describe the authenticated and health state of each peer
|
||||
*
|
||||
* Returns:
|
||||
* - The selected peer id, or `undefined` when no eligible consumer is available
|
||||
*/
|
||||
export function selectConsumerPeerId(options: {
|
||||
eventType: string
|
||||
fromPeerId: string
|
||||
|
|
@ -241,6 +299,19 @@ export interface AppOptions {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes logger settings from explicit options and environment variables.
|
||||
*
|
||||
* Use when:
|
||||
* - The runtime should support config-driven and env-driven logging
|
||||
* - App and websocket logger settings need consistent defaults
|
||||
*
|
||||
* Expects:
|
||||
* - Explicit websocket settings to override app-level defaults
|
||||
*
|
||||
* Returns:
|
||||
* - The resolved app and websocket logger configuration
|
||||
*/
|
||||
export function normalizeLoggerConfig(options?: AppOptions) {
|
||||
const appLogLevel = optionOrEnv(options?.logger?.app?.level, 'LOG_LEVEL', LogLevelString.Log, { validator: (value): value is LogLevelString => availableLogLevelStrings.includes(value as LogLevelString) })
|
||||
const appLogFormat = optionOrEnv(options?.logger?.app?.format, 'LOG_FORMAT', Format.Pretty, { validator: (value): value is Format => Object.values(Format).includes(value as Format) })
|
||||
|
|
@ -255,7 +326,20 @@ export function normalizeLoggerConfig(options?: AppOptions) {
|
|||
}
|
||||
}
|
||||
|
||||
export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () => void } {
|
||||
/**
|
||||
* Creates the H3 websocket application and its in-memory peer registry.
|
||||
*
|
||||
* Use when:
|
||||
* - Embedding the AIRI websocket runtime inside a server process
|
||||
* - Spinning up a testable application instance before binding a socket listener
|
||||
*
|
||||
* Expects:
|
||||
* - Caller lifecycle management to invoke `dispose` when the app is no longer needed
|
||||
*
|
||||
* Returns:
|
||||
* - The H3 app plus cleanup helpers for peer shutdown and timer disposal
|
||||
*/
|
||||
export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () => void, dispose: () => void } {
|
||||
const instanceId = options?.instanceId || optionOrEnv(undefined, 'SERVER_INSTANCE_ID', nanoid())
|
||||
const authToken = optionOrEnv(options?.auth?.token, 'AUTHENTICATION_TOKEN', '')
|
||||
|
||||
|
|
@ -284,8 +368,41 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
const HEALTH_CHECK_MISSES_UNHEALTHY = 5
|
||||
const HEALTH_CHECK_MISSES_DEAD = HEALTH_CHECK_MISSES_UNHEALTHY * 2
|
||||
const healthCheckIntervalMs = Math.max(5_000, Math.floor(heartbeatTtlMs / HEALTH_CHECK_MISSES_UNHEALTHY))
|
||||
let disposed = false
|
||||
|
||||
setInterval(() => {
|
||||
function broadcastPeerHealthy(peerInfo: AuthenticatedPeer, parentId?: string) {
|
||||
if (!peerInfo.name || !peerInfo.identity) {
|
||||
return
|
||||
}
|
||||
|
||||
broadcastToAuthenticated({
|
||||
type: 'registry:modules:health:healthy',
|
||||
data: { name: peerInfo.name, index: peerInfo.index, identity: peerInfo.identity },
|
||||
metadata: createServerEventMetadata(instanceId, parentId),
|
||||
})
|
||||
}
|
||||
|
||||
function markPeerAlive(peerInfo: AuthenticatedPeer, options?: { parentId?: string, logMessage?: string }) {
|
||||
peerInfo.lastHeartbeatAt = Date.now()
|
||||
peerInfo.missedHeartbeats = 0
|
||||
|
||||
if (peerInfo.healthy === false && peerInfo.authenticated) {
|
||||
peerInfo.healthy = true
|
||||
logger.withFields({ peer: peerInfo.peer.id, peerName: peerInfo.name }).debug(options?.logMessage ?? 'peer activity recovered, marking healthy')
|
||||
broadcastPeerHealthy(peerInfo, options?.parentId)
|
||||
}
|
||||
}
|
||||
|
||||
function resetRoutingState() {
|
||||
peers.clear()
|
||||
peersByModule.clear()
|
||||
consumerRegistry.clear()
|
||||
consumerKeysByPeer.clear()
|
||||
deliveryRoundRobinCursor.clear()
|
||||
stickyAssignments.clear()
|
||||
}
|
||||
|
||||
const healthCheckInterval = setInterval(() => {
|
||||
const now = Date.now()
|
||||
for (const [id, peerInfo] of peers.entries()) {
|
||||
if (!peerInfo.lastHeartbeatAt) {
|
||||
|
|
@ -325,6 +442,9 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
}
|
||||
}
|
||||
}, healthCheckIntervalMs)
|
||||
if (typeof healthCheckInterval === 'object') {
|
||||
healthCheckInterval.unref?.()
|
||||
}
|
||||
|
||||
function registerModulePeer(p: AuthenticatedPeer, name: string, index?: number) {
|
||||
if (!peersByModule.has(name)) {
|
||||
|
|
@ -361,7 +481,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
event,
|
||||
group: normalizedGroup,
|
||||
peerId,
|
||||
priority: priority ?? 0,
|
||||
priority: normalizeConsumerPriority(priority),
|
||||
registeredAt: Date.now(),
|
||||
})
|
||||
|
||||
|
|
@ -466,33 +586,45 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
return peers.get(selectedPeerId)
|
||||
}
|
||||
|
||||
function unregisterModulePeer(p: AuthenticatedPeer, reason?: string) {
|
||||
unregisterPeerConsumers(p.peer.id)
|
||||
function unregisterModuleRegistration(
|
||||
peerInfo: AuthenticatedPeer,
|
||||
options?: { reason?: string, unregisterConsumers?: boolean },
|
||||
) {
|
||||
if (options?.unregisterConsumers !== false) {
|
||||
unregisterPeerConsumers(peerInfo.peer.id)
|
||||
}
|
||||
|
||||
if (!p.name)
|
||||
if (!peerInfo.name)
|
||||
return
|
||||
|
||||
const group = peersByModule.get(p.name)
|
||||
const group = peersByModule.get(peerInfo.name)
|
||||
if (group) {
|
||||
group.delete(p.index)
|
||||
group.delete(peerInfo.index)
|
||||
|
||||
if (group.size === 0) {
|
||||
peersByModule.delete(p.name)
|
||||
peersByModule.delete(peerInfo.name)
|
||||
}
|
||||
}
|
||||
|
||||
// broadcast module:de-announced to all authenticated peers
|
||||
if (p.identity) {
|
||||
if (peerInfo.identity) {
|
||||
broadcastToAuthenticated({
|
||||
type: 'module:de-announced',
|
||||
data: { name: p.name, index: p.index, identity: p.identity, reason },
|
||||
data: { name: peerInfo.name, index: peerInfo.index, identity: peerInfo.identity, reason: options?.reason },
|
||||
metadata: createServerEventMetadata(instanceId),
|
||||
})
|
||||
}
|
||||
|
||||
peerInfo.name = ''
|
||||
peerInfo.index = undefined
|
||||
|
||||
broadcastRegistrySync()
|
||||
}
|
||||
|
||||
function unregisterModulePeer(peerInfo: AuthenticatedPeer, reason?: string) {
|
||||
unregisterModuleRegistration(peerInfo, { reason })
|
||||
}
|
||||
|
||||
function listKnownModules() {
|
||||
return Array.from(peers.values())
|
||||
.filter(peerInfo => peerInfo.name && peerInfo.identity)
|
||||
|
|
@ -553,19 +685,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
// liveness only so they do not leak into the application event protocol.
|
||||
if (controlFrame) {
|
||||
if (authenticatedPeer) {
|
||||
authenticatedPeer.lastHeartbeatAt = Date.now()
|
||||
authenticatedPeer.missedHeartbeats = 0
|
||||
|
||||
if (authenticatedPeer.healthy === false && authenticatedPeer.name && authenticatedPeer.identity) {
|
||||
authenticatedPeer.healthy = true
|
||||
logger.withFields({ peer: peer.id, peerName: authenticatedPeer.name })
|
||||
.debug('ping/pong recovered, marking healthy')
|
||||
broadcastToAuthenticated({
|
||||
type: 'registry:modules:health:healthy',
|
||||
data: { name: authenticatedPeer.name, index: authenticatedPeer.index, identity: authenticatedPeer.identity },
|
||||
metadata: createServerEventMetadata(instanceId),
|
||||
})
|
||||
}
|
||||
markPeerAlive(authenticatedPeer, { logMessage: 'ping/pong recovered, marking healthy' })
|
||||
}
|
||||
|
||||
return
|
||||
|
|
@ -592,7 +712,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
event = potentialEvent as WebSocketEvent
|
||||
}
|
||||
catch (err) {
|
||||
const errorMessage = err instanceof Error ? err.message : String(err)
|
||||
const errorMessage = errorMessageFrom(err) ?? 'Unknown JSON parsing error'
|
||||
send(peer, RESPONSES.error(createInvalidJsonServerErrorMessage(errorMessage), instanceId))
|
||||
|
||||
return
|
||||
|
|
@ -606,8 +726,9 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
}).debug('received event')
|
||||
|
||||
if (authenticatedPeer) {
|
||||
authenticatedPeer.lastHeartbeatAt = Date.now()
|
||||
if (event.metadata?.source) {
|
||||
markPeerAlive(authenticatedPeer, { parentId: event.metadata?.event.id })
|
||||
|
||||
if (authenticatedPeer.authenticated && event.metadata?.source) {
|
||||
authenticatedPeer.identity = event.metadata.source
|
||||
}
|
||||
}
|
||||
|
|
@ -616,19 +737,12 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
case 'transport:connection:heartbeat': {
|
||||
const p = peers.get(peer.id)
|
||||
if (p) {
|
||||
p.lastHeartbeatAt = Date.now()
|
||||
p.missedHeartbeats = 0
|
||||
markPeerAlive(p, {
|
||||
parentId: event.metadata?.event.id,
|
||||
logMessage: 'heartbeat recovered, marking healthy',
|
||||
})
|
||||
|
||||
// recover from unhealthy → healthy
|
||||
if (p.healthy === false && p.name && p.identity) {
|
||||
p.healthy = true
|
||||
logger.withFields({ peer: peer.id, peerName: p.name }).debug('heartbeat recovered, marking healthy')
|
||||
broadcastToAuthenticated({
|
||||
type: 'registry:modules:health:healthy',
|
||||
data: { name: p.name, index: p.index, identity: p.identity },
|
||||
metadata: createServerEventMetadata(instanceId, event.metadata?.event.id),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if (event.data.kind === MessageHeartbeatKind.Ping) {
|
||||
|
|
@ -664,9 +778,6 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
return
|
||||
}
|
||||
|
||||
unregisterModulePeer(p, 're-announcing')
|
||||
|
||||
// verify
|
||||
const { name, index, identity } = event.data as { name: string, index?: number, identity?: MetadataEventSource }
|
||||
if (!name || typeof name !== 'string') {
|
||||
send(peer, RESPONSES.error(ServerErrorMessages.moduleAnnounceNameInvalid, instanceId))
|
||||
|
|
@ -691,11 +802,14 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
return
|
||||
}
|
||||
|
||||
unregisterModuleRegistration(p, {
|
||||
reason: 're-announcing',
|
||||
unregisterConsumers: false,
|
||||
})
|
||||
|
||||
p.name = name
|
||||
p.index = index
|
||||
if (identity) {
|
||||
p.identity = identity
|
||||
}
|
||||
p.identity = identity
|
||||
|
||||
registerModulePeer(p, name, index)
|
||||
|
||||
|
|
@ -775,7 +889,13 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
return
|
||||
}
|
||||
|
||||
registerConsumer(peer.id, data.event, data.mode ?? (data.group ? 'consumer-group' : 'consumer'), data.group, data.priority)
|
||||
registerConsumer(
|
||||
peer.id,
|
||||
data.event,
|
||||
normalizeConsumerMode(data.mode, data.group),
|
||||
data.group,
|
||||
normalizeConsumerPriority(data.priority),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -797,7 +917,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
return
|
||||
}
|
||||
|
||||
unregisterConsumer(peer.id, data.event, data.mode ?? (data.group ? 'consumer-group' : 'consumer'), data.group)
|
||||
unregisterConsumer(peer.id, data.event, normalizeConsumerMode(data.mode, data.group), data.group)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -816,6 +936,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
const shouldBypass = Boolean(event.route?.bypass && allowBypass && isDevtoolsPeer(p))
|
||||
const destinations = shouldBypass ? undefined : collectDestinations(event)
|
||||
const delivery = shouldBypass ? undefined : resolveDeliveryConfig(event)
|
||||
const effectiveRoutingMiddleware = shouldBypass ? [] : routingMiddleware
|
||||
const routingContext: RouteContext = {
|
||||
event,
|
||||
fromPeer: p,
|
||||
|
|
@ -824,7 +945,7 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
}
|
||||
|
||||
let decision: RouteDecision | undefined
|
||||
for (const middleware of routingMiddleware) {
|
||||
for (const middleware of effectiveRoutingMiddleware) {
|
||||
const result = middleware(routingContext)
|
||||
if (result) {
|
||||
decision = result
|
||||
|
|
@ -886,11 +1007,16 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
continue
|
||||
}
|
||||
|
||||
if (!other.authenticated) {
|
||||
logger.withFields({ fromPeer: peer.id, toPeer: other.peer.id, toPeerName: other.name, event }).debug('not sending event to unauthenticated peer')
|
||||
continue
|
||||
}
|
||||
|
||||
if (!shouldBroadcast && targetIds && !targetIds.has(id)) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (shouldBroadcast && destinations && destinations.length > 0 && !matchesDestinations(destinations, other)) {
|
||||
if (shouldBroadcast && destinations !== undefined && !matchesDestinations(destinations, other)) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -913,6 +1039,10 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
close: (peer, details) => {
|
||||
const p = peers.get(peer.id)
|
||||
const now = Date.now()
|
||||
const peerName = p?.name
|
||||
const peerIndex = p?.index
|
||||
const peerHealthy = p?.healthy
|
||||
const peerMissedHeartbeats = p?.missedHeartbeats
|
||||
const safeDetails = details ?? {}
|
||||
const closeCode = typeof safeDetails.code === 'number' ? safeDetails.code : undefined
|
||||
const closeReason = typeof safeDetails.reason === 'string' ? safeDetails.reason : undefined
|
||||
|
|
@ -928,8 +1058,10 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
)
|
||||
const likelySilentNetworkClose = closeCode === 1005
|
||||
|
||||
if (p)
|
||||
if (p) {
|
||||
peers.delete(peer.id)
|
||||
unregisterModulePeer(p, 'connection closed')
|
||||
}
|
||||
|
||||
logger.withFields({
|
||||
peer: peer.id,
|
||||
|
|
@ -940,10 +1072,10 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
closeWasClean,
|
||||
activePeers: peers.size,
|
||||
peerAuthenticated: p?.authenticated,
|
||||
peerName: p?.name,
|
||||
peerIndex: p?.index,
|
||||
peerHealthy: p?.healthy,
|
||||
peerMissedHeartbeats: p?.missedHeartbeats,
|
||||
peerName,
|
||||
peerIndex,
|
||||
peerHealthy,
|
||||
peerMissedHeartbeats,
|
||||
heartbeatLastSeenAt,
|
||||
heartbeatSilentForMs,
|
||||
heartbeatTtlMs,
|
||||
|
|
@ -951,20 +1083,36 @@ export function setupApp(options?: AppOptions): { app: H3, closeAllPeers: () =>
|
|||
likelyHeartbeatExpiry,
|
||||
likelySilentNetworkClose,
|
||||
}).log('closed')
|
||||
peers.delete(peer.id)
|
||||
},
|
||||
}))
|
||||
|
||||
function closeAllPeers() {
|
||||
logger.withFields({ totalPeers: peers.size }).log('closing all peers')
|
||||
for (const peer of peers.values()) {
|
||||
for (const peer of Array.from(peers.values())) {
|
||||
logger.withFields({ peer: peer.peer.id, peerName: peer.name }).debug('closing peer')
|
||||
peer.peer.close?.()
|
||||
try {
|
||||
peer.peer.close?.()
|
||||
}
|
||||
catch (error) {
|
||||
logger.withFields({ peer: peer.peer.id, peerName: peer.name }).withError(error as Error).debug('failed to close peer during shutdown')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function dispose() {
|
||||
if (disposed) {
|
||||
return
|
||||
}
|
||||
|
||||
disposed = true
|
||||
clearInterval(healthCheckInterval)
|
||||
closeAllPeers()
|
||||
resetRoutingState()
|
||||
}
|
||||
|
||||
return {
|
||||
app,
|
||||
closeAllPeers,
|
||||
dispose,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ function createPeer(options: {
|
|||
plugin?: string
|
||||
instanceId?: string
|
||||
labels?: Record<string, string>
|
||||
authenticated?: boolean
|
||||
}): AuthenticatedPeer {
|
||||
return {
|
||||
peer: {
|
||||
|
|
@ -21,7 +22,7 @@ function createPeer(options: {
|
|||
request: { url: 'http://localhost', headers: new Headers() },
|
||||
remoteAddress: '127.0.0.1',
|
||||
},
|
||||
authenticated: true,
|
||||
authenticated: options.authenticated ?? true,
|
||||
name: options.name,
|
||||
identity: options.plugin && options.instanceId
|
||||
? { kind: 'plugin', plugin: { id: options.plugin }, id: options.instanceId, labels: options.labels }
|
||||
|
|
@ -57,6 +58,7 @@ describe('match-expression', () => {
|
|||
expect(matchesLabelSelector('env=prod', { env: 'dev' })).toBe(false)
|
||||
expect(matchesLabelSelector('feature', { feature: 'on' })).toBe(true)
|
||||
expect(matchesLabelSelector('missing', { env: 'prod' })).toBe(false)
|
||||
expect(matchesLabelSelector(' env = prod ', { env: 'prod' })).toBe(true)
|
||||
})
|
||||
|
||||
it('matches label selector list', () => {
|
||||
|
|
@ -167,6 +169,48 @@ describe('route middleware', () => {
|
|||
expect([...decision!.targetIds]).toEqual(['peer-1'])
|
||||
})
|
||||
|
||||
it('policy middleware excludes unauthenticated peers', () => {
|
||||
const peers = new Map<string, AuthenticatedPeer>([
|
||||
['peer-1', createPeer({ id: 'peer-1', name: 'telegram', plugin: 'telegram-bot', instanceId: 'telegram-1', labels: { env: 'prod' } })],
|
||||
['peer-2', createPeer({ id: 'peer-2', name: 'stage-ui', plugin: 'stage-ui', instanceId: 'stage-ui-1', labels: { env: 'prod' }, authenticated: false })],
|
||||
])
|
||||
|
||||
const policy = createPolicyMiddleware({ allowLabels: ['env=prod'] })
|
||||
const decision = policy({
|
||||
event: createSparkNotifyEvent(),
|
||||
fromPeer: peers.get('peer-1')!,
|
||||
peers,
|
||||
destinations: undefined,
|
||||
})
|
||||
|
||||
expect(decision).toBeDefined()
|
||||
if (!decision || decision.type !== 'targets')
|
||||
return
|
||||
|
||||
expect([...decision.targetIds]).toEqual(['peer-1'])
|
||||
})
|
||||
|
||||
it('policy middleware does not authorize bypass by itself', () => {
|
||||
const peers = new Map<string, AuthenticatedPeer>([
|
||||
['peer-1', createPeer({ id: 'peer-1', name: 'telegram', plugin: 'telegram-bot', instanceId: 'telegram-1', labels: { env: 'prod' } })],
|
||||
['peer-2', createPeer({ id: 'peer-2', name: 'stage-ui', plugin: 'stage-ui', instanceId: 'stage-ui-1', labels: { env: 'dev' } })],
|
||||
])
|
||||
|
||||
const policy = createPolicyMiddleware({ allowLabels: ['env=prod'] })
|
||||
const decision = policy({
|
||||
event: createSparkNotifyEvent({ route: { bypass: true } }),
|
||||
fromPeer: peers.get('peer-1')!,
|
||||
peers,
|
||||
destinations: undefined,
|
||||
})
|
||||
|
||||
expect(decision).toBeDefined()
|
||||
if (!decision || decision.type !== 'targets')
|
||||
return
|
||||
|
||||
expect([...decision.targetIds]).toEqual(['peer-1'])
|
||||
})
|
||||
|
||||
it('devtools peer detection uses label', () => {
|
||||
const peer = createPeer({
|
||||
id: 'peer-3',
|
||||
|
|
|
|||
|
|
@ -32,13 +32,43 @@ function getPeerLabels(peer: AuthenticatedPeer) {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects whether a peer should be treated as a trusted devtools sender.
|
||||
*
|
||||
* Use when:
|
||||
* - Checking whether route bypass is allowed for a peer
|
||||
* - Applying devtools-only routing affordances
|
||||
*
|
||||
* Expects:
|
||||
* - Peer labels to be sourced from authenticated identity metadata
|
||||
*
|
||||
* Returns:
|
||||
* - `true` when the peer declares a devtools label or uses a devtools module name
|
||||
*/
|
||||
export function isDevtoolsPeer(peer: AuthenticatedPeer) {
|
||||
const devtoolsLabel = getPeerLabels(peer).devtools
|
||||
const isDevtoolsLabel = devtoolsLabel === 'true' || devtoolsLabel === '1'
|
||||
return Boolean(isDevtoolsLabel || peer.name.includes('devtools'))
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates whether a peer is allowed by the active routing policy.
|
||||
*
|
||||
* Use when:
|
||||
* - Building a target list from the connected peer registry
|
||||
* - Enforcing allow/deny lists before broadcasting an event
|
||||
*
|
||||
* Expects:
|
||||
* - Unauthenticated peers must never be considered routable targets
|
||||
*
|
||||
* Returns:
|
||||
* - `true` when the peer is authenticated and satisfies all policy constraints
|
||||
*/
|
||||
export function peerMatchesPolicy(peer: AuthenticatedPeer, policy: RoutingPolicy) {
|
||||
if (!peer.authenticated) {
|
||||
return false
|
||||
}
|
||||
|
||||
const pluginId = peer.identity?.plugin?.id ?? ''
|
||||
|
||||
if (policy.allowPlugins?.length && !policy.allowPlugins.includes(pluginId)) {
|
||||
|
|
@ -61,12 +91,21 @@ export function peerMatchesPolicy(peer: AuthenticatedPeer, policy: RoutingPolicy
|
|||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a routing middleware from a static allow/deny policy.
|
||||
*
|
||||
* Use when:
|
||||
* - Server-wide routing rules should be applied consistently
|
||||
* - Destination filtering should be derived from peer metadata instead of event payloads
|
||||
*
|
||||
* Expects:
|
||||
* - Route bypass authorization to be handled by the caller, not by the policy itself
|
||||
*
|
||||
* Returns:
|
||||
* - A middleware that narrows delivery to the peers allowed by the policy
|
||||
*/
|
||||
export function createPolicyMiddleware(policy: RoutingPolicy): RouteMiddleware {
|
||||
return ({ event, peers }) => {
|
||||
if (event.route?.bypass) {
|
||||
return
|
||||
}
|
||||
|
||||
return ({ peers }) => {
|
||||
const targetIds = new Set<string>()
|
||||
for (const [id, peer] of peers.entries()) {
|
||||
if (peerMatchesPolicy(peer, policy)) {
|
||||
|
|
@ -78,6 +117,19 @@ export function createPolicyMiddleware(policy: RoutingPolicy): RouteMiddleware {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves the destinations attached to an event.
|
||||
*
|
||||
* Use when:
|
||||
* - Route-level destinations should override payload-level destinations
|
||||
* - Delivery logic needs to distinguish between "broadcast" and "explicitly send nowhere"
|
||||
*
|
||||
* Expects:
|
||||
* - An explicit empty `route.destinations` array is a meaningful override
|
||||
*
|
||||
* Returns:
|
||||
* - The route destinations, payload destinations, or `undefined` when the event is unrestricted
|
||||
*/
|
||||
export function collectDestinations(event: WebSocketEvent | (Omit<WebSocketEvent, 'metadata'> & Partial<Pick<WebSocketEvent, 'metadata'>>)) {
|
||||
if (event.route && 'destinations' in event.route) {
|
||||
return event.route.destinations
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ function matchesGlob(glob: string, value?: string) {
|
|||
}
|
||||
|
||||
export function matchesLabelSelector(selector: string, labels: Record<string, string>) {
|
||||
const [key, value] = selector.split('=', 2)
|
||||
const [rawKey, rawValue] = selector.split('=', 2)
|
||||
const key = rawKey?.trim()
|
||||
const value = rawValue?.trim()
|
||||
|
||||
if (!key) {
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import { Format, LogLevelString } from '@guiiai/logg'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const serveMocks = vi.hoisted(() => {
|
||||
|
|
@ -10,15 +11,18 @@ const serveMocks = vi.hoisted(() => {
|
|||
}))
|
||||
|
||||
const closeCall = vi.fn(async () => {})
|
||||
const disposeCall = vi.fn(() => {})
|
||||
const setupAppCall = vi.fn(() => ({
|
||||
app: {
|
||||
fetch: vi.fn(async () => ({ crossws: {} })),
|
||||
},
|
||||
closeAllPeers: vi.fn(),
|
||||
dispose: disposeCall,
|
||||
}))
|
||||
|
||||
return {
|
||||
closeCall,
|
||||
disposeCall,
|
||||
rejectServe: (error: Error) => rejectServe?.(error),
|
||||
resolveServe: () => resolveServe?.(),
|
||||
serveCall,
|
||||
|
|
@ -77,6 +81,7 @@ describe('createServer', async () => {
|
|||
serveMocks.rejectServe(new Error('bind failed'))
|
||||
|
||||
await expect(firstStart).rejects.toThrow('bind failed')
|
||||
expect(serveMocks.disposeCall).toHaveBeenCalledTimes(1)
|
||||
|
||||
const retryStart = server.start()
|
||||
expect(serveMocks.serveCall).toHaveBeenCalledTimes(2)
|
||||
|
|
@ -84,4 +89,37 @@ describe('createServer', async () => {
|
|||
serveMocks.resolveServe()
|
||||
await retryStart
|
||||
})
|
||||
|
||||
it('merges nested config updates instead of replacing sibling logger settings', async () => {
|
||||
const server = createServer({
|
||||
hostname: '127.0.0.1',
|
||||
port: 6121,
|
||||
logger: {
|
||||
app: { level: LogLevelString.Log },
|
||||
websocket: { format: Format.Pretty },
|
||||
},
|
||||
})
|
||||
|
||||
server.updateConfig({
|
||||
logger: {
|
||||
app: { format: Format.Pretty },
|
||||
},
|
||||
})
|
||||
|
||||
const startTask = server.start()
|
||||
serveMocks.resolveServe()
|
||||
await startTask
|
||||
|
||||
expect(serveMocks.setupAppCall).toHaveBeenCalledWith(expect.objectContaining({
|
||||
logger: {
|
||||
app: {
|
||||
level: LogLevelString.Log,
|
||||
format: Format.Pretty,
|
||||
},
|
||||
websocket: {
|
||||
format: Format.Pretty,
|
||||
},
|
||||
},
|
||||
}))
|
||||
})
|
||||
})
|
||||
|
|
|
|||
|
|
@ -32,9 +32,22 @@ export interface Server {
|
|||
updateConfig: (newOptions: ServerOptions) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects local IP addresses that can be used to reach the server from the LAN.
|
||||
*
|
||||
* Use when:
|
||||
* - Building connection hints for `0.0.0.0` listeners
|
||||
* - Showing reachable addresses in logs or UI
|
||||
*
|
||||
* Expects:
|
||||
* - Virtual interfaces should be ignored to reduce noisy or misleading addresses
|
||||
*
|
||||
* Returns:
|
||||
* - A de-duplicated list of valid IP addresses discovered from the host network interfaces
|
||||
*/
|
||||
export function getLocalIPs(): string[] {
|
||||
const interfaces = networkInterfaces()
|
||||
const addresses: string[] = []
|
||||
const addresses = new Set<string>()
|
||||
|
||||
const VIRTUAL_INTERFACE_PREFIXES = [
|
||||
'vboxnet',
|
||||
|
|
@ -63,13 +76,26 @@ export function getLocalIPs(): string[] {
|
|||
|
||||
const address = rawAddress.includes('%') ? rawAddress.split('%')[0] : rawAddress
|
||||
if (isIP(address))
|
||||
addresses.push(address)
|
||||
addresses.add(address)
|
||||
}
|
||||
}
|
||||
|
||||
return addresses
|
||||
return [...addresses]
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the websocket server controller for the AIRI runtime.
|
||||
*
|
||||
* Use when:
|
||||
* - Starting, stopping, or restarting the standalone runtime server
|
||||
* - Updating bind options between restarts
|
||||
*
|
||||
* Expects:
|
||||
* - The returned controller to manage a single active server instance at a time
|
||||
*
|
||||
* Returns:
|
||||
* - Lifecycle helpers for starting, stopping, restarting, and updating server options
|
||||
*/
|
||||
export function createServer(opts?: ServerOptions): Server {
|
||||
let options = merge<ServerOptions>({ port: 6121, hostname: '127.0.0.1' }, opts)
|
||||
|
||||
|
|
@ -140,8 +166,7 @@ export function createServer(opts?: ServerOptions): Server {
|
|||
try {
|
||||
serverInstance = {
|
||||
close: async (closeActiveConnections = false) => {
|
||||
log.log('closing all peers')
|
||||
h3App.closeAllPeers()
|
||||
h3App.dispose()
|
||||
log.log('closing server instance')
|
||||
await instance.close(closeActiveConnections)
|
||||
log.log('server instance closed')
|
||||
|
|
@ -162,7 +187,7 @@ export function createServer(opts?: ServerOptions): Server {
|
|||
}
|
||||
catch (error) {
|
||||
serverInstance = null
|
||||
h3App.closeAllPeers()
|
||||
h3App.dispose()
|
||||
await instance.close(true).catch(() => {})
|
||||
log.withError(error).error('failed to start WebSocket server')
|
||||
throw error
|
||||
|
|
@ -183,12 +208,16 @@ export function createServer(opts?: ServerOptions): Server {
|
|||
await start()
|
||||
}
|
||||
|
||||
async function updateConfig(newOptions: ServerOptions) {
|
||||
options = { ...options, ...newOptions }
|
||||
function updateConfig(newOptions: ServerOptions) {
|
||||
options = merge<ServerOptions>(options, newOptions)
|
||||
}
|
||||
|
||||
return {
|
||||
getConnectionHost: () => {
|
||||
if (options.hostname && options.hostname !== '0.0.0.0' && options.hostname !== '::') {
|
||||
return [options.hostname]
|
||||
}
|
||||
|
||||
return getLocalIPs()
|
||||
},
|
||||
start,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<script setup lang="ts">
|
||||
import type { SpeechProvider } from '@xsai-ext/providers/utils'
|
||||
|
||||
import { getCachedWebGPUCapabilities } from '@proj-airi/stage-shared/webgpu'
|
||||
import { getCachedWebGPUCapabilities, hasNavigatorWebGPU } from '@proj-airi/stage-shared/webgpu'
|
||||
import {
|
||||
SpeechPlayground,
|
||||
SpeechProviderSettings,
|
||||
|
|
@ -104,7 +104,7 @@ onMounted(async () => {
|
|||
// NOTICE: Uses synchronous check for initial render. The cached result from
|
||||
// detectWebGPU() is populated by the providers store during initialization.
|
||||
const capabilities = getCachedWebGPUCapabilities()
|
||||
hasWebGPU.value = capabilities?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
|
||||
hasWebGPU.value = capabilities?.supported ?? hasNavigatorWebGPU()
|
||||
fp16Supported.value = capabilities?.fp16Supported ?? false
|
||||
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@
|
|||
|
||||
import { check as gpuuCheck, isWebGPUSupported as gpuuIsSupported } from 'gpuu/webgpu'
|
||||
|
||||
interface NavigatorWebGPU {
|
||||
requestAdapter: () => Promise<unknown>
|
||||
}
|
||||
|
||||
interface NavigatorWithOptionalWebGPU extends Navigator {
|
||||
gpu?: NavigatorWebGPU
|
||||
}
|
||||
|
||||
export interface WebGPUCapabilities {
|
||||
/** Whether WebGPU is available in this environment */
|
||||
supported: boolean
|
||||
|
|
@ -21,6 +29,50 @@ export interface WebGPUCapabilities {
|
|||
let cachedResult: WebGPUCapabilities | null = null
|
||||
let pendingDetection: Promise<WebGPUCapabilities> | null = null
|
||||
|
||||
/**
|
||||
* Returns the WebGPU navigator entry point when the current runtime exposes it.
|
||||
*
|
||||
* Use when:
|
||||
* - browser or worker code needs guarded access to WebGPU
|
||||
* - a caller only needs the `navigator.gpu` entry point, not a full capability probe
|
||||
*
|
||||
* Expects:
|
||||
* - browser-like runtimes where `navigator` may be unavailable during SSR or tests
|
||||
*
|
||||
* Returns:
|
||||
* - the WebGPU navigator entry point when available, otherwise `null`
|
||||
*/
|
||||
export function getNavigatorWebGPU(): NavigatorWebGPU | null {
|
||||
// NOTICE:
|
||||
// TypeScript's default DOM libs in this repo do not declare `Navigator.gpu`.
|
||||
// Direct `navigator.gpu` access fails in consumers like `@proj-airi/ui-server-auth`
|
||||
// and in `@proj-airi/stage-ui` worker compilation even though the runtime guard is valid.
|
||||
// We centralize the structural cast here until the repo opts into WebGPU ambient types.
|
||||
// Removal condition: remove this helper once the workspace TypeScript config includes
|
||||
// WebGPU navigator typings everywhere that imports these modules.
|
||||
if (typeof navigator === 'undefined' || !('gpu' in navigator))
|
||||
return null
|
||||
|
||||
return (navigator as NavigatorWithOptionalWebGPU).gpu ?? null
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the current runtime exposes a WebGPU entry point on `navigator`.
|
||||
*
|
||||
* Use when:
|
||||
* - synchronous code only needs a fast boolean feature probe
|
||||
* - cached capability detection has not completed yet
|
||||
*
|
||||
* Expects:
|
||||
* - browser-like runtimes where `navigator` may be unavailable
|
||||
*
|
||||
* Returns:
|
||||
* - `true` when `navigator.gpu` is available, otherwise `false`
|
||||
*/
|
||||
export function hasNavigatorWebGPU(): boolean {
|
||||
return getNavigatorWebGPU() != null
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect WebGPU capabilities. The result is cached as a singleton
|
||||
* after the first successful call -- safe to call repeatedly.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
export {
|
||||
detectWebGPU,
|
||||
getCachedWebGPUCapabilities,
|
||||
getNavigatorWebGPU,
|
||||
hasNavigatorWebGPU,
|
||||
isWebGPUSupported,
|
||||
resetWebGPUCache,
|
||||
} from './detect'
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import {
|
|||
TextStreamer,
|
||||
WhisperForConditionalGeneration,
|
||||
} from '@huggingface/transformers'
|
||||
import { getNavigatorWebGPU } from '@proj-airi/stage-shared/webgpu'
|
||||
|
||||
import { MODEL_IDS, MODEL_NAMES } from '../inference/constants'
|
||||
import { classifyError, isRecoverable } from '../inference/protocol'
|
||||
|
|
@ -75,9 +76,10 @@ const MODEL_ID = MODEL_IDS.WHISPER
|
|||
*/
|
||||
async function detectWebGPUInWorker(): Promise<boolean> {
|
||||
try {
|
||||
if (typeof navigator === 'undefined' || !navigator.gpu)
|
||||
const webgpu = getNavigatorWebGPU()
|
||||
if (!webgpu)
|
||||
return false
|
||||
const adapter = await navigator.gpu.requestAdapter()
|
||||
const adapter = await webgpu.requestAdapter()
|
||||
return adapter != null
|
||||
}
|
||||
catch {
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import type {
|
|||
import type { AliyunRealtimeSpeechExtraOptions } from './providers/aliyun/stream-transcription'
|
||||
|
||||
import { isStageTamagotchi, isUrl } from '@proj-airi/stage-shared'
|
||||
import { getCachedWebGPUCapabilities, isWebGPUSupported } from '@proj-airi/stage-shared/webgpu'
|
||||
import { getCachedWebGPUCapabilities, hasNavigatorWebGPU, isWebGPUSupported } from '@proj-airi/stage-shared/webgpu'
|
||||
import { computedAsync, useIntervalFn, useLocalStorage } from '@vueuse/core'
|
||||
import {
|
||||
createOpenAI,
|
||||
|
|
@ -1715,7 +1715,7 @@ export const useProvidersStore = defineStore('providers', () => {
|
|||
|
||||
defaultOptions: () => {
|
||||
const capabilities = getCachedWebGPUCapabilities()
|
||||
const hasWebGPU = capabilities?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
|
||||
const hasWebGPU = capabilities?.supported ?? hasNavigatorWebGPU()
|
||||
const fp16Supported = capabilities?.fp16Supported ?? false
|
||||
const model = getDefaultKokoroModel(hasWebGPU, fp16Supported)
|
||||
return {
|
||||
|
|
@ -1772,7 +1772,7 @@ export const useProvidersStore = defineStore('providers', () => {
|
|||
capabilities: {
|
||||
listModels: async (_config: Record<string, unknown>) => {
|
||||
const caps = getCachedWebGPUCapabilities()
|
||||
const hasWebGPU = caps?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
|
||||
const hasWebGPU = caps?.supported ?? hasNavigatorWebGPU()
|
||||
const fp16Supported = caps?.fp16Supported ?? false
|
||||
return kokoroModelsToModelInfo(hasWebGPU, t, fp16Supported)
|
||||
},
|
||||
|
|
@ -1791,7 +1791,7 @@ export const useProvidersStore = defineStore('providers', () => {
|
|||
|
||||
// Validate platform requirements
|
||||
if (modelDef.platform === 'webgpu') {
|
||||
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
|
||||
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? hasNavigatorWebGPU()
|
||||
if (!hasWebGPU) {
|
||||
throw new Error('WebGPU is required for this model but is not available in your browser')
|
||||
}
|
||||
|
|
@ -1835,7 +1835,7 @@ export const useProvidersStore = defineStore('providers', () => {
|
|||
const modelDef = KOKORO_MODELS.find(m => m.id === modelId)
|
||||
if (modelDef) {
|
||||
if (modelDef.platform === 'webgpu') {
|
||||
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? (typeof navigator !== 'undefined' && !!navigator.gpu)
|
||||
const hasWebGPU = getCachedWebGPUCapabilities()?.supported ?? hasNavigatorWebGPU()
|
||||
if (!hasWebGPU) {
|
||||
throw new Error('WebGPU is required for this model but is not available in your browser')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import type {
|
|||
} from '../../libs/inference/protocol'
|
||||
|
||||
import { AutoModel, AutoProcessor, env, RawImage } from '@huggingface/transformers'
|
||||
import { getNavigatorWebGPU } from '@proj-airi/stage-shared/webgpu'
|
||||
|
||||
import { MODEL_IDS, MODEL_NAMES } from '../../libs/inference/constants'
|
||||
import { classifyError, isRecoverable } from '../../libs/inference/protocol'
|
||||
|
|
@ -80,9 +81,10 @@ function sendError(requestId: string, error: unknown, phase?: 'load' | 'inferenc
|
|||
*/
|
||||
async function detectWebGPUInWorker(): Promise<boolean> {
|
||||
try {
|
||||
if (typeof navigator === 'undefined' || !navigator.gpu)
|
||||
const webgpu = getNavigatorWebGPU()
|
||||
if (!webgpu)
|
||||
return false
|
||||
const adapter = await navigator.gpu.requestAdapter()
|
||||
const adapter = await webgpu.requestAdapter()
|
||||
return adapter != null
|
||||
}
|
||||
catch {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue