free-claude-code/messaging/trees/processor.py
2026-03-01 04:28:22 -08:00

165 lines
5.7 KiB
Python

"""Async queue processor for message trees.
Handles the async processing lifecycle of tree nodes.
"""
import asyncio
from collections.abc import Awaitable, Callable
from loguru import logger
from providers.common import get_user_facing_error_message
from .data import MessageNode, MessageState, MessageTree
class TreeQueueProcessor:
"""
Handles async queue processing for a single tree.
Separates the async processing logic from the data management.
"""
def __init__(
self,
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None = None,
node_started_callback: Callable[[MessageTree, str], Awaitable[None]]
| None = None,
):
self._queue_update_callback = queue_update_callback
self._node_started_callback = node_started_callback
def set_queue_update_callback(
self,
queue_update_callback: Callable[[MessageTree], Awaitable[None]] | None,
) -> None:
"""Update the callback used to refresh queue positions."""
self._queue_update_callback = queue_update_callback
def set_node_started_callback(
self,
node_started_callback: Callable[[MessageTree, str], Awaitable[None]] | None,
) -> None:
"""Update the callback used when a queued node starts processing."""
self._node_started_callback = node_started_callback
async def _notify_queue_updated(self, tree: MessageTree) -> None:
"""Invoke queue update callback if set."""
if not self._queue_update_callback:
return
try:
await self._queue_update_callback(tree)
except Exception as e:
logger.warning(f"Queue update callback failed: {e}")
async def _notify_node_started(self, tree: MessageTree, node_id: str) -> None:
"""Invoke node started callback if set."""
if not self._node_started_callback:
return
try:
await self._node_started_callback(tree, node_id)
except Exception as e:
logger.warning(f"Node started callback failed: {e}")
async def process_node(
self,
tree: MessageTree,
node: MessageNode,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process a single node and then check the queue."""
# Skip if already in terminal state (e.g. from error propagation)
if node.state == MessageState.ERROR:
logger.info(
f"Skipping node {node.node_id} as it is already in state {node.state}"
)
# Still need to check for next messages
await self._process_next(tree, processor)
return
try:
await processor(node.node_id, node)
except asyncio.CancelledError:
logger.info(f"Task for node {node.node_id} was cancelled")
raise
except Exception as e:
logger.error(f"Error processing node {node.node_id}: {e}")
await tree.update_state(
node.node_id,
MessageState.ERROR,
error_message=get_user_facing_error_message(e),
)
finally:
async with tree.with_lock():
tree.clear_current_node()
# Check if there are more messages in the queue
await self._process_next(tree, processor)
async def _process_next(
self,
tree: MessageTree,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> None:
"""Process the next message in queue, if any."""
next_node_id = None
node = None
async with tree.with_lock():
next_node_id = await tree.dequeue()
if not next_node_id:
tree.set_processing_state(None, False)
logger.debug(f"Tree {tree.root_id} queue empty, marking as free")
return
tree.set_processing_state(next_node_id, True)
logger.info(f"Processing next queued node {next_node_id}")
# Process next node (outside lock)
node = tree.get_node(next_node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
# Notify that this node has started processing and refresh queue positions.
if next_node_id:
await self._notify_node_started(tree, next_node_id)
await self._notify_queue_updated(tree)
async def enqueue_and_start(
self,
tree: MessageTree,
node_id: str,
processor: Callable[[str, MessageNode], Awaitable[None]],
) -> bool:
"""
Enqueue a node or start processing immediately.
Args:
tree: The message tree
node_id: Node to process
processor: Async function to process the node
Returns:
True if queued, False if processing immediately
"""
async with tree.with_lock():
if tree.is_processing:
tree.put_queue_unlocked(node_id)
queue_size = tree.get_queue_size()
logger.info(f"Queued node {node_id}, position {queue_size}")
return True
else:
tree.set_processing_state(node_id, True)
# Process outside the lock
node = tree.get_node(node_id)
if node:
tree.set_current_task(
asyncio.create_task(self.process_node(tree, node, processor))
)
return False
def cancel_current(self, tree: MessageTree) -> bool:
"""Cancel the currently running task in a tree."""
return tree.cancel_current_task()