Fixed encapsulation violations

This commit is contained in:
Alishahryar1 2026-03-01 04:28:22 -08:00
parent 302ee28585
commit 35a2760f6e
12 changed files with 129 additions and 96 deletions

View file

@ -109,13 +109,15 @@ async def lifespan(app: FastAPI):
logger.info(f"Restoring {len(saved_trees)} conversation trees...")
from messaging.trees.queue_manager import TreeQueueManager
message_handler.tree_queue = TreeQueueManager.from_dict(
{
"trees": saved_trees,
"node_to_tree": session_store.get_node_mapping(),
},
queue_update_callback=message_handler.update_queue_positions,
node_started_callback=message_handler.mark_node_processing,
message_handler.replace_tree_queue(
TreeQueueManager.from_dict(
{
"trees": saved_trees,
"node_to_tree": session_store.get_node_mapping(),
},
queue_update_callback=message_handler.update_queue_positions,
node_started_callback=message_handler.mark_node_processing,
)
)
# Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
if message_handler.tree_queue.cleanup_stale_nodes() > 0:

View file

@ -275,7 +275,9 @@ async def handle_clear_command(
except Exception as e:
logger.warning(f"Failed to clear session store: {e}")
handler.tree_queue = TreeQueueManager(
queue_update_callback=handler.update_queue_positions,
node_started_callback=handler.mark_node_processing,
handler.replace_tree_queue(
TreeQueueManager(
queue_update_callback=handler.update_queue_positions,
node_started_callback=handler.mark_node_processing,
)
)

View file

@ -120,7 +120,7 @@ class ClaudeMessageHandler:
self.platform = platform
self.cli_manager = cli_manager
self.session_store = session_store
self.tree_queue = TreeQueueManager(
self._tree_queue = TreeQueueManager(
queue_update_callback=self.update_queue_positions,
node_started_callback=self.mark_node_processing,
)
@ -152,6 +152,22 @@ class ClaudeMessageHandler:
def _get_limit_chars(self) -> int:
return self._limit_chars
@property
def tree_queue(self) -> TreeQueueManager:
"""Accessor for the current tree queue manager."""
return self._tree_queue
@tree_queue.setter
def tree_queue(self, tree_queue: TreeQueueManager) -> None:
"""Backward-compatible setter routed through explicit replacement API."""
self.replace_tree_queue(tree_queue)
def replace_tree_queue(self, tree_queue: TreeQueueManager) -> None:
"""Replace tree queue manager via explicit API."""
self._tree_queue = tree_queue
self._tree_queue.set_queue_update_callback(self.update_queue_positions)
self._tree_queue.set_node_started_callback(self.mark_node_processing)
async def handle_message(self, incoming: IncomingMessage) -> None:
"""
Main entry point for handling an incoming message.

View file

@ -66,12 +66,12 @@ if DISCORD_AVAILABLE and _discord_module is not None:
async def on_ready(self) -> None:
"""Called when the bot is ready."""
self._platform._connected = True
self._platform._set_connected(True)
logger.info("Discord platform connected")
async def on_message(self, message: Any) -> None:
"""Handle incoming Discord messages."""
await self._platform._on_discord_message(message)
await self._platform._handle_client_message(message)
else:
_DiscordClient = None
@ -118,6 +118,14 @@ class DiscordPlatform(MessagingPlatform):
self._pending_voice: dict[tuple[str, str], tuple[str, str]] = {}
self._pending_voice_lock = asyncio.Lock()
def _set_connected(self, connected: bool) -> None:
"""Update connection state via an explicit accessor."""
self._connected = connected
async def _handle_client_message(self, message: Any) -> None:
"""Adapter entry point used by the internal discord client."""
await self._on_discord_message(message)
async def _register_pending_voice(
self, chat_id: str, voice_msg_id: str, status_msg_id: str
) -> None:
@ -359,7 +367,7 @@ class DiscordPlatform(MessagingPlatform):
async def stop(self) -> None:
"""Stop the bot."""
if self._client.is_closed():
self._connected = False
self._set_connected(False)
return
await self._client.close()
@ -371,7 +379,7 @@ class DiscordPlatform(MessagingPlatform):
with contextlib.suppress(asyncio.CancelledError):
await self._start_task
self._connected = False
self._set_connected(False)
logger.info("Discord platform stopped")
async def send_message(

View file

@ -90,7 +90,8 @@ class TreeQueueProcessor:
error_message=get_user_facing_error_message(e),
)
finally:
tree.clear_current_node()
async with tree.with_lock():
tree.clear_current_node()
# Check if there are more messages in the queue
await self._process_next(tree, processor)

View file

@ -6,7 +6,6 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic.
import asyncio
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from loguru import logger
@ -222,7 +221,7 @@ class TreeQueueManager:
return affected
def cancel_tree(self, root_id: str) -> list[MessageNode]:
async def cancel_tree(self, root_id: str) -> list[MessageNode]:
"""
Cancel all queued and in-progress messages in a tree.
@ -235,34 +234,35 @@ class TreeQueueManager:
cancelled_nodes = []
# 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
node = tree.get_node(current_id)
if node and node.state not in (
MessageState.COMPLETED,
MessageState.ERROR,
):
tree.set_node_error_sync(node, "Cancelled by user")
cancelled_nodes.append(node)
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
cleanup_count = 0
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node.node_id not in cancelled_ids
):
tree.set_node_error_sync(node, "Stale task cleaned up")
cleanup_count += 1
async with tree.with_lock():
# 1. Cancel running task
if tree.cancel_current_task():
current_id = tree.current_node_id
if current_id:
node = tree.get_node(current_id)
if node and node.state not in (
MessageState.COMPLETED,
MessageState.ERROR,
):
tree.set_node_error_sync(node, "Cancelled by user")
cancelled_nodes.append(node)
tree.reset_processing_state()
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node.node_id not in cancelled_ids
):
tree.set_node_error_sync(node, "Stale task cleaned up")
cleanup_count += 1
tree.reset_processing_state()
if cancelled_nodes:
logger.info(
@ -306,25 +306,18 @@ class TreeQueueManager:
"Failed to remove node from queue; will rely on state=ERROR"
)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(UTC)
tree.set_node_error_sync(node, "Cancelled by user")
return [node]
async def cancel_all(self) -> list[MessageNode]:
"""Cancel all messages in all trees (async wrapper)."""
"""Cancel all messages in all trees."""
async with self._lock:
return self.cancel_all_sync()
root_ids = list(self._repository.tree_ids())
def cancel_all_sync(self) -> list[MessageNode]:
"""
Cancel all messages in all trees (synchronous/locked version).
NOTE: Must be called with self._lock held.
"""
all_cancelled = []
for root_id in list(self._repository.tree_ids()):
all_cancelled.extend(self.cancel_tree(root_id))
all_cancelled: list[MessageNode] = []
for root_id in root_ids:
all_cancelled.extend(await self.cancel_tree(root_id))
return all_cancelled
def cleanup_stale_nodes(self) -> int:
@ -388,15 +381,11 @@ class TreeQueueManager:
if tree.is_current_node(nid):
self._processor.cancel_current(tree)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(UTC)
tree.set_node_error_sync(node, "Cancelled by user")
cancelled.append(node)
else:
tree.remove_from_queue(nid)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(UTC)
tree.set_node_error_sync(node, "Cancelled by user")
cancelled.append(node)
if cancelled:
@ -421,7 +410,7 @@ class TreeQueueManager:
root_id = tree.root_id
if branch_root_id == root_id:
cancelled = self.cancel_tree(root_id)
cancelled = await self.cancel_tree(root_id)
removed_tree = self._repository.remove_tree(root_id)
if removed_tree:
return (removed_tree.all_nodes(), root_id, True)

View file

@ -308,7 +308,7 @@ async def test_stop_all_tasks(handler, mock_cli_manager, mock_platform):
mock_node.status_message_id = "status_1"
with patch.object(
handler.tree_queue, "cancel_all_sync", MagicMock(return_value=[mock_node])
handler.tree_queue, "cancel_all", AsyncMock(return_value=[mock_node])
):
count = await handler.stop_all_tasks()

View file

@ -220,14 +220,15 @@ async def test_handle_message_reply_with_tree_but_no_parent_treated_as_new():
handler = ClaudeMessageHandler(platform, cli_manager, session_store)
# Force "tree exists but parent can't be resolved" branch.
handler.tree_queue = MagicMock()
handler.tree_queue.get_tree_for_node.return_value = object()
handler.tree_queue.resolve_parent_node_id.return_value = None
handler.tree_queue.create_tree = AsyncMock(
mock_queue = MagicMock()
mock_queue.get_tree_for_node.return_value = object()
mock_queue.resolve_parent_node_id.return_value = None
mock_queue.create_tree = AsyncMock(
return_value=MagicMock(root_id="root", to_dict=MagicMock(return_value={"t": 1}))
)
handler.tree_queue.register_node = MagicMock()
handler.tree_queue.enqueue = AsyncMock(return_value=False)
mock_queue.register_node = MagicMock()
mock_queue.enqueue = AsyncMock(return_value=False)
handler.replace_tree_queue(mock_queue)
incoming = IncomingMessage(
text="reply",
@ -239,7 +240,7 @@ async def test_handle_message_reply_with_tree_but_no_parent_treated_as_new():
)
await handler.handle_message(incoming)
handler.tree_queue.create_tree.assert_awaited_once()
mock_queue.create_tree.assert_awaited_once()
@pytest.mark.asyncio
@ -271,8 +272,9 @@ async def test_update_ui_handles_transcript_render_exception():
cli_manager.get_stats.return_value = {"active_sessions": 0}
handler = ClaudeMessageHandler(platform, cli_manager, session_store)
handler.tree_queue = MagicMock()
handler.tree_queue.get_tree_for_node.return_value = None
mock_queue = MagicMock()
mock_queue.get_tree_for_node.return_value = None
handler.replace_tree_queue(mock_queue)
incoming = IncomingMessage(
text="hi",
@ -306,14 +308,15 @@ async def test_handle_message_incoming_text_none_safe():
session_store = MagicMock()
handler = ClaudeMessageHandler(platform, cli_manager, session_store)
handler.tree_queue = MagicMock()
handler.tree_queue.get_tree_for_node.return_value = None
handler.tree_queue.resolve_parent_node_id.return_value = None
handler.tree_queue.create_tree = AsyncMock(
mock_queue = MagicMock()
mock_queue.get_tree_for_node.return_value = None
mock_queue.resolve_parent_node_id.return_value = None
mock_queue.create_tree = AsyncMock(
return_value=MagicMock(root_id="root", to_dict=MagicMock(return_value={"t": 1}))
)
handler.tree_queue.register_node = MagicMock()
handler.tree_queue.enqueue = AsyncMock(return_value=True)
mock_queue.register_node = MagicMock()
mock_queue.enqueue = AsyncMock(return_value=True)
handler.replace_tree_queue(mock_queue)
incoming = MagicMock()
incoming.text = None
@ -325,7 +328,7 @@ async def test_handle_message_incoming_text_none_safe():
incoming.is_reply = MagicMock(return_value=False)
await handler.handle_message(incoming)
handler.tree_queue.create_tree.assert_awaited_once()
mock_queue.create_tree.assert_awaited_once()
@pytest.mark.asyncio
@ -390,8 +393,9 @@ async def test_handler_update_ui_edit_failure_does_not_crash():
session_store = MagicMock()
handler = ClaudeMessageHandler(platform, cli_manager, session_store)
handler.tree_queue = MagicMock()
handler.tree_queue.get_tree_for_node.return_value = None
mock_queue = MagicMock()
mock_queue.get_tree_for_node.return_value = None
handler.replace_tree_queue(mock_queue)
incoming = IncomingMessage(
text="hi",

View file

@ -187,10 +187,11 @@ class TestTreeQueueManager:
# First message should process immediately, not queue
assert was_queued is False
def test_cancel_tree_empty(self):
@pytest.mark.asyncio
async def test_cancel_tree_empty(self):
"""Test cancelling non-existent tree."""
from messaging.trees.queue_manager import TreeQueueManager
mgr = TreeQueueManager()
cancelled = mgr.cancel_tree("nonexistent")
cancelled = await mgr.cancel_tree("nonexistent")
assert cancelled == []

View file

@ -35,10 +35,15 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent(
# "Restart": new store instance loads from disk, and we restore TreeQueueManager.
store2 = SessionStore(storage_path=str(store_path))
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
handler2.tree_queue = TreeQueueManager.from_dict(
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
queue_update_callback=handler2.update_queue_positions,
node_started_callback=handler2.mark_node_processing,
handler2.replace_tree_queue(
TreeQueueManager.from_dict(
{
"trees": store2.get_all_trees(),
"node_to_tree": store2.get_node_mapping(),
},
queue_update_callback=handler2.update_queue_positions,
node_started_callback=handler2.mark_node_processing,
)
)
# Prevent background task scheduling; we only want to validate routing/tree mutation.
@ -87,10 +92,15 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat
store2 = SessionStore(storage_path=str(store_path))
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
handler2.tree_queue = TreeQueueManager.from_dict(
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
queue_update_callback=handler2.update_queue_positions,
node_started_callback=handler2.mark_node_processing,
handler2.replace_tree_queue(
TreeQueueManager.from_dict(
{
"trees": store2.get_all_trees(),
"node_to_tree": store2.get_node_mapping(),
},
queue_update_callback=handler2.update_queue_positions,
node_started_callback=handler2.mark_node_processing,
)
)
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")

View file

@ -408,7 +408,7 @@ class TestTreeQueueManagerConcurrency:
await mgr.enqueue("c2", slow_processor)
# Cancel the tree
cancelled = mgr.cancel_tree("root")
cancelled = await mgr.cancel_tree("root")
assert len(cancelled) >= 1 # At least the current + queued
# Tree should no longer be processing
@ -418,7 +418,7 @@ class TestTreeQueueManagerConcurrency:
async def test_cancel_nonexistent_tree(self):
"""cancel_tree for nonexistent tree returns empty list."""
mgr = TreeQueueManager()
result = mgr.cancel_tree("nonexistent")
result = await mgr.cancel_tree("nonexistent")
assert result == []
@pytest.mark.asyncio

View file

@ -514,7 +514,7 @@ class TestTreeQueueManager:
await manager.enqueue("m1", slow_processor)
# Cancel
cancelled = manager.cancel_tree("m1")
cancelled = await manager.cancel_tree("m1")
assert len(cancelled) == 1
processing_complete.set()