mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-04-28 03:30:23 +00:00
fix: resolve option whitelist, memory leak, task tracking, and dispatch unification
- Fix Memory Leaks: Resolved SID retention in _known_sids after disconnection and cleaned up unreferenced broadcast tasks in _schedule_lifecycle_broadcast. - Unify Dispatching Paths: Unified client and server event dispatching through the process_client_event() method to ensure diagnostic consistency. - Optimization & Cleanup: Expanded the _OPTION_KEYS whitelist, removed dead code (iter_event_types), and deleted unused websocket exports. - Robustness: Added handling for None responses in process_client_event to prevent cluttering responses with empty results. - Testing: Added test cases to verify SID TTL expiration and stale SID cleanup on disconnect.
This commit is contained in:
parent
1160195fb5
commit
b351de456e
4 changed files with 274 additions and 33 deletions
|
|
@ -300,37 +300,56 @@ class WsHandler:
|
|||
for sid, handlers in _active_handlers.items()
|
||||
}
|
||||
|
||||
mgr = self._manager
|
||||
aggregated: list[dict[str, Any]] = []
|
||||
for sid, handlers in snapshot.items():
|
||||
ctx = _ws_contexts.get(sid)
|
||||
sid_results: list[dict[str, Any]] = []
|
||||
security_errors: list[dict[str, Any]] = []
|
||||
passing: list[WsHandler] = []
|
||||
for _path, instance in handlers.items():
|
||||
if ctx is not None:
|
||||
error = _check_security(type(instance), ctx)
|
||||
if error is not None:
|
||||
sid_results.append({
|
||||
security_errors.append({
|
||||
"handlerId": instance.identifier,
|
||||
"ok": False,
|
||||
"correlationId": cid,
|
||||
"error": error,
|
||||
})
|
||||
continue
|
||||
try:
|
||||
result = await instance.process(event, dict(data, correlationId=cid), sid)
|
||||
if result is not None:
|
||||
passing.append(instance)
|
||||
|
||||
if mgr is not None and passing:
|
||||
result = await mgr.process_client_event(
|
||||
self._namespace, event,
|
||||
dict(data, correlationId=cid), sid,
|
||||
handlers=passing,
|
||||
)
|
||||
sid_results = security_errors + result.get("results", [])
|
||||
else:
|
||||
# Fallback: inline processing
|
||||
sid_results = list(security_errors)
|
||||
for _path, instance in handlers.items():
|
||||
if instance not in passing:
|
||||
continue
|
||||
try:
|
||||
result = await instance.process(
|
||||
event, dict(data, correlationId=cid), sid,
|
||||
)
|
||||
if result is not None:
|
||||
sid_results.append({
|
||||
"handlerId": instance.identifier,
|
||||
"ok": True,
|
||||
"correlationId": cid,
|
||||
"data": result,
|
||||
})
|
||||
except Exception as e:
|
||||
sid_results.append({
|
||||
"handlerId": instance.identifier,
|
||||
"ok": True,
|
||||
"ok": False,
|
||||
"correlationId": cid,
|
||||
"data": result,
|
||||
"error": {"code": "HANDLER_ERROR", "error": str(e)},
|
||||
})
|
||||
except Exception as e:
|
||||
sid_results.append({
|
||||
"handlerId": instance.identifier,
|
||||
"ok": False,
|
||||
"correlationId": cid,
|
||||
"error": {"code": "HANDLER_ERROR", "error": str(e)},
|
||||
})
|
||||
aggregated.append({
|
||||
"sid": sid,
|
||||
"correlationId": cid,
|
||||
|
|
@ -530,24 +549,47 @@ def register_ws_namespace(
|
|||
return _error_response("NO_HANDLERS",
|
||||
"No handlers activated", correlation_id)
|
||||
|
||||
# Unwrap nested payload (mirrors WsManager.route_event):
|
||||
# frontend sends {ts, data: {actual fields...}, correlationId}
|
||||
# Pre-filter handlers through security checks
|
||||
passing_handlers: list[WsHandler] = []
|
||||
security_errors: list[dict[str, Any]] = []
|
||||
for path, instance in activated.items():
|
||||
error = _check_security(type(instance), ctx)
|
||||
if error is not None:
|
||||
security_errors.append({
|
||||
"handlerId": instance.identifier,
|
||||
"ok": False,
|
||||
"correlationId": correlation_id,
|
||||
"error": error,
|
||||
})
|
||||
else:
|
||||
passing_handlers.append(instance)
|
||||
|
||||
# Delegate to WsManager for unified processing pipeline
|
||||
# (worker thread isolation, diagnostic events, WsResult support)
|
||||
if manager is not None and passing_handlers:
|
||||
result = await manager.process_client_event(
|
||||
NAMESPACE, event, incoming, sid,
|
||||
handlers=passing_handlers,
|
||||
)
|
||||
if security_errors:
|
||||
result["results"] = security_errors + result.get("results", [])
|
||||
return result
|
||||
|
||||
# All handlers failed security or no manager — return collected errors
|
||||
if not passing_handlers:
|
||||
return {"correlationId": correlation_id, "results": security_errors}
|
||||
|
||||
# Fallback: inline processing (no manager — should not happen in practice)
|
||||
handler_payload: dict[str, Any]
|
||||
if "data" in incoming and isinstance(incoming.get("data"), dict):
|
||||
handler_payload = dict(incoming["data"])
|
||||
else:
|
||||
handler_payload = dict(incoming)
|
||||
handler_payload["correlationId"] = correlation_id
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
results: list[dict[str, Any]] = list(security_errors)
|
||||
for path, instance in activated.items():
|
||||
error = _check_security(type(instance), ctx)
|
||||
if error is not None:
|
||||
results.append({
|
||||
"handlerId": instance.identifier,
|
||||
"ok": False,
|
||||
"correlationId": correlation_id,
|
||||
"error": error,
|
||||
})
|
||||
if instance not in passing_handlers:
|
||||
continue
|
||||
try:
|
||||
result = await instance.process(event, handler_payload, sid)
|
||||
|
|
|
|||
|
|
@ -247,6 +247,7 @@ class WsManager:
|
|||
defaultdict(deque)
|
||||
)
|
||||
self._known_sids: Set[ConnectionIdentity] = set()
|
||||
self._disconnect_times: Dict[ConnectionIdentity, datetime] = {}
|
||||
self._identifier: str = f"{self.__class__.__module__}.{self.__class__.__name__}"
|
||||
# Session tracking (single-user default)
|
||||
self.user_to_sids: defaultdict[str, Set[ConnectionIdentity]] = defaultdict(set)
|
||||
|
|
@ -257,6 +258,7 @@ class WsManager:
|
|||
self._diagnostics_enabled: bool = runtime.is_development()
|
||||
self._dispatcher_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._handler_worker: DeferredTask | None = None
|
||||
self._lifecycle_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
# Internal: development-only debug logging to avoid noise in production
|
||||
def _debug(self, message: str) -> None:
|
||||
|
|
@ -414,7 +416,23 @@ class WsManager:
|
|||
except Exception as exc: # pragma: no cover - diagnostic
|
||||
self._debug(f"Failed to broadcast lifecycle event {event_type}: {exc}")
|
||||
|
||||
asyncio.create_task(_broadcast())
|
||||
task = asyncio.create_task(_broadcast())
|
||||
self._lifecycle_tasks.add(task)
|
||||
task.add_done_callback(self._lifecycle_tasks.discard)
|
||||
|
||||
def _sweep_stale_sids(self) -> None:
|
||||
"""Remove _known_sids entries whose disconnect exceeds BUFFER_TTL."""
|
||||
now = _utcnow()
|
||||
with self.lock:
|
||||
stale = [
|
||||
identity
|
||||
for identity, dt in self._disconnect_times.items()
|
||||
if identity not in self.connections and (now - dt) > BUFFER_TTL
|
||||
]
|
||||
for identity in stale:
|
||||
self._known_sids.discard(identity)
|
||||
self._disconnect_times.pop(identity, None)
|
||||
self.buffers.pop(identity, None)
|
||||
|
||||
def _normalize_handler_filter(self, value: Any, field_name: str) -> Set[str] | None:
|
||||
if value is None:
|
||||
|
|
@ -513,12 +531,133 @@ class WsManager:
|
|||
f"Registered handler {handler.identifier} namespace={namespace}"
|
||||
)
|
||||
|
||||
def iter_event_types(self, namespace: str) -> Iterable[str]:
|
||||
return []
|
||||
|
||||
def iter_namespaces(self) -> list[str]:
|
||||
return list(self.handlers.keys())
|
||||
|
||||
async def process_client_event(
|
||||
self,
|
||||
namespace: str,
|
||||
event_type: str,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
*,
|
||||
handlers: list[WsHandler],
|
||||
) -> dict[str, Any]:
|
||||
"""Process a client-originated event through provided handler instances.
|
||||
|
||||
Unlike ``route_event`` which selects from globally registered handlers,
|
||||
this accepts pre-selected instances (e.g. per-connection activated
|
||||
handlers that have already passed security checks).
|
||||
"""
|
||||
self._ensure_dispatcher_loop()
|
||||
incoming = dict(data or {})
|
||||
correlation_id = self._resolve_correlation_id(incoming)
|
||||
|
||||
if "data" in incoming and isinstance(incoming.get("data"), dict):
|
||||
handler_payload = dict(incoming["data"])
|
||||
if "excludeSids" in incoming:
|
||||
handler_payload["excludeSids"] = incoming["excludeSids"]
|
||||
else:
|
||||
handler_payload = dict(incoming)
|
||||
handler_payload["correlationId"] = correlation_id
|
||||
|
||||
if not handlers:
|
||||
error = self._build_error_result(
|
||||
handler_id=self._identifier,
|
||||
code="NO_HANDLERS",
|
||||
message="No handlers available after security filtering",
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
return {"correlationId": correlation_id, "results": [error]}
|
||||
|
||||
with self.lock:
|
||||
info = self.connections.get((namespace, sid))
|
||||
if info:
|
||||
info.last_activity = _utcnow()
|
||||
|
||||
executions = await asyncio.gather(
|
||||
*[
|
||||
self._invoke_handler(handler, event_type, dict(handler_payload), sid)
|
||||
for handler in handlers
|
||||
]
|
||||
)
|
||||
|
||||
results: List[dict[str, Any]] = []
|
||||
for execution in executions:
|
||||
handler = execution.handler
|
||||
value = execution.value
|
||||
duration_ms = execution.duration_ms
|
||||
|
||||
if isinstance(value, Exception):
|
||||
PrintStyle.error(
|
||||
f"Error in handler {handler.identifier} for '{event_type}' "
|
||||
f"(correlation {correlation_id}): {value}"
|
||||
)
|
||||
results.append(
|
||||
self._build_error_result(
|
||||
handler_id=handler.identifier,
|
||||
code="HANDLER_ERROR",
|
||||
message="Internal server error",
|
||||
details=str(value),
|
||||
correlation_id=correlation_id,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if isinstance(value, WsResult):
|
||||
results.append(
|
||||
value.as_result(
|
||||
handler_id=handler.identifier,
|
||||
fallback_correlation_id=correlation_id,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Skip handlers that return None — they opted out of contributing
|
||||
# a result (fire-and-forget semantics, matching legacy _dispatch).
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
helper_result = WsResult(ok=True, data=value)
|
||||
else:
|
||||
helper_result = WsResult(ok=True, data={"result": value})
|
||||
|
||||
results.append(
|
||||
helper_result.as_result(
|
||||
handler_id=handler.identifier,
|
||||
fallback_correlation_id=correlation_id,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
)
|
||||
|
||||
await self._publish_diagnostic_event(
|
||||
lambda: {
|
||||
"kind": "inbound",
|
||||
"sourceNamespace": namespace,
|
||||
"namespace": namespace,
|
||||
"eventType": event_type,
|
||||
"sid": sid,
|
||||
"correlationId": correlation_id,
|
||||
"timestamp": self._timestamp(),
|
||||
"handlerCount": len(handlers),
|
||||
"durationMs": sum(
|
||||
(exec.duration_ms or 0.0) for exec in executions
|
||||
),
|
||||
"resultSummary": self._summarize_results(results),
|
||||
"payloadSummary": self._summarize_payload(handler_payload),
|
||||
}
|
||||
)
|
||||
|
||||
response = {"correlationId": correlation_id, "results": results}
|
||||
self._debug(
|
||||
f"Completed client event namespace={namespace} '{event_type}' "
|
||||
f"sid={sid} correlation={correlation_id}"
|
||||
)
|
||||
return response
|
||||
|
||||
async def _invoke_handler(
|
||||
self,
|
||||
handler: WsHandler,
|
||||
|
|
@ -551,6 +690,7 @@ class WsManager:
|
|||
with self.lock:
|
||||
self.connections[identity] = ConnectionInfo(namespace=namespace, sid=sid)
|
||||
self._known_sids.add(identity)
|
||||
self._disconnect_times.pop(identity, None)
|
||||
self.sid_to_user[identity] = user_bucket
|
||||
self.user_to_sids[self._ALL_USERS_BUCKET].add(identity)
|
||||
self.user_to_sids[user_bucket].add(identity)
|
||||
|
|
@ -600,7 +740,9 @@ class WsManager:
|
|||
identity: ConnectionIdentity = (namespace, sid)
|
||||
with self.lock:
|
||||
self.connections.pop(identity, None)
|
||||
# Keep identity in _known_sids so emit_to buffers instead of raising
|
||||
# Keep identity in _known_sids so emit_to buffers instead of raising;
|
||||
# record disconnect time for TTL-based cleanup
|
||||
self._disconnect_times[identity] = _utcnow()
|
||||
# session tracking cleanup
|
||||
user_bucket = self.sid_to_user.pop(identity, None)
|
||||
if self._ALL_USERS_BUCKET in self.user_to_sids:
|
||||
|
|
@ -633,6 +775,7 @@ class WsManager:
|
|||
self._schedule_lifecycle_broadcast(
|
||||
namespace, LIFECYCLE_DISCONNECT_EVENT, lifecycle_payload
|
||||
)
|
||||
self._sweep_stale_sids()
|
||||
|
||||
async def route_event(
|
||||
self,
|
||||
|
|
@ -1112,6 +1255,14 @@ class WsManager:
|
|||
with self.lock:
|
||||
connected = identity in self.connections
|
||||
known = identity in self._known_sids or identity in self.buffers
|
||||
# Evict if disconnect has exceeded BUFFER_TTL
|
||||
if not connected and known:
|
||||
dt = self._disconnect_times.get(identity)
|
||||
if dt is not None and (_utcnow() - dt) > BUFFER_TTL:
|
||||
self._known_sids.discard(identity)
|
||||
self._disconnect_times.pop(identity, None)
|
||||
self.buffers.pop(identity, None)
|
||||
known = False
|
||||
|
||||
if connected:
|
||||
self._debug(
|
||||
|
|
|
|||
|
|
@ -396,6 +396,56 @@ async def test_flush_buffer_delivers_and_logs(monkeypatch):
|
|||
assert (NAMESPACE, "sid-1") not in manager.buffers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_known_sid_expires_after_buffer_ttl(monkeypatch):
|
||||
"""After BUFFER_TTL, a disconnected sid is swept from _known_sids and emit_to raises."""
|
||||
socketio = FakeSocketIOServer()
|
||||
manager = WsManager(socketio, threading.RLock())
|
||||
|
||||
await manager.handle_connect(NAMESPACE, "sid-stale")
|
||||
await manager.handle_disconnect(NAMESPACE, "sid-stale")
|
||||
|
||||
# Immediately after disconnect, buffering still works
|
||||
await manager.emit_to(NAMESPACE, "sid-stale", "event", {"x": 1})
|
||||
assert (NAMESPACE, "sid-stale") in manager.buffers
|
||||
|
||||
from datetime import timedelta, timezone, datetime
|
||||
|
||||
future = datetime.now(timezone.utc) + BUFFER_TTL + timedelta(seconds=10)
|
||||
monkeypatch.setattr("helpers.ws_manager._utcnow", lambda: future)
|
||||
|
||||
# After TTL, emit_to should raise because the sid is no longer known
|
||||
with pytest.raises(ConnectionNotFoundError):
|
||||
await manager.emit_to(NAMESPACE, "sid-stale", "event", {"x": 2})
|
||||
|
||||
# _known_sids and buffers should be cleaned
|
||||
assert (NAMESPACE, "sid-stale") not in manager._known_sids
|
||||
assert (NAMESPACE, "sid-stale") not in manager.buffers
|
||||
assert (NAMESPACE, "sid-stale") not in manager._disconnect_times
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sweep_cleans_stale_sids_on_disconnect(monkeypatch):
|
||||
"""_sweep_stale_sids runs during handle_disconnect and cleans expired entries."""
|
||||
socketio = FakeSocketIOServer()
|
||||
manager = WsManager(socketio, threading.RLock())
|
||||
|
||||
await manager.handle_connect(NAMESPACE, "old-sid")
|
||||
await manager.handle_disconnect(NAMESPACE, "old-sid")
|
||||
|
||||
from datetime import timedelta, timezone, datetime
|
||||
|
||||
future = datetime.now(timezone.utc) + BUFFER_TTL + timedelta(seconds=10)
|
||||
monkeypatch.setattr("helpers.ws_manager._utcnow", lambda: future)
|
||||
|
||||
# A new connect/disconnect triggers sweep which cleans old-sid
|
||||
await manager.handle_connect(NAMESPACE, "new-sid")
|
||||
await manager.handle_disconnect(NAMESPACE, "new-sid")
|
||||
|
||||
assert (NAMESPACE, "old-sid") not in manager._known_sids
|
||||
assert (NAMESPACE, "old-sid") not in manager._disconnect_times
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_excludes_multiple_sids():
|
||||
socketio = FakeSocketIOServer()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ const MAX_PAYLOAD_BYTES = 50 * 1024 * 1024; // 50MB hard cap per contract
|
|||
const DEFAULT_TIMEOUT_MS = 0;
|
||||
|
||||
const _UUID_HEX = [..."0123456789abcdef"];
|
||||
const _OPTION_KEYS = new Set(["correlationId"]);
|
||||
const _OPTION_KEYS = new Set(["correlationId", "includeHandlers", "excludeHandlers", "excludeSids"]);
|
||||
|
||||
/**
|
||||
* @param {unknown} value
|
||||
|
|
@ -744,5 +744,3 @@ export function getNamespacedClient(namespace) {
|
|||
_namespacedClients.set(key, client);
|
||||
return client;
|
||||
}
|
||||
|
||||
export const websocket = getNamespacedClient("/");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue