Refactor done by z-ai/glm5

This commit is contained in:
Alishahryar1 2026-02-14 18:47:26 -08:00
parent 754ca99314
commit 64e5b10612
11 changed files with 157 additions and 533 deletions

View file

@ -1,6 +1,6 @@
"""Request utility functions for API route handlers.
Contains token counting and re-exports detection/command utilities.
Contains token counting for API requests.
"""
import json
@ -9,29 +9,10 @@ from typing import List, Optional, Union
import tiktoken
from .models.anthropic import MessagesRequest
from .detection import (
is_quota_check_request,
is_title_generation_request,
is_prefix_detection_request,
is_suggestion_mode_request,
is_filepath_extraction_request,
)
from .command_utils import extract_command_prefix, extract_filepaths_from_command
logger = logging.getLogger(__name__)
ENCODER = tiktoken.get_encoding("cl100k_base")
__all__ = [
"is_quota_check_request",
"is_title_generation_request",
"is_prefix_detection_request",
"is_suggestion_mode_request",
"is_filepath_extraction_request",
"extract_command_prefix",
"extract_filepaths_from_command",
"get_token_count",
]
__all__ = ["get_token_count"]
def get_token_count(

View file

@ -10,7 +10,7 @@ import time
import asyncio
import logging
import os
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from .base import MessagingPlatform, SessionManagerInterface
from .models import IncomingMessage
@ -29,6 +29,47 @@ from .telegram_markdown import (
logger = logging.getLogger(__name__)
# Status message prefixes used to filter our own messages (ignore echo)
STATUS_MESSAGE_PREFIXES = ("", "💭", "🔧", "", "", "🚀", "🤖", "📋", "📊", "🔄")
# Event types that update the transcript
TRANSCRIPT_EVENT_TYPES = (
"thinking_start",
"thinking_delta",
"thinking_chunk",
"thinking_stop",
"text_start",
"text_delta",
"text_chunk",
"text_stop",
"tool_use_start",
"tool_use_delta",
"tool_use_stop",
"tool_use",
"tool_result",
"block_stop",
"error",
)
# Event types -> (emoji, label) for status updates
_EVENT_STATUS_MAP = {
("thinking_start", "thinking_delta", "thinking_chunk"): ("🧠", "Claude is thinking..."),
("text_start", "text_delta", "text_chunk"): ("🧠", "Claude is working..."),
("tool_result",): ("", "Executing tools..."),
}
def _get_status_for_event(ptype: str, parsed: dict) -> Optional[str]:
"""Return status string for event type, or None if no status update needed."""
for types, (emoji, label) in _EVENT_STATUS_MAP.items():
if ptype in types:
return format_status(emoji, label)
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
if parsed.get("name") == "Task":
return format_status("🤖", "Subagent working...")
return format_status("", "Executing tools...")
return None
class ClaudeMessageHandler:
"""
@ -95,10 +136,7 @@ class ClaudeMessageHandler:
return
# Filter out status messages (our own messages)
if any(
incoming.text.startswith(p)
for p in ["", "💭", "🔧", "", "", "🚀", "🤖", "📋", "📊", "🔄"]
):
if any(incoming.text.startswith(p) for p in STATUS_MESSAGE_PREFIXES):
return
# Check if this is a reply to an existing node in a tree
@ -131,17 +169,9 @@ class ClaudeMessageHandler:
reply_to=incoming.message_id,
fire_and_forget=False,
)
try:
if status_msg_id:
self.session_store.record_message_id(
incoming.platform,
incoming.chat_id,
str(status_msg_id),
direction="out",
kind="status",
)
except Exception as e:
logger.debug(f"Failed to record status message_id: {e}")
self._record_outgoing_message(
incoming.platform, incoming.chat_id, status_msg_id, "status"
)
# Create or extend tree
if parent_node_id and tree and status_msg_id:
@ -285,45 +315,16 @@ class ClaudeMessageHandler:
captured_session_id: Optional[str],
) -> Tuple[Optional[str], bool]:
"""Process a single parsed CLI event. Returns (last_status, had_transcript_events)."""
ptype = parsed.get("type")
transcript_types = (
"thinking_start",
"thinking_delta",
"thinking_chunk",
"thinking_stop",
"text_start",
"text_delta",
"text_chunk",
"text_stop",
"tool_use_start",
"tool_use_delta",
"tool_use_stop",
"tool_use",
"tool_result",
"block_stop",
"error",
)
ptype = parsed.get("type") or ""
if ptype in transcript_types:
if ptype in TRANSCRIPT_EVENT_TYPES:
transcript.apply(parsed)
had_transcript_events = True
if ptype in ("thinking_start", "thinking_delta", "thinking_chunk"):
await update_ui(format_status("🧠", "Claude is thinking..."))
last_status = format_status("🧠", "Claude is thinking...")
elif ptype in ("text_start", "text_delta", "text_chunk"):
await update_ui(format_status("🧠", "Claude is working..."))
last_status = format_status("🧠", "Claude is working...")
elif ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
if parsed.get("name") == "Task":
await update_ui(format_status("🤖", "Subagent working..."))
last_status = format_status("🤖", "Subagent working...")
else:
await update_ui(format_status("", "Executing tools..."))
last_status = format_status("", "Executing tools...")
elif ptype == "tool_result":
await update_ui(format_status("", "Executing tools..."))
last_status = format_status("", "Executing tools...")
status = _get_status_for_event(ptype, parsed)
if status is not None:
await update_ui(status)
last_status = status
elif ptype == "block_stop":
await update_ui(last_status, force=True)
elif ptype == "complete":
@ -584,20 +585,7 @@ class ClaudeMessageHandler:
await self.cli_manager.stop_all()
# 3. Update UI and persist state for all cancelled nodes
for node in cancelled_nodes:
self.platform.fire_and_forget(
self.platform.queue_edit_message(
node.incoming.chat_id,
node.status_message_id,
format_status("", "Stopped."),
parse_mode="MarkdownV2",
)
)
# Persist tree state
tree = self.tree_queue.get_tree_for_node(node.node_id)
if tree:
self.session_store.save_tree(tree.root_id, tree.to_dict())
self._update_cancelled_nodes_ui(cancelled_nodes)
return len(cancelled_nodes)
@ -615,8 +603,29 @@ class ClaudeMessageHandler:
node.context = {"cancel_reason": "stop"}
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
self._update_cancelled_nodes_ui(cancelled_nodes)
return len(cancelled_nodes)
for node in cancelled_nodes:
def _record_outgoing_message(
self,
platform: str,
chat_id: str,
msg_id: Optional[str],
kind: str,
) -> None:
"""Record outgoing message ID for /clear. Best-effort, never raises."""
if not msg_id:
return
try:
self.session_store.record_message_id(
platform, chat_id, str(msg_id), direction="out", kind=kind
)
except Exception as e:
logger.debug(f"Failed to record message_id: {e}")
def _update_cancelled_nodes_ui(self, nodes: List[MessageNode]) -> None:
"""Update status messages and persist tree state for cancelled nodes."""
for node in nodes:
self.platform.fire_and_forget(
self.platform.queue_edit_message(
node.incoming.chat_id,
@ -625,13 +634,10 @@ class ClaudeMessageHandler:
parse_mode="MarkdownV2",
)
)
tree = self.tree_queue.get_tree_for_node(node.node_id)
if tree:
self.session_store.save_tree(tree.root_id, tree.to_dict())
return len(cancelled_nodes)
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
"""Handle /stop command from messaging platform."""
# Reply-scoped stop: reply "/stop" to stop only that task.
@ -646,17 +652,9 @@ class ClaudeMessageHandler:
format_status("", "Stopped.", "Nothing to stop for that message."),
fire_and_forget=False,
)
try:
if msg_id:
self.session_store.record_message_id(
incoming.platform,
incoming.chat_id,
str(msg_id),
direction="out",
kind="command",
)
except Exception:
pass
self._record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command"
)
return
count = await self.stop_task(node_id)
@ -666,17 +664,9 @@ class ClaudeMessageHandler:
format_status("", "Stopped.", f"Cancelled {count} {noun}."),
fire_and_forget=False,
)
try:
if msg_id:
self.session_store.record_message_id(
incoming.platform,
incoming.chat_id,
str(msg_id),
direction="out",
kind="command",
)
except Exception:
pass
self._record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command"
)
return
# Global stop: legacy behavior (stop everything)
@ -688,17 +678,9 @@ class ClaudeMessageHandler:
),
fire_and_forget=False,
)
try:
if msg_id:
self.session_store.record_message_id(
incoming.platform,
incoming.chat_id,
str(msg_id),
direction="out",
kind="command",
)
except Exception:
pass
self._record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command"
)
async def _handle_stats_command(self, incoming: IncomingMessage) -> None:
"""Handle /stats command."""
@ -716,17 +698,9 @@ class ClaudeMessageHandler:
+ escape_md_v2(f"• Message Trees: {tree_count}"),
fire_and_forget=False,
)
try:
if msg_id:
self.session_store.record_message_id(
incoming.platform,
incoming.chat_id,
str(msg_id),
direction="out",
kind="command",
)
except Exception:
pass
self._record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command"
)
async def _handle_clear_command(self, incoming: IncomingMessage) -> None:
"""

View file

@ -205,14 +205,6 @@ class SessionStore:
if x.get("message_id") is not None
]
def clear_message_log_for_chat(self, platform: str, chat_id: str) -> None:
"""Clear recorded message IDs for a single chat."""
chat_key = self._make_chat_key(str(platform), str(chat_id))
with self._lock:
self._message_log.pop(chat_key, None)
self._message_log_ids.pop(chat_key, None)
self._save()
def clear_all(self) -> None:
"""Clear all stored sessions/trees/mappings and persist an empty store."""
with self._lock:
@ -224,124 +216,6 @@ class SessionStore:
self._message_log_ids.clear()
self._save()
# ==================== Session Methods ====================
def save_session(
self,
session_id: str,
chat_id: str,
initial_msg_id: str,
platform: str = "telegram",
) -> None:
"""Save a new session mapping."""
with self._lock:
now = datetime.now(timezone.utc).isoformat()
record = SessionRecord(
session_id=session_id,
chat_id=str(chat_id),
initial_msg_id=str(initial_msg_id),
last_msg_id=str(initial_msg_id),
platform=platform,
created_at=now,
updated_at=now,
)
self._sessions[session_id] = record
self._msg_to_session[
self._make_key(platform, str(chat_id), str(initial_msg_id))
] = session_id
self._save()
logger.info(
f"Saved session {session_id} for {platform} chat {chat_id}, msg {initial_msg_id}"
)
def get_session_by_msg(
self, chat_id: str, msg_id: str, platform: str = "telegram"
) -> Optional[str]:
"""Look up a session ID by a message that's part of that session."""
with self._lock:
key = self._make_key(platform, str(chat_id), str(msg_id))
return self._msg_to_session.get(key)
def update_last_message(self, session_id: str, msg_id: str) -> None:
"""Update the last message ID for a session."""
with self._lock:
if session_id not in self._sessions:
logger.warning(f"Session {session_id} not found for update")
return
record = self._sessions[session_id]
record.last_msg_id = str(msg_id)
record.updated_at = datetime.now(timezone.utc).isoformat()
new_key = self._make_key(record.platform, record.chat_id, str(msg_id))
self._msg_to_session[new_key] = session_id
self._save()
logger.debug(f"Updated session {session_id} last_msg to {msg_id}")
def rename_session(self, old_id: str, new_id: str) -> bool:
"""Rename a session ID, migrating all message mappings."""
with self._lock:
if old_id not in self._sessions:
logger.warning(f"Session {old_id} not found for rename to {new_id}")
return False
record = self._sessions.pop(old_id)
record.session_id = new_id
record.updated_at = datetime.now(timezone.utc).isoformat()
self._sessions[new_id] = record
items_to_update = [
k for k, v in self._msg_to_session.items() if v == old_id
]
for key in items_to_update:
self._msg_to_session[key] = new_id
self._save()
logger.info(
f"Renamed session {old_id} to {new_id} ({len(items_to_update)} mappings updated)"
)
return True
def get_session_record(self, session_id: str) -> Optional[SessionRecord]:
"""Get full session record."""
with self._lock:
return self._sessions.get(session_id)
def cleanup_old_sessions(self, max_age_days: int = 30) -> int:
"""Remove sessions older than max_age_days."""
with self._lock:
cutoff = datetime.now(timezone.utc)
removed = 0
to_remove = []
for sid, record in self._sessions.items():
try:
created = datetime.fromisoformat(record.created_at)
age_days = (cutoff - created).days
if age_days > max_age_days:
to_remove.append(sid)
except Exception:
pass
for sid in to_remove:
record = self._sessions.pop(sid)
self._msg_to_session.pop(
self._make_key(
record.platform, record.chat_id, record.initial_msg_id
),
None,
)
self._msg_to_session.pop(
self._make_key(record.platform, record.chat_id, record.last_msg_id),
None,
)
removed += 1
if removed:
self._save()
logger.info(f"Cleaned up {removed} old sessions")
return removed
# ==================== Tree Methods ====================
def save_tree(self, root_id: str, tree_data: dict) -> None:
@ -367,14 +241,6 @@ class SessionStore:
with self._lock:
return self._trees.get(root_id)
def get_tree_by_node(self, node_id: str) -> Optional[dict]:
"""Get the tree containing a node."""
with self._lock:
root_id = self._node_to_tree.get(node_id)
if not root_id:
return None
return self._trees.get(root_id)
def get_tree_root_for_node(self, node_id: str) -> Optional[str]:
"""Get the root ID of the tree containing a node."""
with self._lock:
@ -386,20 +252,6 @@ class SessionStore:
self._node_to_tree[node_id] = root_id
self._save()
def update_tree_node(self, root_id: str, node_id: str, node_data: dict) -> None:
"""Update a specific node in a tree."""
with self._lock:
if root_id not in self._trees:
logger.warning(f"Tree {root_id} not found")
return
if "nodes" not in self._trees[root_id]:
self._trees[root_id]["nodes"] = {}
self._trees[root_id]["nodes"][node_id] = node_data
self._node_to_tree[node_id] = root_id
self._save()
def get_all_trees(self) -> Dict[str, dict]:
"""Get all stored trees (public accessor)."""
with self._lock:

View file

@ -5,6 +5,8 @@ Contains MessageState, MessageNode, and MessageTree classes.
import asyncio
import logging
from collections import deque
from contextlib import asynccontextmanager
from enum import Enum
from datetime import datetime, timezone
from typing import Dict, Optional, List, Any
@ -276,6 +278,42 @@ class MessageTree:
"""Get number of messages waiting in queue."""
return self._queue.qsize()
def remove_from_queue(self, node_id: str) -> bool:
"""
Remove node_id from the internal queue if present.
Caller must hold the tree lock (e.g. via with_lock).
Returns True if node was removed, False if not in queue.
"""
queue_deque: deque = self._queue._queue # type: ignore[attr-defined]
if node_id not in queue_deque:
return False
self._queue._queue = deque(x for x in queue_deque if x != node_id) # type: ignore[attr-defined]
return True
@asynccontextmanager
async def with_lock(self):
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
async with self._lock:
yield
def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None:
"""Set processing state. Caller must hold lock for consistency with queue operations."""
self._is_processing = is_processing
self._current_node_id = node_id if is_processing else None
def clear_current_node(self) -> None:
"""Clear the currently processing node ID. Caller must hold lock."""
self._current_node_id = None
def is_current_node(self, node_id: str) -> bool:
"""Check if node_id is the currently processing node."""
return self._current_node_id == node_id
def put_queue_unlocked(self, node_id: str) -> None:
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
self._queue.put_nowait(node_id)
def cancel_current_task(self) -> bool:
"""Cancel the currently running task. Returns True if a task was cancelled."""
if self._current_task and not self._current_task.done():

View file

@ -90,7 +90,7 @@ class TreeQueueProcessor:
node.node_id, MessageState.ERROR, error_message=str(e)
)
finally:
tree._current_node_id = None
tree.clear_current_node()
# Check if there are more messages in the queue
await self._process_next(tree, processor)
@ -102,16 +102,15 @@ class TreeQueueProcessor:
"""Process the next message in queue, if any."""
next_node_id = None
node = None
async with tree._lock:
async with tree.with_lock():
next_node_id = await tree.dequeue()
if not next_node_id:
# No more messages, mark tree as free
tree._is_processing = False
tree.set_processing_state(None, False)
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
return
tree._current_node_id = next_node_id
tree.set_processing_state(next_node_id, True)
logger.info(f"Processing next queued node {next_node_id}")
# Process next node (outside lock)
@ -143,17 +142,14 @@ class TreeQueueProcessor:
Returns:
True if queued, False if processing immediately
"""
async with tree._lock:
if tree._is_processing:
# Tree is busy, queue the message
await tree._queue.put(node_id)
queue_size = tree._queue.qsize()
async with tree.with_lock():
if tree.is_processing:
tree.put_queue_unlocked(node_id)
queue_size = tree.get_queue_size()
logger.info(f"Queued node {node_id}, position {queue_size}")
return True
else:
# Tree is free, start processing
tree._is_processing = True
tree._current_node_id = node_id
tree.set_processing_state(node_id, True)
# Process outside the lock
node = tree.get_node(node_id)

View file

@ -6,7 +6,6 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic.
import asyncio
import logging
from collections import deque
from datetime import datetime, timezone
from typing import Callable, Awaitable, List, Optional
@ -293,7 +292,7 @@ class TreeQueueManager:
if not tree:
return []
async with tree._lock:
async with tree.with_lock():
node = tree.get_node(node_id)
if not node:
return []
@ -301,18 +300,12 @@ class TreeQueueManager:
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
return []
# Cancel running task if this is the current node.
if tree._current_node_id == node_id:
if tree.is_current_node(node_id):
self._processor.cancel_current(tree)
# Remove from queue if present (asyncio.Queue exposes its internal deque).
try:
q = tree._queue._queue # type: ignore[attr-defined]
if q and node_id in q:
tree._queue._queue = deque(x for x in q if x != node_id) # type: ignore[attr-defined]
tree.remove_from_queue(node_id)
except Exception:
# Best-effort: if we can't mutate the queue internals, the node will
# still be dequeued later and skipped due to state=ERROR.
logger.debug(
"Failed to remove node from queue; will rely on state=ERROR"
)

View file

@ -65,145 +65,6 @@ class TestSessionStore:
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
assert store._sessions == {}
def test_save_and_get_session(self, tmp_path):
"""Test saving and retrieving a session."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_session(
session_id="sess_123",
chat_id="chat_456",
initial_msg_id="msg_789",
platform="telegram",
)
# Retrieve by message
found = store.get_session_by_msg("chat_456", "msg_789", "telegram")
assert found == "sess_123"
# Verify persistence file created
assert os.path.exists(str(tmp_path / "sessions.json"))
def test_update_last_message(self, tmp_path):
"""Test updating last message in session."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_session("sess_1", "chat_1", "msg_1", "telegram")
store.update_last_message("sess_1", "msg_2")
# Should find session by new message too
found = store.get_session_by_msg("chat_1", "msg_2", "telegram")
assert found == "sess_1"
# Original message mapping should still work
found_old = store.get_session_by_msg("chat_1", "msg_1", "telegram")
assert found_old == "sess_1"
# Verify record updated
record = store.get_session_record("sess_1")
assert record is not None
assert record.last_msg_id == "msg_2"
def test_update_last_message_unknown_session(self, tmp_path):
"""Test updating unknown session does nothing."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.update_last_message("unknown", "msg_x")
# Should log warning but not crash
def test_get_session_record(self, tmp_path):
"""Test getting full session record."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_session("sess_1", "chat_1", "msg_1", "telegram")
record = store.get_session_record("sess_1")
assert record is not None
assert record.session_id == "sess_1"
assert record.platform == "telegram"
def test_session_not_found(self, tmp_path):
"""Test getting non-existent session returns None."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
found = store.get_session_by_msg("notexist", "notexist", "telegram")
assert found is None
def test_rename_session(self, tmp_path):
"""Test renaming a session."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_session("old_id", "c1", "m1", "telegram")
store.update_last_message("old_id", "m2")
success = store.rename_session("old_id", "new_id")
assert success is True
# Verify old id gone
assert store.get_session_record("old_id") is None
# Verify new id exists
rec = store.get_session_record("new_id")
assert rec is not None
assert rec.session_id == "new_id"
# Verify mappings point to new id
assert store.get_session_by_msg("c1", "m1", "telegram") == "new_id"
assert store.get_session_by_msg("c1", "m2", "telegram") == "new_id"
def test_rename_unknown_session(self, tmp_path):
"""Test renaming unknown session fails."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
success = store.rename_session("unknown", "new")
assert success is False
def test_cleanup_old_sessions(self, tmp_path):
"""Test cleaning up expired sessions."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
# Create an old session manually
old_date = (datetime.now(timezone.utc) - timedelta(days=40)).isoformat()
store.save_session("old_sess", "c_old", "m_old")
# Manipulate the created_at directly
store._sessions["old_sess"].created_at = old_date
# Create a new session
store.save_session("new_sess", "c_new", "m_new")
# Cleanup
removed = store.cleanup_old_sessions(max_age_days=30)
assert removed == 1
assert store.get_session_record("old_sess") is None
assert store.get_session_by_msg("c_old", "m_old") is None
assert store.get_session_record("new_sess") is not None
def test_cleanup_old_sessions_invalid_date(self, tmp_path):
"""Test cleanup handles invalid date formats gracefully."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_session("bad_date_sess", "c", "m")
store._sessions["bad_date_sess"].created_at = "not-a-date"
# Should not crash
store.cleanup_old_sessions(30)
# Should still exist because parsing failed so it wasn't removed (or default behavior)
# The code tries parsing, excepts, and continues, so it isn't removed.
assert store.get_session_record("bad_date_sess") is not None
# --- Tree Tests ---
def test_save_and_get_tree(self, tmp_path):
@ -225,34 +86,6 @@ class TestSessionStore:
assert store.get_tree_root_for_node("r1") == "r1"
assert store.get_tree_root_for_node("n1") == "r1"
# Verify get_tree_by_node
assert store.get_tree_by_node("n1") == tree_data
assert store.get_tree_by_node("unknown") is None
def test_update_tree_node(self, tmp_path):
"""Test updating a specific node in a tree."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_tree("r1", {"nodes": {"r1": {}}})
# Add new node
store.update_tree_node("r1", "n2", {"data": "test"})
tree = store.get_tree("r1")
assert tree is not None
assert "n2" in tree["nodes"]
assert tree["nodes"]["n2"]["data"] == "test"
assert store.get_tree_root_for_node("n2") == "r1"
def test_update_tree_node_unknown_tree(self, tmp_path):
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.update_tree_node("unknown_root", "n1", {})
# Should not crash
def test_register_node(self, tmp_path):
"""Test manual node registration."""
from messaging.session import SessionStore
@ -296,7 +129,7 @@ class TestSessionStore:
# --- Persistence & Edge Cases ---
def test_load_existing_legacy_format(self, tmp_path):
"""Test loading legacy session format (int IDs)."""
"""Test loading legacy session format (int IDs) - backward compat."""
from messaging.session import SessionStore
data = {
@ -318,12 +151,11 @@ class TestSessionStore:
json.dump(data, f)
store = SessionStore(storage_path=str(p))
rec = store.get_session_record("s1")
assert rec is not None
# Legacy sessions are loaded for backward compat; verify conversion
assert "s1" in store._sessions
rec = store._sessions["s1"]
assert rec.chat_id == "123" # Converted to str
assert rec.platform == "telegram" # Defaulted
assert store.get_session_by_msg("123", "100", "telegram") == "s1"
def test_load_corrupt_file(self, tmp_path):
"""Test loading corrupt/invalid json file."""
@ -342,13 +174,14 @@ class TestSessionStore:
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
# Mock open to raise exception
with patch("builtins.open", side_effect=IOError("Disk full")):
store.save_session("s1", "c1", "m1")
store.save_tree("r2", {"root_id": "r2", "nodes": {"r2": {}}})
# Should log error but not crash. Session should be in memory though.
assert "s1" in store._sessions
# Should log error but not crash. Tree should be in memory.
assert "r2" in store._trees
class TestTreeQueueManager:

View file

@ -3,13 +3,13 @@
import pytest
from unittest.mock import MagicMock
from api.request_utils import (
from api.detection import (
is_quota_check_request,
is_title_generation_request,
extract_command_prefix,
is_prefix_detection_request,
get_token_count,
)
from api.command_utils import extract_command_prefix
from api.request_utils import get_token_count
from api.models.anthropic import MessagesRequest, Message

View file

@ -2,11 +2,11 @@ import pytest
from unittest.mock import MagicMock
from api.models.anthropic import MessagesRequest, Message
from api.request_utils import (
from api.detection import (
is_suggestion_mode_request,
is_filepath_extraction_request,
extract_filepaths_from_command,
)
from api.command_utils import extract_filepaths_from_command
def _mk_req(messages, tools=None):

View file

@ -85,44 +85,17 @@ class TestSessionStoreSaveEdgeCases:
def test_save_io_error_handled(self, tmp_store):
"""Write failure in _save() is logged but doesn't raise."""
tmp_store.save_session("s1", "c1", "m1")
# Make the path read-only dir to trigger write error
tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}})
with patch("builtins.open", side_effect=IOError("disk full")):
# _save() catches all exceptions and logs them
tmp_store._save()
# Should not raise
class TestSessionStoreOperationEdgeCases:
"""Tests for edge cases in session operations."""
def test_rename_session_not_found(self, tmp_store):
"""rename_session with non-existent old_id returns False."""
result = tmp_store.rename_session("nonexistent", "new_id")
assert result is False
def test_update_last_message_not_found(self, tmp_store):
"""update_last_message with unknown session_id logs warning."""
# Should not raise
tmp_store.update_last_message("nonexistent", "msg_1")
def test_get_session_by_msg_not_found(self, tmp_store):
"""Looking up non-existent message returns None."""
result = tmp_store.get_session_by_msg("c1", "m999")
assert result is None
def test_get_session_record_not_found(self, tmp_store):
"""Getting non-existent session record returns None."""
result = tmp_store.get_session_record("nonexistent")
assert result is None
class TestSessionStoreClearAll:
def test_clear_all_wipes_state_and_persists(self, tmp_path):
path = str(tmp_path / "sessions.json")
store = SessionStore(storage_path=path)
store.save_session("s1", "c1", "m1", platform="telegram")
store.save_tree(
"root1",
{
@ -154,7 +127,6 @@ class TestSessionStoreClearAll:
store.clear_all()
assert store.get_session_by_msg("c1", "m1", "telegram") is None
assert store.get_all_trees() == {}
assert store.get_node_mapping() == {}
@ -187,15 +159,6 @@ class TestSessionStoreClearAll:
class TestSessionStoreCleanupEdgeCases:
"""Tests for cleanup with malformed data."""
def test_cleanup_sessions_malformed_timestamp(self, tmp_store):
"""Malformed created_at in cleanup doesn't crash."""
tmp_store.save_session("s1", "c1", "m1")
# Corrupt the timestamp
tmp_store._sessions["s1"].created_at = "not-a-date"
# Should not crash - the except block silently skips bad records
removed = tmp_store.cleanup_old_sessions(max_age_days=0)
assert removed == 0 # Skipped due to parse error
def test_cleanup_trees_malformed_timestamp(self, tmp_store):
"""Malformed created_at in cleanup_old_trees doesn't crash."""
tmp_store._trees["root1"] = {"nodes": {"root1": {"created_at": "bad-date"}}}
@ -207,9 +170,3 @@ class TestSessionStoreCleanupEdgeCases:
tmp_store._trees["root1"] = {"nodes": {"root1": {}}}
removed = tmp_store.cleanup_old_trees(max_age_days=0)
assert removed == 0
def test_update_tree_node_nonexistent_tree(self, tmp_store):
"""Updating a node in a nonexistent tree logs warning."""
tmp_store.update_tree_node("nonexistent", "node1", {"data": "test"})
# Should not crash, tree not created
assert "nonexistent" not in tmp_store._trees

View file

@ -468,8 +468,8 @@ class TestSessionStoreTrees:
assert retrieved is not None
assert retrieved["root_id"] == "root_1"
def test_get_tree_by_node(self, tmp_path):
"""Test getting tree by node ID."""
def test_get_tree_by_root_id(self, tmp_path):
"""Test getting tree by root ID and node mapping."""
from messaging.session import SessionStore
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
@ -484,10 +484,10 @@ class TestSessionStoreTrees:
store.save_tree("root", tree_data)
# Should find tree by child node
retrieved = store.get_tree_by_node("child")
retrieved = store.get_tree("root")
assert retrieved is not None
assert retrieved["root_id"] == "root"
assert store.get_tree_root_for_node("child") == "root"
def test_register_node(self, tmp_path):
"""Test registering a node to a tree."""