mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
Refactor done by z-ai/glm5
This commit is contained in:
parent
754ca99314
commit
64e5b10612
11 changed files with 157 additions and 533 deletions
|
|
@ -1,6 +1,6 @@
|
|||
"""Request utility functions for API route handlers.
|
||||
|
||||
Contains token counting and re-exports detection/command utilities.
|
||||
Contains token counting for API requests.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -9,29 +9,10 @@ from typing import List, Optional, Union
|
|||
|
||||
import tiktoken
|
||||
|
||||
from .models.anthropic import MessagesRequest
|
||||
from .detection import (
|
||||
is_quota_check_request,
|
||||
is_title_generation_request,
|
||||
is_prefix_detection_request,
|
||||
is_suggestion_mode_request,
|
||||
is_filepath_extraction_request,
|
||||
)
|
||||
from .command_utils import extract_command_prefix, extract_filepaths_from_command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
__all__ = [
|
||||
"is_quota_check_request",
|
||||
"is_title_generation_request",
|
||||
"is_prefix_detection_request",
|
||||
"is_suggestion_mode_request",
|
||||
"is_filepath_extraction_request",
|
||||
"extract_command_prefix",
|
||||
"extract_filepaths_from_command",
|
||||
"get_token_count",
|
||||
]
|
||||
__all__ = ["get_token_count"]
|
||||
|
||||
|
||||
def get_token_count(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import time
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .base import MessagingPlatform, SessionManagerInterface
|
||||
from .models import IncomingMessage
|
||||
|
|
@ -29,6 +29,47 @@ from .telegram_markdown import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Status message prefixes used to filter our own messages (ignore echo)
|
||||
STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄")
|
||||
|
||||
# Event types that update the transcript
|
||||
TRANSCRIPT_EVENT_TYPES = (
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_chunk",
|
||||
"thinking_stop",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_chunk",
|
||||
"text_stop",
|
||||
"tool_use_start",
|
||||
"tool_use_delta",
|
||||
"tool_use_stop",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"block_stop",
|
||||
"error",
|
||||
)
|
||||
|
||||
# Event types -> (emoji, label) for status updates
|
||||
_EVENT_STATUS_MAP = {
|
||||
("thinking_start", "thinking_delta", "thinking_chunk"): ("🧠", "Claude is thinking..."),
|
||||
("text_start", "text_delta", "text_chunk"): ("🧠", "Claude is working..."),
|
||||
("tool_result",): ("⏳", "Executing tools..."),
|
||||
}
|
||||
|
||||
|
||||
def _get_status_for_event(ptype: str, parsed: dict) -> Optional[str]:
|
||||
"""Return status string for event type, or None if no status update needed."""
|
||||
for types, (emoji, label) in _EVENT_STATUS_MAP.items():
|
||||
if ptype in types:
|
||||
return format_status(emoji, label)
|
||||
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
|
||||
if parsed.get("name") == "Task":
|
||||
return format_status("🤖", "Subagent working...")
|
||||
return format_status("⏳", "Executing tools...")
|
||||
return None
|
||||
|
||||
|
||||
class ClaudeMessageHandler:
|
||||
"""
|
||||
|
|
@ -95,10 +136,7 @@ class ClaudeMessageHandler:
|
|||
return
|
||||
|
||||
# Filter out status messages (our own messages)
|
||||
if any(
|
||||
incoming.text.startswith(p)
|
||||
for p in ["⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄"]
|
||||
):
|
||||
if any(incoming.text.startswith(p) for p in STATUS_MESSAGE_PREFIXES):
|
||||
return
|
||||
|
||||
# Check if this is a reply to an existing node in a tree
|
||||
|
|
@ -131,17 +169,9 @@ class ClaudeMessageHandler:
|
|||
reply_to=incoming.message_id,
|
||||
fire_and_forget=False,
|
||||
)
|
||||
try:
|
||||
if status_msg_id:
|
||||
self.session_store.record_message_id(
|
||||
incoming.platform,
|
||||
incoming.chat_id,
|
||||
str(status_msg_id),
|
||||
direction="out",
|
||||
kind="status",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record status message_id: {e}")
|
||||
self._record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, status_msg_id, "status"
|
||||
)
|
||||
|
||||
# Create or extend tree
|
||||
if parent_node_id and tree and status_msg_id:
|
||||
|
|
@ -285,45 +315,16 @@ class ClaudeMessageHandler:
|
|||
captured_session_id: Optional[str],
|
||||
) -> Tuple[Optional[str], bool]:
|
||||
"""Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
|
||||
ptype = parsed.get("type")
|
||||
transcript_types = (
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_chunk",
|
||||
"thinking_stop",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_chunk",
|
||||
"text_stop",
|
||||
"tool_use_start",
|
||||
"tool_use_delta",
|
||||
"tool_use_stop",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"block_stop",
|
||||
"error",
|
||||
)
|
||||
ptype = parsed.get("type") or ""
|
||||
|
||||
if ptype in transcript_types:
|
||||
if ptype in TRANSCRIPT_EVENT_TYPES:
|
||||
transcript.apply(parsed)
|
||||
had_transcript_events = True
|
||||
|
||||
if ptype in ("thinking_start", "thinking_delta", "thinking_chunk"):
|
||||
await update_ui(format_status("🧠", "Claude is thinking..."))
|
||||
last_status = format_status("🧠", "Claude is thinking...")
|
||||
elif ptype in ("text_start", "text_delta", "text_chunk"):
|
||||
await update_ui(format_status("🧠", "Claude is working..."))
|
||||
last_status = format_status("🧠", "Claude is working...")
|
||||
elif ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
|
||||
if parsed.get("name") == "Task":
|
||||
await update_ui(format_status("🤖", "Subagent working..."))
|
||||
last_status = format_status("🤖", "Subagent working...")
|
||||
else:
|
||||
await update_ui(format_status("⏳", "Executing tools..."))
|
||||
last_status = format_status("⏳", "Executing tools...")
|
||||
elif ptype == "tool_result":
|
||||
await update_ui(format_status("⏳", "Executing tools..."))
|
||||
last_status = format_status("⏳", "Executing tools...")
|
||||
status = _get_status_for_event(ptype, parsed)
|
||||
if status is not None:
|
||||
await update_ui(status)
|
||||
last_status = status
|
||||
elif ptype == "block_stop":
|
||||
await update_ui(last_status, force=True)
|
||||
elif ptype == "complete":
|
||||
|
|
@ -584,20 +585,7 @@ class ClaudeMessageHandler:
|
|||
await self.cli_manager.stop_all()
|
||||
|
||||
# 3. Update UI and persist state for all cancelled nodes
|
||||
for node in cancelled_nodes:
|
||||
self.platform.fire_and_forget(
|
||||
self.platform.queue_edit_message(
|
||||
node.incoming.chat_id,
|
||||
node.status_message_id,
|
||||
format_status("⏹", "Stopped."),
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
)
|
||||
|
||||
# Persist tree state
|
||||
tree = self.tree_queue.get_tree_for_node(node.node_id)
|
||||
if tree:
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
self._update_cancelled_nodes_ui(cancelled_nodes)
|
||||
|
||||
return len(cancelled_nodes)
|
||||
|
||||
|
|
@ -615,8 +603,29 @@ class ClaudeMessageHandler:
|
|||
node.context = {"cancel_reason": "stop"}
|
||||
|
||||
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
|
||||
self._update_cancelled_nodes_ui(cancelled_nodes)
|
||||
return len(cancelled_nodes)
|
||||
|
||||
for node in cancelled_nodes:
|
||||
def _record_outgoing_message(
|
||||
self,
|
||||
platform: str,
|
||||
chat_id: str,
|
||||
msg_id: Optional[str],
|
||||
kind: str,
|
||||
) -> None:
|
||||
"""Record outgoing message ID for /clear. Best-effort, never raises."""
|
||||
if not msg_id:
|
||||
return
|
||||
try:
|
||||
self.session_store.record_message_id(
|
||||
platform, chat_id, str(msg_id), direction="out", kind=kind
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to record message_id: {e}")
|
||||
|
||||
def _update_cancelled_nodes_ui(self, nodes: List[MessageNode]) -> None:
|
||||
"""Update status messages and persist tree state for cancelled nodes."""
|
||||
for node in nodes:
|
||||
self.platform.fire_and_forget(
|
||||
self.platform.queue_edit_message(
|
||||
node.incoming.chat_id,
|
||||
|
|
@ -625,13 +634,10 @@ class ClaudeMessageHandler:
|
|||
parse_mode="MarkdownV2",
|
||||
)
|
||||
)
|
||||
|
||||
tree = self.tree_queue.get_tree_for_node(node.node_id)
|
||||
if tree:
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
|
||||
return len(cancelled_nodes)
|
||||
|
||||
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
|
||||
"""Handle /stop command from messaging platform."""
|
||||
# Reply-scoped stop: reply "/stop" to stop only that task.
|
||||
|
|
@ -646,17 +652,9 @@ class ClaudeMessageHandler:
|
|||
format_status("⏹", "Stopped.", "Nothing to stop for that message."),
|
||||
fire_and_forget=False,
|
||||
)
|
||||
try:
|
||||
if msg_id:
|
||||
self.session_store.record_message_id(
|
||||
incoming.platform,
|
||||
incoming.chat_id,
|
||||
str(msg_id),
|
||||
direction="out",
|
||||
kind="command",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
|
||||
count = await self.stop_task(node_id)
|
||||
|
|
@ -666,17 +664,9 @@ class ClaudeMessageHandler:
|
|||
format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."),
|
||||
fire_and_forget=False,
|
||||
)
|
||||
try:
|
||||
if msg_id:
|
||||
self.session_store.record_message_id(
|
||||
incoming.platform,
|
||||
incoming.chat_id,
|
||||
str(msg_id),
|
||||
direction="out",
|
||||
kind="command",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
|
||||
# Global stop: legacy behavior (stop everything)
|
||||
|
|
@ -688,17 +678,9 @@ class ClaudeMessageHandler:
|
|||
),
|
||||
fire_and_forget=False,
|
||||
)
|
||||
try:
|
||||
if msg_id:
|
||||
self.session_store.record_message_id(
|
||||
incoming.platform,
|
||||
incoming.chat_id,
|
||||
str(msg_id),
|
||||
direction="out",
|
||||
kind="command",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
|
||||
async def _handle_stats_command(self, incoming: IncomingMessage) -> None:
|
||||
"""Handle /stats command."""
|
||||
|
|
@ -716,17 +698,9 @@ class ClaudeMessageHandler:
|
|||
+ escape_md_v2(f"• Message Trees: {tree_count}"),
|
||||
fire_and_forget=False,
|
||||
)
|
||||
try:
|
||||
if msg_id:
|
||||
self.session_store.record_message_id(
|
||||
incoming.platform,
|
||||
incoming.chat_id,
|
||||
str(msg_id),
|
||||
direction="out",
|
||||
kind="command",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
|
||||
async def _handle_clear_command(self, incoming: IncomingMessage) -> None:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -205,14 +205,6 @@ class SessionStore:
|
|||
if x.get("message_id") is not None
|
||||
]
|
||||
|
||||
def clear_message_log_for_chat(self, platform: str, chat_id: str) -> None:
|
||||
"""Clear recorded message IDs for a single chat."""
|
||||
chat_key = self._make_chat_key(str(platform), str(chat_id))
|
||||
with self._lock:
|
||||
self._message_log.pop(chat_key, None)
|
||||
self._message_log_ids.pop(chat_key, None)
|
||||
self._save()
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all stored sessions/trees/mappings and persist an empty store."""
|
||||
with self._lock:
|
||||
|
|
@ -224,124 +216,6 @@ class SessionStore:
|
|||
self._message_log_ids.clear()
|
||||
self._save()
|
||||
|
||||
# ==================== Session Methods ====================
|
||||
|
||||
def save_session(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_id: str,
|
||||
initial_msg_id: str,
|
||||
platform: str = "telegram",
|
||||
) -> None:
|
||||
"""Save a new session mapping."""
|
||||
with self._lock:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
record = SessionRecord(
|
||||
session_id=session_id,
|
||||
chat_id=str(chat_id),
|
||||
initial_msg_id=str(initial_msg_id),
|
||||
last_msg_id=str(initial_msg_id),
|
||||
platform=platform,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self._sessions[session_id] = record
|
||||
self._msg_to_session[
|
||||
self._make_key(platform, str(chat_id), str(initial_msg_id))
|
||||
] = session_id
|
||||
self._save()
|
||||
logger.info(
|
||||
f"Saved session {session_id} for {platform} chat {chat_id}, msg {initial_msg_id}"
|
||||
)
|
||||
|
||||
def get_session_by_msg(
|
||||
self, chat_id: str, msg_id: str, platform: str = "telegram"
|
||||
) -> Optional[str]:
|
||||
"""Look up a session ID by a message that's part of that session."""
|
||||
with self._lock:
|
||||
key = self._make_key(platform, str(chat_id), str(msg_id))
|
||||
return self._msg_to_session.get(key)
|
||||
|
||||
def update_last_message(self, session_id: str, msg_id: str) -> None:
|
||||
"""Update the last message ID for a session."""
|
||||
with self._lock:
|
||||
if session_id not in self._sessions:
|
||||
logger.warning(f"Session {session_id} not found for update")
|
||||
return
|
||||
|
||||
record = self._sessions[session_id]
|
||||
record.last_msg_id = str(msg_id)
|
||||
record.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
new_key = self._make_key(record.platform, record.chat_id, str(msg_id))
|
||||
self._msg_to_session[new_key] = session_id
|
||||
self._save()
|
||||
logger.debug(f"Updated session {session_id} last_msg to {msg_id}")
|
||||
|
||||
def rename_session(self, old_id: str, new_id: str) -> bool:
|
||||
"""Rename a session ID, migrating all message mappings."""
|
||||
with self._lock:
|
||||
if old_id not in self._sessions:
|
||||
logger.warning(f"Session {old_id} not found for rename to {new_id}")
|
||||
return False
|
||||
|
||||
record = self._sessions.pop(old_id)
|
||||
record.session_id = new_id
|
||||
record.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
self._sessions[new_id] = record
|
||||
|
||||
items_to_update = [
|
||||
k for k, v in self._msg_to_session.items() if v == old_id
|
||||
]
|
||||
for key in items_to_update:
|
||||
self._msg_to_session[key] = new_id
|
||||
|
||||
self._save()
|
||||
logger.info(
|
||||
f"Renamed session {old_id} to {new_id} ({len(items_to_update)} mappings updated)"
|
||||
)
|
||||
return True
|
||||
|
||||
def get_session_record(self, session_id: str) -> Optional[SessionRecord]:
|
||||
"""Get full session record."""
|
||||
with self._lock:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def cleanup_old_sessions(self, max_age_days: int = 30) -> int:
|
||||
"""Remove sessions older than max_age_days."""
|
||||
with self._lock:
|
||||
cutoff = datetime.now(timezone.utc)
|
||||
removed = 0
|
||||
|
||||
to_remove = []
|
||||
for sid, record in self._sessions.items():
|
||||
try:
|
||||
created = datetime.fromisoformat(record.created_at)
|
||||
age_days = (cutoff - created).days
|
||||
if age_days > max_age_days:
|
||||
to_remove.append(sid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for sid in to_remove:
|
||||
record = self._sessions.pop(sid)
|
||||
self._msg_to_session.pop(
|
||||
self._make_key(
|
||||
record.platform, record.chat_id, record.initial_msg_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
self._msg_to_session.pop(
|
||||
self._make_key(record.platform, record.chat_id, record.last_msg_id),
|
||||
None,
|
||||
)
|
||||
removed += 1
|
||||
|
||||
if removed:
|
||||
self._save()
|
||||
logger.info(f"Cleaned up {removed} old sessions")
|
||||
|
||||
return removed
|
||||
|
||||
# ==================== Tree Methods ====================
|
||||
|
||||
def save_tree(self, root_id: str, tree_data: dict) -> None:
|
||||
|
|
@ -367,14 +241,6 @@ class SessionStore:
|
|||
with self._lock:
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_tree_by_node(self, node_id: str) -> Optional[dict]:
|
||||
"""Get the tree containing a node."""
|
||||
with self._lock:
|
||||
root_id = self._node_to_tree.get(node_id)
|
||||
if not root_id:
|
||||
return None
|
||||
return self._trees.get(root_id)
|
||||
|
||||
def get_tree_root_for_node(self, node_id: str) -> Optional[str]:
|
||||
"""Get the root ID of the tree containing a node."""
|
||||
with self._lock:
|
||||
|
|
@ -386,20 +252,6 @@ class SessionStore:
|
|||
self._node_to_tree[node_id] = root_id
|
||||
self._save()
|
||||
|
||||
def update_tree_node(self, root_id: str, node_id: str, node_data: dict) -> None:
|
||||
"""Update a specific node in a tree."""
|
||||
with self._lock:
|
||||
if root_id not in self._trees:
|
||||
logger.warning(f"Tree {root_id} not found")
|
||||
return
|
||||
|
||||
if "nodes" not in self._trees[root_id]:
|
||||
self._trees[root_id]["nodes"] = {}
|
||||
|
||||
self._trees[root_id]["nodes"][node_id] = node_data
|
||||
self._node_to_tree[node_id] = root_id
|
||||
self._save()
|
||||
|
||||
def get_all_trees(self) -> Dict[str, dict]:
|
||||
"""Get all stored trees (public accessor)."""
|
||||
with self._lock:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ Contains MessageState, MessageNode, and MessageTree classes.
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, List, Any
|
||||
|
|
@ -276,6 +278,42 @@ class MessageTree:
|
|||
"""Get number of messages waiting in queue."""
|
||||
return self._queue.qsize()
|
||||
|
||||
def remove_from_queue(self, node_id: str) -> bool:
|
||||
"""
|
||||
Remove node_id from the internal queue if present.
|
||||
|
||||
Caller must hold the tree lock (e.g. via with_lock).
|
||||
Returns True if node was removed, False if not in queue.
|
||||
"""
|
||||
queue_deque: deque = self._queue._queue # type: ignore[attr-defined]
|
||||
if node_id not in queue_deque:
|
||||
return False
|
||||
self._queue._queue = deque(x for x in queue_deque if x != node_id) # type: ignore[attr-defined]
|
||||
return True
|
||||
|
||||
@asynccontextmanager
|
||||
async def with_lock(self):
|
||||
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
|
||||
async with self._lock:
|
||||
yield
|
||||
|
||||
def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None:
|
||||
"""Set processing state. Caller must hold lock for consistency with queue operations."""
|
||||
self._is_processing = is_processing
|
||||
self._current_node_id = node_id if is_processing else None
|
||||
|
||||
def clear_current_node(self) -> None:
|
||||
"""Clear the currently processing node ID. Caller must hold lock."""
|
||||
self._current_node_id = None
|
||||
|
||||
def is_current_node(self, node_id: str) -> bool:
|
||||
"""Check if node_id is the currently processing node."""
|
||||
return self._current_node_id == node_id
|
||||
|
||||
def put_queue_unlocked(self, node_id: str) -> None:
|
||||
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
|
||||
self._queue.put_nowait(node_id)
|
||||
|
||||
def cancel_current_task(self) -> bool:
|
||||
"""Cancel the currently running task. Returns True if a task was cancelled."""
|
||||
if self._current_task and not self._current_task.done():
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class TreeQueueProcessor:
|
|||
node.node_id, MessageState.ERROR, error_message=str(e)
|
||||
)
|
||||
finally:
|
||||
tree._current_node_id = None
|
||||
tree.clear_current_node()
|
||||
# Check if there are more messages in the queue
|
||||
await self._process_next(tree, processor)
|
||||
|
||||
|
|
@ -102,16 +102,15 @@ class TreeQueueProcessor:
|
|||
"""Process the next message in queue, if any."""
|
||||
next_node_id = None
|
||||
node = None
|
||||
async with tree._lock:
|
||||
async with tree.with_lock():
|
||||
next_node_id = await tree.dequeue()
|
||||
|
||||
if not next_node_id:
|
||||
# No more messages, mark tree as free
|
||||
tree._is_processing = False
|
||||
tree.set_processing_state(None, False)
|
||||
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
|
||||
return
|
||||
|
||||
tree._current_node_id = next_node_id
|
||||
tree.set_processing_state(next_node_id, True)
|
||||
logger.info(f"Processing next queued node {next_node_id}")
|
||||
|
||||
# Process next node (outside lock)
|
||||
|
|
@ -143,17 +142,14 @@ class TreeQueueProcessor:
|
|||
Returns:
|
||||
True if queued, False if processing immediately
|
||||
"""
|
||||
async with tree._lock:
|
||||
if tree._is_processing:
|
||||
# Tree is busy, queue the message
|
||||
await tree._queue.put(node_id)
|
||||
queue_size = tree._queue.qsize()
|
||||
async with tree.with_lock():
|
||||
if tree.is_processing:
|
||||
tree.put_queue_unlocked(node_id)
|
||||
queue_size = tree.get_queue_size()
|
||||
logger.info(f"Queued node {node_id}, position {queue_size}")
|
||||
return True
|
||||
else:
|
||||
# Tree is free, start processing
|
||||
tree._is_processing = True
|
||||
tree._current_node_id = node_id
|
||||
tree.set_processing_state(node_id, True)
|
||||
|
||||
# Process outside the lock
|
||||
node = tree.get_node(node_id)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic.
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Awaitable, List, Optional
|
||||
|
||||
|
|
@ -293,7 +292,7 @@ class TreeQueueManager:
|
|||
if not tree:
|
||||
return []
|
||||
|
||||
async with tree._lock:
|
||||
async with tree.with_lock():
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
|
@ -301,18 +300,12 @@ class TreeQueueManager:
|
|||
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
return []
|
||||
|
||||
# Cancel running task if this is the current node.
|
||||
if tree._current_node_id == node_id:
|
||||
if tree.is_current_node(node_id):
|
||||
self._processor.cancel_current(tree)
|
||||
|
||||
# Remove from queue if present (asyncio.Queue exposes its internal deque).
|
||||
try:
|
||||
q = tree._queue._queue # type: ignore[attr-defined]
|
||||
if q and node_id in q:
|
||||
tree._queue._queue = deque(x for x in q if x != node_id) # type: ignore[attr-defined]
|
||||
tree.remove_from_queue(node_id)
|
||||
except Exception:
|
||||
# Best-effort: if we can't mutate the queue internals, the node will
|
||||
# still be dequeued later and skipped due to state=ERROR.
|
||||
logger.debug(
|
||||
"Failed to remove node from queue; will rely on state=ERROR"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,145 +65,6 @@ class TestSessionStore:
|
|||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
assert store._sessions == {}
|
||||
|
||||
def test_save_and_get_session(self, tmp_path):
|
||||
"""Test saving and retrieving a session."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
store.save_session(
|
||||
session_id="sess_123",
|
||||
chat_id="chat_456",
|
||||
initial_msg_id="msg_789",
|
||||
platform="telegram",
|
||||
)
|
||||
|
||||
# Retrieve by message
|
||||
found = store.get_session_by_msg("chat_456", "msg_789", "telegram")
|
||||
assert found == "sess_123"
|
||||
|
||||
# Verify persistence file created
|
||||
assert os.path.exists(str(tmp_path / "sessions.json"))
|
||||
|
||||
def test_update_last_message(self, tmp_path):
|
||||
"""Test updating last message in session."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
store.save_session("sess_1", "chat_1", "msg_1", "telegram")
|
||||
store.update_last_message("sess_1", "msg_2")
|
||||
|
||||
# Should find session by new message too
|
||||
found = store.get_session_by_msg("chat_1", "msg_2", "telegram")
|
||||
assert found == "sess_1"
|
||||
|
||||
# Original message mapping should still work
|
||||
found_old = store.get_session_by_msg("chat_1", "msg_1", "telegram")
|
||||
assert found_old == "sess_1"
|
||||
|
||||
# Verify record updated
|
||||
record = store.get_session_record("sess_1")
|
||||
assert record is not None
|
||||
assert record.last_msg_id == "msg_2"
|
||||
|
||||
def test_update_last_message_unknown_session(self, tmp_path):
|
||||
"""Test updating unknown session does nothing."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.update_last_message("unknown", "msg_x")
|
||||
# Should log warning but not crash
|
||||
|
||||
def test_get_session_record(self, tmp_path):
|
||||
"""Test getting full session record."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.save_session("sess_1", "chat_1", "msg_1", "telegram")
|
||||
|
||||
record = store.get_session_record("sess_1")
|
||||
assert record is not None
|
||||
assert record.session_id == "sess_1"
|
||||
assert record.platform == "telegram"
|
||||
|
||||
def test_session_not_found(self, tmp_path):
|
||||
"""Test getting non-existent session returns None."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
found = store.get_session_by_msg("notexist", "notexist", "telegram")
|
||||
assert found is None
|
||||
|
||||
def test_rename_session(self, tmp_path):
|
||||
"""Test renaming a session."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.save_session("old_id", "c1", "m1", "telegram")
|
||||
store.update_last_message("old_id", "m2")
|
||||
|
||||
success = store.rename_session("old_id", "new_id")
|
||||
assert success is True
|
||||
|
||||
# Verify old id gone
|
||||
assert store.get_session_record("old_id") is None
|
||||
|
||||
# Verify new id exists
|
||||
rec = store.get_session_record("new_id")
|
||||
assert rec is not None
|
||||
assert rec.session_id == "new_id"
|
||||
|
||||
# Verify mappings point to new id
|
||||
assert store.get_session_by_msg("c1", "m1", "telegram") == "new_id"
|
||||
assert store.get_session_by_msg("c1", "m2", "telegram") == "new_id"
|
||||
|
||||
def test_rename_unknown_session(self, tmp_path):
|
||||
"""Test renaming unknown session fails."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
success = store.rename_session("unknown", "new")
|
||||
assert success is False
|
||||
|
||||
def test_cleanup_old_sessions(self, tmp_path):
|
||||
"""Test cleaning up expired sessions."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
# Create an old session manually
|
||||
old_date = (datetime.now(timezone.utc) - timedelta(days=40)).isoformat()
|
||||
store.save_session("old_sess", "c_old", "m_old")
|
||||
# Manipulate the created_at directly
|
||||
store._sessions["old_sess"].created_at = old_date
|
||||
|
||||
# Create a new session
|
||||
store.save_session("new_sess", "c_new", "m_new")
|
||||
|
||||
# Cleanup
|
||||
removed = store.cleanup_old_sessions(max_age_days=30)
|
||||
assert removed == 1
|
||||
|
||||
assert store.get_session_record("old_sess") is None
|
||||
assert store.get_session_by_msg("c_old", "m_old") is None
|
||||
assert store.get_session_record("new_sess") is not None
|
||||
|
||||
def test_cleanup_old_sessions_invalid_date(self, tmp_path):
|
||||
"""Test cleanup handles invalid date formats gracefully."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.save_session("bad_date_sess", "c", "m")
|
||||
store._sessions["bad_date_sess"].created_at = "not-a-date"
|
||||
|
||||
# Should not crash
|
||||
store.cleanup_old_sessions(30)
|
||||
# Should still exist because parsing failed so it wasn't removed (or default behavior)
|
||||
# The code tries parsing, excepts, and continues, so it isn't removed.
|
||||
assert store.get_session_record("bad_date_sess") is not None
|
||||
|
||||
# --- Tree Tests ---
|
||||
|
||||
def test_save_and_get_tree(self, tmp_path):
|
||||
|
|
@ -225,34 +86,6 @@ class TestSessionStore:
|
|||
assert store.get_tree_root_for_node("r1") == "r1"
|
||||
assert store.get_tree_root_for_node("n1") == "r1"
|
||||
|
||||
# Verify get_tree_by_node
|
||||
assert store.get_tree_by_node("n1") == tree_data
|
||||
assert store.get_tree_by_node("unknown") is None
|
||||
|
||||
def test_update_tree_node(self, tmp_path):
|
||||
"""Test updating a specific node in a tree."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
store.save_tree("r1", {"nodes": {"r1": {}}})
|
||||
|
||||
# Add new node
|
||||
store.update_tree_node("r1", "n2", {"data": "test"})
|
||||
|
||||
tree = store.get_tree("r1")
|
||||
assert tree is not None
|
||||
assert "n2" in tree["nodes"]
|
||||
assert tree["nodes"]["n2"]["data"] == "test"
|
||||
assert store.get_tree_root_for_node("n2") == "r1"
|
||||
|
||||
def test_update_tree_node_unknown_tree(self, tmp_path):
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.update_tree_node("unknown_root", "n1", {})
|
||||
# Should not crash
|
||||
|
||||
def test_register_node(self, tmp_path):
|
||||
"""Test manual node registration."""
|
||||
from messaging.session import SessionStore
|
||||
|
|
@ -296,7 +129,7 @@ class TestSessionStore:
|
|||
# --- Persistence & Edge Cases ---
|
||||
|
||||
def test_load_existing_legacy_format(self, tmp_path):
|
||||
"""Test loading legacy session format (int IDs)."""
|
||||
"""Test loading legacy session format (int IDs) - backward compat."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
data = {
|
||||
|
|
@ -318,12 +151,11 @@ class TestSessionStore:
|
|||
json.dump(data, f)
|
||||
|
||||
store = SessionStore(storage_path=str(p))
|
||||
rec = store.get_session_record("s1")
|
||||
assert rec is not None
|
||||
|
||||
# Legacy sessions are loaded for backward compat; verify conversion
|
||||
assert "s1" in store._sessions
|
||||
rec = store._sessions["s1"]
|
||||
assert rec.chat_id == "123" # Converted to str
|
||||
assert rec.platform == "telegram" # Defaulted
|
||||
assert store.get_session_by_msg("123", "100", "telegram") == "s1"
|
||||
|
||||
def test_load_corrupt_file(self, tmp_path):
|
||||
"""Test loading corrupt/invalid json file."""
|
||||
|
|
@ -342,13 +174,14 @@ class TestSessionStore:
|
|||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
|
||||
|
||||
# Mock open to raise exception
|
||||
with patch("builtins.open", side_effect=IOError("Disk full")):
|
||||
store.save_session("s1", "c1", "m1")
|
||||
store.save_tree("r2", {"root_id": "r2", "nodes": {"r2": {}}})
|
||||
|
||||
# Should log error but not crash. Session should be in memory though.
|
||||
assert "s1" in store._sessions
|
||||
# Should log error but not crash. Tree should be in memory.
|
||||
assert "r2" in store._trees
|
||||
|
||||
|
||||
class TestTreeQueueManager:
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from api.request_utils import (
|
||||
from api.detection import (
|
||||
is_quota_check_request,
|
||||
is_title_generation_request,
|
||||
extract_command_prefix,
|
||||
is_prefix_detection_request,
|
||||
get_token_count,
|
||||
)
|
||||
from api.command_utils import extract_command_prefix
|
||||
from api.request_utils import get_token_count
|
||||
from api.models.anthropic import MessagesRequest, Message
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import pytest
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from api.models.anthropic import MessagesRequest, Message
|
||||
from api.request_utils import (
|
||||
from api.detection import (
|
||||
is_suggestion_mode_request,
|
||||
is_filepath_extraction_request,
|
||||
extract_filepaths_from_command,
|
||||
)
|
||||
from api.command_utils import extract_filepaths_from_command
|
||||
|
||||
|
||||
def _mk_req(messages, tools=None):
|
||||
|
|
|
|||
|
|
@ -85,44 +85,17 @@ class TestSessionStoreSaveEdgeCases:
|
|||
|
||||
def test_save_io_error_handled(self, tmp_store):
|
||||
"""Write failure in _save() is logged but doesn't raise."""
|
||||
tmp_store.save_session("s1", "c1", "m1")
|
||||
# Make the path read-only dir to trigger write error
|
||||
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
|
||||
with patch("builtins.open", side_effect=IOError("disk full")):
|
||||
# _save() catches all exceptions and logs them
|
||||
tmp_store._save()
|
||||
# Should not raise
|
||||
|
||||
|
||||
class TestSessionStoreOperationEdgeCases:
|
||||
"""Tests for edge cases in session operations."""
|
||||
|
||||
def test_rename_session_not_found(self, tmp_store):
|
||||
"""rename_session with non-existent old_id returns False."""
|
||||
result = tmp_store.rename_session("nonexistent", "new_id")
|
||||
assert result is False
|
||||
|
||||
def test_update_last_message_not_found(self, tmp_store):
|
||||
"""update_last_message with unknown session_id logs warning."""
|
||||
# Should not raise
|
||||
tmp_store.update_last_message("nonexistent", "msg_1")
|
||||
|
||||
def test_get_session_by_msg_not_found(self, tmp_store):
|
||||
"""Looking up non-existent message returns None."""
|
||||
result = tmp_store.get_session_by_msg("c1", "m999")
|
||||
assert result is None
|
||||
|
||||
def test_get_session_record_not_found(self, tmp_store):
|
||||
"""Getting non-existent session record returns None."""
|
||||
result = tmp_store.get_session_record("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSessionStoreClearAll:
|
||||
def test_clear_all_wipes_state_and_persists(self, tmp_path):
|
||||
path = str(tmp_path / "sessions.json")
|
||||
store = SessionStore(storage_path=path)
|
||||
|
||||
store.save_session("s1", "c1", "m1", platform="telegram")
|
||||
store.save_tree(
|
||||
"root1",
|
||||
{
|
||||
|
|
@ -154,7 +127,6 @@ class TestSessionStoreClearAll:
|
|||
|
||||
store.clear_all()
|
||||
|
||||
assert store.get_session_by_msg("c1", "m1", "telegram") is None
|
||||
assert store.get_all_trees() == {}
|
||||
assert store.get_node_mapping() == {}
|
||||
|
||||
|
|
@ -187,15 +159,6 @@ class TestSessionStoreClearAll:
|
|||
class TestSessionStoreCleanupEdgeCases:
|
||||
"""Tests for cleanup with malformed data."""
|
||||
|
||||
def test_cleanup_sessions_malformed_timestamp(self, tmp_store):
|
||||
"""Malformed created_at in cleanup doesn't crash."""
|
||||
tmp_store.save_session("s1", "c1", "m1")
|
||||
# Corrupt the timestamp
|
||||
tmp_store._sessions["s1"].created_at = "not-a-date"
|
||||
# Should not crash - the except block silently skips bad records
|
||||
removed = tmp_store.cleanup_old_sessions(max_age_days=0)
|
||||
assert removed == 0 # Skipped due to parse error
|
||||
|
||||
def test_cleanup_trees_malformed_timestamp(self, tmp_store):
|
||||
"""Malformed created_at in cleanup_old_trees doesn't crash."""
|
||||
tmp_store._trees["root1"] = {"nodes": {"root1": {"created_at": "bad-date"}}}
|
||||
|
|
@ -207,9 +170,3 @@ class TestSessionStoreCleanupEdgeCases:
|
|||
tmp_store._trees["root1"] = {"nodes": {"root1": {}}}
|
||||
removed = tmp_store.cleanup_old_trees(max_age_days=0)
|
||||
assert removed == 0
|
||||
|
||||
def test_update_tree_node_nonexistent_tree(self, tmp_store):
|
||||
"""Updating a node in a nonexistent tree logs warning."""
|
||||
tmp_store.update_tree_node("nonexistent", "node1", {"data": "test"})
|
||||
# Should not crash, tree not created
|
||||
assert "nonexistent" not in tmp_store._trees
|
||||
|
|
|
|||
|
|
@ -468,8 +468,8 @@ class TestSessionStoreTrees:
|
|||
assert retrieved is not None
|
||||
assert retrieved["root_id"] == "root_1"
|
||||
|
||||
def test_get_tree_by_node(self, tmp_path):
|
||||
"""Test getting tree by node ID."""
|
||||
def test_get_tree_by_root_id(self, tmp_path):
|
||||
"""Test getting tree by root ID and node mapping."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
|
@ -484,10 +484,10 @@ class TestSessionStoreTrees:
|
|||
|
||||
store.save_tree("root", tree_data)
|
||||
|
||||
# Should find tree by child node
|
||||
retrieved = store.get_tree_by_node("child")
|
||||
retrieved = store.get_tree("root")
|
||||
assert retrieved is not None
|
||||
assert retrieved["root_id"] == "root"
|
||||
assert store.get_tree_root_for_node("child") == "root"
|
||||
|
||||
def test_register_node(self, tmp_path):
|
||||
"""Test registering a node to a tree."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue