diff --git a/api/request_utils.py b/api/request_utils.py index 7f0d7d0..0a15838 100644 --- a/api/request_utils.py +++ b/api/request_utils.py @@ -1,6 +1,6 @@ """Request utility functions for API route handlers. -Contains token counting and re-exports detection/command utilities. +Contains token counting for API requests. """ import json @@ -9,29 +9,10 @@ from typing import List, Optional, Union import tiktoken -from .models.anthropic import MessagesRequest -from .detection import ( - is_quota_check_request, - is_title_generation_request, - is_prefix_detection_request, - is_suggestion_mode_request, - is_filepath_extraction_request, -) -from .command_utils import extract_command_prefix, extract_filepaths_from_command - logger = logging.getLogger(__name__) ENCODER = tiktoken.get_encoding("cl100k_base") -__all__ = [ - "is_quota_check_request", - "is_title_generation_request", - "is_prefix_detection_request", - "is_suggestion_mode_request", - "is_filepath_extraction_request", - "extract_command_prefix", - "extract_filepaths_from_command", - "get_token_count", -] +__all__ = ["get_token_count"] def get_token_count( diff --git a/messaging/handler.py b/messaging/handler.py index 9c84010..95ae242 100644 --- a/messaging/handler.py +++ b/messaging/handler.py @@ -10,7 +10,7 @@ import time import asyncio import logging import os -from typing import Optional, Tuple +from typing import List, Optional, Tuple from .base import MessagingPlatform, SessionManagerInterface from .models import IncomingMessage @@ -29,6 +29,47 @@ from .telegram_markdown import ( logger = logging.getLogger(__name__) +# Status message prefixes used to filter our own messages (ignore echo) +STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄") + +# Event types that update the transcript +TRANSCRIPT_EVENT_TYPES = ( + "thinking_start", + "thinking_delta", + "thinking_chunk", + "thinking_stop", + "text_start", + "text_delta", + "text_chunk", + "text_stop", + "tool_use_start", + "tool_use_delta", + "tool_use_stop", + "tool_use", + "tool_result", + "block_stop", + "error", +) + +# Event types -> (emoji, label) for status updates +_EVENT_STATUS_MAP = { + ("thinking_start", "thinking_delta", "thinking_chunk"): ("🧠", "Claude is thinking..."), + ("text_start", "text_delta", "text_chunk"): ("🧠", "Claude is working..."), + ("tool_result",): ("⏳", "Executing tools..."), +} + + +def _get_status_for_event(ptype: str, parsed: dict) -> Optional[str]: + """Return status string for event type, or None if no status update needed.""" + for types, (emoji, label) in _EVENT_STATUS_MAP.items(): + if ptype in types: + return format_status(emoji, label) + if ptype in ("tool_use_start", "tool_use_delta", "tool_use"): + if parsed.get("name") == "Task": + return format_status("🤖", "Subagent working...") + return format_status("⏳", "Executing tools...") + return None + class ClaudeMessageHandler: """ @@ -95,10 +136,7 @@ class ClaudeMessageHandler: return # Filter out status messages (our own messages) - if any( - incoming.text.startswith(p) - for p in ["⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄"] - ): + if any(incoming.text.startswith(p) for p in STATUS_MESSAGE_PREFIXES): return # Check if this is a reply to an existing node in a tree @@ -131,17 +169,9 @@ class ClaudeMessageHandler: reply_to=incoming.message_id, fire_and_forget=False, ) - try: - if status_msg_id: - self.session_store.record_message_id( - incoming.platform, - incoming.chat_id, - str(status_msg_id), - direction="out", - kind="status", - ) - except Exception as e: - logger.debug(f"Failed to record status message_id: {e}") + self._record_outgoing_message( + incoming.platform, incoming.chat_id, status_msg_id, "status" + ) # Create or extend tree if parent_node_id and tree and status_msg_id: @@ -285,45 +315,16 @@ class ClaudeMessageHandler: captured_session_id: Optional[str], ) -> Tuple[Optional[str], bool]: """Process a single parsed CLI event. Returns (last_status, had_transcript_events).""" - ptype = parsed.get("type") - transcript_types = ( - "thinking_start", - "thinking_delta", - "thinking_chunk", - "thinking_stop", - "text_start", - "text_delta", - "text_chunk", - "text_stop", - "tool_use_start", - "tool_use_delta", - "tool_use_stop", - "tool_use", - "tool_result", - "block_stop", - "error", - ) + ptype = parsed.get("type") or "" - if ptype in transcript_types: + if ptype in TRANSCRIPT_EVENT_TYPES: transcript.apply(parsed) had_transcript_events = True - if ptype in ("thinking_start", "thinking_delta", "thinking_chunk"): - await update_ui(format_status("🧠", "Claude is thinking...")) - last_status = format_status("🧠", "Claude is thinking...") - elif ptype in ("text_start", "text_delta", "text_chunk"): - await update_ui(format_status("🧠", "Claude is working...")) - last_status = format_status("🧠", "Claude is working...") - elif ptype in ("tool_use_start", "tool_use_delta", "tool_use"): - if parsed.get("name") == "Task": - await update_ui(format_status("🤖", "Subagent working...")) - last_status = format_status("🤖", "Subagent working...") - else: - await update_ui(format_status("⏳", "Executing tools...")) - last_status = format_status("⏳", "Executing tools...") - elif ptype == "tool_result": - await update_ui(format_status("⏳", "Executing tools...")) - last_status = format_status("⏳", "Executing tools...") + status = _get_status_for_event(ptype, parsed) + if status is not None: + await update_ui(status) + last_status = status elif ptype == "block_stop": await update_ui(last_status, force=True) elif ptype == "complete": @@ -584,20 +585,7 @@ class ClaudeMessageHandler: await self.cli_manager.stop_all() # 3. Update UI and persist state for all cancelled nodes - for node in cancelled_nodes: - self.platform.fire_and_forget( - self.platform.queue_edit_message( - node.incoming.chat_id, - node.status_message_id, - format_status("⏹", "Stopped."), - parse_mode="MarkdownV2", - ) - ) - - # Persist tree state - tree = self.tree_queue.get_tree_for_node(node.node_id) - if tree: - self.session_store.save_tree(tree.root_id, tree.to_dict()) + self._update_cancelled_nodes_ui(cancelled_nodes) return len(cancelled_nodes) @@ -615,8 +603,29 @@ class ClaudeMessageHandler: node.context = {"cancel_reason": "stop"} cancelled_nodes = await self.tree_queue.cancel_node(node_id) + self._update_cancelled_nodes_ui(cancelled_nodes) + return len(cancelled_nodes) - for node in cancelled_nodes: + def _record_outgoing_message( + self, + platform: str, + chat_id: str, + msg_id: Optional[str], + kind: str, + ) -> None: + """Record outgoing message ID for /clear. Best-effort, never raises.""" + if not msg_id: + return + try: + self.session_store.record_message_id( + platform, chat_id, str(msg_id), direction="out", kind=kind + ) + except Exception as e: + logger.debug(f"Failed to record message_id: {e}") + + def _update_cancelled_nodes_ui(self, nodes: List[MessageNode]) -> None: + """Update status messages and persist tree state for cancelled nodes.""" + for node in nodes: self.platform.fire_and_forget( self.platform.queue_edit_message( node.incoming.chat_id, @@ -625,13 +634,10 @@ class ClaudeMessageHandler: parse_mode="MarkdownV2", ) ) - tree = self.tree_queue.get_tree_for_node(node.node_id) if tree: self.session_store.save_tree(tree.root_id, tree.to_dict()) - return len(cancelled_nodes) - async def _handle_stop_command(self, incoming: IncomingMessage) -> None: """Handle /stop command from messaging platform.""" # Reply-scoped stop: reply "/stop" to stop only that task. @@ -646,17 +652,9 @@ class ClaudeMessageHandler: format_status("⏹", "Stopped.", "Nothing to stop for that message."), fire_and_forget=False, ) - try: - if msg_id: - self.session_store.record_message_id( - incoming.platform, - incoming.chat_id, - str(msg_id), - direction="out", - kind="command", - ) - except Exception: - pass + self._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) return count = await self.stop_task(node_id) @@ -666,17 +664,9 @@ class ClaudeMessageHandler: format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."), fire_and_forget=False, ) - try: - if msg_id: - self.session_store.record_message_id( - incoming.platform, - incoming.chat_id, - str(msg_id), - direction="out", - kind="command", - ) - except Exception: - pass + self._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) return # Global stop: legacy behavior (stop everything) @@ -688,17 +678,9 @@ class ClaudeMessageHandler: ), fire_and_forget=False, ) - try: - if msg_id: - self.session_store.record_message_id( - incoming.platform, - incoming.chat_id, - str(msg_id), - direction="out", - kind="command", - ) - except Exception: - pass + self._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) async def _handle_stats_command(self, incoming: IncomingMessage) -> None: """Handle /stats command.""" @@ -716,17 +698,9 @@ class ClaudeMessageHandler: + escape_md_v2(f"• Message Trees: {tree_count}"), fire_and_forget=False, ) - try: - if msg_id: - self.session_store.record_message_id( - incoming.platform, - incoming.chat_id, - str(msg_id), - direction="out", - kind="command", - ) - except Exception: - pass + self._record_outgoing_message( + incoming.platform, incoming.chat_id, msg_id, "command" + ) async def _handle_clear_command(self, incoming: IncomingMessage) -> None: """ diff --git a/messaging/session.py b/messaging/session.py index f34c442..a87c871 100644 --- a/messaging/session.py +++ b/messaging/session.py @@ -205,14 +205,6 @@ class SessionStore: if x.get("message_id") is not None ] - def clear_message_log_for_chat(self, platform: str, chat_id: str) -> None: - """Clear recorded message IDs for a single chat.""" - chat_key = self._make_chat_key(str(platform), str(chat_id)) - with self._lock: - self._message_log.pop(chat_key, None) - self._message_log_ids.pop(chat_key, None) - self._save() - def clear_all(self) -> None: """Clear all stored sessions/trees/mappings and persist an empty store.""" with self._lock: @@ -224,124 +216,6 @@ class SessionStore: self._message_log_ids.clear() self._save() - # ==================== Session Methods ==================== - - def save_session( - self, - session_id: str, - chat_id: str, - initial_msg_id: str, - platform: str = "telegram", - ) -> None: - """Save a new session mapping.""" - with self._lock: - now = datetime.now(timezone.utc).isoformat() - record = SessionRecord( - session_id=session_id, - chat_id=str(chat_id), - initial_msg_id=str(initial_msg_id), - last_msg_id=str(initial_msg_id), - platform=platform, - created_at=now, - updated_at=now, - ) - self._sessions[session_id] = record - self._msg_to_session[ - self._make_key(platform, str(chat_id), str(initial_msg_id)) - ] = session_id - self._save() - logger.info( - f"Saved session {session_id} for {platform} chat {chat_id}, msg {initial_msg_id}" - ) - - def get_session_by_msg( - self, chat_id: str, msg_id: str, platform: str = "telegram" - ) -> Optional[str]: - """Look up a session ID by a message that's part of that session.""" - with self._lock: - key = self._make_key(platform, str(chat_id), str(msg_id)) - return self._msg_to_session.get(key) - - def update_last_message(self, session_id: str, msg_id: str) -> None: - """Update the last message ID for a session.""" - with self._lock: - if session_id not in self._sessions: - logger.warning(f"Session {session_id} not found for update") - return - - record = self._sessions[session_id] - record.last_msg_id = str(msg_id) - record.updated_at = datetime.now(timezone.utc).isoformat() - new_key = self._make_key(record.platform, record.chat_id, str(msg_id)) - self._msg_to_session[new_key] = session_id - self._save() - logger.debug(f"Updated session {session_id} last_msg to {msg_id}") - - def rename_session(self, old_id: str, new_id: str) -> bool: - """Rename a session ID, migrating all message mappings.""" - with self._lock: - if old_id not in self._sessions: - logger.warning(f"Session {old_id} not found for rename to {new_id}") - return False - - record = self._sessions.pop(old_id) - record.session_id = new_id - record.updated_at = datetime.now(timezone.utc).isoformat() - self._sessions[new_id] = record - - items_to_update = [ - k for k, v in self._msg_to_session.items() if v == old_id - ] - for key in items_to_update: - self._msg_to_session[key] = new_id - - self._save() - logger.info( - f"Renamed session {old_id} to {new_id} ({len(items_to_update)} mappings updated)" - ) - return True - - def get_session_record(self, session_id: str) -> Optional[SessionRecord]: - """Get full session record.""" - with self._lock: - return self._sessions.get(session_id) - - def cleanup_old_sessions(self, max_age_days: int = 30) -> int: - """Remove sessions older than max_age_days.""" - with self._lock: - cutoff = datetime.now(timezone.utc) - removed = 0 - - to_remove = [] - for sid, record in self._sessions.items(): - try: - created = datetime.fromisoformat(record.created_at) - age_days = (cutoff - created).days - if age_days > max_age_days: - to_remove.append(sid) - except Exception: - pass - - for sid in to_remove: - record = self._sessions.pop(sid) - self._msg_to_session.pop( - self._make_key( - record.platform, record.chat_id, record.initial_msg_id - ), - None, - ) - self._msg_to_session.pop( - self._make_key(record.platform, record.chat_id, record.last_msg_id), - None, - ) - removed += 1 - - if removed: - self._save() - logger.info(f"Cleaned up {removed} old sessions") - - return removed - # ==================== Tree Methods ==================== def save_tree(self, root_id: str, tree_data: dict) -> None: @@ -367,14 +241,6 @@ class SessionStore: with self._lock: return self._trees.get(root_id) - def get_tree_by_node(self, node_id: str) -> Optional[dict]: - """Get the tree containing a node.""" - with self._lock: - root_id = self._node_to_tree.get(node_id) - if not root_id: - return None - return self._trees.get(root_id) - def get_tree_root_for_node(self, node_id: str) -> Optional[str]: """Get the root ID of the tree containing a node.""" with self._lock: @@ -386,20 +252,6 @@ class SessionStore: self._node_to_tree[node_id] = root_id self._save() - def update_tree_node(self, root_id: str, node_id: str, node_data: dict) -> None: - """Update a specific node in a tree.""" - with self._lock: - if root_id not in self._trees: - logger.warning(f"Tree {root_id} not found") - return - - if "nodes" not in self._trees[root_id]: - self._trees[root_id]["nodes"] = {} - - self._trees[root_id]["nodes"][node_id] = node_data - self._node_to_tree[node_id] = root_id - self._save() - def get_all_trees(self) -> Dict[str, dict]: """Get all stored trees (public accessor).""" with self._lock: diff --git a/messaging/tree_data.py b/messaging/tree_data.py index 079f813..81e8eb8 100644 --- a/messaging/tree_data.py +++ b/messaging/tree_data.py @@ -5,6 +5,8 @@ Contains MessageState, MessageNode, and MessageTree classes. import asyncio import logging +from collections import deque +from contextlib import asynccontextmanager from enum import Enum from datetime import datetime, timezone from typing import Dict, Optional, List, Any @@ -276,6 +278,42 @@ class MessageTree: """Get number of messages waiting in queue.""" return self._queue.qsize() + def remove_from_queue(self, node_id: str) -> bool: + """ + Remove node_id from the internal queue if present. + + Caller must hold the tree lock (e.g. via with_lock). + Returns True if node was removed, False if not in queue. + """ + queue_deque: deque = self._queue._queue # type: ignore[attr-defined] + if node_id not in queue_deque: + return False + self._queue._queue = deque(x for x in queue_deque if x != node_id) # type: ignore[attr-defined] + return True + + @asynccontextmanager + async def with_lock(self): + """Async context manager for tree lock. Use when multiple operations need atomicity.""" + async with self._lock: + yield + + def set_processing_state(self, node_id: Optional[str], is_processing: bool) -> None: + """Set processing state. Caller must hold lock for consistency with queue operations.""" + self._is_processing = is_processing + self._current_node_id = node_id if is_processing else None + + def clear_current_node(self) -> None: + """Clear the currently processing node ID. Caller must hold lock.""" + self._current_node_id = None + + def is_current_node(self, node_id: str) -> bool: + """Check if node_id is the currently processing node.""" + return self._current_node_id == node_id + + def put_queue_unlocked(self, node_id: str) -> None: + """Add node to queue. Caller must hold lock (e.g. via with_lock).""" + self._queue.put_nowait(node_id) + def cancel_current_task(self) -> bool: """Cancel the currently running task. Returns True if a task was cancelled.""" if self._current_task and not self._current_task.done(): diff --git a/messaging/tree_processor.py b/messaging/tree_processor.py index 71ee939..503176e 100644 --- a/messaging/tree_processor.py +++ b/messaging/tree_processor.py @@ -90,7 +90,7 @@ class TreeQueueProcessor: node.node_id, MessageState.ERROR, error_message=str(e) ) finally: - tree._current_node_id = None + tree.clear_current_node() # Check if there are more messages in the queue await self._process_next(tree, processor) @@ -102,16 +102,15 @@ class TreeQueueProcessor: """Process the next message in queue, if any.""" next_node_id = None node = None - async with tree._lock: + async with tree.with_lock(): next_node_id = await tree.dequeue() if not next_node_id: - # No more messages, mark tree as free - tree._is_processing = False + tree.set_processing_state(None, False) logger.debug(f"Tree {tree.root_id} queue empty, marking as free") return - tree._current_node_id = next_node_id + tree.set_processing_state(next_node_id, True) logger.info(f"Processing next queued node {next_node_id}") # Process next node (outside lock) @@ -143,17 +142,14 @@ class TreeQueueProcessor: Returns: True if queued, False if processing immediately """ - async with tree._lock: - if tree._is_processing: - # Tree is busy, queue the message - await tree._queue.put(node_id) - queue_size = tree._queue.qsize() + async with tree.with_lock(): + if tree.is_processing: + tree.put_queue_unlocked(node_id) + queue_size = tree.get_queue_size() logger.info(f"Queued node {node_id}, position {queue_size}") return True else: - # Tree is free, start processing - tree._is_processing = True - tree._current_node_id = node_id + tree.set_processing_state(node_id, True) # Process outside the lock node = tree.get_node(node_id) diff --git a/messaging/tree_queue.py b/messaging/tree_queue.py index eb59931..10dfdf6 100644 --- a/messaging/tree_queue.py +++ b/messaging/tree_queue.py @@ -6,7 +6,6 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic. import asyncio import logging -from collections import deque from datetime import datetime, timezone from typing import Callable, Awaitable, List, Optional @@ -293,7 +292,7 @@ class TreeQueueManager: if not tree: return [] - async with tree._lock: + async with tree.with_lock(): node = tree.get_node(node_id) if not node: return [] @@ -301,18 +300,12 @@ class TreeQueueManager: if node.state in (MessageState.COMPLETED, MessageState.ERROR): return [] - # Cancel running task if this is the current node. - if tree._current_node_id == node_id: + if tree.is_current_node(node_id): self._processor.cancel_current(tree) - # Remove from queue if present (asyncio.Queue exposes its internal deque). try: - q = tree._queue._queue # type: ignore[attr-defined] - if q and node_id in q: - tree._queue._queue = deque(x for x in q if x != node_id) # type: ignore[attr-defined] + tree.remove_from_queue(node_id) except Exception: - # Best-effort: if we can't mutate the queue internals, the node will - # still be dequeued later and skipped due to state=ERROR. logger.debug( "Failed to remove node from queue; will rely on state=ERROR" ) diff --git a/tests/test_messaging.py b/tests/test_messaging.py index 98beeeb..1310c2d 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -65,145 +65,6 @@ class TestSessionStore: store = SessionStore(storage_path=str(tmp_path / "sessions.json")) assert store._sessions == {} - def test_save_and_get_session(self, tmp_path): - """Test saving and retrieving a session.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - - store.save_session( - session_id="sess_123", - chat_id="chat_456", - initial_msg_id="msg_789", - platform="telegram", - ) - - # Retrieve by message - found = store.get_session_by_msg("chat_456", "msg_789", "telegram") - assert found == "sess_123" - - # Verify persistence file created - assert os.path.exists(str(tmp_path / "sessions.json")) - - def test_update_last_message(self, tmp_path): - """Test updating last message in session.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - - store.save_session("sess_1", "chat_1", "msg_1", "telegram") - store.update_last_message("sess_1", "msg_2") - - # Should find session by new message too - found = store.get_session_by_msg("chat_1", "msg_2", "telegram") - assert found == "sess_1" - - # Original message mapping should still work - found_old = store.get_session_by_msg("chat_1", "msg_1", "telegram") - assert found_old == "sess_1" - - # Verify record updated - record = store.get_session_record("sess_1") - assert record is not None - assert record.last_msg_id == "msg_2" - - def test_update_last_message_unknown_session(self, tmp_path): - """Test updating unknown session does nothing.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - store.update_last_message("unknown", "msg_x") - # Should log warning but not crash - - def test_get_session_record(self, tmp_path): - """Test getting full session record.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - store.save_session("sess_1", "chat_1", "msg_1", "telegram") - - record = store.get_session_record("sess_1") - assert record is not None - assert record.session_id == "sess_1" - assert record.platform == "telegram" - - def test_session_not_found(self, tmp_path): - """Test getting non-existent session returns None.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - - found = store.get_session_by_msg("notexist", "notexist", "telegram") - assert found is None - - def test_rename_session(self, tmp_path): - """Test renaming a session.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - store.save_session("old_id", "c1", "m1", "telegram") - store.update_last_message("old_id", "m2") - - success = store.rename_session("old_id", "new_id") - assert success is True - - # Verify old id gone - assert store.get_session_record("old_id") is None - - # Verify new id exists - rec = store.get_session_record("new_id") - assert rec is not None - assert rec.session_id == "new_id" - - # Verify mappings point to new id - assert store.get_session_by_msg("c1", "m1", "telegram") == "new_id" - assert store.get_session_by_msg("c1", "m2", "telegram") == "new_id" - - def test_rename_unknown_session(self, tmp_path): - """Test renaming unknown session fails.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - success = store.rename_session("unknown", "new") - assert success is False - - def test_cleanup_old_sessions(self, tmp_path): - """Test cleaning up expired sessions.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - - # Create an old session manually - old_date = (datetime.now(timezone.utc) - timedelta(days=40)).isoformat() - store.save_session("old_sess", "c_old", "m_old") - # Manipulate the created_at directly - store._sessions["old_sess"].created_at = old_date - - # Create a new session - store.save_session("new_sess", "c_new", "m_new") - - # Cleanup - removed = store.cleanup_old_sessions(max_age_days=30) - assert removed == 1 - - assert store.get_session_record("old_sess") is None - assert store.get_session_by_msg("c_old", "m_old") is None - assert store.get_session_record("new_sess") is not None - - def test_cleanup_old_sessions_invalid_date(self, tmp_path): - """Test cleanup handles invalid date formats gracefully.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - store.save_session("bad_date_sess", "c", "m") - store._sessions["bad_date_sess"].created_at = "not-a-date" - - # Should not crash - store.cleanup_old_sessions(30) - # Should still exist because parsing failed so it wasn't removed (or default behavior) - # The code tries parsing, excepts, and continues, so it isn't removed. - assert store.get_session_record("bad_date_sess") is not None - # --- Tree Tests --- def test_save_and_get_tree(self, tmp_path): @@ -225,34 +86,6 @@ class TestSessionStore: assert store.get_tree_root_for_node("r1") == "r1" assert store.get_tree_root_for_node("n1") == "r1" - # Verify get_tree_by_node - assert store.get_tree_by_node("n1") == tree_data - assert store.get_tree_by_node("unknown") is None - - def test_update_tree_node(self, tmp_path): - """Test updating a specific node in a tree.""" - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - - store.save_tree("r1", {"nodes": {"r1": {}}}) - - # Add new node - store.update_tree_node("r1", "n2", {"data": "test"}) - - tree = store.get_tree("r1") - assert tree is not None - assert "n2" in tree["nodes"] - assert tree["nodes"]["n2"]["data"] == "test" - assert store.get_tree_root_for_node("n2") == "r1" - - def test_update_tree_node_unknown_tree(self, tmp_path): - from messaging.session import SessionStore - - store = SessionStore(storage_path=str(tmp_path / "sessions.json")) - store.update_tree_node("unknown_root", "n1", {}) - # Should not crash - def test_register_node(self, tmp_path): """Test manual node registration.""" from messaging.session import SessionStore @@ -296,7 +129,7 @@ class TestSessionStore: # --- Persistence & Edge Cases --- def test_load_existing_legacy_format(self, tmp_path): - """Test loading legacy session format (int IDs).""" + """Test loading legacy session format (int IDs) - backward compat.""" from messaging.session import SessionStore data = { @@ -318,12 +151,11 @@ class TestSessionStore: json.dump(data, f) store = SessionStore(storage_path=str(p)) - rec = store.get_session_record("s1") - assert rec is not None - + # Legacy sessions are loaded for backward compat; verify conversion + assert "s1" in store._sessions + rec = store._sessions["s1"] assert rec.chat_id == "123" # Converted to str assert rec.platform == "telegram" # Defaulted - assert store.get_session_by_msg("123", "100", "telegram") == "s1" def test_load_corrupt_file(self, tmp_path): """Test loading corrupt/invalid json file.""" @@ -342,13 +174,14 @@ class TestSessionStore: from messaging.session import SessionStore store = SessionStore(storage_path=str(tmp_path / "sessions.json")) + store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) # Mock open to raise exception with patch("builtins.open", side_effect=IOError("Disk full")): - store.save_session("s1", "c1", "m1") + store.save_tree("r2", {"root_id": "r2", "nodes": {"r2": {}}}) - # Should log error but not crash. Session should be in memory though. - assert "s1" in store._sessions + # Should log error but not crash. Tree should be in memory. + assert "r2" in store._trees class TestTreeQueueManager: diff --git a/tests/test_request_utils.py b/tests/test_request_utils.py index 8093814..6df9db9 100644 --- a/tests/test_request_utils.py +++ b/tests/test_request_utils.py @@ -3,13 +3,13 @@ import pytest from unittest.mock import MagicMock -from api.request_utils import ( +from api.detection import ( is_quota_check_request, is_title_generation_request, - extract_command_prefix, is_prefix_detection_request, - get_token_count, ) +from api.command_utils import extract_command_prefix +from api.request_utils import get_token_count from api.models.anthropic import MessagesRequest, Message diff --git a/tests/test_request_utils_filepaths_and_suggestions.py b/tests/test_request_utils_filepaths_and_suggestions.py index c3302f3..4a916de 100644 --- a/tests/test_request_utils_filepaths_and_suggestions.py +++ b/tests/test_request_utils_filepaths_and_suggestions.py @@ -2,11 +2,11 @@ import pytest from unittest.mock import MagicMock from api.models.anthropic import MessagesRequest, Message -from api.request_utils import ( +from api.detection import ( is_suggestion_mode_request, is_filepath_extraction_request, - extract_filepaths_from_command, ) +from api.command_utils import extract_filepaths_from_command def _mk_req(messages, tools=None): diff --git a/tests/test_session_store_edge_cases.py b/tests/test_session_store_edge_cases.py index d190291..e1cb590 100644 --- a/tests/test_session_store_edge_cases.py +++ b/tests/test_session_store_edge_cases.py @@ -85,44 +85,17 @@ class TestSessionStoreSaveEdgeCases: def test_save_io_error_handled(self, tmp_store): """Write failure in _save() is logged but doesn't raise.""" - tmp_store.save_session("s1", "c1", "m1") - # Make the path read-only dir to trigger write error + tmp_store.save_tree("r1", {"root_id": "r1", "nodes": {"r1": {}}}) with patch("builtins.open", side_effect=IOError("disk full")): - # _save() catches all exceptions and logs them tmp_store._save() # Should not raise -class TestSessionStoreOperationEdgeCases: - """Tests for edge cases in session operations.""" - - def test_rename_session_not_found(self, tmp_store): - """rename_session with non-existent old_id returns False.""" - result = tmp_store.rename_session("nonexistent", "new_id") - assert result is False - - def test_update_last_message_not_found(self, tmp_store): - """update_last_message with unknown session_id logs warning.""" - # Should not raise - tmp_store.update_last_message("nonexistent", "msg_1") - - def test_get_session_by_msg_not_found(self, tmp_store): - """Looking up non-existent message returns None.""" - result = tmp_store.get_session_by_msg("c1", "m999") - assert result is None - - def test_get_session_record_not_found(self, tmp_store): - """Getting non-existent session record returns None.""" - result = tmp_store.get_session_record("nonexistent") - assert result is None - - class TestSessionStoreClearAll: def test_clear_all_wipes_state_and_persists(self, tmp_path): path = str(tmp_path / "sessions.json") store = SessionStore(storage_path=path) - store.save_session("s1", "c1", "m1", platform="telegram") store.save_tree( "root1", { @@ -154,7 +127,6 @@ class TestSessionStoreClearAll: store.clear_all() - assert store.get_session_by_msg("c1", "m1", "telegram") is None assert store.get_all_trees() == {} assert store.get_node_mapping() == {} @@ -187,15 +159,6 @@ class TestSessionStoreClearAll: class TestSessionStoreCleanupEdgeCases: """Tests for cleanup with malformed data.""" - def test_cleanup_sessions_malformed_timestamp(self, tmp_store): - """Malformed created_at in cleanup doesn't crash.""" - tmp_store.save_session("s1", "c1", "m1") - # Corrupt the timestamp - tmp_store._sessions["s1"].created_at = "not-a-date" - # Should not crash - the except block silently skips bad records - removed = tmp_store.cleanup_old_sessions(max_age_days=0) - assert removed == 0 # Skipped due to parse error - def test_cleanup_trees_malformed_timestamp(self, tmp_store): """Malformed created_at in cleanup_old_trees doesn't crash.""" tmp_store._trees["root1"] = {"nodes": {"root1": {"created_at": "bad-date"}}} @@ -207,9 +170,3 @@ class TestSessionStoreCleanupEdgeCases: tmp_store._trees["root1"] = {"nodes": {"root1": {}}} removed = tmp_store.cleanup_old_trees(max_age_days=0) assert removed == 0 - - def test_update_tree_node_nonexistent_tree(self, tmp_store): - """Updating a node in a nonexistent tree logs warning.""" - tmp_store.update_tree_node("nonexistent", "node1", {"data": "test"}) - # Should not crash, tree not created - assert "nonexistent" not in tmp_store._trees diff --git a/tests/test_tree_queue.py b/tests/test_tree_queue.py index 1aaebd4..0272972 100644 --- a/tests/test_tree_queue.py +++ b/tests/test_tree_queue.py @@ -468,8 +468,8 @@ class TestSessionStoreTrees: assert retrieved is not None assert retrieved["root_id"] == "root_1" - def test_get_tree_by_node(self, tmp_path): - """Test getting tree by node ID.""" + def test_get_tree_by_root_id(self, tmp_path): + """Test getting tree by root ID and node mapping.""" from messaging.session import SessionStore store = SessionStore(storage_path=str(tmp_path / "sessions.json")) @@ -484,10 +484,10 @@ class TestSessionStoreTrees: store.save_tree("root", tree_data) - # Should find tree by child node - retrieved = store.get_tree_by_node("child") + retrieved = store.get_tree("root") assert retrieved is not None assert retrieved["root_id"] == "root" + assert store.get_tree_root_for_node("child") == "root" def test_register_node(self, tmp_path): """Test registering a node to a tree."""