mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Add code review fix plan covering 11 issues across modularity, encapsulation, performance, and dead code (#62)
This commit is contained in:
parent
c54c57a742
commit
aee9f0ad93
17 changed files with 107 additions and 114 deletions
|
|
@ -114,8 +114,8 @@ async def lifespan(app: FastAPI):
|
||||||
"trees": saved_trees,
|
"trees": saved_trees,
|
||||||
"node_to_tree": session_store.get_node_mapping(),
|
"node_to_tree": session_store.get_node_mapping(),
|
||||||
},
|
},
|
||||||
queue_update_callback=message_handler._update_queue_positions,
|
queue_update_callback=message_handler.update_queue_positions,
|
||||||
node_started_callback=message_handler._mark_node_processing,
|
node_started_callback=message_handler.mark_node_processing,
|
||||||
)
|
)
|
||||||
# Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
|
# Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
|
||||||
if message_handler.tree_queue.cleanup_stale_nodes() > 0:
|
if message_handler.tree_queue.cleanup_stale_nodes() > 0:
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,7 @@ from providers.base import BaseProvider, ProviderConfig
|
||||||
from providers.exceptions import AuthenticationError
|
from providers.exceptions import AuthenticationError
|
||||||
from providers.lmstudio import LMStudioProvider
|
from providers.lmstudio import LMStudioProvider
|
||||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||||
from providers.open_router import OpenRouterProvider
|
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||||
from providers.open_router.client import OPENROUTER_BASE_URL
|
|
||||||
|
|
||||||
# Global provider instance (singleton)
|
# Global provider instance (singleton)
|
||||||
_provider: BaseProvider | None = None
|
_provider: BaseProvider | None = None
|
||||||
|
|
|
||||||
|
|
@ -28,13 +28,13 @@ async def handle_stop_command(
|
||||||
if not node_id:
|
if not node_id:
|
||||||
msg_id = await handler.platform.queue_send_message(
|
msg_id = await handler.platform.queue_send_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
handler._format_status(
|
handler.format_status(
|
||||||
"⏹", "Stopped.", "Nothing to stop for that message."
|
"⏹", "Stopped.", "Nothing to stop for that message."
|
||||||
),
|
),
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
handler._record_outgoing_message(
|
handler.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
@ -43,11 +43,11 @@ async def handle_stop_command(
|
||||||
noun = "request" if count == 1 else "requests"
|
noun = "request" if count == 1 else "requests"
|
||||||
msg_id = await handler.platform.queue_send_message(
|
msg_id = await handler.platform.queue_send_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
handler._format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."),
|
handler.format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."),
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
handler._record_outgoing_message(
|
handler.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
@ -56,13 +56,13 @@ async def handle_stop_command(
|
||||||
count = await handler.stop_all_tasks()
|
count = await handler.stop_all_tasks()
|
||||||
msg_id = await handler.platform.queue_send_message(
|
msg_id = await handler.platform.queue_send_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
handler._format_status(
|
handler.format_status(
|
||||||
"⏹", "Stopped.", f"Cancelled {count} pending or active requests."
|
"⏹", "Stopped.", f"Cancelled {count} pending or active requests."
|
||||||
),
|
),
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
handler._record_outgoing_message(
|
handler.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -73,7 +73,7 @@ async def handle_stats_command(
|
||||||
"""Handle /stats command."""
|
"""Handle /stats command."""
|
||||||
stats = handler.cli_manager.get_stats()
|
stats = handler.cli_manager.get_stats()
|
||||||
tree_count = handler.tree_queue.get_tree_count()
|
tree_count = handler.tree_queue.get_tree_count()
|
||||||
ctx = handler._get_render_ctx()
|
ctx = handler.get_render_ctx()
|
||||||
msg_id = await handler.platform.queue_send_message(
|
msg_id = await handler.platform.queue_send_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
"📊 "
|
"📊 "
|
||||||
|
|
@ -85,7 +85,7 @@ async def handle_stats_command(
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
handler._record_outgoing_message(
|
handler.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -149,7 +149,7 @@ async def _handle_clear_branch(
|
||||||
|
|
||||||
# 1) Cancel branch tasks (no stop_all)
|
# 1) Cancel branch tasks (no stop_all)
|
||||||
cancelled = await handler.tree_queue.cancel_branch(branch_root_id)
|
cancelled = await handler.tree_queue.cancel_branch(branch_root_id)
|
||||||
handler._update_cancelled_nodes_ui(cancelled)
|
handler.update_cancelled_nodes_ui(cancelled)
|
||||||
|
|
||||||
# 2) Collect message IDs from branch nodes only
|
# 2) Collect message IDs from branch nodes only
|
||||||
msg_ids: set[str] = set()
|
msg_ids: set[str] = set()
|
||||||
|
|
@ -214,25 +214,23 @@ async def handle_clear_command(
|
||||||
await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del)
|
await _delete_message_ids(handler, incoming.chat_id, msg_ids_to_del)
|
||||||
msg_id = await handler.platform.queue_send_message(
|
msg_id = await handler.platform.queue_send_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
handler._format_status(
|
handler.format_status("🗑", "Cleared.", "Voice note cancelled."),
|
||||||
"🗑", "Cleared.", "Voice note cancelled."
|
|
||||||
),
|
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
handler._record_outgoing_message(
|
handler.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
msg_id = await handler.platform.queue_send_message(
|
msg_id = await handler.platform.queue_send_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
handler._format_status(
|
handler.format_status(
|
||||||
"🗑", "Cleared.", "Nothing to clear for that message."
|
"🗑", "Cleared.", "Nothing to clear for that message."
|
||||||
),
|
),
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
handler._record_outgoing_message(
|
handler.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
@ -278,6 +276,6 @@ async def handle_clear_command(
|
||||||
logger.warning(f"Failed to clear session store: {e}")
|
logger.warning(f"Failed to clear session store: {e}")
|
||||||
|
|
||||||
handler.tree_queue = TreeQueueManager(
|
handler.tree_queue = TreeQueueManager(
|
||||||
queue_update_callback=handler._update_queue_positions,
|
queue_update_callback=handler.update_queue_positions,
|
||||||
node_started_callback=handler._mark_node_processing,
|
node_started_callback=handler.mark_node_processing,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -119,8 +119,8 @@ class ClaudeMessageHandler:
|
||||||
self.cli_manager = cli_manager
|
self.cli_manager = cli_manager
|
||||||
self.session_store = session_store
|
self.session_store = session_store
|
||||||
self.tree_queue = TreeQueueManager(
|
self.tree_queue = TreeQueueManager(
|
||||||
queue_update_callback=self._update_queue_positions,
|
queue_update_callback=self.update_queue_positions,
|
||||||
node_started_callback=self._mark_node_processing,
|
node_started_callback=self.mark_node_processing,
|
||||||
)
|
)
|
||||||
is_discord = platform.name == "discord"
|
is_discord = platform.name == "discord"
|
||||||
self._format_status_fn = (
|
self._format_status_fn = (
|
||||||
|
|
@ -138,13 +138,13 @@ class ClaudeMessageHandler:
|
||||||
)
|
)
|
||||||
self._limit_chars = 1900 if is_discord else 3900
|
self._limit_chars = 1900 if is_discord else 3900
|
||||||
|
|
||||||
def _format_status(self, emoji: str, label: str, suffix: str | None = None) -> str:
|
def format_status(self, emoji: str, label: str, suffix: str | None = None) -> str:
|
||||||
return self._format_status_fn(emoji, label, suffix)
|
return self._format_status_fn(emoji, label, suffix)
|
||||||
|
|
||||||
def _parse_mode(self) -> str | None:
|
def _parse_mode(self) -> str | None:
|
||||||
return self._parse_mode_val
|
return self._parse_mode_val
|
||||||
|
|
||||||
def _get_render_ctx(self) -> RenderCtx:
|
def get_render_ctx(self) -> RenderCtx:
|
||||||
return self._render_ctx_val
|
return self._render_ctx_val
|
||||||
|
|
||||||
def _get_limit_chars(self) -> int:
|
def _get_limit_chars(self) -> int:
|
||||||
|
|
@ -253,7 +253,7 @@ class ClaudeMessageHandler:
|
||||||
fire_and_forget=False,
|
fire_and_forget=False,
|
||||||
message_thread_id=incoming.message_thread_id,
|
message_thread_id=incoming.message_thread_id,
|
||||||
)
|
)
|
||||||
self._record_outgoing_message(
|
self.record_outgoing_message(
|
||||||
incoming.platform, incoming.chat_id, status_msg_id, "status"
|
incoming.platform, incoming.chat_id, status_msg_id, "status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -298,13 +298,13 @@ class ClaudeMessageHandler:
|
||||||
await self.platform.queue_edit_message(
|
await self.platform.queue_edit_message(
|
||||||
incoming.chat_id,
|
incoming.chat_id,
|
||||||
status_msg_id,
|
status_msg_id,
|
||||||
self._format_status(
|
self.format_status(
|
||||||
"📋", "Queued", f"(position {queue_size}) - waiting..."
|
"📋", "Queued", f"(position {queue_size}) - waiting..."
|
||||||
),
|
),
|
||||||
parse_mode=self._parse_mode(),
|
parse_mode=self._parse_mode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_queue_positions(self, tree: MessageTree) -> None:
|
async def update_queue_positions(self, tree: MessageTree) -> None:
|
||||||
"""Refresh queued status messages after a dequeue."""
|
"""Refresh queued status messages after a dequeue."""
|
||||||
try:
|
try:
|
||||||
queued_ids = await tree.get_queue_snapshot()
|
queued_ids = await tree.get_queue_snapshot()
|
||||||
|
|
@ -325,14 +325,14 @@ class ClaudeMessageHandler:
|
||||||
self.platform.queue_edit_message(
|
self.platform.queue_edit_message(
|
||||||
node.incoming.chat_id,
|
node.incoming.chat_id,
|
||||||
node.status_message_id,
|
node.status_message_id,
|
||||||
self._format_status(
|
self.format_status(
|
||||||
"📋", "Queued", f"(position {position}) - waiting..."
|
"📋", "Queued", f"(position {position}) - waiting..."
|
||||||
),
|
),
|
||||||
parse_mode=self._parse_mode(),
|
parse_mode=self._parse_mode(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _mark_node_processing(self, tree: MessageTree, node_id: str) -> None:
|
async def mark_node_processing(self, tree: MessageTree, node_id: str) -> None:
|
||||||
"""Update the dequeued node's status to processing immediately."""
|
"""Update the dequeued node's status to processing immediately."""
|
||||||
node = tree.get_node(node_id)
|
node = tree.get_node(node_id)
|
||||||
if not node or node.state == MessageState.ERROR:
|
if not node or node.state == MessageState.ERROR:
|
||||||
|
|
@ -341,7 +341,7 @@ class ClaudeMessageHandler:
|
||||||
self.platform.queue_edit_message(
|
self.platform.queue_edit_message(
|
||||||
node.incoming.chat_id,
|
node.incoming.chat_id,
|
||||||
node.status_message_id,
|
node.status_message_id,
|
||||||
self._format_status("🔄", "Processing..."),
|
self.format_status("🔄", "Processing..."),
|
||||||
parse_mode=self._parse_mode(),
|
parse_mode=self._parse_mode(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -351,7 +351,7 @@ class ClaudeMessageHandler:
|
||||||
) -> tuple[TranscriptBuffer, RenderCtx]:
|
) -> tuple[TranscriptBuffer, RenderCtx]:
|
||||||
"""Create transcript buffer and render context for node processing."""
|
"""Create transcript buffer and render context for node processing."""
|
||||||
transcript = TranscriptBuffer(show_tool_results=False)
|
transcript = TranscriptBuffer(show_tool_results=False)
|
||||||
return transcript, self._get_render_ctx()
|
return transcript, self.get_render_ctx()
|
||||||
|
|
||||||
async def _handle_session_info_event(
|
async def _handle_session_info_event(
|
||||||
self,
|
self,
|
||||||
|
|
@ -400,7 +400,7 @@ class ClaudeMessageHandler:
|
||||||
transcript.apply(parsed)
|
transcript.apply(parsed)
|
||||||
had_transcript_events = True
|
had_transcript_events = True
|
||||||
|
|
||||||
status = _get_status_for_event(ptype, parsed, self._format_status)
|
status = _get_status_for_event(ptype, parsed, self.format_status)
|
||||||
if status is not None:
|
if status is not None:
|
||||||
await update_ui(status)
|
await update_ui(status)
|
||||||
last_status = status
|
last_status = status
|
||||||
|
|
@ -410,7 +410,7 @@ class ClaudeMessageHandler:
|
||||||
if not had_transcript_events:
|
if not had_transcript_events:
|
||||||
transcript.apply({"type": "text_chunk", "text": "Done."})
|
transcript.apply({"type": "text_chunk", "text": "Done."})
|
||||||
logger.info("HANDLER: Task complete, updating UI")
|
logger.info("HANDLER: Task complete, updating UI")
|
||||||
await update_ui(self._format_status("✅", "Complete"), force=True)
|
await update_ui(self.format_status("✅", "Complete"), force=True)
|
||||||
if tree and captured_session_id:
|
if tree and captured_session_id:
|
||||||
await tree.update_state(
|
await tree.update_state(
|
||||||
node_id,
|
node_id,
|
||||||
|
|
@ -422,7 +422,7 @@ class ClaudeMessageHandler:
|
||||||
error_msg = parsed.get("message", "Unknown error")
|
error_msg = parsed.get("message", "Unknown error")
|
||||||
logger.error(f"HANDLER: Error event received: {error_msg}")
|
logger.error(f"HANDLER: Error event received: {error_msg}")
|
||||||
logger.info("HANDLER: Updating UI with error status")
|
logger.info("HANDLER: Updating UI with error status")
|
||||||
await update_ui(self._format_status("❌", "Error"), force=True)
|
await update_ui(self.format_status("❌", "Error"), force=True)
|
||||||
if tree:
|
if tree:
|
||||||
await self._propagate_error_to_children(
|
await self._propagate_error_to_children(
|
||||||
node_id, error_msg, "Parent task failed"
|
node_id, error_msg, "Parent task failed"
|
||||||
|
|
@ -533,7 +533,7 @@ class ClaudeMessageHandler:
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
transcript.apply({"type": "error", "message": str(e)})
|
transcript.apply({"type": "error", "message": str(e)})
|
||||||
await update_ui(
|
await update_ui(
|
||||||
self._format_status("⏳", "Session limit reached"),
|
self.format_status("⏳", "Session limit reached"),
|
||||||
force=True,
|
force=True,
|
||||||
)
|
)
|
||||||
if tree:
|
if tree:
|
||||||
|
|
@ -592,10 +592,10 @@ class ClaudeMessageHandler:
|
||||||
cancel_reason = node.context.get("cancel_reason")
|
cancel_reason = node.context.get("cancel_reason")
|
||||||
|
|
||||||
if cancel_reason == "stop":
|
if cancel_reason == "stop":
|
||||||
await update_ui(self._format_status("⏹", "Stopped."), force=True)
|
await update_ui(self.format_status("⏹", "Stopped."), force=True)
|
||||||
else:
|
else:
|
||||||
transcript.apply({"type": "error", "message": "Task was cancelled"})
|
transcript.apply({"type": "error", "message": "Task was cancelled"})
|
||||||
await update_ui(self._format_status("❌", "Cancelled"), force=True)
|
await update_ui(self.format_status("❌", "Cancelled"), force=True)
|
||||||
|
|
||||||
# Do not propagate cancellation to children; a reply-scoped "/stop"
|
# Do not propagate cancellation to children; a reply-scoped "/stop"
|
||||||
# should only stop the targeted task.
|
# should only stop the targeted task.
|
||||||
|
|
@ -609,7 +609,7 @@ class ClaudeMessageHandler:
|
||||||
)
|
)
|
||||||
error_msg = str(e)[:200]
|
error_msg = str(e)[:200]
|
||||||
transcript.apply({"type": "error", "message": error_msg})
|
transcript.apply({"type": "error", "message": error_msg})
|
||||||
await update_ui(self._format_status("💥", "Task Failed"), force=True)
|
await update_ui(self.format_status("💥", "Task Failed"), force=True)
|
||||||
if tree:
|
if tree:
|
||||||
await self._propagate_error_to_children(
|
await self._propagate_error_to_children(
|
||||||
node_id, error_msg, "Parent task failed"
|
node_id, error_msg, "Parent task failed"
|
||||||
|
|
@ -643,7 +643,7 @@ class ClaudeMessageHandler:
|
||||||
self.platform.queue_edit_message(
|
self.platform.queue_edit_message(
|
||||||
child.incoming.chat_id,
|
child.incoming.chat_id,
|
||||||
child.status_message_id,
|
child.status_message_id,
|
||||||
self._format_status("❌", "Cancelled:", child_status_text),
|
self.format_status("❌", "Cancelled:", child_status_text),
|
||||||
parse_mode=self._parse_mode(),
|
parse_mode=self._parse_mode(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -658,13 +658,13 @@ class ClaudeMessageHandler:
|
||||||
# Reply to existing tree
|
# Reply to existing tree
|
||||||
if self.tree_queue.is_node_tree_busy(parent_node_id):
|
if self.tree_queue.is_node_tree_busy(parent_node_id):
|
||||||
queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1
|
queue_size = self.tree_queue.get_queue_size(parent_node_id) + 1
|
||||||
return self._format_status(
|
return self.format_status(
|
||||||
"📋", "Queued", f"(position {queue_size}) - waiting..."
|
"📋", "Queued", f"(position {queue_size}) - waiting..."
|
||||||
)
|
)
|
||||||
return self._format_status("🔄", "Continuing conversation...")
|
return self.format_status("🔄", "Continuing conversation...")
|
||||||
|
|
||||||
# New conversation
|
# New conversation
|
||||||
return self._format_status("⏳", "Launching new Claude CLI instance...")
|
return self.format_status("⏳", "Launching new Claude CLI instance...")
|
||||||
|
|
||||||
async def stop_all_tasks(self) -> int:
|
async def stop_all_tasks(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -685,7 +685,7 @@ class ClaudeMessageHandler:
|
||||||
await self.cli_manager.stop_all()
|
await self.cli_manager.stop_all()
|
||||||
|
|
||||||
# 3. Update UI and persist state for all cancelled nodes
|
# 3. Update UI and persist state for all cancelled nodes
|
||||||
self._update_cancelled_nodes_ui(cancelled_nodes)
|
self.update_cancelled_nodes_ui(cancelled_nodes)
|
||||||
|
|
||||||
return len(cancelled_nodes)
|
return len(cancelled_nodes)
|
||||||
|
|
||||||
|
|
@ -703,10 +703,10 @@ class ClaudeMessageHandler:
|
||||||
node.set_context({"cancel_reason": "stop"})
|
node.set_context({"cancel_reason": "stop"})
|
||||||
|
|
||||||
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
|
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
|
||||||
self._update_cancelled_nodes_ui(cancelled_nodes)
|
self.update_cancelled_nodes_ui(cancelled_nodes)
|
||||||
return len(cancelled_nodes)
|
return len(cancelled_nodes)
|
||||||
|
|
||||||
def _record_outgoing_message(
|
def record_outgoing_message(
|
||||||
self,
|
self,
|
||||||
platform: str,
|
platform: str,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
|
|
@ -723,7 +723,7 @@ class ClaudeMessageHandler:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to record message_id: {e}")
|
logger.debug(f"Failed to record message_id: {e}")
|
||||||
|
|
||||||
def _update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
|
def update_cancelled_nodes_ui(self, nodes: list[MessageNode]) -> None:
|
||||||
"""Update status messages and persist tree state for cancelled nodes."""
|
"""Update status messages and persist tree state for cancelled nodes."""
|
||||||
trees_to_save: dict[str, MessageTree] = {}
|
trees_to_save: dict[str, MessageTree] = {}
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
|
|
@ -731,7 +731,7 @@ class ClaudeMessageHandler:
|
||||||
self.platform.queue_edit_message(
|
self.platform.queue_edit_message(
|
||||||
node.incoming.chat_id,
|
node.incoming.chat_id,
|
||||||
node.status_message_id,
|
node.status_message_id,
|
||||||
self._format_status("⏹", "Stopped."),
|
self.format_status("⏹", "Stopped."),
|
||||||
parse_mode=self._parse_mode(),
|
parse_mode=self._parse_mode(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -26,11 +27,11 @@ def _safe_json_dumps(obj: Any) -> str:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Segment:
|
class Segment(ABC):
|
||||||
kind: str
|
kind: str
|
||||||
|
|
||||||
def render(self, ctx: RenderCtx) -> str:
|
@abstractmethod
|
||||||
raise NotImplementedError
|
def render(self, ctx: RenderCtx) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -88,23 +89,6 @@ class ToolCallSegment(Segment):
|
||||||
self.tool_use_id = str(tool_use_id or "")
|
self.tool_use_id = str(tool_use_id or "")
|
||||||
self.name = str(name or "tool")
|
self.name = str(name or "tool")
|
||||||
self.indent_level = max(0, int(indent_level))
|
self.indent_level = max(0, int(indent_level))
|
||||||
self._parts: list[str] = []
|
|
||||||
|
|
||||||
def set_initial_input(self, inp: Any) -> None:
|
|
||||||
if inp is None:
|
|
||||||
return
|
|
||||||
if isinstance(inp, str):
|
|
||||||
self._parts = [inp]
|
|
||||||
else:
|
|
||||||
self._parts = [_safe_json_dumps(inp)]
|
|
||||||
|
|
||||||
def append_input_delta(self, partial: str) -> None:
|
|
||||||
if partial:
|
|
||||||
self._parts.append(partial)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_text(self) -> str:
|
|
||||||
return "".join(self._parts)
|
|
||||||
|
|
||||||
def render(self, ctx: RenderCtx) -> str:
|
def render(self, ctx: RenderCtx) -> str:
|
||||||
name = ctx.code_inline(self.name)
|
name = ctx.code_inline(self.name)
|
||||||
|
|
@ -444,17 +428,12 @@ class TranscriptBuffer:
|
||||||
seg = ToolCallSegment(tool_id, name)
|
seg = ToolCallSegment(tool_id, name)
|
||||||
self._segments.append(seg)
|
self._segments.append(seg)
|
||||||
|
|
||||||
seg.set_initial_input(ev.get("input"))
|
|
||||||
if idx >= 0:
|
if idx >= 0:
|
||||||
self._open_tools_by_index[idx] = seg
|
self._open_tools_by_index[idx] = seg
|
||||||
return
|
return
|
||||||
|
|
||||||
if et == "tool_use_delta":
|
if et == "tool_use_delta":
|
||||||
idx = int(ev.get("index", -1))
|
# Track open tool by index for tool_use_stop (closing state).
|
||||||
partial = str(ev.get("partial_json", "") or "")
|
|
||||||
seg = self._open_tools_by_index.get(idx)
|
|
||||||
if seg is not None:
|
|
||||||
seg.append_input_delta(partial)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if et == "tool_use_stop":
|
if et == "tool_use_stop":
|
||||||
|
|
@ -501,7 +480,6 @@ class TranscriptBuffer:
|
||||||
seg = ToolCallSegment(tool_id, name)
|
seg = ToolCallSegment(tool_id, name)
|
||||||
self._segments.append(seg)
|
self._segments.append(seg)
|
||||||
|
|
||||||
seg.set_initial_input(ev.get("input"))
|
|
||||||
seg.closed = True
|
seg.closed = True
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,21 +17,26 @@ from ..models import IncomingMessage
|
||||||
|
|
||||||
|
|
||||||
class _SnapshotQueue:
|
class _SnapshotQueue:
|
||||||
"""Queue with snapshot/remove helpers, backed by a deque."""
|
"""Queue with snapshot/remove helpers, backed by a deque and a set index."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._deque: deque[str] = deque()
|
self._deque: deque[str] = deque()
|
||||||
|
self._set: set[str] = set()
|
||||||
|
|
||||||
async def put(self, item: str) -> None:
|
async def put(self, item: str) -> None:
|
||||||
self._deque.append(item)
|
self._deque.append(item)
|
||||||
|
self._set.add(item)
|
||||||
|
|
||||||
def put_nowait(self, item: str) -> None:
|
def put_nowait(self, item: str) -> None:
|
||||||
self._deque.append(item)
|
self._deque.append(item)
|
||||||
|
self._set.add(item)
|
||||||
|
|
||||||
def get_nowait(self) -> str:
|
def get_nowait(self) -> str:
|
||||||
if not self._deque:
|
if not self._deque:
|
||||||
raise asyncio.QueueEmpty()
|
raise asyncio.QueueEmpty()
|
||||||
return self._deque.popleft()
|
item = self._deque.popleft()
|
||||||
|
self._set.discard(item)
|
||||||
|
return item
|
||||||
|
|
||||||
def qsize(self) -> int:
|
def qsize(self) -> int:
|
||||||
return len(self._deque)
|
return len(self._deque)
|
||||||
|
|
@ -41,12 +46,11 @@ class _SnapshotQueue:
|
||||||
return list(self._deque)
|
return list(self._deque)
|
||||||
|
|
||||||
def remove_if_present(self, item: str) -> bool:
|
def remove_if_present(self, item: str) -> bool:
|
||||||
"""Remove item from queue if present. Returns True if removed."""
|
"""Remove item from queue if present (O(1) membership check). Returns True if removed."""
|
||||||
if item not in self._deque:
|
if item not in self._set:
|
||||||
return False
|
return False
|
||||||
items = [x for x in self._deque if x != item]
|
self._set.discard(item)
|
||||||
self._deque.clear()
|
self._deque = deque(x for x in self._deque if x != item)
|
||||||
self._deque.extend(items)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -350,7 +354,7 @@ class MessageTree:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
|
def set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
|
||||||
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
|
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
|
||||||
node.state = MessageState.ERROR
|
node.state = MessageState.ERROR
|
||||||
node.error_message = error_message
|
node.error_message = error_message
|
||||||
|
|
@ -371,7 +375,7 @@ class MessageTree:
|
||||||
break
|
break
|
||||||
node = self._nodes.get(node_id)
|
node = self._nodes.get(node_id)
|
||||||
if node:
|
if node:
|
||||||
self._set_node_error_sync(node, error_message)
|
self.set_node_error_sync(node, error_message)
|
||||||
nodes.append(node)
|
nodes.append(node)
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ class TreeQueueProcessor:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a single node and then check the queue."""
|
"""Process a single node and then check the queue."""
|
||||||
# Skip if already in terminal state (e.g. from error propagation)
|
# Skip if already in terminal state (e.g. from error propagation)
|
||||||
if node.state.value == MessageState.ERROR.value:
|
if node.state == MessageState.ERROR:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
f"Skipping node {node.node_id} as it is already in state {node.state}"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ class TreeQueueManager:
|
||||||
MessageState.COMPLETED,
|
MessageState.COMPLETED,
|
||||||
MessageState.ERROR,
|
MessageState.ERROR,
|
||||||
):
|
):
|
||||||
tree._set_node_error_sync(node, "Cancelled by user")
|
tree.set_node_error_sync(node, "Cancelled by user")
|
||||||
cancelled_nodes.append(node)
|
cancelled_nodes.append(node)
|
||||||
|
|
||||||
# 2. Drain queue and mark nodes as cancelled
|
# 2. Drain queue and mark nodes as cancelled
|
||||||
|
|
@ -259,7 +259,7 @@ class TreeQueueManager:
|
||||||
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
||||||
and node.node_id not in cancelled_ids
|
and node.node_id not in cancelled_ids
|
||||||
):
|
):
|
||||||
tree._set_node_error_sync(node, "Stale task cleaned up")
|
tree.set_node_error_sync(node, "Stale task cleaned up")
|
||||||
cleanup_count += 1
|
cleanup_count += 1
|
||||||
|
|
||||||
tree.reset_processing_state()
|
tree.reset_processing_state()
|
||||||
|
|
@ -336,7 +336,7 @@ class TreeQueueManager:
|
||||||
for tree in self._repository.all_trees():
|
for tree in self._repository.all_trees():
|
||||||
for node in tree.all_nodes():
|
for node in tree.all_nodes():
|
||||||
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
|
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
|
||||||
tree._set_node_error_sync(node, "Lost during server restart")
|
tree.set_node_error_sync(node, "Lost during server restart")
|
||||||
count += 1
|
count += 1
|
||||||
if count:
|
if count:
|
||||||
logger.info(f"Cleaned up {count} stale nodes during startup")
|
logger.info(f"Cleaned up {count} stale nodes during startup")
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,11 @@ class TreeRepository:
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
|
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> set[str]:
|
||||||
"""Get all message IDs (incoming + status) for a given platform/chat."""
|
"""Get all message IDs (incoming + status) for a given platform/chat.
|
||||||
|
|
||||||
|
Note: O(total_nodes) scan. Acceptable because this is only called
|
||||||
|
from /clear (user-initiated, infrequent).
|
||||||
|
"""
|
||||||
msg_ids: set[str] = set()
|
msg_ids: set[str] = set()
|
||||||
for tree in self._trees.values():
|
for tree in self._trees.values():
|
||||||
for node in tree.all_nodes():
|
for node in tree.all_nodes():
|
||||||
|
|
|
||||||
|
|
@ -366,16 +366,19 @@ class SSEBuilder:
|
||||||
# Tool calls are harder to tokenize exactly without reconstruction, but we can approximate
|
# Tool calls are harder to tokenize exactly without reconstruction, but we can approximate
|
||||||
# by tokenizing the json dumps of tool contents
|
# by tokenizing the json dumps of tool contents
|
||||||
tool_tokens = 0
|
tool_tokens = 0
|
||||||
|
started_tool_count = 0
|
||||||
for state in self.blocks.tool_states.values():
|
for state in self.blocks.tool_states.values():
|
||||||
tool_tokens += len(ENCODER.encode(state.name))
|
tool_tokens += len(ENCODER.encode(state.name))
|
||||||
tool_tokens += len(ENCODER.encode("".join(state.contents)))
|
tool_tokens += len(ENCODER.encode("".join(state.contents)))
|
||||||
tool_tokens += 15 # Control tokens overhead per tool
|
tool_tokens += 15 # Control tokens overhead per tool
|
||||||
|
if state.started:
|
||||||
|
started_tool_count += 1
|
||||||
|
|
||||||
# Per-block overhead (~4 tokens per content block)
|
# Per-block overhead (~4 tokens per content block)
|
||||||
block_count = (
|
block_count = (
|
||||||
(1 if accumulated_reasoning else 0)
|
(1 if accumulated_reasoning else 0)
|
||||||
+ (1 if accumulated_text else 0)
|
+ (1 if accumulated_text else 0)
|
||||||
+ sum(1 for s in self.blocks.tool_states.values() if s.started)
|
+ started_tool_count
|
||||||
)
|
)
|
||||||
block_overhead = block_count * 4
|
block_overhead = block_count * 4
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,27 +16,29 @@ from providers.common.text import extract_text_from_content
|
||||||
def generate_request_fingerprint(messages: list[Any]) -> str:
|
def generate_request_fingerprint(messages: list[Any]) -> str:
|
||||||
"""Generate unique short hash for message content.
|
"""Generate unique short hash for message content.
|
||||||
|
|
||||||
Creates a SHA256 hash of all message content, returning an 8-char prefix
|
Uses incremental SHA256 hashing to avoid building a large intermediate
|
||||||
that's sufficient for correlation without full content logging.
|
string. Returns an 8-char hex prefix sufficient for correlation.
|
||||||
"""
|
"""
|
||||||
content_parts = []
|
h = hashlib.sha256()
|
||||||
|
sep = b"|"
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if hasattr(msg, "content"):
|
if hasattr(msg, "content"):
|
||||||
content = msg.content
|
content = msg.content
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
content_parts.append(content)
|
h.update(content.encode("utf-8"))
|
||||||
|
h.update(sep)
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
for block in content:
|
for block in content:
|
||||||
if hasattr(block, "text"):
|
if hasattr(block, "text"):
|
||||||
content_parts.append(block.text)
|
h.update(block.text.encode("utf-8"))
|
||||||
|
h.update(sep)
|
||||||
elif hasattr(block, "type"):
|
elif hasattr(block, "type"):
|
||||||
content_parts.append(f"<{block.type}>")
|
h.update(f"<{block.type}>".encode())
|
||||||
|
h.update(sep)
|
||||||
elif hasattr(msg, "role"):
|
elif hasattr(msg, "role"):
|
||||||
content_parts.append(msg.role)
|
h.update(msg.role.encode("utf-8"))
|
||||||
|
h.update(sep)
|
||||||
combined = "|".join(content_parts)
|
return f"fp_{h.hexdigest()[:8]}"
|
||||||
hash_digest = hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
|
||||||
return f"fp_{hash_digest[:8]}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_user_message_preview(messages: list[Any], max_len: int = 100) -> str:
|
def get_last_user_message_preview(messages: list[Any], max_len: int = 100) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""OpenRouter provider - OpenAI-compatible API for hundreds of models."""
|
"""OpenRouter provider - OpenAI-compatible API for hundreds of models."""
|
||||||
|
|
||||||
from .client import OpenRouterProvider
|
from .client import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||||
|
|
||||||
__all__ = ["OpenRouterProvider"]
|
__all__ = ["OPENROUTER_BASE_URL", "OpenRouterProvider"]
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from config.nim import NimSettings
|
|
||||||
from providers.base import BaseProvider, ProviderConfig
|
from providers.base import BaseProvider, ProviderConfig
|
||||||
from providers.common import (
|
from providers.common import (
|
||||||
ContentType,
|
ContentType,
|
||||||
|
|
@ -33,7 +32,7 @@ class OpenAICompatibleProvider(BaseProvider):
|
||||||
provider_name: str,
|
provider_name: str,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
nim_settings: NimSettings | None = None,
|
nim_settings: Any = None,
|
||||||
):
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self._provider_name = provider_name
|
self._provider_name = provider_name
|
||||||
|
|
|
||||||
|
|
@ -30,13 +30,19 @@ class GlobalRateLimiter:
|
||||||
|
|
||||||
_instance: ClassVar[GlobalRateLimiter | None] = None
|
_instance: ClassVar[GlobalRateLimiter | None] = None
|
||||||
|
|
||||||
|
def __new__(cls, *args: Any, **kwargs: Any) -> GlobalRateLimiter:
|
||||||
|
if cls._instance is not None:
|
||||||
|
return cls._instance
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
return instance
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
rate_limit: int = 40,
|
rate_limit: int = 40,
|
||||||
rate_window: float = 60.0,
|
rate_window: float = 60.0,
|
||||||
max_concurrency: int = 5,
|
max_concurrency: int = 5,
|
||||||
):
|
):
|
||||||
# Prevent double initialization in singleton
|
# Prevent re-initialization on singleton reuse
|
||||||
if hasattr(self, "_initialized"):
|
if hasattr(self, "_initialized"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -247,7 +247,7 @@ async def test_update_queue_positions(handler, mock_platform):
|
||||||
await tree.enqueue("child_1")
|
await tree.enqueue("child_1")
|
||||||
await tree.enqueue("child_2")
|
await tree.enqueue("child_2")
|
||||||
|
|
||||||
await handler._update_queue_positions(tree)
|
await handler.update_queue_positions(tree)
|
||||||
|
|
||||||
calls = mock_platform.queue_edit_message.call_args_list
|
calls = mock_platform.queue_edit_message.call_args_list
|
||||||
assert len(calls) == 2
|
assert len(calls) == 2
|
||||||
|
|
@ -291,7 +291,7 @@ async def test_mark_node_processing(handler, mock_platform):
|
||||||
parent_id="root",
|
parent_id="root",
|
||||||
)
|
)
|
||||||
|
|
||||||
await handler._mark_node_processing(tree, "child")
|
await handler.mark_node_processing(tree, "child")
|
||||||
|
|
||||||
mock_platform.queue_edit_message.assert_called_once()
|
mock_platform.queue_edit_message.assert_called_once()
|
||||||
args, kwargs = mock_platform.queue_edit_message.call_args
|
args, kwargs = mock_platform.queue_edit_message.call_args
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi
|
||||||
# Snapshot error is swallowed.
|
# Snapshot error is swallowed.
|
||||||
tree = MagicMock()
|
tree = MagicMock()
|
||||||
tree.get_queue_snapshot = AsyncMock(side_effect=RuntimeError("boom"))
|
tree.get_queue_snapshot = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
await handler._update_queue_positions(tree)
|
await handler.update_queue_positions(tree)
|
||||||
platform.fire_and_forget.assert_not_called()
|
platform.fire_and_forget.assert_not_called()
|
||||||
|
|
||||||
# Normal path: only PENDING nodes get an update.
|
# Normal path: only PENDING nodes get an update.
|
||||||
|
|
@ -130,7 +130,7 @@ async def test_update_queue_positions_handles_snapshot_error_and_skips_non_pendi
|
||||||
tree.get_queue_snapshot = AsyncMock(return_value=["n1", "n2"])
|
tree.get_queue_snapshot = AsyncMock(return_value=["n1", "n2"])
|
||||||
tree.get_node = MagicMock(side_effect=[node_pending, node_done])
|
tree.get_node = MagicMock(side_effect=[node_pending, node_done])
|
||||||
|
|
||||||
await handler._update_queue_positions(tree)
|
await handler.update_queue_positions(tree)
|
||||||
assert platform.fire_and_forget.call_count == 1
|
assert platform.fire_and_forget.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,8 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent(
|
||||||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||||
handler2.tree_queue = TreeQueueManager.from_dict(
|
handler2.tree_queue = TreeQueueManager.from_dict(
|
||||||
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
||||||
queue_update_callback=handler2._update_queue_positions,
|
queue_update_callback=handler2.update_queue_positions,
|
||||||
node_started_callback=handler2._mark_node_processing,
|
node_started_callback=handler2.mark_node_processing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prevent background task scheduling; we only want to validate routing/tree mutation.
|
# Prevent background task scheduling; we only want to validate routing/tree mutation.
|
||||||
|
|
@ -89,8 +89,8 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat
|
||||||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||||
handler2.tree_queue = TreeQueueManager.from_dict(
|
handler2.tree_queue = TreeQueueManager.from_dict(
|
||||||
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
||||||
queue_update_callback=handler2._update_queue_positions,
|
queue_update_callback=handler2.update_queue_positions,
|
||||||
node_started_callback=handler2._mark_node_processing,
|
node_started_callback=handler2.mark_node_processing,
|
||||||
)
|
)
|
||||||
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
|
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue