mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Add code review fix plan covering 11 issues across modularity, encapsulation, performance, and dead code (#62)
This commit is contained in:
parent
c54c57a742
commit
aee9f0ad93
17 changed files with 107 additions and 114 deletions
|
|
@ -114,8 +114,8 @@ async def lifespan(app: FastAPI):
|
|||
"trees": saved_trees,
|
||||
"node_to_tree": session_store.get_node_mapping(),
|
||||
},
|
||||
queue_update_callback=message_handler._update_queue_positions,
|
||||
node_started_callback=message_handler._mark_node_processing,
|
||||
queue_update_callback=message_handler.update_queue_positions,
|
||||
node_started_callback=message_handler.mark_node_processing,
|
||||
)
|
||||
# Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
|
||||
if message_handler.tree_queue.cleanup_stale_nodes() > 0:
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ from providers.base import BaseProvider, ProviderConfig
|
|||
from providers.exceptions import AuthenticationError
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
from providers.open_router import OpenRouterProvider
|
||||
from providers.open_router.client import OPENROUTER_BASE_URL
|
||||
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||
|
||||
# Global provider instance (singleton)
|
||||
_provider: BaseProvider | None = None
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ async def handle_stop_command(
|
|||
if not node_id:
|
||||
msg_id = await handler.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
handler._format_status(
|
||||
handler.format_status(
|
||||
"⏹", "Stopped.", "Nothing to stop for that message."
|
||||
),
|
||||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
handler._record_outgoing_message(
|
||||
handler.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
|
|
@ -43,11 +43,11 @@ async def handle_stop_command(
|
|||
noun = "request" if count == 1 else "requests"
|
||||
msg_id = await handler.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
handler._format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."),
|
||||
handler.format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."),
|
||||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
handler._record_outgoing_message(
|
||||
handler.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
|
|
@ -56,13 +56,13 @@ async def handle_stop_command(
|
|||
count = await handler.stop_all_tasks()
|
||||
msg_id = await handler.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
handler._format_status(
|
||||
handler.format_status(
|
||||
"⏹", "Stopped.", f"Cancelled {count} pending or active requests."
|
||||
),
|
||||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
handler._record_outgoing_message(
|
||||
handler.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ async def handle_stats_command(
|
|||
"""Handle /stats command."""
|
||||
stats = handler.cli_manager.get_stats()
|
||||
tree_count = handler.tree_queue.get_tree_count()
|
||||
ctx = handler._get_render_ctx()
|
||||
ctx = handler.get_render_ctx()
|
||||
msg_id = await handler.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
"📊 "
|
||||
|
|
@ -85,7 +85,7 @@ async def handle_stats_command(
|
|||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
handler._record_outgoing_message(
|
||||
handler.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
|
||||
|
|
@ -149,7 +149,7 @@ async def _handle_clear_branch(
|
|||
|
||||
# 1) Cancel branch tasks (no stop_all)
|
||||
cancelled = await handler.tree_queue.cancel_branch(branch_root_id)
|
||||
handler._update_cancelled_nodes_ui(cancelled)
|
||||
handler.update_cancelled_nodes_ui(cancelled)
|
||||
|
||||
# 2) Collect message IDs from branch nodes only
|
||||
msg_ids: set[str] = set()
|
||||
|
|
@ -214,25 +214,23 @@ async def handle_clear_command(
|
|||
await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del)
|
||||
msg_id = await handler.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
handler._format_status(
|
||||
"🗑", "Cleared.", "Voice note cancelled."
|
||||
),
|
||||
handler.format_status("🗑", "Cleared.", "Voice note cancelled."),
|
||||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
handler._record_outgoing_message(
|
||||
handler.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
msg_id = await handler.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
handler._format_status(
|
||||
handler.format_status(
|
||||
"🗑", "Cleared.", "Nothing to clear for that message."
|
||||
),
|
||||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
handler._record_outgoing_message(
|
||||
handler.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
|
|
@ -278,6 +276,6 @@ async def handle_clear_command(
|
|||
logger.warning(f"Failed to clear session store: {e}")
|
||||
|
||||
handler.tree_queue = TreeQueueManager(
|
||||
queue_update_callback=handler._update_queue_positions,
|
||||
node_started_callback=handler._mark_node_processing,
|
||||
queue_update_callback=handler.update_queue_positions,
|
||||
node_started_callback=handler.mark_node_processing,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -119,8 +119,8 @@ class ClaudeMessageHandler:
|
|||
self.cli_manager = cli_manager
|
||||
self.session_store = session_store
|
||||
self.tree_queue = TreeQueueManager(
|
||||
queue_update_callback=self._update_queue_positions,
|
||||
node_started_callback=self._mark_node_processing,
|
||||
queue_update_callback=self.update_queue_positions,
|
||||
node_started_callback=self.mark_node_processing,
|
||||
)
|
||||
is_discord = platform.name == "discord"
|
||||
self._format_status_fn = (
|
||||
|
|
@ -138,13 +138,13 @@ class ClaudeMessageHandler:
|
|||
)
|
||||
self._limit_chars = 1900 if is_discord else 3900
|
||||
|
||||
def _format_status(self, emoji: str, label: str, suffix: str | None = None) -> str:
|
||||
def format_status(self, emoji: str, label: str, suffix: str | None = None) -> str:
|
||||
return self._format_status_fn(emoji, label, suffix)
|
||||
|
||||
def _parse_mode(self) -> str | None:
|
||||
return self._parse_mode_val
|
||||
|
||||
def _get_render_ctx(self) -> RenderCtx:
|
||||
def get_render_ctx(self) -> RenderCtx:
|
||||
return self._render_ctx_val
|
||||
|
||||
def _get_limit_chars(self) -> int:
|
||||
|
|
@ -253,7 +253,7 @@ class ClaudeMessageHandler:
|
|||
fire_and_forget=False,
|
||||
message_thread_id=incoming.message_thread_id,
|
||||
)
|
||||
self._record_outgoing_message(
|
||||
self.record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, status_msg_id, "status"
|
||||
)
|
||||
|
||||
|
|
@ -298,13 +298,13 @@ class ClaudeMessageHandler:
|
|||
await self.platform.queue_edit_message(
|
||||
incoming.chat_id,
|
||||
status_msg_id,
|
||||
self._format_status(
|
||||
self.format_status(
|
||||
"📋", "Queued", f"(position {queue_size}) - waiting..."
|
||||
),
|
||||
parse_mode=self._parse_mode(),
|
||||
)
|
||||
|
||||
async def _update_queue_positions(self, tree: MessageTree) -> None:
|
||||
async def update_queue_positions(self, tree: MessageTree) -> None:
|
||||
"""Refresh queued status messages after a dequeue."""
|
||||
try:
|
||||
queued_ids = await tree.get_queue_snapshot()
|
||||
|
|
@ -325,14 +325,14 @@ class ClaudeMessageHandler:
|
|||
self.platform.queue_edit_message(
|
||||
node.incoming.chat_id,
|
||||
node.status_message_id,
|
||||
self._format_status(
|
||||
self.format_status(
|
||||
"📋", "Queued", f"(position {position}) - waiting..."
|
||||
),
|
||||
parse_mode=self._parse_mode(),
|
||||
)
|
||||
)
|
||||
|
||||
async def _mark_node_processing(self, tree: MessageTree, node_id: str) -> None:
|
||||
async def mark_node_processing(self, tree: MessageTree, node_id: str) -> None:
|
||||
"""Update the dequeued node's status to processing immediately."""
|
||||
node = tree.get_node(node_id)
|
||||
if not node or node.state == MessageState.ERROR:
|
||||
|
|
@ -341,7 +341,7 @@ class ClaudeMessageHandler:
|
|||
self.platform.queue_edit_message(
|
||||
node.incoming.chat_id,
|
||||
node.status_message_id,
|
||||
self._format_status("🔄", "Processing..."),
|
||||
self.format_status("🔄", "Processing..."),
|
||||
parse_mode=self._parse_mode(),
|
||||
)
|
||||
)
|
||||
|
|
@ -351,7 +351,7 @@ class ClaudeMessageHandler:
|
|||
) -> tuple[TranscriptBuffer, RenderCtx]:
|
||||
"""Create transcript buffer and render context for node processing."""
|
||||
transcript = TranscriptBuffer(show_tool_results=False)
|
||||
return transcript, self._get_render_ctx()
|
||||
return transcript, self.get_render_ctx()
|
||||
|
||||
async def _handle_session_info_event(
|
||||
self,
|
||||
|
|
@ -400,7 +400,7 @@ class ClaudeMessageHandler:
|
|||
transcript.apply(parsed)
|
||||
had_transcript_events = True
|
||||
|
||||
status = _get_status_for_event(ptype, parsed, self._format_status)
|
||||
status = _get_status_for_event(ptype, parsed, self.format_status)
|
||||
if status is not None:
|
||||
await update_ui(status)
|
||||
last_status = status
|
||||
|
|
@ -410,7 +410,7 @@ class ClaudeMessageHandler:
|
|||
if not had_transcript_events:
|
||||
transcript.apply({"type": "text_chunk", "text": "Done."})
|
||||
logger.info("HANDLER: Task complete, updating UI")
|
||||
await update_ui(self._format_status("✅", "Complete"), force=True)
|
||||
await update_ui(self.format_status("✅", "Complete"), force=True)
|
||||
if tree and captured_session_id:
|
||||
await tree.update_state(
|
||||
node_id,
|
||||
|
|
@ -422,7 +422,7 @@ class ClaudeMessageHandler:
|
|||
error_msg = parsed.get("message", "Unknown error")
|
||||
logger.error(f"HANDLER: Error event received: {error_msg}")
|
||||
logger.info("HANDLER: Updating UI with error status")
|
||||
await update_ui(self._format_status("❌", "Error"), force=True)
|
||||
await update_ui(self.format_status("❌", "Error"), force=True)
|
||||
if tree:
|
||||
await self._propagate_error_to_children(
|
||||
node_id, error_msg, "Parent task failed"
|
||||
|
|
@ -533,7 +533,7 @@ class ClaudeMessageHandler:
|
|||
except RuntimeError as e:
|
||||
transcript.apply({"type": "error", "message": str(e)})
|
||||
await update_ui(
|
||||
self._format_status("⏳", "Session limit reached"),
|
||||
self.format_status("⏳", "Session limit reached"),
|
||||
force=True,
|
||||
)
|
||||
if tree:
|
||||
|
|
@ -592,10 +592,10 @@ class ClaudeMessageHandler:
|
|||
cancel_reason = node.context.get("cancel_reason")
|
||||
|
||||
if cancel_reason == "stop":
|
||||
await update_ui(self._format_status("⏹", "Stopped."), force=True)
|
||||
await update_ui(self.format_status("⏹", "Stopped."), force=True)
|
||||
else:
|
||||
transcript.apply({"type": "error", "message": "Task was cancelled"})
|
||||
await update_ui(self._format_status("❌", "Cancelled"), force=True)
|
||||
await update_ui(self.format_status("❌", "Cancelled"), force=True)
|
||||
|
||||
# Do not propagate cancellation to children; a reply-scoped "/stop"
|
||||
# should only stop the targeted task.
|
||||
|
|
@ -609,7 +609,7 @@ class ClaudeMessageHandler:
|
|||
)
|
||||
error_msg = str(e)[:200]
|
||||
transcript.apply({"type": "error", "message": error_msg})
|
||||
await update_ui(self._format_status("💥", "Task Failed"), force=True)
|
||||
await update_ui(self.format_status("💥", "Task Failed"), force=True)
|
||||
if tree:
|
||||
await self._propagate_error_to_children(
|
||||
node_id, error_msg, "Parent task failed"
|
||||
|
|
@ -643,7 +643,7 @@ class ClaudeMessageHandler:
|
|||
self.platform.queue_edit_message(
|
||||
child.incoming.chat_id,
|
||||
child.status_message_id,
|
||||
self._format_status("❌", "Cancelled:", child_status_text),
|
||||
self.format_status("❌", "Cancelled:", child_status_text),
|
||||
parse_mode=self._parse_mode(),
|
||||
)
|
||||
)
|
||||
|
|
@ -658,13 +658,13 @@ class ClaudeMessageHandler:
|
|||
# Reply to existing tree
|
||||
if self.tree_queue.is_node_tree_busy(parent_node_id):
|
||||
queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1
|
||||
return self._format_status(
|
||||
return self.format_status(
|
||||
"📋", "Queued", f"(position {queue_size}) - waiting..."
|
||||
)
|
||||
return self._format_status("🔄", "Continuing conversation...")
|
||||
return self.format_status("🔄", "Continuing conversation...")
|
||||
|
||||
# New conversation
|
||||
return self._format_status("⏳", "Launching new Claude CLI instance...")
|
||||
return self.format_status("⏳", "Launching new Claude CLI instance...")
|
||||
|
||||
async def stop_all_tasks(self) -> int:
|
||||
"""
|
||||
|
|
@ -685,7 +685,7 @@ class ClaudeMessageHandler:
|
|||
await self.cli_manager.stop_all()
|
||||
|
||||
# 3. Update UI and persist state for all cancelled nodes
|
||||
self._update_cancelled_nodes_ui(cancelled_nodes)
|
||||
self.update_cancelled_nodes_ui(cancelled_nodes)
|
||||
|
||||
return len(cancelled_nodes)
|
||||
|
||||
|
|
@ -703,10 +703,10 @@ class ClaudeMessageHandler:
|
|||
node.set_context({"cancel_reason": "stop"})
|
||||
|
||||
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
|
||||
self._update_cancelled_nodes_ui(cancelled_nodes)
|
||||
self.update_cancelled_nodes_ui(cancelled_nodes)
|
||||
return len(cancelled_nodes)
|
||||
|
||||
def _record_outgoing_message(
|
||||
def record_outgoing_message(
|
||||
self,
|
||||
platform: str,
|
||||
chat_id: str,
|
||||
|
|
@ -723,7 +723,7 @@ class ClaudeMessageHandler:
|
|||
except Exception as e:
|
||||
logger.debug(f"Failed to record message_id: {e}")
|
||||
|
||||
def _update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
|
||||
def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
|
||||
"""Update status messages and persist tree state for cancelled nodes."""
|
||||
trees_to_save: dict[str, MessageTree] = {}
|
||||
for node in nodes:
|
||||
|
|
@ -731,7 +731,7 @@ class ClaudeMessageHandler:
|
|||
self.platform.queue_edit_message(
|
||||
node.incoming.chat_id,
|
||||
node.status_message_id,
|
||||
self._format_status("⏹", "Stopped."),
|
||||
self.format_status("⏹", "Stopped."),
|
||||
parse_mode=self._parse_mode(),
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -26,11 +27,11 @@ def _safe_json_dumps(obj: Any) -> str:
|
|||
|
||||
|
||||
@dataclass
|
||||
class Segment:
|
||||
class Segment(ABC):
|
||||
kind: str
|
||||
|
||||
def render(self, ctx: RenderCtx) -> str:
|
||||
raise NotImplementedError
|
||||
@abstractmethod
|
||||
def render(self, ctx: RenderCtx) -> str: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -88,23 +89,6 @@ class ToolCallSegment(Segment):
|
|||
self.tool_use_id = str(tool_use_id or "")
|
||||
self.name = str(name or "tool")
|
||||
self.indent_level = max(0, int(indent_level))
|
||||
self._parts: list[str] = []
|
||||
|
||||
def set_initial_input(self, inp: Any) -> None:
|
||||
if inp is None:
|
||||
return
|
||||
if isinstance(inp, str):
|
||||
self._parts = [inp]
|
||||
else:
|
||||
self._parts = [_safe_json_dumps(inp)]
|
||||
|
||||
def append_input_delta(self, partial: str) -> None:
|
||||
if partial:
|
||||
self._parts.append(partial)
|
||||
|
||||
@property
|
||||
def input_text(self) -> str:
|
||||
return "".join(self._parts)
|
||||
|
||||
def render(self, ctx: RenderCtx) -> str:
|
||||
name = ctx.code_inline(self.name)
|
||||
|
|
@ -444,17 +428,12 @@ class TranscriptBuffer:
|
|||
seg = ToolCallSegment(tool_id, name)
|
||||
self._segments.append(seg)
|
||||
|
||||
seg.set_initial_input(ev.get("input"))
|
||||
if idx >= 0:
|
||||
self._open_tools_by_index[idx] = seg
|
||||
return
|
||||
|
||||
if et == "tool_use_delta":
|
||||
idx = int(ev.get("index", -1))
|
||||
partial = str(ev.get("partial_json", "") or "")
|
||||
seg = self._open_tools_by_index.get(idx)
|
||||
if seg is not None:
|
||||
seg.append_input_delta(partial)
|
||||
# Track open tool by index for tool_use_stop (closing state).
|
||||
return
|
||||
|
||||
if et == "tool_use_stop":
|
||||
|
|
@ -501,7 +480,6 @@ class TranscriptBuffer:
|
|||
seg = ToolCallSegment(tool_id, name)
|
||||
self._segments.append(seg)
|
||||
|
||||
seg.set_initial_input(ev.get("input"))
|
||||
seg.closed = True
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -17,21 +17,26 @@ from ..models import IncomingMessage
|
|||
|
||||
|
||||
class _SnapshotQueue:
|
||||
"""Queue with snapshot/remove helpers, backed by a deque."""
|
||||
"""Queue with snapshot/remove helpers, backed by a deque and a set index."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._deque: deque[str] = deque()
|
||||
self._set: set[str] = set()
|
||||
|
||||
async def put(self, item: str) -> None:
|
||||
self._deque.append(item)
|
||||
self._set.add(item)
|
||||
|
||||
def put_nowait(self, item: str) -> None:
|
||||
self._deque.append(item)
|
||||
self._set.add(item)
|
||||
|
||||
def get_nowait(self) -> str:
|
||||
if not self._deque:
|
||||
raise asyncio.QueueEmpty()
|
||||
return self._deque.popleft()
|
||||
item = self._deque.popleft()
|
||||
self._set.discard(item)
|
||||
return item
|
||||
|
||||
def qsize(self) -> int:
|
||||
return len(self._deque)
|
||||
|
|
@ -41,12 +46,11 @@ class _SnapshotQueue:
|
|||
return list(self._deque)
|
||||
|
||||
def remove_if_present(self, item: str) -> bool:
|
||||
"""Remove item from queue if present. Returns True if removed."""
|
||||
if item not in self._deque:
|
||||
"""Remove item from queue if present (O(1) membership check). Returns True if removed."""
|
||||
if item not in self._set:
|
||||
return False
|
||||
items = [x for x in self._deque if x != item]
|
||||
self._deque.clear()
|
||||
self._deque.extend(items)
|
||||
self._set.discard(item)
|
||||
self._deque = deque(x for x in self._deque if x != item)
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -350,7 +354,7 @@ class MessageTree:
|
|||
return True
|
||||
return False
|
||||
|
||||
def _set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
|
||||
def set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
|
||||
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = error_message
|
||||
|
|
@ -371,7 +375,7 @@ class MessageTree:
|
|||
break
|
||||
node = self._nodes.get(node_id)
|
||||
if node:
|
||||
self._set_node_error_sync(node, error_message)
|
||||
self.set_node_error_sync(node, error_message)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class TreeQueueProcessor:
|
|||
) -> None:
|
||||
"""Process a single node and then check the queue."""
|
||||
# Skip if already in terminal state (e.g. from error propagation)
|
||||
if node.state.value == MessageState.ERROR.value:
|
||||
if node.state == MessageState.ERROR:
|
||||
logger.info(
|
||||
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ class TreeQueueManager:
|
|||
MessageState.COMPLETED,
|
||||
MessageState.ERROR,
|
||||
):
|
||||
tree._set_node_error_sync(node, "Cancelled by user")
|
||||
tree.set_node_error_sync(node, "Cancelled by user")
|
||||
cancelled_nodes.append(node)
|
||||
|
||||
# 2. Drain queue and mark nodes as cancelled
|
||||
|
|
@ -259,7 +259,7 @@ class TreeQueueManager:
|
|||
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
||||
and node.node_id not in cancelled_ids
|
||||
):
|
||||
tree._set_node_error_sync(node, "Stale task cleaned up")
|
||||
tree.set_node_error_sync(node, "Stale task cleaned up")
|
||||
cleanup_count += 1
|
||||
|
||||
tree.reset_processing_state()
|
||||
|
|
@ -336,7 +336,7 @@ class TreeQueueManager:
|
|||
for tree in self._repository.all_trees():
|
||||
for node in tree.all_nodes():
|
||||
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
|
||||
tree._set_node_error_sync(node, "Lost during server restart")
|
||||
tree.set_node_error_sync(node, "Lost during server restart")
|
||||
count += 1
|
||||
if count:
|
||||
logger.info(f"Cleaned up {count} stale nodes during startup")
|
||||
|
|
|
|||
|
|
@ -150,7 +150,11 @@ class TreeRepository:
|
|||
return tree
|
||||
|
||||
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
|
||||
"""Get all message IDs (incoming + status) for a given platform/chat."""
|
||||
"""Get all message IDs (incoming + status) for a given platform/chat.
|
||||
|
||||
Note: O(total_nodes) scan. Acceptable because this is only called
|
||||
from /clear (user-initiated, infrequent).
|
||||
"""
|
||||
msg_ids: set[str] = set()
|
||||
for tree in self._trees.values():
|
||||
for node in tree.all_nodes():
|
||||
|
|
|
|||
|
|
@ -366,16 +366,19 @@ class SSEBuilder:
|
|||
# Tool calls are harder to tokenize exactly without reconstruction, but we can approximate
|
||||
# by tokenizing the json dumps of tool contents
|
||||
tool_tokens = 0
|
||||
started_tool_count = 0
|
||||
for state in self.blocks.tool_states.values():
|
||||
tool_tokens += len(ENCODER.encode(state.name))
|
||||
tool_tokens += len(ENCODER.encode("".join(state.contents)))
|
||||
tool_tokens += 15 # Control tokens overhead per tool
|
||||
if state.started:
|
||||
started_tool_count += 1
|
||||
|
||||
# Per-block overhead (~4 tokens per content block)
|
||||
block_count = (
|
||||
(1 if accumulated_reasoning else 0)
|
||||
+ (1 if accumulated_text else 0)
|
||||
+ sum(1 for s in self.blocks.tool_states.values() if s.started)
|
||||
+ started_tool_count
|
||||
)
|
||||
block_overhead = block_count * 4
|
||||
|
||||
|
|
|
|||
|
|
@ -16,27 +16,29 @@ from providers.common.text import extract_text_from_content
|
|||
def generate_request_fingerprint(messages: list[Any]) -> str:
|
||||
"""Generate unique short hash for message content.
|
||||
|
||||
Creates a SHA256 hash of all message content, returning an 8-char prefix
|
||||
that's sufficient for correlation without full content logging.
|
||||
Uses incremental SHA256 hashing to avoid building a large intermediate
|
||||
string. Returns an 8-char hex prefix sufficient for correlation.
|
||||
"""
|
||||
content_parts = []
|
||||
h = hashlib.sha256()
|
||||
sep = b"|"
|
||||
for msg in messages:
|
||||
if hasattr(msg, "content"):
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
content_parts.append(content)
|
||||
h.update(content.encode("utf-8"))
|
||||
h.update(sep)
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if hasattr(block, "text"):
|
||||
content_parts.append(block.text)
|
||||
h.update(block.text.encode("utf-8"))
|
||||
h.update(sep)
|
||||
elif hasattr(block, "type"):
|
||||
content_parts.append(f"<{block.type}>")
|
||||
h.update(f"<{block.type}>".encode())
|
||||
h.update(sep)
|
||||
elif hasattr(msg, "role"):
|
||||
content_parts.append(msg.role)
|
||||
|
||||
combined = "|".join(content_parts)
|
||||
hash_digest = hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
return f"fp_{hash_digest[:8]}"
|
||||
h.update(msg.role.encode("utf-8"))
|
||||
h.update(sep)
|
||||
return f"fp_{h.hexdigest()[:8]}"
|
||||
|
||||
|
||||
def get_last_user_message_preview(messages: list[Any], max_len: int = 100) -> str:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""OpenRouter provider - OpenAI-compatible API for hundreds of models."""
|
||||
|
||||
from .client import OpenRouterProvider
|
||||
from .client import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||
|
||||
__all__ = ["OpenRouterProvider"]
|
||||
__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"]
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import httpx
|
|||
from loguru import logger
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from config.nim import NimSettings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.common import (
|
||||
ContentType,
|
||||
|
|
@ -33,7 +32,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
|||
provider_name: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
nim_settings: NimSettings | None = None,
|
||||
nim_settings: Any = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self._provider_name = provider_name
|
||||
|
|
|
|||
|
|
@ -30,13 +30,19 @@ class GlobalRateLimiter:
|
|||
|
||||
_instance: ClassVar[GlobalRateLimiter | None] = None
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> GlobalRateLimiter:
|
||||
if cls._instance is not None:
|
||||
return cls._instance
|
||||
instance = super().__new__(cls)
|
||||
return instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate_limit: int = 40,
|
||||
rate_window: float = 60.0,
|
||||
max_concurrency: int = 5,
|
||||
):
|
||||
# Prevent double initialization in singleton
|
||||
# Prevent re-initialization on singleton reuse
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -247,7 +247,7 @@ async def test_update_queue_positions(handler, mock_platform):
|
|||
await tree.enqueue("child_1")
|
||||
await tree.enqueue("child_2")
|
||||
|
||||
await handler._update_queue_positions(tree)
|
||||
await handler.update_queue_positions(tree)
|
||||
|
||||
calls = mock_platform.queue_edit_message.call_args_list
|
||||
assert len(calls) == 2
|
||||
|
|
@ -291,7 +291,7 @@ async def test_mark_node_processing(handler, mock_platform):
|
|||
parent_id="root",
|
||||
)
|
||||
|
||||
await handler._mark_node_processing(tree, "child")
|
||||
await handler.mark_node_processing(tree, "child")
|
||||
|
||||
mock_platform.queue_edit_message.assert_called_once()
|
||||
args, kwargs = mock_platform.queue_edit_message.call_args
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi
|
|||
# Snapshot error is swallowed.
|
||||
tree = MagicMock()
|
||||
tree.get_queue_snapshot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
await handler._update_queue_positions(tree)
|
||||
await handler.update_queue_positions(tree)
|
||||
platform.fire_and_forget.assert_not_called()
|
||||
|
||||
# Normal path: only PENDING nodes get an update.
|
||||
|
|
@ -130,7 +130,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi
|
|||
tree.get_queue_snapshot = AsyncMock(return_value=["n1", "n2"])
|
||||
tree.get_node = MagicMock(side_effect=[node_pending, node_done])
|
||||
|
||||
await handler._update_queue_positions(tree)
|
||||
await handler.update_queue_positions(tree)
|
||||
assert platform.fire_and_forget.call_count == 1
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent(
|
|||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||
handler2.tree_queue = TreeQueueManager.from_dict(
|
||||
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
||||
queue_update_callback=handler2._update_queue_positions,
|
||||
node_started_callback=handler2._mark_node_processing,
|
||||
queue_update_callback=handler2.update_queue_positions,
|
||||
node_started_callback=handler2.mark_node_processing,
|
||||
)
|
||||
|
||||
# Prevent background task scheduling; we only want to validate routing/tree mutation.
|
||||
|
|
@ -89,8 +89,8 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat
|
|||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||
handler2.tree_queue = TreeQueueManager.from_dict(
|
||||
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
||||
queue_update_callback=handler2._update_queue_positions,
|
||||
node_started_callback=handler2._mark_node_processing,
|
||||
queue_update_callback=handler2.update_queue_positions,
|
||||
node_started_callback=handler2.mark_node_processing,
|
||||
)
|
||||
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue