free-claude-code/messaging/tree_queue.py
Cursor Agent 959d1bdf5b Phase 5: Encapsulation - Repository and MessageTree API
TreeRepository:
- Add has_node(node_id) -> bool
- Add tree_count() -> int

MessageTree:
- Add cancel_current_task() -> bool
- Add drain_queue_and_mark_cancelled() -> List[MessageNode]
- Add reset_processing_state()
- Add current_node_id property

TreeQueueManager:
- Remove _trees and _node_to_tree properties
- add_to_tree: use has_node() and get_tree_for_node()
- get_tree_count: use repository.tree_count()
- cancel_tree: use tree.cancel_current_task(), drain_queue_and_mark_cancelled(), reset_processing_state()

TreeQueueProcessor:
- cancel_current: delegate to tree.cancel_current_task()

Tests: update to use get_tree_count(), has_node(), get_tree_for_node()

Co-authored-by: Ali Khokhar <alishahryar2@gmail.com>
2026-02-15 01:38:35 +00:00

400 lines
14 KiB
Python

"""Tree-Based Message Queue Manager - Refactored.
Coordinates data access, async processing, and error handling.
Uses TreeRepository for data, TreeQueueProcessor for async logic.
"""
import asyncio
import logging
from collections import deque
from datetime import datetime, timezone
from typing import Callable, Awaitable, List, Optional
from .models import IncomingMessage
from .tree_data import MessageState, MessageNode, MessageTree
from .tree_repository import TreeRepository
from .tree_processor import TreeQueueProcessor
# Backward compatibility: re-export moved classes
__all__ = [
"TreeQueueManager",
"MessageState",
"MessageNode",
"MessageTree",
]
logger = logging.getLogger(__name__)
class TreeQueueManager:
"""
Manages multiple message trees. Facade that coordinates components.
Each new conversation creates a new tree.
Replies to existing messages add nodes to existing trees.
Components:
- TreeRepository: Data access layer
- TreeQueueProcessor: Async queue processing
"""
def __init__(
self,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
):
self._repository = TreeRepository()
self._processor = TreeQueueProcessor(
queue_update_callback=queue_update_callback,
node_started_callback=node_started_callback,
)
self._lock = asyncio.Lock()
logger.info("TreeQueueManager initialized")
async def create_tree(
self,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
) -> MessageTree:
"""
Create a new tree with a root node.
Args:
node_id: ID for the root node
incoming: The incoming message
status_message_id: Bot's status message ID
Returns:
The created MessageTree
"""
async with self._lock:
root_node = MessageNode(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
state=MessageState.PENDING,
)
tree = MessageTree(root_node)
self._repository.add_tree(node_id, tree)
logger.info(f"Created new tree with root {node_id}")
return tree
async def add_to_tree(
self,
parent_node_id: str,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
) -> tuple[MessageTree, MessageNode]:
"""
Add a reply as a child node to an existing tree.
Args:
parent_node_id: ID of the parent message
node_id: ID for the new node
incoming: The incoming reply message
status_message_id: Bot's status message ID
Returns:
Tuple of (tree, new_node)
"""
async with self._lock:
if not self._repository.has_node(parent_node_id):
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
tree = self._repository.get_tree_for_node(parent_node_id)
if not tree:
raise ValueError(f"Parent node {parent_node_id} not found in any tree")
# Add node (tree has its own lock) - outside manager lock to avoid deadlock
node = await tree.add_node(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
parent_id=parent_node_id,
)
async with self._lock:
self._repository.register_node(node_id, tree.root_id)
logger.info(f"Added node {node_id} to tree {tree.root_id}")
return tree, node
def get_tree(self, root_id: str) -> Optional[MessageTree]:
"""Get a tree by its root ID."""
return self._repository.get_tree(root_id)
def get_tree_for_node(self, node_id: str) -> Optional[MessageTree]:
"""Get the tree containing a given node."""
return self._repository.get_tree_for_node(node_id)
def get_node(self, node_id: str) -> Optional[MessageNode]:
"""Get a node from any tree."""
return self._repository.get_node(node_id)
def resolve_parent_node_id(self, msg_id: str) -> Optional[str]:
"""Resolve a message ID to the actual parent node ID."""
return self._repository.resolve_parent_node_id(msg_id)
def is_tree_busy(self, root_id: str) -> bool:
"""Check if a tree is currently processing."""
return self._repository.is_tree_busy(root_id)
def is_node_tree_busy(self, node_id: str) -> bool:
"""Check if the tree containing a node is busy."""
return self._repository.is_node_tree_busy(node_id)
async def enqueue(
self,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node for processing.
If the tree is not busy, processing starts immediately.
If busy, the message is queued.
Args:
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
logger.error(f"No tree found for node {node_id}")
return False
return await self._processor.enqueue_and_start(tree, node_id, processor)
def get_queue_size(self, node_id: str) -> int:
"""Get queue size for the tree containing a node."""
return self._repository.get_queue_size(node_id)
def get_pending_children(self, node_id: str) -> List[MessageNode]:
"""Get all pending child nodes (recursively) of a given node."""
return self._repository.get_pending_children(node_id)
async def mark_node_error(
self,
node_id: str,
error_message: str,
propagate_to_children: bool = True,
) -> List[MessageNode]:
"""
Mark a node as ERROR and optionally propagate to pending children.
Args:
node_id: The node to mark as error
error_message: Error description
propagate_to_children: If True, also mark pending children as error
Returns:
List of all nodes marked as error (including children)
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
affected = []
node = tree.get_node(node_id)
if node:
await tree.update_state(
node_id, MessageState.ERROR, error_message=error_message
)
affected.append(node)
if propagate_to_children:
pending_children = self._repository.get_pending_children(node_id)
for child in pending_children:
await tree.update_state(
child.node_id,
MessageState.ERROR,
error_message=f"Parent failed: {error_message}",
)
affected.append(child)
return affected
def cancel_tree(self, root_id: str) -> List[MessageNode]:
"""
Cancel all queued and in-progress messages in a tree.
Updates node states to ERROR and returns list of affected nodes
that were actually active or in the current processing queue.
"""
tree = self._repository.get_tree(root_id)
if not tree:
return []
cancelled_nodes = []
# 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
node = tree.get_node(current_id)
if node and node.state not in (
MessageState.COMPLETED,
MessageState.ERROR,
):
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
cancelled_nodes.append(node)
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
cleanup_count = 0
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node not in cancelled_nodes
):
node.state = MessageState.ERROR
node.error_message = "Stale task cleaned up"
cleanup_count += 1
tree.reset_processing_state()
if cancelled_nodes:
logger.info(
f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
)
if cleanup_count:
logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")
return cancelled_nodes
async def cancel_node(self, node_id: str) -> List[MessageNode]:
"""
Cancel a single node (queued or in-progress) without affecting other nodes.
- If the node is currently running, cancels the current asyncio task.
- If the node is queued, removes it from the queue.
- Marks the node as ERROR with "Cancelled by user".
Returns:
List containing the cancelled node if it was cancellable, else empty list.
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
async with tree._lock:
node = tree.get_node(node_id)
if not node:
return []
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
return []
# Cancel running task if this is the current node.
if tree._current_node_id == node_id:
self._processor.cancel_current(tree)
# Remove from queue if present (asyncio.Queue exposes its internal deque).
try:
q = tree._queue._queue # type: ignore[attr-defined]
if q and node_id in q:
tree._queue._queue = deque(x for x in q if x != node_id) # type: ignore[attr-defined]
except Exception:
# Best-effort: if we can't mutate the queue internals, the node will
# still be dequeued later and skipped due to state=ERROR.
logger.debug(
"Failed to remove node from queue; will rely on state=ERROR"
)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
return [node]
async def cancel_all(self) -> List[MessageNode]:
"""Cancel all messages in all trees (async wrapper)."""
async with self._lock:
return self.cancel_all_sync()
def cancel_all_sync(self) -> List[MessageNode]:
"""
Cancel all messages in all trees (synchronous/locked version).
NOTE: Must be called with self._lock held.
"""
all_cancelled = []
for root_id in list(self._repository.tree_ids()):
all_cancelled.extend(self.cancel_tree(root_id))
return all_cancelled
def cleanup_stale_nodes(self) -> int:
"""
Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
Used on startup to reconcile restored state.
"""
count = 0
for tree in self._repository.all_trees():
for node in tree.all_nodes():
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
node.state = MessageState.ERROR
node.error_message = "Lost during server restart"
count += 1
if count:
logger.info(f"Cleaned up {count} stale nodes during startup")
return count
def get_tree_count(self) -> int:
"""Get the number of active message trees."""
return self._repository.tree_count()
def set_queue_update_callback(
self,
queue_update_callback: Optional[Callable[[MessageTree], Awaitable[None]]],
) -> None:
"""Set callback for queue position updates."""
self._processor.set_queue_update_callback(queue_update_callback)
def set_node_started_callback(
self,
node_started_callback: Optional[Callable[[MessageTree, str], Awaitable[None]]],
) -> None:
"""Set callback for when a queued node starts processing."""
self._processor.set_node_started_callback(node_started_callback)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree (for external mapping)."""
self._repository.register_node(node_id, root_id)
def to_dict(self) -> dict:
"""Serialize all trees."""
return self._repository.to_dict()
@classmethod
def from_dict(
cls,
data: dict,
queue_update_callback: Optional[
Callable[[MessageTree], Awaitable[None]]
] = None,
node_started_callback: Optional[
Callable[[MessageTree, str], Awaitable[None]]
] = None,
) -> "TreeQueueManager":
"""Deserialize from dictionary."""
manager = cls(
queue_update_callback=queue_update_callback,
node_started_callback=node_started_callback,
)
manager._repository = TreeRepository.from_dict(data)
return manager