free-claude-code/messaging/trees/data.py

482 lines
16 KiB
Python

"""Tree data structures for message queue.
Contains MessageState, MessageNode, and MessageTree classes.
"""
import asyncio
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from loguru import logger
from ..models import IncomingMessage
class _SnapshotQueue:
"""Queue with snapshot/remove helpers, backed by a deque and a set index."""
def __init__(self) -> None:
self._deque: deque[str] = deque()
self._set: set[str] = set()
async def put(self, item: str) -> None:
self._deque.append(item)
self._set.add(item)
def put_nowait(self, item: str) -> None:
self._deque.append(item)
self._set.add(item)
def get_nowait(self) -> str:
if not self._deque:
raise asyncio.QueueEmpty()
item = self._deque.popleft()
self._set.discard(item)
return item
def qsize(self) -> int:
return len(self._deque)
def get_snapshot(self) -> list[str]:
"""Return current queue contents in FIFO order (read-only copy)."""
return list(self._deque)
def remove_if_present(self, item: str) -> bool:
"""Remove item from queue if present (O(1) membership check). Returns True if removed."""
if item not in self._set:
return False
self._set.discard(item)
self._deque = deque(x for x in self._deque if x != item)
return True
class MessageState(Enum):
"""State of a message node in the tree."""
PENDING = "pending" # Queued, waiting to be processed
IN_PROGRESS = "in_progress" # Currently being processed by Claude
COMPLETED = "completed" # Processing finished successfully
ERROR = "error" # Processing failed
@dataclass
class MessageNode:
"""
A node in the message tree.
Each node represents a single message and tracks:
- Its relationship to parent/children
- Its processing state
- Claude session information
"""
node_id: str # Unique ID (typically message_id)
incoming: IncomingMessage # The original message
status_message_id: str # Bot's status message ID
state: MessageState = MessageState.PENDING
parent_id: str | None = None # Parent node ID (None for root)
session_id: str | None = None # Claude session ID (forked from parent)
children_ids: list[str] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
completed_at: datetime | None = None
error_message: str | None = None
context: Any = None # Additional context if needed
def set_context(self, context: Any) -> None:
self.context = context
def to_dict(self) -> dict:
"""Convert to dictionary for JSON serialization."""
return {
"node_id": self.node_id,
"incoming": {
"text": self.incoming.text,
"chat_id": self.incoming.chat_id,
"user_id": self.incoming.user_id,
"message_id": self.incoming.message_id,
"platform": self.incoming.platform,
"reply_to_message_id": self.incoming.reply_to_message_id,
"message_thread_id": self.incoming.message_thread_id,
"username": self.incoming.username,
},
"status_message_id": self.status_message_id,
"state": self.state.value,
"parent_id": self.parent_id,
"session_id": self.session_id,
"children_ids": self.children_ids,
"created_at": self.created_at.isoformat(),
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"error_message": self.error_message,
}
@classmethod
def from_dict(cls, data: dict) -> MessageNode:
"""Create from dictionary (JSON deserialization)."""
incoming_data = data["incoming"]
incoming = IncomingMessage(
text=incoming_data["text"],
chat_id=incoming_data["chat_id"],
user_id=incoming_data["user_id"],
message_id=incoming_data["message_id"],
platform=incoming_data["platform"],
reply_to_message_id=incoming_data.get("reply_to_message_id"),
message_thread_id=incoming_data.get("message_thread_id"),
username=incoming_data.get("username"),
)
return cls(
node_id=data["node_id"],
incoming=incoming,
status_message_id=data["status_message_id"],
state=MessageState(data["state"]),
parent_id=data.get("parent_id"),
session_id=data.get("session_id"),
children_ids=data.get("children_ids", []),
created_at=datetime.fromisoformat(data["created_at"]),
completed_at=datetime.fromisoformat(data["completed_at"])
if data.get("completed_at")
else None,
error_message=data.get("error_message"),
)
class MessageTree:
"""
A tree of message nodes with queue functionality.
Provides:
- O(1) node lookup via hashmap
- Per-tree message queue
- Thread-safe operations via asyncio.Lock
"""
def __init__(self, root_node: MessageNode):
"""
Initialize tree with a root node.
Args:
root_node: The root message node
"""
self.root_id = root_node.node_id
self._nodes: dict[str, MessageNode] = {root_node.node_id: root_node}
self._status_to_node: dict[str, str] = {
root_node.status_message_id: root_node.node_id
}
self._queue: _SnapshotQueue = _SnapshotQueue()
self._lock = asyncio.Lock()
self._is_processing = False
self._current_node_id: str | None = None
self._current_task: asyncio.Task | None = None
logger.debug(f"Created MessageTree with root {self.root_id}")
def set_current_task(self, task: asyncio.Task | None) -> None:
"""Set the current processing task. Caller must hold lock."""
self._current_task = task
@property
def is_processing(self) -> bool:
"""Check if tree is currently processing a message."""
return self._is_processing
async def add_node(
self,
node_id: str,
incoming: IncomingMessage,
status_message_id: str,
parent_id: str,
) -> MessageNode:
"""
Add a child node to the tree.
Args:
node_id: Unique ID for the new node
incoming: The incoming message
status_message_id: Bot's status message ID
parent_id: Parent node ID
Returns:
The created MessageNode
"""
async with self._lock:
if parent_id not in self._nodes:
raise ValueError(f"Parent node {parent_id} not found in tree")
node = MessageNode(
node_id=node_id,
incoming=incoming,
status_message_id=status_message_id,
parent_id=parent_id,
state=MessageState.PENDING,
)
self._nodes[node_id] = node
self._status_to_node[status_message_id] = node_id
self._nodes[parent_id].children_ids.append(node_id)
logger.debug(f"Added node {node_id} as child of {parent_id}")
return node
def get_node(self, node_id: str) -> MessageNode | None:
"""Get a node by ID (O(1) lookup)."""
return self._nodes.get(node_id)
def get_root(self) -> MessageNode:
"""Get the root node."""
return self._nodes[self.root_id]
def get_children(self, node_id: str) -> list[MessageNode]:
"""Get all child nodes of a given node."""
node = self._nodes.get(node_id)
if not node:
return []
return [self._nodes[cid] for cid in node.children_ids if cid in self._nodes]
def get_parent(self, node_id: str) -> MessageNode | None:
"""Get the parent node."""
node = self._nodes.get(node_id)
if not node or not node.parent_id:
return None
return self._nodes.get(node.parent_id)
def get_parent_session_id(self, node_id: str) -> str | None:
"""
Get the parent's session ID for forking.
Returns None for root nodes.
"""
parent = self.get_parent(node_id)
return parent.session_id if parent else None
async def update_state(
self,
node_id: str,
state: MessageState,
session_id: str | None = None,
error_message: str | None = None,
) -> None:
"""Update a node's state."""
async with self._lock:
node = self._nodes.get(node_id)
if not node:
logger.warning(f"Node {node_id} not found for state update")
return
node.state = state
if session_id:
node.session_id = session_id
if error_message:
node.error_message = error_message
if state in (MessageState.COMPLETED, MessageState.ERROR):
node.completed_at = datetime.now(UTC)
logger.debug(f"Node {node_id} state -> {state.value}")
async def enqueue(self, node_id: str) -> int:
"""
Add a node to the processing queue.
Returns:
Queue position (1-indexed)
"""
async with self._lock:
await self._queue.put(node_id)
position = self._queue.qsize()
logger.debug(f"Enqueued node {node_id}, position {position}")
return position
async def dequeue(self) -> str | None:
"""
Get the next node ID from the queue.
Returns None if queue is empty.
"""
try:
return self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
async def get_queue_snapshot(self) -> list[str]:
"""
Get a snapshot of the current queue order.
Returns:
List of node IDs in FIFO order.
"""
async with self._lock:
return self._queue.get_snapshot()
def get_queue_size(self) -> int:
"""Get number of messages waiting in queue."""
return self._queue.qsize()
def remove_from_queue(self, node_id: str) -> bool:
"""
Remove node_id from the internal queue if present.
Caller must hold the tree lock (e.g. via with_lock).
Returns True if node was removed, False if not in queue.
"""
return self._queue.remove_if_present(node_id)
@asynccontextmanager
async def with_lock(self):
"""Async context manager for tree lock. Use when multiple operations need atomicity."""
async with self._lock:
yield
def set_processing_state(self, node_id: str | None, is_processing: bool) -> None:
"""Set processing state. Caller must hold lock for consistency with queue operations."""
self._is_processing = is_processing
self._current_node_id = node_id if is_processing else None
def clear_current_node(self) -> None:
"""Clear the currently processing node ID. Caller must hold lock."""
self._current_node_id = None
def is_current_node(self, node_id: str) -> bool:
"""Check if node_id is the currently processing node."""
return self._current_node_id == node_id
def put_queue_unlocked(self, node_id: str) -> None:
"""Add node to queue. Caller must hold lock (e.g. via with_lock)."""
self._queue.put_nowait(node_id)
def cancel_current_task(self) -> bool:
"""Cancel the currently running task. Returns True if a task was cancelled."""
if self._current_task and not self._current_task.done():
self._current_task.cancel()
return True
return False
def set_node_error_sync(self, node: MessageNode, error_message: str) -> None:
"""Synchronously mark a node as ERROR. Caller must ensure no concurrent access."""
node.state = MessageState.ERROR
node.error_message = error_message
node.completed_at = datetime.now(UTC)
def drain_queue_and_mark_cancelled(
self, error_message: str = "Cancelled by user"
) -> list[MessageNode]:
"""
Drain the queue, mark each node as ERROR, and return affected nodes.
Does not acquire lock; caller must ensure no concurrent queue access.
"""
nodes: list[MessageNode] = []
while True:
try:
node_id = self._queue.get_nowait()
except asyncio.QueueEmpty:
break
node = self._nodes.get(node_id)
if node:
self.set_node_error_sync(node, error_message)
nodes.append(node)
return nodes
def reset_processing_state(self) -> None:
"""Reset processing flags after cancel/cleanup."""
self._is_processing = False
self._current_node_id = None
@property
def current_node_id(self) -> str | None:
"""Get the ID of the node currently being processed."""
return self._current_node_id
def to_dict(self) -> dict:
"""Serialize tree to dictionary."""
return {
"root_id": self.root_id,
"nodes": {nid: node.to_dict() for nid, node in self._nodes.items()},
}
def _add_node_from_dict(self, node: MessageNode) -> None:
"""Register a deserialized node into the tree's internal indices."""
self._nodes[node.node_id] = node
self._status_to_node[node.status_message_id] = node.node_id
@classmethod
def from_dict(cls, data: dict) -> MessageTree:
"""Deserialize tree from dictionary."""
root_id = data["root_id"]
nodes_data = data["nodes"]
# Create root node first
root_node = MessageNode.from_dict(nodes_data[root_id])
tree = cls(root_node)
# Add remaining nodes and build status->node index
for node_id, node_data in nodes_data.items():
if node_id != root_id:
node = MessageNode.from_dict(node_data)
tree._add_node_from_dict(node)
return tree
def all_nodes(self) -> list[MessageNode]:
"""Get all nodes in the tree."""
return list(self._nodes.values())
def has_node(self, node_id: str) -> bool:
"""Check if a node exists in this tree."""
return node_id in self._nodes
def find_node_by_status_message(self, status_msg_id: str) -> MessageNode | None:
"""Find the node that has this status message ID (O(1) lookup)."""
node_id = self._status_to_node.get(status_msg_id)
return self._nodes.get(node_id) if node_id else None
def get_descendants(self, node_id: str) -> list[str]:
"""
Get node_id and all descendant IDs (subtree).
Returns:
List of node IDs including the given node.
"""
if node_id not in self._nodes:
return []
result: list[str] = []
stack = [node_id]
while stack:
nid = stack.pop()
result.append(nid)
node = self._nodes.get(nid)
if node:
stack.extend(node.children_ids)
return result
def remove_branch(self, branch_root_id: str) -> list[MessageNode]:
"""
Remove a subtree (branch_root and all descendants) from the tree.
Updates parent's children_ids. Caller must hold lock for consistency.
Does not acquire lock internally.
Returns:
List of removed nodes.
"""
if branch_root_id not in self._nodes:
return []
parent = self.get_parent(branch_root_id)
removed = []
for nid in self.get_descendants(branch_root_id):
node = self._nodes.get(nid)
if node:
removed.append(node)
del self._nodes[nid]
del self._status_to_node[node.status_message_id]
if parent and branch_root_id in parent.children_ids:
parent.children_ids = [
c for c in parent.children_ids if c != branch_root_id
]
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
return removed