refactor: extract shared utilities, fix send_data signature & plugins.py bug

- Consolidate ConnectionIdentity, _ws_debug_enabled(), ws_debug() into ws.py as single-source exports, removing duplicate definitions in ws_manager.py and state_monitor.py
- Make send_data() optional args keyword-only to prevent positional argument confusion with the instance method signature
- Fix clear_plugin_cache in plugins.py: wrong parameter name (event_name → event_type) and stale namespace (/webui → /ws)
This commit is contained in:
keyboardstaff 2026-03-28 00:38:46 -07:00
parent 0749ddc932
commit 04d930ab02
4 changed files with 38 additions and 40 deletions

View file

@ -206,9 +206,9 @@ def clear_plugin_cache(plugin_names: list[str] | None = None):
DeferredTask().start_task(
send_data,
endpoint_name="/webui",
event_name="clear_cache",
data={"areas": areas},
"clear_cache",
{"areas": areas},
endpoint_name="/ws",
)

View file

@ -14,26 +14,12 @@ from helpers.state_snapshot import (
advance_state_request_after_snapshot,
build_snapshot_from_request,
)
from helpers.ws import ConnectionNotFoundError
from helpers.ws import ConnectionIdentity, ConnectionNotFoundError, _ws_debug_enabled, ws_debug
if TYPE_CHECKING: # pragma: no cover - hints only
from helpers.ws_manager import WsManager
ConnectionIdentity = tuple[str, str] # (namespace, sid)
def _ws_debug_enabled() -> bool:
value = os.getenv("A0_WS_DEBUG", "").strip().lower()
return value in {"1", "true", "yes", "on"}
def _debug_log(message: str) -> None:
if not _ws_debug_enabled():
return
PrintStyle.debug(message)
@dataclass
class ConnectionProjection:
namespace: str
@ -73,7 +59,7 @@ class StateMonitor:
# Use the manager's dispatcher loop for all scheduling so mark_dirty can be
# invoked safely from non-async contexts and other threads.
self._dispatcher_loop = getattr(manager, "_dispatcher_loop", None)
_debug_log(
ws_debug(
f"[StateMonitor] bind_manager handler_id={handler_id or self._emit_handler_id}"
)
@ -83,7 +69,7 @@ class StateMonitor:
self._projections.setdefault(
identity, ConnectionProjection(namespace=namespace, sid=sid)
)
_debug_log(f"[StateMonitor] register_sid namespace={namespace} sid={sid}")
ws_debug(f"[StateMonitor] register_sid namespace={namespace} sid={sid}")
def unregister_sid(self, namespace: str, sid: str) -> None:
identity: ConnectionIdentity = (namespace, sid)
@ -95,7 +81,7 @@ class StateMonitor:
if task is not None:
task.cancel()
self._projections.pop(identity, None)
_debug_log(f"[StateMonitor] unregister_sid namespace={namespace} sid={sid}")
ws_debug(f"[StateMonitor] unregister_sid namespace={namespace} sid={sid}")
def mark_dirty_all(self, *, reason: str | None = None) -> None:
wave_id = None
@ -142,7 +128,7 @@ class StateMonitor:
projection.request = request
projection.seq_base = seq_base
projection.seq = seq_base
_debug_log(
ws_debug(
f"[StateMonitor] update_projection namespace={namespace} sid={sid} context={request.context!r} "
f"log_from={request.log_from} notifications_from={request.notifications_from} "
f"timezone={request.timezone!r} seq_base={seq_base}"
@ -221,7 +207,7 @@ class StateMonitor:
self.debounce_seconds, self._on_debounce_fire, identity
)
self._debounce_handles[identity] = handle
_debug_log(
ws_debug(
f"[StateMonitor] schedule_push namespace={projection.namespace} sid={projection.sid} "
f"delay_s={self.debounce_seconds} "
f"dirty={projection.dirty_version} pushed={projection.pushed_version} "
@ -298,7 +284,7 @@ class StateMonitor:
if isinstance(snapshot.get("logs"), list)
else None
)
_debug_log(
ws_debug(
f"[StateMonitor] emit state_push namespace={namespace} sid={sid} seq={seq} "
f"context={request.context!r} logs_len={logs_len} "
f"reason={dirty_reason!r} wave={dirty_wave_id!r}"
@ -312,13 +298,13 @@ class StateMonitor:
)
except ConnectionNotFoundError:
# Sid was removed before the emit; treat as benign.
_debug_log(
ws_debug(
f"[StateMonitor] emit skipped: sid not found namespace={namespace} sid={sid}"
)
return
except RuntimeError:
# Dispatcher loop may be closing (e.g., during shutdown or test teardown).
_debug_log(
ws_debug(
f"[StateMonitor] emit skipped: dispatcher closing namespace={namespace} sid={sid}"
)
return
@ -341,7 +327,7 @@ class StateMonitor:
if not follow_up:
return
_debug_log(
ws_debug(
f"[StateMonitor] follow_up_push namespace={namespace} sid={sid} dirty={dirty_version} pushed={pushed_version}"
)
try:

View file

@ -1,3 +1,4 @@
import os
import threading
import uuid
from abc import abstractmethod
@ -17,10 +18,24 @@ if TYPE_CHECKING:
from helpers.ws_manager import WsManager
# Utilities
# Shared types and utilities
from helpers.network import is_loopback_address
ConnectionIdentity = tuple[str, str] # (namespace, sid)
def _ws_debug_enabled() -> bool:
"""Check A0_WS_DEBUG env var — lightweight, no heavy imports."""
value = os.getenv("A0_WS_DEBUG", "").strip().lower()
return value in {"1", "true", "yes", "on"}
def ws_debug(message: str) -> None:
"""Log *message* via :class:`PrintStyle` when ``A0_WS_DEBUG`` is active."""
if _ws_debug_enabled():
PrintStyle.debug(message)
class ConnectionNotFoundError(RuntimeError):
"""Raised when attempting to emit to a non-existent WebSocket connection."""

View file

@ -15,13 +15,7 @@ import uuid
from helpers.defer import DeferredTask
from helpers.print_style import PrintStyle
from helpers import runtime
from helpers.ws import ConnectionNotFoundError, WsHandler
def _ws_debug_enabled() -> bool:
"""Check A0_WS_DEBUG env var — no heavyweight imports needed."""
value = os.getenv("A0_WS_DEBUG", "").strip().lower()
return value in {"1", "true", "yes", "on"}
from helpers.ws import ConnectionIdentity, ConnectionNotFoundError, WsHandler, _ws_debug_enabled, ws_debug
# Event validation
@ -182,9 +176,16 @@ _shared_ws_manager: WsManager | None = None
async def send_data(
event_type: str,
data: dict[str, Any],
*,
endpoint_name: str = "/ws",
connection_id: str | None = None,
) -> None:
"""Convenience wrapper around :pymeth:`WsManager.send_data`.
All optional parameters are keyword-only to match the instance method's
``(endpoint_name, event_type, data, connection_id)`` order and avoid
positional confusion between the two signatures.
"""
manager = get_shared_ws_manager()
await manager.send_data(endpoint_name, event_type, data, connection_id)
@ -222,9 +223,6 @@ class ConnectionInfo:
last_activity: datetime = field(default_factory=_utcnow)
ConnectionIdentity = tuple[str, str] # (namespace, sid)
@dataclass
class _HandlerExecution:
handler: WsHandler
@ -262,8 +260,7 @@ class WsManager:
# Internal: development-only debug logging to avoid noise in production
def _debug(self, message: str) -> None:
if _ws_debug_enabled():
PrintStyle.debug(message)
ws_debug(message)
def _ensure_dispatcher_loop(self) -> None:
if self._dispatcher_loop is None: