Remove over-engineering: drop tree_queue setter, _set_connected(), fi… (#63)

…x cancel_all() TOCTOU

- Remove tree_queue property setter (backward-compat hack; all callers
already migrated to replace_tree_queue()); keep property getter only
- Update 2 remaining tests that still used direct assignment to use
replace_tree_queue()
- Remove _set_connected() 1-line wrapper on DiscordPlatform; assign
_connected directly
- Fix cancel_all() TOCTOU: hold self._lock for the full loop so newly
created trees cannot slip through between the snapshot and cancellation

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Ali Khokhar 2026-03-01 12:34:00 -08:00 committed by GitHub
parent 25b329a3fc
commit fae8a2a044
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 9 additions and 19 deletions

View file

@ -157,11 +157,6 @@ class ClaudeMessageHandler:
"""Accessor for the current tree queue manager.""" """Accessor for the current tree queue manager."""
return self._tree_queue return self._tree_queue
@tree_queue.setter
def tree_queue(self, tree_queue: TreeQueueManager) -> None:
"""Backward-compatible setter routed through explicit replacement API."""
self.replace_tree_queue(tree_queue)
def replace_tree_queue(self, tree_queue: TreeQueueManager) -> None: def replace_tree_queue(self, tree_queue: TreeQueueManager) -> None:
"""Replace tree queue manager via explicit API.""" """Replace tree queue manager via explicit API."""
self._tree_queue = tree_queue self._tree_queue = tree_queue

View file

@ -66,7 +66,7 @@ if DISCORD_AVAILABLE and _discord_module is not None:
async def on_ready(self) -> None: async def on_ready(self) -> None:
"""Called when the bot is ready.""" """Called when the bot is ready."""
self._platform._set_connected(True) self._platform._connected = True
logger.info("Discord platform connected") logger.info("Discord platform connected")
async def on_message(self, message: Any) -> None: async def on_message(self, message: Any) -> None:
@ -118,10 +118,6 @@ class DiscordPlatform(MessagingPlatform):
self._pending_voice: dict[tuple[str, str], tuple[str, str]] = {} self._pending_voice: dict[tuple[str, str], tuple[str, str]] = {}
self._pending_voice_lock = asyncio.Lock() self._pending_voice_lock = asyncio.Lock()
def _set_connected(self, connected: bool) -> None:
"""Update connection state via an explicit accessor."""
self._connected = connected
async def _handle_client_message(self, message: Any) -> None: async def _handle_client_message(self, message: Any) -> None:
"""Adapter entry point used by the internal discord client.""" """Adapter entry point used by the internal discord client."""
await self._on_discord_message(message) await self._on_discord_message(message)
@ -367,7 +363,7 @@ class DiscordPlatform(MessagingPlatform):
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the bot.""" """Stop the bot."""
if self._client.is_closed(): if self._client.is_closed():
self._set_connected(False) self._connected = False
return return
await self._client.close() await self._client.close()
@ -379,7 +375,7 @@ class DiscordPlatform(MessagingPlatform):
with contextlib.suppress(asyncio.CancelledError): with contextlib.suppress(asyncio.CancelledError):
await self._start_task await self._start_task
self._set_connected(False) self._connected = False
logger.info("Discord platform stopped") logger.info("Discord platform stopped")
async def send_message( async def send_message(

View file

@ -314,7 +314,6 @@ class TreeQueueManager:
"""Cancel all messages in all trees.""" """Cancel all messages in all trees."""
async with self._lock: async with self._lock:
root_ids = list(self._repository.tree_ids()) root_ids = list(self._repository.tree_ids())
all_cancelled: list[MessageNode] = [] all_cancelled: list[MessageNode] = []
for root_id in root_ids: for root_id in root_ids:
all_cancelled.extend(await self.cancel_tree(root_id)) all_cancelled.extend(await self.cancel_tree(root_id))

View file

@ -24,7 +24,7 @@ def test_get_initial_status_reply_tree_busy_queued(handler):
mock_queue = MagicMock() mock_queue = MagicMock()
mock_queue.is_node_tree_busy.return_value = True mock_queue.is_node_tree_busy.return_value = True
mock_queue.get_queue_size.return_value = 2 mock_queue.get_queue_size.return_value = 2
handler.tree_queue = mock_queue handler.replace_tree_queue(mock_queue)
result = handler._get_initial_status(MagicMock(), "parent_1") result = handler._get_initial_status(MagicMock(), "parent_1")
assert "Queued" in result assert "Queued" in result
assert "position 3" in result assert "position 3" in result
@ -34,7 +34,7 @@ def test_get_initial_status_reply_tree_not_busy_continuing(handler):
"""Reply to tree when not busy returns continuing message.""" """Reply to tree when not busy returns continuing message."""
mock_queue = MagicMock() mock_queue = MagicMock()
mock_queue.is_node_tree_busy.return_value = False mock_queue.is_node_tree_busy.return_value = False
handler.tree_queue = mock_queue handler.replace_tree_queue(mock_queue)
result = handler._get_initial_status(MagicMock(), "parent_1") result = handler._get_initial_status(MagicMock(), "parent_1")
assert "Continuing" in result assert "Continuing" in result