Add code review fix plan covering 11 issues across modularity, encapsulation, performance, and dead code (#62)

This commit is contained in:
Ali Khokhar 2026-03-01 00:45:33 -08:00 committed by GitHub
parent c54c57a742
commit aee9f0ad93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 107 additions and 114 deletions

View file

@ -114,8 +114,8 @@ async def lifespan(app: FastAPI):
"trees": saved_trees, "trees": saved_trees,
"node_to_tree": session_store.get_node_mapping(), "node_to_tree": session_store.get_node_mapping(),
}, },
queue_update_callback=message_handler._update_queue_positions, queue_update_callback=message_handler.update_queue_positions,
node_started_callback=message_handler._mark_node_processing, node_started_callback=message_handler.mark_node_processing,
) )
# Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart # Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
if message_handler.tree_queue.cleanup_stale_nodes() > 0: if message_handler.tree_queue.cleanup_stale_nodes() > 0:

View file

@ -9,8 +9,7 @@ from providers.base import BaseProvider, ProviderConfig
from providers.exceptions import AuthenticationError from providers.exceptions import AuthenticationError
from providers.lmstudio import LMStudioProvider from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import OpenRouterProvider from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
from providers.open_router.client import OPENROUTER_BASE_URL
# Global provider instance (singleton) # Global provider instance (singleton)
_provider: BaseProvider | None = None _provider: BaseProvider | None = None

View file

@ -28,13 +28,13 @@ async def handle_stop_command(
if not node_id: if not node_id:
msg_id = await handler.platform.queue_send_message( msg_id = await handler.platform.queue_send_message(
incoming.chat_id, incoming.chat_id,
handler._format_status( handler.format_status(
"", "Stopped.", "Nothing to stop for that message." "", "Stopped.", "Nothing to stop for that message."
), ),
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
handler._record_outgoing_message( handler.record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command" incoming.platform, incoming.chat_id, msg_id, "command"
) )
return return
@ -43,11 +43,11 @@ async def handle_stop_command(
noun = "request" if count == 1 else "requests" noun = "request" if count == 1 else "requests"
msg_id = await handler.platform.queue_send_message( msg_id = await handler.platform.queue_send_message(
incoming.chat_id, incoming.chat_id,
handler._format_status("", "Stopped.", f"Cancelled {count} {noun}."), handler.format_status("", "Stopped.", f"Cancelled {count} {noun}."),
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
handler._record_outgoing_message( handler.record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command" incoming.platform, incoming.chat_id, msg_id, "command"
) )
return return
@ -56,13 +56,13 @@ async def handle_stop_command(
count = await handler.stop_all_tasks() count = await handler.stop_all_tasks()
msg_id = await handler.platform.queue_send_message( msg_id = await handler.platform.queue_send_message(
incoming.chat_id, incoming.chat_id,
handler._format_status( handler.format_status(
"", "Stopped.", f"Cancelled {count} pending or active requests." "", "Stopped.", f"Cancelled {count} pending or active requests."
), ),
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
handler._record_outgoing_message( handler.record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command" incoming.platform, incoming.chat_id, msg_id, "command"
) )
@ -73,7 +73,7 @@ async def handle_stats_command(
"""Handle /stats command.""" """Handle /stats command."""
stats = handler.cli_manager.get_stats() stats = handler.cli_manager.get_stats()
tree_count = handler.tree_queue.get_tree_count() tree_count = handler.tree_queue.get_tree_count()
ctx = handler._get_render_ctx() ctx = handler.get_render_ctx()
msg_id = await handler.platform.queue_send_message( msg_id = await handler.platform.queue_send_message(
incoming.chat_id, incoming.chat_id,
"📊 " "📊 "
@ -85,7 +85,7 @@ async def handle_stats_command(
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
handler._record_outgoing_message( handler.record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command" incoming.platform, incoming.chat_id, msg_id, "command"
) )
@ -149,7 +149,7 @@ async def _handle_clear_branch(
# 1) Cancel branch tasks (no stop_all) # 1) Cancel branch tasks (no stop_all)
cancelled = await handler.tree_queue.cancel_branch(branch_root_id) cancelled = await handler.tree_queue.cancel_branch(branch_root_id)
handler._update_cancelled_nodes_ui(cancelled) handler.update_cancelled_nodes_ui(cancelled)
# 2) Collect message IDs from branch nodes only # 2) Collect message IDs from branch nodes only
msg_ids: set[str] = set() msg_ids: set[str] = set()
@ -214,25 +214,23 @@ async def handle_clear_command(
await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del) await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del)
msg_id = await handler.platform.queue_send_message( msg_id = await handler.platform.queue_send_message(
incoming.chat_id, incoming.chat_id,
handler._format_status( handler.format_status("🗑", "Cleared.", "Voice note cancelled."),
"🗑", "Cleared.", "Voice note cancelled."
),
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
handler._record_outgoing_message( handler.record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command" incoming.platform, incoming.chat_id, msg_id, "command"
) )
return return
msg_id = await handler.platform.queue_send_message( msg_id = await handler.platform.queue_send_message(
incoming.chat_id, incoming.chat_id,
handler._format_status( handler.format_status(
"🗑", "Cleared.", "Nothing to clear for that message." "🗑", "Cleared.", "Nothing to clear for that message."
), ),
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
handler._record_outgoing_message( handler.record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command" incoming.platform, incoming.chat_id, msg_id, "command"
) )
return return
@ -278,6 +276,6 @@ async def handle_clear_command(
logger.warning(f"Failed to clear session store: {e}") logger.warning(f"Failed to clear session store: {e}")
handler.tree_queue = TreeQueueManager( handler.tree_queue = TreeQueueManager(
queue_update_callback=handler._update_queue_positions, queue_update_callback=handler.update_queue_positions,
node_started_callback=handler._mark_node_processing, node_started_callback=handler.mark_node_processing,
) )

View file

@ -119,8 +119,8 @@ class ClaudeMessageHandler:
self.cli_manager = cli_manager self.cli_manager = cli_manager
self.session_store = session_store self.session_store = session_store
self.tree_queue = TreeQueueManager( self.tree_queue = TreeQueueManager(
queue_update_callback=self._update_queue_positions, queue_update_callback=self.update_queue_positions,
node_started_callback=self._mark_node_processing, node_started_callback=self.mark_node_processing,
) )
is_discord = platform.name == "discord" is_discord = platform.name == "discord"
self._format_status_fn = ( self._format_status_fn = (
@ -138,13 +138,13 @@ class ClaudeMessageHandler:
) )
self._limit_chars = 1900 if is_discord else 3900 self._limit_chars = 1900 if is_discord else 3900
def _format_status(self, emoji: str, label: str, suffix: str | None = None) -> str: def format_status(self, emoji: str, label: str, suffix: str | None = None) -> str:
return self._format_status_fn(emoji, label, suffix) return self._format_status_fn(emoji, label, suffix)
def _parse_mode(self) -> str | None: def _parse_mode(self) -> str | None:
return self._parse_mode_val return self._parse_mode_val
def _get_render_ctx(self) -> RenderCtx: def get_render_ctx(self) -> RenderCtx:
return self._render_ctx_val return self._render_ctx_val
def _get_limit_chars(self) -> int: def _get_limit_chars(self) -> int:
@ -253,7 +253,7 @@ class ClaudeMessageHandler:
fire_and_forget=False, fire_and_forget=False,
message_thread_id=incoming.message_thread_id, message_thread_id=incoming.message_thread_id,
) )
self._record_outgoing_message( self.record_outgoing_message(
incoming.platform, incoming.chat_id, status_msg_id, "status" incoming.platform, incoming.chat_id, status_msg_id, "status"
) )
@ -298,13 +298,13 @@ class ClaudeMessageHandler:
await self.platform.queue_edit_message( await self.platform.queue_edit_message(
incoming.chat_id, incoming.chat_id,
status_msg_id, status_msg_id,
self._format_status( self.format_status(
"📋", "Queued", f"(position {queue_size}) - waiting..." "📋", "Queued", f"(position {queue_size}) - waiting..."
), ),
parse_mode=self._parse_mode(), parse_mode=self._parse_mode(),
) )
async def _update_queue_positions(self, tree: MessageTree) -> None: async def update_queue_positions(self, tree: MessageTree) -> None:
"""Refresh queued status messages after a dequeue.""" """Refresh queued status messages after a dequeue."""
try: try:
queued_ids = await tree.get_queue_snapshot() queued_ids = await tree.get_queue_snapshot()
@ -325,14 +325,14 @@ class ClaudeMessageHandler:
self.platform.queue_edit_message( self.platform.queue_edit_message(
node.incoming.chat_id, node.incoming.chat_id,
node.status_message_id, node.status_message_id,
self._format_status( self.format_status(
"📋", "Queued", f"(position {position}) - waiting..." "📋", "Queued", f"(position {position}) - waiting..."
), ),
parse_mode=self._parse_mode(), parse_mode=self._parse_mode(),
) )
) )
async def _mark_node_processing(self, tree: MessageTree, node_id: str) -> None: async def mark_node_processing(self, tree: MessageTree, node_id: str) -> None:
"""Update the dequeued node's status to processing immediately.""" """Update the dequeued node's status to processing immediately."""
node = tree.get_node(node_id) node = tree.get_node(node_id)
if not node or node.state == MessageState.ERROR: if not node or node.state == MessageState.ERROR:
@ -341,7 +341,7 @@ class ClaudeMessageHandler:
self.platform.queue_edit_message( self.platform.queue_edit_message(
node.incoming.chat_id, node.incoming.chat_id,
node.status_message_id, node.status_message_id,
self._format_status("🔄", "Processing..."), self.format_status("🔄", "Processing..."),
parse_mode=self._parse_mode(), parse_mode=self._parse_mode(),
) )
) )
@ -351,7 +351,7 @@ class ClaudeMessageHandler:
) -> tuple[TranscriptBuffer, RenderCtx]: ) -> tuple[TranscriptBuffer, RenderCtx]:
"""Create transcript buffer and render context for node processing.""" """Create transcript buffer and render context for node processing."""
transcript = TranscriptBuffer(show_tool_results=False) transcript = TranscriptBuffer(show_tool_results=False)
return transcript, self._get_render_ctx() return transcript, self.get_render_ctx()
async def _handle_session_info_event( async def _handle_session_info_event(
self, self,
@ -400,7 +400,7 @@ class ClaudeMessageHandler:
transcript.apply(parsed) transcript.apply(parsed)
had_transcript_events = True had_transcript_events = True
status = _get_status_for_event(ptype, parsed, self._format_status) status = _get_status_for_event(ptype, parsed, self.format_status)
if status is not None: if status is not None:
await update_ui(status) await update_ui(status)
last_status = status last_status = status
@ -410,7 +410,7 @@ class ClaudeMessageHandler:
if not had_transcript_events: if not had_transcript_events:
transcript.apply({"type": "text_chunk", "text": "Done."}) transcript.apply({"type": "text_chunk", "text": "Done."})
logger.info("HANDLER: Task complete, updating UI") logger.info("HANDLER: Task complete, updating UI")
await update_ui(self._format_status("", "Complete"), force=True) await update_ui(self.format_status("", "Complete"), force=True)
if tree and captured_session_id: if tree and captured_session_id:
await tree.update_state( await tree.update_state(
node_id, node_id,
@ -422,7 +422,7 @@ class ClaudeMessageHandler:
error_msg = parsed.get("message", "Unknown error") error_msg = parsed.get("message", "Unknown error")
logger.error(f"HANDLER: Error event received: {error_msg}") logger.error(f"HANDLER: Error event received: {error_msg}")
logger.info("HANDLER: Updating UI with error status") logger.info("HANDLER: Updating UI with error status")
await update_ui(self._format_status("", "Error"), force=True) await update_ui(self.format_status("", "Error"), force=True)
if tree: if tree:
await self._propagate_error_to_children( await self._propagate_error_to_children(
node_id, error_msg, "Parent task failed" node_id, error_msg, "Parent task failed"
@ -533,7 +533,7 @@ class ClaudeMessageHandler:
except RuntimeError as e: except RuntimeError as e:
transcript.apply({"type": "error", "message": str(e)}) transcript.apply({"type": "error", "message": str(e)})
await update_ui( await update_ui(
self._format_status("", "Session limit reached"), self.format_status("", "Session limit reached"),
force=True, force=True,
) )
if tree: if tree:
@ -592,10 +592,10 @@ class ClaudeMessageHandler:
cancel_reason = node.context.get("cancel_reason") cancel_reason = node.context.get("cancel_reason")
if cancel_reason == "stop": if cancel_reason == "stop":
await update_ui(self._format_status("", "Stopped."), force=True) await update_ui(self.format_status("", "Stopped."), force=True)
else: else:
transcript.apply({"type": "error", "message": "Task was cancelled"}) transcript.apply({"type": "error", "message": "Task was cancelled"})
await update_ui(self._format_status("", "Cancelled"), force=True) await update_ui(self.format_status("", "Cancelled"), force=True)
# Do not propagate cancellation to children; a reply-scoped "/stop" # Do not propagate cancellation to children; a reply-scoped "/stop"
# should only stop the targeted task. # should only stop the targeted task.
@ -609,7 +609,7 @@ class ClaudeMessageHandler:
) )
error_msg = str(e)[:200] error_msg = str(e)[:200]
transcript.apply({"type": "error", "message": error_msg}) transcript.apply({"type": "error", "message": error_msg})
await update_ui(self._format_status("💥", "Task Failed"), force=True) await update_ui(self.format_status("💥", "Task Failed"), force=True)
if tree: if tree:
await self._propagate_error_to_children( await self._propagate_error_to_children(
node_id, error_msg, "Parent task failed" node_id, error_msg, "Parent task failed"
@ -643,7 +643,7 @@ class ClaudeMessageHandler:
self.platform.queue_edit_message( self.platform.queue_edit_message(
child.incoming.chat_id, child.incoming.chat_id,
child.status_message_id, child.status_message_id,
self._format_status("", "Cancelled:", child_status_text), self.format_status("", "Cancelled:", child_status_text),
parse_mode=self._parse_mode(), parse_mode=self._parse_mode(),
) )
) )
@ -658,13 +658,13 @@ class ClaudeMessageHandler:
# Reply to existing tree # Reply to existing tree
if self.tree_queue.is_node_tree_busy(parent_node_id): if self.tree_queue.is_node_tree_busy(parent_node_id):
queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1 queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1
return self._format_status( return self.format_status(
"📋", "Queued", f"(position {queue_size}) - waiting..." "📋", "Queued", f"(position {queue_size}) - waiting..."
) )
return self._format_status("🔄", "Continuing conversation...") return self.format_status("🔄", "Continuing conversation...")
# New conversation # New conversation
return self._format_status("", "Launching new Claude CLI instance...") return self.format_status("", "Launching new Claude CLI instance...")
async def stop_all_tasks(self) -> int: async def stop_all_tasks(self) -> int:
""" """
@ -685,7 +685,7 @@ class ClaudeMessageHandler:
await self.cli_manager.stop_all() await self.cli_manager.stop_all()
# 3. Update UI and persist state for all cancelled nodes # 3. Update UI and persist state for all cancelled nodes
self._update_cancelled_nodes_ui(cancelled_nodes) self.update_cancelled_nodes_ui(cancelled_nodes)
return len(cancelled_nodes) return len(cancelled_nodes)
@ -703,10 +703,10 @@ class ClaudeMessageHandler:
node.set_context({"cancel_reason": "stop"}) node.set_context({"cancel_reason": "stop"})
cancelled_nodes = await self.tree_queue.cancel_node(node_id) cancelled_nodes = await self.tree_queue.cancel_node(node_id)
self._update_cancelled_nodes_ui(cancelled_nodes) self.update_cancelled_nodes_ui(cancelled_nodes)
return len(cancelled_nodes) return len(cancelled_nodes)
def _record_outgoing_message( def record_outgoing_message(
self, self,
platform: str, platform: str,
chat_id: str, chat_id: str,
@ -723,7 +723,7 @@ class ClaudeMessageHandler:
except Exception as e: except Exception as e:
logger.debug(f"Failed to record message_id: {e}") logger.debug(f"Failed to record message_id: {e}")
def _update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None: def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
"""Update status messages and persist tree state for cancelled nodes.""" """Update status messages and persist tree state for cancelled nodes."""
trees_to_save: dict[str, MessageTree] = {} trees_to_save: dict[str, MessageTree] = {}
for node in nodes: for node in nodes:
@ -731,7 +731,7 @@ class ClaudeMessageHandler:
self.platform.queue_edit_message( self.platform.queue_edit_message(
node.incoming.chat_id, node.incoming.chat_id,
node.status_message_id, node.status_message_id,
self._format_status("", "Stopped."), self.format_status("", "Stopped."),
parse_mode=self._parse_mode(), parse_mode=self._parse_mode(),
) )
) )

View file

@ -10,6 +10,7 @@ from __future__ import annotations
import json import json
import os import os
from abc import ABC, abstractmethod
from collections import deque from collections import deque
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -26,11 +27,11 @@ def _safe_json_dumps(obj: Any) -> str:
@dataclass @dataclass
class Segment: class Segment(ABC):
kind: str kind: str
def render(self, ctx: RenderCtx) -> str: @abstractmethod
raise NotImplementedError def render(self, ctx: RenderCtx) -> str: ...
@dataclass @dataclass
@ -88,23 +89,6 @@ class ToolCallSegment(Segment):
self.tool_use_id = str(tool_use_id or "") self.tool_use_id = str(tool_use_id or "")
self.name = str(name or "tool") self.name = str(name or "tool")
self.indent_level = max(0, int(indent_level)) self.indent_level = max(0, int(indent_level))
self._parts: list[str] = []
def set_initial_input(self, inp: Any) -> None:
if inp is None:
return
if isinstance(inp, str):
self._parts = [inp]
else:
self._parts = [_safe_json_dumps(inp)]
def append_input_delta(self, partial: str) -> None:
if partial:
self._parts.append(partial)
@property
def input_text(self) -> str:
return "".join(self._parts)
def render(self, ctx: RenderCtx) -> str: def render(self, ctx: RenderCtx) -> str:
name = ctx.code_inline(self.name) name = ctx.code_inline(self.name)
@ -444,17 +428,12 @@ class TranscriptBuffer:
seg = ToolCallSegment(tool_id, name) seg = ToolCallSegment(tool_id, name)
self._segments.append(seg) self._segments.append(seg)
seg.set_initial_input(ev.get("input"))
if idx >= 0: if idx >= 0:
self._open_tools_by_index[idx] = seg self._open_tools_by_index[idx] = seg
return return
if et == "tool_use_delta": if et == "tool_use_delta":
idx = int(ev.get("index", -1)) # Track open tool by index for tool_use_stop (closing state).
partial = str(ev.get("partial_json", "") or "")
seg = self._open_tools_by_index.get(idx)
if seg is not None:
seg.append_input_delta(partial)
return return
if et == "tool_use_stop": if et == "tool_use_stop":
@ -501,7 +480,6 @@ class TranscriptBuffer:
seg = ToolCallSegment(tool_id, name) seg = ToolCallSegment(tool_id, name)
self._segments.append(seg) self._segments.append(seg)
seg.set_initial_input(ev.get("input"))
seg.closed = True seg.closed = True
return return

View file

@ -17,21 +17,26 @@ from ..models import IncomingMessage
class _SnapshotQueue: class _SnapshotQueue:
"""Queue with snapshot/remove helpers, backed by a deque.""" """Queue with snapshot/remove helpers, backed by a deque and a set index."""
def __init__(self) -> None: def __init__(self) -> None:
self._deque: deque[str] = deque() self._deque: deque[str] = deque()
self._set: set[str] = set()
async def put(self, item: str) -> None: async def put(self, item: str) -> None:
self._deque.append(item) self._deque.append(item)
self._set.add(item)
def put_nowait(self, item: str) -> None: def put_nowait(self, item: str) -> None:
self._deque.append(item) self._deque.append(item)
self._set.add(item)
def get_nowait(self) -> str: def get_nowait(self) -> str:
if not self._deque: if not self._deque:
raise asyncio.QueueEmpty() raise asyncio.QueueEmpty()
return self._deque.popleft() item = self._deque.popleft()
self._set.discard(item)
return item
def qsize(self) -> int: def qsize(self) -> int:
return len(self._deque) return len(self._deque)
@ -41,12 +46,11 @@ class _SnapshotQueue:
return list(self._deque) return list(self._deque)
def remove_if_present(self, item: str) -> bool: def remove_if_present(self, item: str) -> bool:
"""Remove item from queue if present. Returns True if removed.""" """Remove item from queue if present (O(1) membership check). Returns True if removed."""
if item not in self._deque: if item not in self._set:
return False return False
items = [x for x in self._deque if x != item] self._set.discard(item)
self._deque.clear() self._deque = deque(x for x in self._deque if x != item)
self._deque.extend(items)
return True return True
@ -350,7 +354,7 @@ class MessageTree:
return True return True
return False return False
def _set_node_error_sync(self, node: MessageNode, error_message: str) -> None: def set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access.""" """Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
node.state = MessageState.ERROR node.state = MessageState.ERROR
node.error_message = error_message node.error_message = error_message
@ -371,7 +375,7 @@ class MessageTree:
break break
node = self._nodes.get(node_id) node = self._nodes.get(node_id)
if node: if node:
self._set_node_error_sync(node, error_message) self.set_node_error_sync(node, error_message)
nodes.append(node) nodes.append(node)
return nodes return nodes

View file

@ -67,7 +67,7 @@ class TreeQueueProcessor:
) -> None: ) -> None:
"""Process a single node and then check the queue.""" """Process a single node and then check the queue."""
# Skip if already in terminal state (e.g. from error propagation) # Skip if already in terminal state (e.g. from error propagation)
if node.state.value == MessageState.ERROR.value: if node.state == MessageState.ERROR:
logger.info( logger.info(
f"Skipping node {node.node_id} as it is already in state {node.state}" f"Skipping node {node.node_id} as it is already in state {node.state}"
) )

View file

@ -244,7 +244,7 @@ class TreeQueueManager:
MessageState.COMPLETED, MessageState.COMPLETED,
MessageState.ERROR, MessageState.ERROR,
): ):
tree._set_node_error_sync(node, "Cancelled by user") tree.set_node_error_sync(node, "Cancelled by user")
cancelled_nodes.append(node) cancelled_nodes.append(node)
# 2. Drain queue and mark nodes as cancelled # 2. Drain queue and mark nodes as cancelled
@ -259,7 +259,7 @@ class TreeQueueManager:
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS) node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node.node_id not in cancelled_ids and node.node_id not in cancelled_ids
): ):
tree._set_node_error_sync(node, "Stale task cleaned up") tree.set_node_error_sync(node, "Stale task cleaned up")
cleanup_count += 1 cleanup_count += 1
tree.reset_processing_state() tree.reset_processing_state()
@ -336,7 +336,7 @@ class TreeQueueManager:
for tree in self._repository.all_trees(): for tree in self._repository.all_trees():
for node in tree.all_nodes(): for node in tree.all_nodes():
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS): if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
tree._set_node_error_sync(node, "Lost during server restart") tree.set_node_error_sync(node, "Lost during server restart")
count += 1 count += 1
if count: if count:
logger.info(f"Cleaned up {count} stale nodes during startup") logger.info(f"Cleaned up {count} stale nodes during startup")

View file

@ -150,7 +150,11 @@ class TreeRepository:
return tree return tree
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]: def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
"""Get all message IDs (incoming + status) for a given platform/chat.""" """Get all message IDs (incoming + status) for a given platform/chat.
Note: O(total_nodes) scan. Acceptable because this is only called
from /clear (user-initiated, infrequent).
"""
msg_ids: set[str] = set() msg_ids: set[str] = set()
for tree in self._trees.values(): for tree in self._trees.values():
for node in tree.all_nodes(): for node in tree.all_nodes():

View file

@ -366,16 +366,19 @@ class SSEBuilder:
# Tool calls are harder to tokenize exactly without reconstruction, but we can approximate # Tool calls are harder to tokenize exactly without reconstruction, but we can approximate
# by tokenizing the json dumps of tool contents # by tokenizing the json dumps of tool contents
tool_tokens = 0 tool_tokens = 0
started_tool_count = 0
for state in self.blocks.tool_states.values(): for state in self.blocks.tool_states.values():
tool_tokens += len(ENCODER.encode(state.name)) tool_tokens += len(ENCODER.encode(state.name))
tool_tokens += len(ENCODER.encode("".join(state.contents))) tool_tokens += len(ENCODER.encode("".join(state.contents)))
tool_tokens += 15 # Control tokens overhead per tool tool_tokens += 15 # Control tokens overhead per tool
if state.started:
started_tool_count += 1
# Per-block overhead (~4 tokens per content block) # Per-block overhead (~4 tokens per content block)
block_count = ( block_count = (
(1 if accumulated_reasoning else 0) (1 if accumulated_reasoning else 0)
+ (1 if accumulated_text else 0) + (1 if accumulated_text else 0)
+ sum(1 for s in self.blocks.tool_states.values() if s.started) + started_tool_count
) )
block_overhead = block_count * 4 block_overhead = block_count * 4

View file

@ -16,27 +16,29 @@ from providers.common.text import extract_text_from_content
def generate_request_fingerprint(messages: list[Any]) -> str: def generate_request_fingerprint(messages: list[Any]) -> str:
"""Generate unique short hash for message content. """Generate unique short hash for message content.
Creates a SHA256 hash of all message content, returning an 8-char prefix Uses incremental SHA256 hashing to avoid building a large intermediate
that's sufficient for correlation without full content logging. string. Returns an 8-char hex prefix sufficient for correlation.
""" """
content_parts = [] h = hashlib.sha256()
sep = b"|"
for msg in messages: for msg in messages:
if hasattr(msg, "content"): if hasattr(msg, "content"):
content = msg.content content = msg.content
if isinstance(content, str): if isinstance(content, str):
content_parts.append(content) h.update(content.encode("utf-8"))
h.update(sep)
elif isinstance(content, list): elif isinstance(content, list):
for block in content: for block in content:
if hasattr(block, "text"): if hasattr(block, "text"):
content_parts.append(block.text) h.update(block.text.encode("utf-8"))
h.update(sep)
elif hasattr(block, "type"): elif hasattr(block, "type"):
content_parts.append(f"<{block.type}>") h.update(f"<{block.type}>".encode())
h.update(sep)
elif hasattr(msg, "role"): elif hasattr(msg, "role"):
content_parts.append(msg.role) h.update(msg.role.encode("utf-8"))
h.update(sep)
combined = "|".join(content_parts) return f"fp_{h.hexdigest()[:8]}"
hash_digest = hashlib.sha256(combined.encode("utf-8")).hexdigest()
return f"fp_{hash_digest[:8]}"
def get_last_user_message_preview(messages: list[Any], max_len: int = 100) -> str: def get_last_user_message_preview(messages: list[Any], max_len: int = 100) -> str:

View file

@ -1,5 +1,5 @@
"""OpenRouter provider - OpenAI-compatible API for hundreds of models.""" """OpenRouter provider - OpenAI-compatible API for hundreds of models."""
from .client import OpenRouterProvider from .client import OPENROUTER_BASE_URL, OpenRouterProvider
__all__ = ["OpenRouterProvider"] __all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"]

View file

@ -10,7 +10,6 @@ import httpx
from loguru import logger from loguru import logger
from openai import AsyncOpenAI from openai import AsyncOpenAI
from config.nim import NimSettings
from providers.base import BaseProvider, ProviderConfig from providers.base import BaseProvider, ProviderConfig
from providers.common import ( from providers.common import (
ContentType, ContentType,
@ -33,7 +32,7 @@ class OpenAICompatibleProvider(BaseProvider):
provider_name: str, provider_name: str,
base_url: str, base_url: str,
api_key: str, api_key: str,
nim_settings: NimSettings | None = None, nim_settings: Any = None,
): ):
super().__init__(config) super().__init__(config)
self._provider_name = provider_name self._provider_name = provider_name

View file

@ -30,13 +30,19 @@ class GlobalRateLimiter:
_instance: ClassVar[GlobalRateLimiter | None] = None _instance: ClassVar[GlobalRateLimiter | None] = None
def __new__(cls, *args: Any, **kwargs: Any) -> GlobalRateLimiter:
if cls._instance is not None:
return cls._instance
instance = super().__new__(cls)
return instance
def __init__( def __init__(
self, self,
rate_limit: int = 40, rate_limit: int = 40,
rate_window: float = 60.0, rate_window: float = 60.0,
max_concurrency: int = 5, max_concurrency: int = 5,
): ):
# Prevent double initialization in singleton # Prevent re-initialization on singleton reuse
if hasattr(self, "_initialized"): if hasattr(self, "_initialized"):
return return

View file

@ -247,7 +247,7 @@ async def test_update_queue_positions(handler, mock_platform):
await tree.enqueue("child_1") await tree.enqueue("child_1")
await tree.enqueue("child_2") await tree.enqueue("child_2")
await handler._update_queue_positions(tree) await handler.update_queue_positions(tree)
calls = mock_platform.queue_edit_message.call_args_list calls = mock_platform.queue_edit_message.call_args_list
assert len(calls) == 2 assert len(calls) == 2
@ -291,7 +291,7 @@ async def test_mark_node_processing(handler, mock_platform):
parent_id="root", parent_id="root",
) )
await handler._mark_node_processing(tree, "child") await handler.mark_node_processing(tree, "child")
mock_platform.queue_edit_message.assert_called_once() mock_platform.queue_edit_message.assert_called_once()
args, kwargs = mock_platform.queue_edit_message.call_args args, kwargs = mock_platform.queue_edit_message.call_args

View file

@ -115,7 +115,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi
# Snapshot error is swallowed. # Snapshot error is swallowed.
tree = MagicMock() tree = MagicMock()
tree.get_queue_snapshot = AsyncMock(side_effect=RuntimeError("boom")) tree.get_queue_snapshot = AsyncMock(side_effect=RuntimeError("boom"))
await handler._update_queue_positions(tree) await handler.update_queue_positions(tree)
platform.fire_and_forget.assert_not_called() platform.fire_and_forget.assert_not_called()
# Normal path: only PENDING nodes get an update. # Normal path: only PENDING nodes get an update.
@ -130,7 +130,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi
tree.get_queue_snapshot = AsyncMock(return_value=["n1", "n2"]) tree.get_queue_snapshot = AsyncMock(return_value=["n1", "n2"])
tree.get_node = MagicMock(side_effect=[node_pending, node_done]) tree.get_node = MagicMock(side_effect=[node_pending, node_done])
await handler._update_queue_positions(tree) await handler.update_queue_positions(tree)
assert platform.fire_and_forget.call_count == 1 assert platform.fire_and_forget.call_count == 1

View file

@ -37,8 +37,8 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent(
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2) handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
handler2.tree_queue = TreeQueueManager.from_dict( handler2.tree_queue = TreeQueueManager.from_dict(
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()}, {"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
queue_update_callback=handler2._update_queue_positions, queue_update_callback=handler2.update_queue_positions,
node_started_callback=handler2._mark_node_processing, node_started_callback=handler2.mark_node_processing,
) )
# Prevent background task scheduling; we only want to validate routing/tree mutation. # Prevent background task scheduling; we only want to validate routing/tree mutation.
@ -89,8 +89,8 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2) handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
handler2.tree_queue = TreeQueueManager.from_dict( handler2.tree_queue = TreeQueueManager.from_dict(
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()}, {"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
queue_update_callback=handler2._update_queue_positions, queue_update_callback=handler2.update_queue_positions,
node_started_callback=handler2._mark_node_processing, node_started_callback=handler2.mark_node_processing,
) )
mock_platform.queue_send_message = AsyncMock(return_value="status_reply") mock_platform.queue_send_message = AsyncMock(return_value="status_reply")