fixed /stop

This commit is contained in:
Alishahryar1 2026-01-31 14:39:24 -08:00
parent 8ce86f4267
commit 78d0276d03
3 changed files with 52 additions and 10 deletions

View file

@ -99,6 +99,14 @@ async def lifespan(app: FastAPI):
"node_to_tree": session_store._node_to_tree,
}
)
# Reconcile restored state - anything PENDING/IN_PROGRESS is lost across restart
if message_handler.tree_queue.cleanup_stale_nodes() > 0:
# Sync back and save
session_store._trees = message_handler.tree_queue.to_dict()["trees"]
session_store._node_to_tree = message_handler.tree_queue.to_dict()[
"node_to_tree"
]
session_store._save()
# Wire up the handler
messaging_platform.on_message(message_handler.handle_message)

View file

@ -491,9 +491,9 @@ class ClaudeMessageHandler:
logger.info("Stopping all CLI sessions...")
await self.cli_manager.stop_all()
# 3. Update UI for all cancelled nodes
# 3. Update UI and persist state for all cancelled nodes
for node in cancelled_nodes:
# Fire and forget to avoid blocking the cleanup process
# Fire and forget UI update
self.platform.fire_and_forget(
self.platform.queue_edit_message(
node.incoming.chat_id,
@ -503,6 +503,11 @@ class ClaudeMessageHandler:
)
)
# Persist tree state
tree = self.tree_queue.get_tree_for_node(node.node_id)
if tree:
self.session_store.save_tree(tree.root_id, tree.to_dict())
return len(cancelled_nodes)
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:

View file

@ -226,7 +226,8 @@ class TreeQueueManager:
"""
Cancel all queued and in-progress messages in a tree.
Updates node states to ERROR and returns list of affected nodes.
Updates node states to ERROR and returns list of affected nodes
that were actually active or in the current processing queue.
"""
tree = self._repository.get_tree(root_id)
if not tree:
@ -234,7 +235,7 @@ class TreeQueueManager:
cancelled_nodes = []
# Cancel running task via processor
# 1. Cancel running task via processor
if self._processor.cancel_current(tree):
if tree._current_node_id:
node = tree.get_node(tree._current_node_id)
@ -246,7 +247,7 @@ class TreeQueueManager:
node.error_message = "Cancelled by user"
cancelled_nodes.append(node)
# Clear queue and update states
# 2. Clear queue and update states
while not tree._queue.empty():
try:
node_id = tree._queue.get_nowait()
@ -258,17 +259,29 @@ class TreeQueueManager:
except asyncio.QueueEmpty:
break
# Also cancel any PENDING nodes that weren't in queue
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
# This handles stale nodes from previous sessions without counting them as "cancelled now"
# unless they were in the active queue above.
cleanup_count = 0
for node in tree.all_nodes():
if node.state == MessageState.PENDING and node not in cancelled_nodes:
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node not in cancelled_nodes
):
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
cancelled_nodes.append(node)
node.error_message = "Stale task cleaned up"
cleanup_count += 1
tree._is_processing = False
tree._current_node_id = None
logger.info(f"Cancelled {len(cancelled_nodes)} nodes in tree {root_id}")
if cancelled_nodes:
logger.info(
f"Cancelled {len(cancelled_nodes)} active nodes in tree {root_id}"
)
if cleanup_count:
logger.info(f"Cleaned up {cleanup_count} stale nodes in tree {root_id}")
return cancelled_nodes
async def cancel_all(self) -> List[MessageNode]:
@ -286,6 +299,22 @@ class TreeQueueManager:
all_cancelled.extend(self.cancel_tree(root_id))
return all_cancelled
def cleanup_stale_nodes(self) -> int:
"""
Mark any PENDING or IN_PROGRESS nodes in all trees as ERROR.
Used on startup to reconcile restored state.
"""
count = 0
for tree in self._repository.all_trees():
for node in tree.all_nodes():
if node.state in (MessageState.PENDING, MessageState.IN_PROGRESS):
node.state = MessageState.ERROR
node.error_message = "Lost during server restart"
count += 1
if count:
logger.info(f"Cleaned up {count} stale nodes during startup")
return count
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree (for external mapping)."""
self._repository.register_node(node_id, root_id)