fix: resolve option whitelist, memory leak, task tracking, and dispatch unification

- Fix Memory Leaks: Resolved SID retention in _known_sids after disconnection and cleaned up unreferenced broadcast tasks in _schedule_lifecycle_broadcast.
- Unify Dispatching Paths: Unified client and server event dispatching through the process_client_event() method to ensure diagnostic consistency.
- Optimization & Cleanup: Expanded the _OPTION_KEYS whitelist, removed dead code (iter_event_types), and deleted unused websocket exports.
- Robustness: Added handling for None responses in process_client_event to prevent cluttering responses with empty results.
- Testing: Added test cases to verify SID TTL expiration and stale SID cleanup on disconnect.
This commit is contained in:
keyboardstaff 2026-03-27 01:21:45 -07:00
parent 1160195fb5
commit b351de456e
4 changed files with 274 additions and 33 deletions

View file

@ -300,37 +300,56 @@ class WsHandler:
for sid, handlers in _active_handlers.items()
}
mgr = self._manager
aggregated: list[dict[str, Any]] = []
for sid, handlers in snapshot.items():
ctx = _ws_contexts.get(sid)
sid_results: list[dict[str, Any]] = []
security_errors: list[dict[str, Any]] = []
passing: list[WsHandler] = []
for _path, instance in handlers.items():
if ctx is not None:
error = _check_security(type(instance), ctx)
if error is not None:
sid_results.append({
security_errors.append({
"handlerId": instance.identifier,
"ok": False,
"correlationId": cid,
"error": error,
})
continue
try:
result = await instance.process(event, dict(data, correlationId=cid), sid)
if result is not None:
passing.append(instance)
if mgr is not None and passing:
result = await mgr.process_client_event(
self._namespace, event,
dict(data, correlationId=cid), sid,
handlers=passing,
)
sid_results = security_errors + result.get("results", [])
else:
# Fallback: inline processing
sid_results = list(security_errors)
for _path, instance in handlers.items():
if instance not in passing:
continue
try:
result = await instance.process(
event, dict(data, correlationId=cid), sid,
)
if result is not None:
sid_results.append({
"handlerId": instance.identifier,
"ok": True,
"correlationId": cid,
"data": result,
})
except Exception as e:
sid_results.append({
"handlerId": instance.identifier,
"ok": True,
"ok": False,
"correlationId": cid,
"data": result,
"error": {"code": "HANDLER_ERROR", "error": str(e)},
})
except Exception as e:
sid_results.append({
"handlerId": instance.identifier,
"ok": False,
"correlationId": cid,
"error": {"code": "HANDLER_ERROR", "error": str(e)},
})
aggregated.append({
"sid": sid,
"correlationId": cid,
@ -530,24 +549,47 @@ def register_ws_namespace(
return _error_response("NO_HANDLERS",
"No handlers activated", correlation_id)
# Unwrap nested payload (mirrors WsManager.route_event):
# frontend sends {ts, data: {actual fields...}, correlationId}
# Pre-filter handlers through security checks
passing_handlers: list[WsHandler] = []
security_errors: list[dict[str, Any]] = []
for path, instance in activated.items():
error = _check_security(type(instance), ctx)
if error is not None:
security_errors.append({
"handlerId": instance.identifier,
"ok": False,
"correlationId": correlation_id,
"error": error,
})
else:
passing_handlers.append(instance)
# Delegate to WsManager for unified processing pipeline
# (worker thread isolation, diagnostic events, WsResult support)
if manager is not None and passing_handlers:
result = await manager.process_client_event(
NAMESPACE, event, incoming, sid,
handlers=passing_handlers,
)
if security_errors:
result["results"] = security_errors + result.get("results", [])
return result
# All handlers failed security or no manager — return collected errors
if not passing_handlers:
return {"correlationId": correlation_id, "results": security_errors}
# Fallback: inline processing (no manager — should not happen in practice)
handler_payload: dict[str, Any]
if "data" in incoming and isinstance(incoming.get("data"), dict):
handler_payload = dict(incoming["data"])
else:
handler_payload = dict(incoming)
handler_payload["correlationId"] = correlation_id
results: list[dict[str, Any]] = []
results: list[dict[str, Any]] = list(security_errors)
for path, instance in activated.items():
error = _check_security(type(instance), ctx)
if error is not None:
results.append({
"handlerId": instance.identifier,
"ok": False,
"correlationId": correlation_id,
"error": error,
})
if instance not in passing_handlers:
continue
try:
result = await instance.process(event, handler_payload, sid)

View file

@ -247,6 +247,7 @@ class WsManager:
defaultdict(deque)
)
self._known_sids: Set[ConnectionIdentity] = set()
self._disconnect_times: Dict[ConnectionIdentity, datetime] = {}
self._identifier: str = f"{self.__class__.__module__}.{self.__class__.__name__}"
# Session tracking (single-user default)
self.user_to_sids: defaultdict[str, Set[ConnectionIdentity]] = defaultdict(set)
@ -257,6 +258,7 @@ class WsManager:
self._diagnostics_enabled: bool = runtime.is_development()
self._dispatcher_loop: asyncio.AbstractEventLoop | None = None
self._handler_worker: DeferredTask | None = None
self._lifecycle_tasks: Set[asyncio.Task] = set()
# Internal: development-only debug logging to avoid noise in production
def _debug(self, message: str) -> None:
@ -414,7 +416,23 @@ class WsManager:
except Exception as exc: # pragma: no cover - diagnostic
self._debug(f"Failed to broadcast lifecycle event {event_type}: {exc}")
asyncio.create_task(_broadcast())
task = asyncio.create_task(_broadcast())
self._lifecycle_tasks.add(task)
task.add_done_callback(self._lifecycle_tasks.discard)
def _sweep_stale_sids(self) -> None:
"""Remove _known_sids entries whose disconnect exceeds BUFFER_TTL."""
now = _utcnow()
with self.lock:
stale = [
identity
for identity, dt in self._disconnect_times.items()
if identity not in self.connections and (now - dt) > BUFFER_TTL
]
for identity in stale:
self._known_sids.discard(identity)
self._disconnect_times.pop(identity, None)
self.buffers.pop(identity, None)
def _normalize_handler_filter(self, value: Any, field_name: str) -> Set[str] | None:
if value is None:
@ -513,12 +531,133 @@ class WsManager:
f"Registered handler {handler.identifier} namespace={namespace}"
)
def iter_event_types(self, namespace: str) -> Iterable[str]:
return []
def iter_namespaces(self) -> list[str]:
return list(self.handlers.keys())
async def process_client_event(
self,
namespace: str,
event_type: str,
data: dict[str, Any],
sid: str,
*,
handlers: list[WsHandler],
) -> dict[str, Any]:
"""Process a client-originated event through provided handler instances.
Unlike ``route_event`` which selects from globally registered handlers,
this accepts pre-selected instances (e.g. per-connection activated
handlers that have already passed security checks).
"""
self._ensure_dispatcher_loop()
incoming = dict(data or {})
correlation_id = self._resolve_correlation_id(incoming)
if "data" in incoming and isinstance(incoming.get("data"), dict):
handler_payload = dict(incoming["data"])
if "excludeSids" in incoming:
handler_payload["excludeSids"] = incoming["excludeSids"]
else:
handler_payload = dict(incoming)
handler_payload["correlationId"] = correlation_id
if not handlers:
error = self._build_error_result(
handler_id=self._identifier,
code="NO_HANDLERS",
message="No handlers available after security filtering",
correlation_id=correlation_id,
)
return {"correlationId": correlation_id, "results": [error]}
with self.lock:
info = self.connections.get((namespace, sid))
if info:
info.last_activity = _utcnow()
executions = await asyncio.gather(
*[
self._invoke_handler(handler, event_type, dict(handler_payload), sid)
for handler in handlers
]
)
results: List[dict[str, Any]] = []
for execution in executions:
handler = execution.handler
value = execution.value
duration_ms = execution.duration_ms
if isinstance(value, Exception):
PrintStyle.error(
f"Error in handler {handler.identifier} for '{event_type}' "
f"(correlation {correlation_id}): {value}"
)
results.append(
self._build_error_result(
handler_id=handler.identifier,
code="HANDLER_ERROR",
message="Internal server error",
details=str(value),
correlation_id=correlation_id,
duration_ms=duration_ms,
)
)
continue
if isinstance(value, WsResult):
results.append(
value.as_result(
handler_id=handler.identifier,
fallback_correlation_id=correlation_id,
duration_ms=duration_ms,
)
)
continue
# Skip handlers that return None — they opted out of contributing
# a result (fire-and-forget semantics, matching legacy _dispatch).
if value is None:
continue
if isinstance(value, dict):
helper_result = WsResult(ok=True, data=value)
else:
helper_result = WsResult(ok=True, data={"result": value})
results.append(
helper_result.as_result(
handler_id=handler.identifier,
fallback_correlation_id=correlation_id,
duration_ms=duration_ms,
)
)
await self._publish_diagnostic_event(
lambda: {
"kind": "inbound",
"sourceNamespace": namespace,
"namespace": namespace,
"eventType": event_type,
"sid": sid,
"correlationId": correlation_id,
"timestamp": self._timestamp(),
"handlerCount": len(handlers),
"durationMs": sum(
(exec.duration_ms or 0.0) for exec in executions
),
"resultSummary": self._summarize_results(results),
"payloadSummary": self._summarize_payload(handler_payload),
}
)
response = {"correlationId": correlation_id, "results": results}
self._debug(
f"Completed client event namespace={namespace} '{event_type}' "
f"sid={sid} correlation={correlation_id}"
)
return response
async def _invoke_handler(
self,
handler: WsHandler,
@ -551,6 +690,7 @@ class WsManager:
with self.lock:
self.connections[identity] = ConnectionInfo(namespace=namespace, sid=sid)
self._known_sids.add(identity)
self._disconnect_times.pop(identity, None)
self.sid_to_user[identity] = user_bucket
self.user_to_sids[self._ALL_USERS_BUCKET].add(identity)
self.user_to_sids[user_bucket].add(identity)
@ -600,7 +740,9 @@ class WsManager:
identity: ConnectionIdentity = (namespace, sid)
with self.lock:
self.connections.pop(identity, None)
# Keep identity in _known_sids so emit_to buffers instead of raising
# Keep identity in _known_sids so emit_to buffers instead of raising;
# record disconnect time for TTL-based cleanup
self._disconnect_times[identity] = _utcnow()
# session tracking cleanup
user_bucket = self.sid_to_user.pop(identity, None)
if self._ALL_USERS_BUCKET in self.user_to_sids:
@ -633,6 +775,7 @@ class WsManager:
self._schedule_lifecycle_broadcast(
namespace, LIFECYCLE_DISCONNECT_EVENT, lifecycle_payload
)
self._sweep_stale_sids()
async def route_event(
self,
@ -1112,6 +1255,14 @@ class WsManager:
with self.lock:
connected = identity in self.connections
known = identity in self._known_sids or identity in self.buffers
# Evict if disconnect has exceeded BUFFER_TTL
if not connected and known:
dt = self._disconnect_times.get(identity)
if dt is not None and (_utcnow() - dt) > BUFFER_TTL:
self._known_sids.discard(identity)
self._disconnect_times.pop(identity, None)
self.buffers.pop(identity, None)
known = False
if connected:
self._debug(

View file

@ -396,6 +396,56 @@ async def test_flush_buffer_delivers_and_logs(monkeypatch):
assert (NAMESPACE, "sid-1") not in manager.buffers
@pytest.mark.asyncio
async def test_known_sid_expires_after_buffer_ttl(monkeypatch):
"""After BUFFER_TTL, a disconnected sid is swept from _known_sids and emit_to raises."""
socketio = FakeSocketIOServer()
manager = WsManager(socketio, threading.RLock())
await manager.handle_connect(NAMESPACE, "sid-stale")
await manager.handle_disconnect(NAMESPACE, "sid-stale")
# Immediately after disconnect, buffering still works
await manager.emit_to(NAMESPACE, "sid-stale", "event", {"x": 1})
assert (NAMESPACE, "sid-stale") in manager.buffers
from datetime import timedelta, timezone, datetime
future = datetime.now(timezone.utc) + BUFFER_TTL + timedelta(seconds=10)
monkeypatch.setattr("helpers.ws_manager._utcnow", lambda: future)
# After TTL, emit_to should raise because the sid is no longer known
with pytest.raises(ConnectionNotFoundError):
await manager.emit_to(NAMESPACE, "sid-stale", "event", {"x": 2})
# _known_sids and buffers should be cleaned
assert (NAMESPACE, "sid-stale") not in manager._known_sids
assert (NAMESPACE, "sid-stale") not in manager.buffers
assert (NAMESPACE, "sid-stale") not in manager._disconnect_times
@pytest.mark.asyncio
async def test_sweep_cleans_stale_sids_on_disconnect(monkeypatch):
"""_sweep_stale_sids runs during handle_disconnect and cleans expired entries."""
socketio = FakeSocketIOServer()
manager = WsManager(socketio, threading.RLock())
await manager.handle_connect(NAMESPACE, "old-sid")
await manager.handle_disconnect(NAMESPACE, "old-sid")
from datetime import timedelta, timezone, datetime
future = datetime.now(timezone.utc) + BUFFER_TTL + timedelta(seconds=10)
monkeypatch.setattr("helpers.ws_manager._utcnow", lambda: future)
# A new connect/disconnect triggers sweep which cleans old-sid
await manager.handle_connect(NAMESPACE, "new-sid")
await manager.handle_disconnect(NAMESPACE, "new-sid")
assert (NAMESPACE, "old-sid") not in manager._known_sids
assert (NAMESPACE, "old-sid") not in manager._disconnect_times
@pytest.mark.asyncio
async def test_broadcast_excludes_multiple_sids():
socketio = FakeSocketIOServer()

View file

@ -5,7 +5,7 @@ const MAX_PAYLOAD_BYTES = 50 * 1024 * 1024; // 50MB hard cap per contract
const DEFAULT_TIMEOUT_MS = 0;
const _UUID_HEX = [..."0123456789abcdef"];
const _OPTION_KEYS = new Set(["correlationId"]);
const _OPTION_KEYS = new Set(["correlationId", "includeHandlers", "excludeHandlers", "excludeSids"]);
/**
* @param {unknown} value
@ -744,5 +744,3 @@ export function getNamespacedClient(namespace) {
_namespacedClients.set(key, client);
return client;
}
export const websocket = getNamespacedClient("/");