diff --git a/api/app.py b/api/app.py index a915323..7aff6e4 100644 --- a/api/app.py +++ b/api/app.py @@ -114,8 +114,8 @@ async def lifespan(app: FastAPI): "trees": saved_trees, "node_to_tree": session_store.get_node_mapping(), }, - queue_update_callback=message_handler._update_queue_positions, - node_started_callback=message_handler._mark_node_processing, + queue_update_callback=message_handler.update_queue_positions, + node_started_callback=message_handler.mark_node_processing, ) # Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart if message_handler.tree_queue.cleanup_stale_nodes() > 0: diff --git a/api/dependencies.py b/api/dependencies.py index 084dd8b..a79c0c4 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -9,8 +9,7 @@ from providers.base import BaseProvider, ProviderConfig from providers.exceptions import AuthenticationError from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider -from providers.open_router import OpenRouterProvider -from providers.open_router.client import OPENROUTER_BASE_URL +from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider # Global provider instance (singleton) _provider: BaseProvider | None = None diff --git a/messaging/commands.py b/messaging/commands.py index 10c5932..a7eca4b 100644 --- a/messaging/commands.py +++ b/messaging/commands.py @@ -28,13 +28,13 @@ async def handle_stop_command( if not node_id: msg_id = await handler.platform.queue_send_message( incoming.chat_id, - handler._format_status( + handler.format_status( "⏹", "Stopped.", "Nothing to stop for that message." ), fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - handler._record_outgoing_message( + handler.record_outgoing_message( incoming.platform, incoming.chat_id, msg_id, "command" ) return @@ -43,11 +43,11 @@ async def handle_stop_command( noun = "request" if count == 1 else "requests" msg_id = await handler.platform.queue_send_message( incoming.chat_id, - handler._format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."), + handler.format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."), fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - handler._record_outgoing_message( + handler.record_outgoing_message( incoming.platform, incoming.chat_id, msg_id, "command" ) return @@ -56,13 +56,13 @@ async def handle_stop_command( count = await handler.stop_all_tasks() msg_id = await handler.platform.queue_send_message( incoming.chat_id, - handler._format_status( + handler.format_status( "⏹", "Stopped.", f"Cancelled {count} pending or active requests." ), fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - handler._record_outgoing_message( + handler.record_outgoing_message( incoming.platform, incoming.chat_id, msg_id, "command" ) @@ -73,7 +73,7 @@ async def handle_stats_command( """Handle /stats command.""" stats = handler.cli_manager.get_stats() tree_count = handler.tree_queue.get_tree_count() - ctx = handler._get_render_ctx() + ctx = handler.get_render_ctx() msg_id = await handler.platform.queue_send_message( incoming.chat_id, "📊 " @@ -85,7 +85,7 @@ async def handle_stats_command( fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - handler._record_outgoing_message( + handler.record_outgoing_message( incoming.platform, incoming.chat_id, msg_id, "command" ) @@ -149,7 +149,7 @@ async def _handle_clear_branch( # 1) Cancel branch tasks (no stop_all) cancelled = await handler.tree_queue.cancel_branch(branch_root_id) - handler._update_cancelled_nodes_ui(cancelled) + handler.update_cancelled_nodes_ui(cancelled) # 2) Collect message IDs from branch nodes only msg_ids: set[str] = set() @@ -214,25 +214,23 @@ async def handle_clear_command( await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del) msg_id = await handler.platform.queue_send_message( incoming.chat_id, - handler._format_status( - "🗑", "Cleared.", "Voice note cancelled." - ), + handler.format_status("🗑", "Cleared.", "Voice note cancelled."), fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - handler._record_outgoing_message( + handler.record_outgoing_message( incoming.platform, incoming.chat_id, msg_id, "command" ) return msg_id = await handler.platform.queue_send_message( incoming.chat_id, - handler._format_status( + handler.format_status( "🗑", "Cleared.", "Nothing to clear for that message." ), fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - handler._record_outgoing_message( + handler.record_outgoing_message( incoming.platform, incoming.chat_id, msg_id, "command" ) return @@ -278,6 +276,6 @@ async def handle_clear_command( logger.warning(f"Failed to clear session store: {e}") handler.tree_queue = TreeQueueManager( - queue_update_callback=handler._update_queue_positions, - node_started_callback=handler._mark_node_processing, + queue_update_callback=handler.update_queue_positions, + node_started_callback=handler.mark_node_processing, ) diff --git a/messaging/handler.py b/messaging/handler.py index 19a6e8d..91ce9c9 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -119,8 +119,8 @@ class ClaudeMessageHandler: self.cli_manager = cli_manager self.session_store = session_store self.tree_queue = TreeQueueManager( - queue_update_callback=self._update_queue_positions, - node_started_callback=self._mark_node_processing, + queue_update_callback=self.update_queue_positions, + node_started_callback=self.mark_node_processing, ) is_discord = platform.name == "discord" self._format_status_fn = ( @@ -138,13 +138,13 @@ class ClaudeMessageHandler: ) self._limit_chars = 1900 if is_discord else 3900 - def _format_status(self, emoji: str, label: str, suffix: str | None = None) -> str: + def format_status(self, emoji: str, label: str, suffix: str | None = None) -> str: return self._format_status_fn(emoji, label, suffix) def _parse_mode(self) -> str | None: return self._parse_mode_val - def _get_render_ctx(self) -> RenderCtx: + def get_render_ctx(self) -> RenderCtx: return self._render_ctx_val def _get_limit_chars(self) -> int: @@ -253,7 +253,7 @@ class ClaudeMessageHandler: fire_and_forget=False, message_thread_id=incoming.message_thread_id, ) - self._record_outgoing_message( + self.record_outgoing_message( incoming.platform, incoming.chat_id, status_msg_id, "status" ) @@ -298,13 +298,13 @@ class ClaudeMessageHandler: await self.platform.queue_edit_message( incoming.chat_id, status_msg_id, - self._format_status( + self.format_status( "📋", "Queued", f"(position {queue_size}) - waiting..." ), parse_mode=self._parse_mode(), ) - async def _update_queue_positions(self, tree: MessageTree) -> None: + async def update_queue_positions(self, tree: MessageTree) -> None: """Refresh queued status messages after a dequeue.""" try: queued_ids = await tree.get_queue_snapshot() @@ -325,14 +325,14 @@ class ClaudeMessageHandler: self.platform.queue_edit_message( node.incoming.chat_id, node.status_message_id, - self._format_status( + self.format_status( "📋", "Queued", f"(position {position}) - waiting..." ), parse_mode=self._parse_mode(), ) ) - async def _mark_node_processing(self, tree: MessageTree, node_id: str) -> None: + async def mark_node_processing(self, tree: MessageTree, node_id: str) -> None: """Update the dequeued node's status to processing immediately.""" node = tree.get_node(node_id) if not node or node.state == MessageState.ERROR: @@ -341,7 +341,7 @@ class ClaudeMessageHandler: self.platform.queue_edit_message( node.incoming.chat_id, node.status_message_id, - self._format_status("🔄", "Processing..."), + self.format_status("🔄", "Processing..."), parse_mode=self._parse_mode(), ) ) @@ -351,7 +351,7 @@ class ClaudeMessageHandler: ) -> tuple[TranscriptBuffer, RenderCtx]: """Create transcript buffer and render context for node processing.""" transcript = TranscriptBuffer(show_tool_results=False) - return transcript, self._get_render_ctx() + return transcript, self.get_render_ctx() async def _handle_session_info_event( self, @@ -400,7 +400,7 @@ class ClaudeMessageHandler: transcript.apply(parsed) had_transcript_events = True - status = _get_status_for_event(ptype, parsed, self._format_status) + status = _get_status_for_event(ptype, parsed, self.format_status) if status is not None: await update_ui(status) last_status = status @@ -410,7 +410,7 @@ class ClaudeMessageHandler: if not had_transcript_events: transcript.apply({"type": "text_chunk", "text": "Done."}) logger.info("HANDLER: Task complete, updating UI") - await update_ui(self._format_status("✅", "Complete"), force=True) + await update_ui(self.format_status("✅", "Complete"), force=True) if tree and captured_session_id: await tree.update_state( node_id, @@ -422,7 +422,7 @@ class ClaudeMessageHandler: error_msg = parsed.get("message", "Unknown error") logger.error(f"HANDLER: Error event received: {error_msg}") logger.info("HANDLER: Updating UI with error status") - await update_ui(self._format_status("❌", "Error"), force=True) + await update_ui(self.format_status("❌", "Error"), force=True) if tree: await self._propagate_error_to_children( node_id, error_msg, "Parent task failed" @@ -533,7 +533,7 @@ class ClaudeMessageHandler: except RuntimeError as e: transcript.apply({"type": "error", "message": str(e)}) await update_ui( - self._format_status("⏳", "Session limit reached"), + self.format_status("⏳", "Session limit reached"), force=True, ) if tree: @@ -592,10 +592,10 @@ class ClaudeMessageHandler: cancel_reason = node.context.get("cancel_reason") if cancel_reason == "stop": - await update_ui(self._format_status("⏹", "Stopped."), force=True) + await update_ui(self.format_status("⏹", "Stopped."), force=True) else: transcript.apply({"type": "error", "message": "Task was cancelled"}) - await update_ui(self._format_status("❌", "Cancelled"), force=True) + await update_ui(self.format_status("❌", "Cancelled"), force=True) # Do not propagate cancellation to children; a reply-scoped "/stop" # should only stop the targeted task. @@ -609,7 +609,7 @@ class ClaudeMessageHandler: ) error_msg = str(e)[:200] transcript.apply({"type": "error", "message": error_msg}) - await update_ui(self._format_status("💥", "Task Failed"), force=True) + await update_ui(self.format_status("💥", "Task Failed"), force=True) if tree: await self._propagate_error_to_children( node_id, error_msg, "Parent task failed" @@ -643,7 +643,7 @@ class ClaudeMessageHandler: self.platform.queue_edit_message( child.incoming.chat_id, child.status_message_id, - self._format_status("❌", "Cancelled:", child_status_text), + self.format_status("❌", "Cancelled:", child_status_text), parse_mode=self._parse_mode(), ) ) @@ -658,13 +658,13 @@ class ClaudeMessageHandler: # Reply to existing tree if self.tree_queue.is_node_tree_busy(parent_node_id): queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1 - return self._format_status( + return self.format_status( "📋", "Queued", f"(position {queue_size}) - waiting..." ) - return self._format_status("🔄", "Continuing conversation...") + return self.format_status("🔄", "Continuing conversation...") # New conversation - return self._format_status("⏳", "Launching new Claude CLI instance...") + return self.format_status("⏳", "Launching new Claude CLI instance...") async def stop_all_tasks(self) -> int: """ @@ -685,7 +685,7 @@ class ClaudeMessageHandler: await self.cli_manager.stop_all() # 3. Update UI and persist state for all cancelled nodes - self._update_cancelled_nodes_ui(cancelled_nodes) + self.update_cancelled_nodes_ui(cancelled_nodes) return len(cancelled_nodes) @@ -703,10 +703,10 @@ class ClaudeMessageHandler: node.set_context({"cancel_reason": "stop"}) cancelled_nodes = await self.tree_queue.cancel_node(node_id) - self._update_cancelled_nodes_ui(cancelled_nodes) + self.update_cancelled_nodes_ui(cancelled_nodes) return len(cancelled_nodes) - def _record_outgoing_message( + def record_outgoing_message( self, platform: str, chat_id: str, @@ -723,7 +723,7 @@ class ClaudeMessageHandler: except Exception as e: logger.debug(f"Failed to record message_id: {e}") - def _update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None: + def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None: """Update status messages and persist tree state for cancelled nodes.""" trees_to_save: dict[str, MessageTree] = {} for node in nodes: @@ -731,7 +731,7 @@ class ClaudeMessageHandler: self.platform.queue_edit_message( node.incoming.chat_id, node.status_message_id, - self._format_status("⏹", "Stopped."), + self.format_status("⏹", "Stopped."), parse_mode=self._parse_mode(), ) ) diff --git a/messaging/transcript.py b/messaging/transcript.py index 8274100..073f40b 100644 --- a/messaging/transcript.py +++ b/messaging/transcript.py @@ -10,6 +10,7 @@ from __future__ import annotations import json import os +from abc import ABC, abstractmethod from collections import deque from collections.abc import Callable, Iterable from dataclasses import dataclass, field @@ -26,11 +27,11 @@ def _safe_json_dumps(obj: Any) -> str: @dataclass -class Segment: +class Segment(ABC): kind: str - def render(self, ctx: RenderCtx) -> str: - raise NotImplementedError + @abstractmethod + def render(self, ctx: RenderCtx) -> str: ... @dataclass @@ -88,23 +89,6 @@ class ToolCallSegment(Segment): self.tool_use_id = str(tool_use_id or "") self.name = str(name or "tool") self.indent_level = max(0, int(indent_level)) - self._parts: list[str] = [] - - def set_initial_input(self, inp: Any) -> None: - if inp is None: - return - if isinstance(inp, str): - self._parts = [inp] - else: - self._parts = [_safe_json_dumps(inp)] - - def append_input_delta(self, partial: str) -> None: - if partial: - self._parts.append(partial) - - @property - def input_text(self) -> str: - return "".join(self._parts) def render(self, ctx: RenderCtx) -> str: name = ctx.code_inline(self.name) @@ -444,17 +428,12 @@ class TranscriptBuffer: seg = ToolCallSegment(tool_id, name) self._segments.append(seg) - seg.set_initial_input(ev.get("input")) if idx >= 0: self._open_tools_by_index[idx] = seg return if et == "tool_use_delta": - idx = int(ev.get("index", -1)) - partial = str(ev.get("partial_json", "") or "") - seg = self._open_tools_by_index.get(idx) - if seg is not None: - seg.append_input_delta(partial) + # Track open tool by index for tool_use_stop (closing state). return if et == "tool_use_stop": @@ -501,7 +480,6 @@ class TranscriptBuffer: seg = ToolCallSegment(tool_id, name) self._segments.append(seg) - seg.set_initial_input(ev.get("input")) seg.closed = True return diff --git a/messaging/trees/data.py b/messaging/trees/data.py index e2cd2bd..d5db01d 100644 --- a/messaging/trees/data.py +++ b/messaging/trees/data.py @@ -17,21 +17,26 @@ from ..models import IncomingMessage class _SnapshotQueue: - """Queue with snapshot/remove helpers, backed by a deque.""" + """Queue with snapshot/remove helpers, backed by a deque and a set index.""" def __init__(self) -> None: self._deque: deque[str] = deque() + self._set: set[str] = set() async def put(self, item: str) -> None: self._deque.append(item) + self._set.add(item) def put_nowait(self, item: str) -> None: self._deque.append(item) + self._set.add(item) def get_nowait(self) -> str: if not self._deque: raise asyncio.QueueEmpty() - return self._deque.popleft() + item = self._deque.popleft() + self._set.discard(item) + return item def qsize(self) -> int: return len(self._deque) @@ -41,12 +46,11 @@ class _SnapshotQueue: return list(self._deque) def remove_if_present(self, item: str) -> bool: - """Remove item from queue if present. Returns True if removed.""" - if item not in self._deque: + """Remove item from queue if present (O(1) membership check). Returns True if removed.""" + if item not in self._set: return False - items = [x for x in self._deque if x != item] - self._deque.clear() - self._deque.extend(items) + self._set.discard(item) + self._deque = deque(x for x in self._deque if x != item) return True @@ -350,7 +354,7 @@ class MessageTree: return True return False - def _set_node_error_sync(self, node: MessageNode, error_message: str) -> None: + def set_node_error_sync(self, node: MessageNode, error_message: str) -> None: """Synchronously mark a node as ERROR. Caller must ensure no concurrent access.""" node.state = MessageState.ERROR node.error_message = error_message @@ -371,7 +375,7 @@ class MessageTree: break node = self._nodes.get(node_id) if node: - self._set_node_error_sync(node, error_message) + self.set_node_error_sync(node, error_message) nodes.append(node) return nodes diff --git a/messaging/trees/processor.py b/messaging/trees/processor.py index 2e8cdf8..371915d 100644 --- a/messaging/trees/processor.py +++ b/messaging/trees/processor.py @@ -67,7 +67,7 @@ class TreeQueueProcessor: ) -> None: """Process a single node and then check the queue.""" # Skip if already in terminal state (e.g. from error propagation) - if node.state.value == MessageState.ERROR.value: + if node.state == MessageState.ERROR: logger.info( f"Skipping node {node.node_id} as it is already in state {node.state}" ) diff --git a/messaging/trees/queue_manager.py b/messaging/trees/queue_manager.py index 6226476..2358a18 100644 --- a/messaging/trees/queue_manager.py +++ b/messaging/trees/queue_manager.py @@ -244,7 +244,7 @@ class TreeQueueManager: MessageState.COMPLETED, MessageState.ERROR, ): - tree._set_node_error_sync(node, "Cancelled by user") + tree.set_node_error_sync(node, "Cancelled by user") cancelled_nodes.append(node) # 2. Drain queue and mark nodes as cancelled @@ -259,7 +259,7 @@ class TreeQueueManager: node.state in (MessageState.PENDING, MessageState.IN_PROGRESS) and node.node_id not in cancelled_ids ): - tree._set_node_error_sync(node, "Stale task cleaned up") + tree.set_node_error_sync(node, "Stale task cleaned up") cleanup_count += 1 tree.reset_processing_state() @@ -336,7 +336,7 @@ class TreeQueueManager: for tree in self._repository.all_trees(): for node in tree.all_nodes(): if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS): - tree._set_node_error_sync(node, "Lost during server restart") + tree.set_node_error_sync(node, "Lost during server restart") count += 1 if count: logger.info(f"Cleaned up {count} stale nodes during startup") diff --git a/messaging/trees/repository.py b/messaging/trees/repository.py index faa6da4..629fa90 100644 --- a/messaging/trees/repository.py +++ b/messaging/trees/repository.py @@ -150,7 +150,11 @@ class TreeRepository: return tree def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: - """Get all message IDs (incoming + status) for a given platform/chat.""" + """Get all message IDs (incoming + status) for a given platform/chat. + + Note: O(total_nodes) scan. Acceptable because this is only called + from /clear (user-initiated, infrequent). + """ msg_ids: set[str] = set() for tree in self._trees.values(): for node in tree.all_nodes(): diff --git a/providers/common/sse_builder.py b/providers/common/sse_builder.py index 201f224..ae80789 100644 --- a/providers/common/sse_builder.py +++ b/providers/common/sse_builder.py @@ -366,16 +366,19 @@ class SSEBuilder: # Tool calls are harder to tokenize exactly without reconstruction, but we can approximate # by tokenizing the json dumps of tool contents tool_tokens = 0 + started_tool_count = 0 for state in self.blocks.tool_states.values(): tool_tokens += len(ENCODER.encode(state.name)) tool_tokens += len(ENCODER.encode("".join(state.contents))) tool_tokens += 15 # Control tokens overhead per tool + if state.started: + started_tool_count += 1 # Per-block overhead (~4 tokens per content block) block_count = ( (1 if accumulated_reasoning else 0) + (1 if accumulated_text else 0) - + sum(1 for s in self.blocks.tool_states.values() if s.started) + + started_tool_count ) block_overhead = block_count * 4 diff --git a/providers/logging_utils.py b/providers/logging_utils.py index 5cc5dff..1a891b4 100644 --- a/providers/logging_utils.py +++ b/providers/logging_utils.py @@ -16,27 +16,29 @@ from providers.common.text import extract_text_from_content def generate_request_fingerprint(messages: list[Any]) -> str: """Generate unique short hash for message content. - Creates a SHA256 hash of all message content, returning an 8-char prefix - that's sufficient for correlation without full content logging. + Uses incremental SHA256 hashing to avoid building a large intermediate + string. Returns an 8-char hex prefix sufficient for correlation. """ - content_parts = [] + h = hashlib.sha256() + sep = b"|" for msg in messages: if hasattr(msg, "content"): content = msg.content if isinstance(content, str): - content_parts.append(content) + h.update(content.encode("utf-8")) + h.update(sep) elif isinstance(content, list): for block in content: if hasattr(block, "text"): - content_parts.append(block.text) + h.update(block.text.encode("utf-8")) + h.update(sep) elif hasattr(block, "type"): - content_parts.append(f"<{block.type}>") + h.update(f"<{block.type}>".encode()) + h.update(sep) elif hasattr(msg, "role"): - content_parts.append(msg.role) - - combined = "|".join(content_parts) - hash_digest = hashlib.sha256(combined.encode("utf-8")).hexdigest() - return f"fp_{hash_digest[:8]}" + h.update(msg.role.encode("utf-8")) + h.update(sep) + return f"fp_{h.hexdigest()[:8]}" def get_last_user_message_preview(messages: list[Any], max_len: int = 100) -> str: diff --git a/providers/open_router/__init__.py b/providers/open_router/__init__.py index 2253357..a72244a 100644 --- a/providers/open_router/__init__.py +++ b/providers/open_router/__init__.py @@ -1,5 +1,5 @@ """OpenRouter provider - OpenAI-compatible API for hundreds of models.""" -from .client import OpenRouterProvider +from .client import OPENROUTER_BASE_URL, OpenRouterProvider -__all__ = ["OpenRouterProvider"] +__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"] diff --git a/providers/openai_compat.py b/providers/openai_compat.py index 60e4d0c..ae11c62 100644 --- a/providers/openai_compat.py +++ b/providers/openai_compat.py @@ -10,7 +10,6 @@ import httpx from loguru import logger from openai import AsyncOpenAI -from config.nim import NimSettings from providers.base import BaseProvider, ProviderConfig from providers.common import ( ContentType, @@ -33,7 +32,7 @@ class OpenAICompatibleProvider(BaseProvider): provider_name: str, base_url: str, api_key: str, - nim_settings: NimSettings | None = None, + nim_settings: Any = None, ): super().__init__(config) self._provider_name = provider_name diff --git a/providers/rate_limit.py b/providers/rate_limit.py index ae19f36..8eae7bf 100644 --- a/providers/rate_limit.py +++ b/providers/rate_limit.py @@ -30,13 +30,19 @@ class GlobalRateLimiter: _instance: ClassVar[GlobalRateLimiter | None] = None + def __new__(cls, *args: Any, **kwargs: Any) -> GlobalRateLimiter: + if cls._instance is not None: + return cls._instance + instance = super().__new__(cls) + return instance + def __init__( self, rate_limit: int = 40, rate_window: float = 60.0, max_concurrency: int = 5, ): - # Prevent double initialization in singleton + # Prevent re-initialization on singleton reuse if hasattr(self, "_initialized"): return diff --git a/tests/messaging/test_handler.py b/tests/messaging/test_handler.py index 92fb1e2..b67bc08 100644 --- a/tests/messaging/test_handler.py +++ b/tests/messaging/test_handler.py @@ -247,7 +247,7 @@ async def test_update_queue_positions(handler, mock_platform): await tree.enqueue("child_1") await tree.enqueue("child_2") - await handler._update_queue_positions(tree) + await handler.update_queue_positions(tree) calls = mock_platform.queue_edit_message.call_args_list assert len(calls) == 2 @@ -291,7 +291,7 @@ async def test_mark_node_processing(handler, mock_platform): parent_id="root", ) - await handler._mark_node_processing(tree, "child") + await handler.mark_node_processing(tree, "child") mock_platform.queue_edit_message.assert_called_once() args, kwargs = mock_platform.queue_edit_message.call_args diff --git a/tests/messaging/test_handler_markdown_and_status_edges.py b/tests/messaging/test_handler_markdown_and_status_edges.py index fd58ec1..efd6eb0 100644 --- a/tests/messaging/test_handler_markdown_and_status_edges.py +++ b/tests/messaging/test_handler_markdown_and_status_edges.py @@ -115,7 +115,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi # Snapshot error is swallowed. tree = MagicMock() tree.get_queue_snapshot = AsyncMock(side_effect=RuntimeError("boom")) - await handler._update_queue_positions(tree) + await handler.update_queue_positions(tree) platform.fire_and_forget.assert_not_called() # Normal path: only PENDING nodes get an update. @@ -130,7 +130,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi tree.get_queue_snapshot = AsyncMock(return_value=["n1", "n2"]) tree.get_node = MagicMock(side_effect=[node_pending, node_done]) - await handler._update_queue_positions(tree) + await handler.update_queue_positions(tree) assert platform.fire_and_forget.call_count == 1 diff --git a/tests/messaging/test_restart_reply_restore.py b/tests/messaging/test_restart_reply_restore.py index f52183a..5e1bb11 100644 --- a/tests/messaging/test_restart_reply_restore.py +++ b/tests/messaging/test_restart_reply_restore.py @@ -37,8 +37,8 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent( handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2) handler2.tree_queue = TreeQueueManager.from_dict( {"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()}, - queue_update_callback=handler2._update_queue_positions, - node_started_callback=handler2._mark_node_processing, + queue_update_callback=handler2.update_queue_positions, + node_started_callback=handler2.mark_node_processing, ) # Prevent background task scheduling; we only want to validate routing/tree mutation. @@ -89,8 +89,8 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2) handler2.tree_queue = TreeQueueManager.from_dict( {"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()}, - queue_update_callback=handler2._update_queue_positions, - node_started_callback=handler2._mark_node_processing, + queue_update_callback=handler2.update_queue_positions, + node_started_callback=handler2.mark_node_processing, ) mock_platform.queue_send_message = AsyncMock(return_value="status_reply")