introduced per reply /stop command to stop just that task and fixed context leakage in reply session forking

This commit is contained in:
Alishahryar1 2026-02-13 14:53:48 -08:00
parent de7677ba6c
commit 459ce6e8fd
10 changed files with 513 additions and 40 deletions

View file

@ -32,7 +32,7 @@ class CLISession:
return self._is_busy
async def start_task(
self, prompt: str, session_id: Optional[str] = None
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
) -> AsyncGenerator[dict, None]:
"""
Start a new task or continue an existing session.
@ -66,6 +66,10 @@ class CLISession:
"claude",
"--resume",
session_id,
]
if fork_session:
cmd.append("--fork-session")
cmd += [
"-p",
prompt,
"--output-format",
@ -106,11 +110,36 @@ class CLISession:
session_id_extracted = False
buffer = bytearray()
while True:
chunk = await self.process.stdout.read(65536)
if not chunk:
if buffer:
line_str = buffer.decode("utf-8", errors="replace").strip()
try:
while True:
chunk = await self.process.stdout.read(65536)
if not chunk:
if buffer:
line_str = buffer.decode(
"utf-8", errors="replace"
).strip()
if line_str:
async for event in self._handle_line_gen(
line_str, session_id_extracted
):
if event.get("type") == "session_info":
session_id_extracted = True
yield event
break
buffer.extend(chunk)
while True:
newline_pos = buffer.find(b"\n")
if newline_pos == -1:
break
line = buffer[:newline_pos]
buffer = buffer[newline_pos + 1 :]
line_str = line.decode(
"utf-8", errors="replace"
).strip()
if line_str:
async for event in self._handle_line_gen(
line_str, session_id_extracted
@ -118,26 +147,13 @@ class CLISession:
if event.get("type") == "session_info":
session_id_extracted = True
yield event
break
buffer.extend(chunk)
while True:
newline_pos = buffer.find(b"\n")
if newline_pos == -1:
break
line = buffer[:newline_pos]
buffer = buffer[newline_pos + 1 :]
line_str = line.decode("utf-8", errors="replace").strip()
if line_str:
async for event in self._handle_line_gen(
line_str, session_id_extracted
):
if event.get("type") == "session_info":
session_id_extracted = True
yield event
except asyncio.CancelledError:
# Cancelling the handler task should not leave a Claude CLI
# subprocess running in the background.
try:
await asyncio.shield(self.stop())
finally:
raise
stderr_text = None
if self.process.stderr:

View file

@ -20,7 +20,7 @@ class CLISession(Protocol):
"""Protocol for CLI session - avoid circular import from cli package."""
def start_task(
self, prompt: str, session_id: Optional[str] = None
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
) -> AsyncGenerator[Dict, Any]:
"""Start a task in the CLI session."""
...
@ -60,6 +60,10 @@ class SessionManagerInterface(Protocol):
"""Stop all sessions."""
...
async def remove_session(self, session_id: str) -> bool:
"""Remove a session from the manager."""
...
def get_stats(self) -> dict:
"""Get session statistics."""
...

View file

@ -464,7 +464,7 @@ class ClaudeMessageHandler:
session_or_temp_id,
is_new,
) = await self.cli_manager.get_or_create_session(
session_id=parent_session_id # Fork from parent if available
session_id=None # Always create a fresh session per node
)
if is_new:
temp_session_id = session_or_temp_id
@ -485,7 +485,9 @@ class ClaudeMessageHandler:
logger.info(f"HANDLER: Starting CLI task processing for node {node_id}")
event_count = 0
async for event_data in cli_session.start_task(
incoming.text, session_id=captured_session_id
incoming.text,
session_id=parent_session_id,
fork_session=bool(parent_session_id),
):
if not isinstance(event_data, dict):
logger.warning(
@ -505,6 +507,15 @@ class ClaudeMessageHandler:
)
captured_session_id = real_session_id
temp_session_id = None
# Persist session_id early so replies can fork even if a task
# is stopped before completion.
if tree and captured_session_id:
await tree.update_state(
node_id,
MessageState.IN_PROGRESS,
session_id=captured_session_id,
)
self.session_store.save_tree(tree.root_id, tree.to_dict())
continue
parsed_list = parse_cli_event(event_data)
@ -560,11 +571,21 @@ class ClaudeMessageHandler:
except asyncio.CancelledError:
logger.warning(f"HANDLER: Task cancelled for node {node_id}")
components["errors"].append("Task was cancelled")
await update_ui(format_status("", "Cancelled"), force=True)
cancel_reason = None
if isinstance(node.context, dict):
cancel_reason = node.context.get("cancel_reason")
if cancel_reason == "stop":
await update_ui(format_status("", "Stopped."), force=True)
else:
components["errors"].append("Task was cancelled")
await update_ui(format_status("", "Cancelled"), force=True)
# Do not propagate cancellation to children; a reply-scoped "/stop"
# should only stop the targeted task.
if tree:
await self._propagate_error_to_children(
node_id, "Cancelled by user", "Parent task was stopped"
await tree.update_state(
node_id, MessageState.ERROR, error_message="Cancelled by user"
)
except Exception as e:
logger.error(
@ -581,6 +602,16 @@ class ClaudeMessageHandler:
logger.info(
f"HANDLER: _process_node completed for node {node_id}, errors={len(components['errors'])}"
)
# Free the session-manager slot. Session IDs are persisted in the tree and
# can be resumed later by ID; we don't need to keep a CLISession instance
# around after this node completes.
try:
if captured_session_id:
await self.cli_manager.remove_session(captured_session_id)
elif temp_session_id:
await self.cli_manager.remove_session(temp_session_id)
except Exception as e:
logger.debug(f"Failed to remove session for node {node_id}: {e}")
async def _propagate_error_to_children(
self,
@ -745,8 +776,61 @@ class ClaudeMessageHandler:
return len(cancelled_nodes)
async def stop_task(self, node_id: str) -> int:
"""
Stop a single queued or in-progress task node.
Used when the user replies "/stop" to a specific status/user message.
"""
tree = self.tree_queue.get_tree_for_node(node_id)
if tree:
node = tree.get_node(node_id)
if node and node.state not in (MessageState.COMPLETED, MessageState.ERROR):
# Used by _process_node cancellation path to render "Stopped."
node.context = {"cancel_reason": "stop"}
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
for node in cancelled_nodes:
self.platform.fire_and_forget(
self.platform.queue_edit_message(
node.incoming.chat_id,
node.status_message_id,
format_status("", "Stopped."),
parse_mode="MarkdownV2",
)
)
tree = self.tree_queue.get_tree_for_node(node.node_id)
if tree:
self.session_store.save_tree(tree.root_id, tree.to_dict())
return len(cancelled_nodes)
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
"""Handle /stop command from messaging platform."""
# Reply-scoped stop: reply "/stop" to stop only that task.
if incoming.is_reply() and incoming.reply_to_message_id:
reply_id = incoming.reply_to_message_id
tree = self.tree_queue.get_tree_for_node(reply_id)
node_id = self.tree_queue.resolve_parent_node_id(reply_id) if tree else None
if not node_id:
await self.platform.queue_send_message(
incoming.chat_id,
format_status("", "Stopped.", "Nothing to stop for that message."),
)
return
count = await self.stop_task(node_id)
noun = "request" if count == 1 else "requests"
await self.platform.queue_send_message(
incoming.chat_id,
format_status("", "Stopped.", f"Cancelled {count} {noun}."),
)
return
# Global stop: legacy behavior (stop everything)
count = await self.stop_all_tasks()
await self.platform.queue_send_message(
incoming.chat_id,

View file

@ -6,6 +6,8 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic.
import asyncio
import logging
from collections import deque
from datetime import datetime, timezone
from typing import Callable, Awaitable, List, Optional
from .models import IncomingMessage
@ -295,6 +297,49 @@ class TreeQueueManager:
return cancelled_nodes
async def cancel_node(self, node_id: str) -> List[MessageNode]:
"""
Cancel a single node (queued or in-progress) without affecting other nodes.
- If the node is currently running, cancels the current asyncio task.
- If the node is queued, removes it from the queue.
- Marks the node as ERROR with "Cancelled by user".
Returns:
List containing the cancelled node if it was cancellable, else empty list.
"""
tree = self._repository.get_tree_for_node(node_id)
if not tree:
return []
async with tree._lock:
node = tree.get_node(node_id)
if not node:
return []
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
return []
# Cancel running task if this is the current node.
if tree._current_node_id == node_id:
self._processor.cancel_current(tree)
# Remove from queue if present (asyncio.Queue exposes its internal deque).
try:
q = tree._queue._queue # type: ignore[attr-defined]
if q and node_id in q:
tree._queue._queue = deque(x for x in q if x != node_id) # type: ignore[attr-defined]
except Exception:
# Best-effort: if we can't mutate the queue internals, the node will
# still be dequeued later and skipped due to state=ERROR.
logger.debug("Failed to remove node from queue; will rely on state=ERROR")
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
return [node]
async def cancel_all(self) -> List[MessageNode]:
"""Cancel all messages in all trees (async wrapper)."""
async with self._lock:

View file

@ -50,6 +50,7 @@ def mock_cli_manager():
manager.get_or_create_session = AsyncMock()
manager.register_real_session_id = AsyncMock(return_value=True)
manager.stop_all = AsyncMock()
manager.remove_session = AsyncMock(return_value=True)
manager.get_stats = MagicMock(
return_value={"active_sessions": 0, "max_sessions": 5}
)

View file

@ -259,6 +259,34 @@ class TestCLISession:
args = mock_exec.call_args[0]
assert "--resume" in args
assert "sess_abc" in args
assert "--fork-session" not in args
@pytest.mark.asyncio
async def test_start_task_with_session_resume_and_fork(self):
"""Test resuming an existing session and forking."""
from cli.session import CLISession
session = CLISession("/tmp", "http://localhost:8082/v1")
mock_process = AsyncMock()
mock_process.stdout.read.side_effect = [b""] # Immediate EOF
mock_process.stderr.read.return_value = b""
mock_process.wait.return_value = 0
with patch(
"asyncio.create_subprocess_exec", new_callable=AsyncMock
) as mock_exec:
mock_exec.return_value = mock_process
async for _ in session.start_task(
"Hello", session_id="sess_abc", fork_session=True
):
pass
args = mock_exec.call_args[0]
assert "--resume" in args
assert "sess_abc" in args
assert "--fork-session" in args
@pytest.mark.asyncio
async def test_start_task_process_failure_with_stderr(self):

View file

@ -29,6 +29,63 @@ async def test_handle_message_stop_command(
)
@pytest.mark.asyncio
async def test_handle_message_stop_command_reply_stops_only_target_node(
handler, mock_platform, mock_cli_manager, incoming_message_factory
):
# Create a tree with a root node and register its status message ID mapping.
root_incoming = incoming_message_factory(
text="do something", message_id="root_msg", reply_to_message_id=None
)
tree = await handler.tree_queue.create_tree(
node_id="root_msg",
incoming=root_incoming,
status_message_id="status_root",
)
handler.tree_queue.register_node("status_root", tree.root_id)
# Reply "/stop" to the status message; should stop only that node.
incoming = incoming_message_factory(
text="/stop",
message_id="stop_msg",
reply_to_message_id="status_root",
)
handler.stop_all_tasks = AsyncMock(return_value=999)
await handler.handle_message(incoming)
handler.stop_all_tasks.assert_not_called()
mock_cli_manager.stop_all.assert_not_called()
assert tree.get_node("root_msg").state == MessageState.ERROR
mock_platform.queue_send_message.assert_called_once_with(
incoming.chat_id,
"⏹ *Stopped\\.* Cancelled 1 request\\.",
)
@pytest.mark.asyncio
async def test_handle_message_stop_command_reply_unknown_does_not_stop_all(
handler, mock_platform, mock_cli_manager, incoming_message_factory
):
incoming = incoming_message_factory(
text="/stop",
message_id="stop_msg",
reply_to_message_id="unknown_msg",
)
handler.stop_all_tasks = AsyncMock(return_value=5)
await handler.handle_message(incoming)
handler.stop_all_tasks.assert_not_called()
mock_cli_manager.stop_all.assert_not_called()
mock_platform.queue_send_message.assert_called_once_with(
incoming.chat_id,
"⏹ *Stopped\\.* Nothing to stop for that message\\.",
)
@pytest.mark.asyncio
async def test_handle_message_stats_command(
handler, mock_platform, mock_cli_manager, incoming_message_factory

View file

@ -0,0 +1,127 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from messaging.handler import ClaudeMessageHandler
from messaging.tree_data import MessageState
async def _gen_session(events):
for e in events:
yield e
@pytest.fixture
def handler(mock_platform, mock_cli_manager, mock_session_store):
return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store)
@pytest.mark.asyncio
async def test_sibling_replies_fork_from_parent_session_id(handler, mock_cli_manager, incoming_message_factory):
# Root node A with a known session_id.
root_incoming = incoming_message_factory(text="A", message_id="A")
tree = await handler.tree_queue.create_tree(node_id="A", incoming=root_incoming, status_message_id="status_A")
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
# Add two sibling replies R1 and R2 under A.
r1_incoming = incoming_message_factory(text="R1", message_id="R1", reply_to_message_id="A")
r2_incoming = incoming_message_factory(text="R2", message_id="R2", reply_to_message_id="A")
_, r1_node = await handler.tree_queue.add_to_tree("A", "R1", r1_incoming, "status_R1")
_, r2_node = await handler.tree_queue.add_to_tree("A", "R2", r2_incoming, "status_R2")
# Mock a fresh cli_session per node.
calls = []
async def _get_or_create_session(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
child_sid = f"sess_{prompt}"
async for ev in _gen_session(
[
{"type": "session_info", "session_id": child_sid},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, f"pending_{len(calls)+1}", True
mock_cli_manager.get_or_create_session = AsyncMock(side_effect=_get_or_create_session)
await handler._process_node("R1", r1_node)
await handler._process_node("R2", r2_node)
# Both siblings must resume from the same parent session and fork.
assert calls[0][0] == "R1"
assert calls[0][1] == "sess_A"
assert calls[0][2] is True
assert calls[1][0] == "R2"
assert calls[1][1] == "sess_A"
assert calls[1][2] is True
@pytest.mark.asyncio
async def test_grandchild_reply_forks_from_branch_session(handler, mock_cli_manager, incoming_message_factory):
root_incoming = incoming_message_factory(text="A", message_id="A")
tree = await handler.tree_queue.create_tree(node_id="A", incoming=root_incoming, status_message_id="status_A")
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
r1_incoming = incoming_message_factory(text="R1", message_id="R1", reply_to_message_id="A")
_, r1_node = await handler.tree_queue.add_to_tree("A", "R1", r1_incoming, "status_R1")
calls = []
async def _get_or_create_session(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
# R1 gets its own forked session id.
child_sid = "sess_R1"
async for ev in _gen_session(
[
{"type": "session_info", "session_id": child_sid},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, "pending_R1", True
mock_cli_manager.get_or_create_session = AsyncMock(side_effect=_get_or_create_session)
await handler._process_node("R1", r1_node)
assert r1_node.session_id == "sess_R1"
# Grandchild C1 replies to R1 and must fork from sess_R1, not sess_A.
c1_incoming = incoming_message_factory(text="C1", message_id="C1", reply_to_message_id="R1")
_, c1_node = await handler.tree_queue.add_to_tree("R1", "C1", c1_incoming, "status_C1")
async def _get_or_create_session_c1(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
async for ev in _gen_session(
[
{"type": "session_info", "session_id": "sess_C1"},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, "pending_C1", True
mock_cli_manager.get_or_create_session = AsyncMock(side_effect=_get_or_create_session_c1)
await handler._process_node("C1", c1_node)
# The last call should be for C1 and must resume from sess_R1.
assert calls[-1][0] == "C1"
assert calls[-1][1] == "sess_R1"
assert calls[-1][2] is True

View file

@ -29,17 +29,18 @@ async def test_full_conversation_flow_single_user(
mock_session1 = MagicMock()
mock_session1.start_task.return_value = mock_async_gen(
[
{"type": "session_info", "session_id": "sess1"},
{
"type": "assistant",
"message": {"content": [{"type": "text", "text": "Reply 1"}]},
},
{"type": "exit", "code": 0},
{"type": "exit", "code": 0, "stderr": None},
]
)
mock_cli_manager.get_or_create_session.return_value = (
mock_session1,
"sess1",
False,
"pending_1",
True,
)
await handler_integration.handle_message(msg1)
@ -53,6 +54,9 @@ async def test_full_conversation_flow_single_user(
assert tree.get_node("m1").state.value == MessageState.COMPLETED.value
assert tree.get_node("m1").session_id == "sess1"
mock_session1.start_task.assert_called_with(
"message 1", session_id=None, fork_session=False
)
# 2. Reply to m1
msg2 = incoming_message_factory(
@ -64,18 +68,19 @@ async def test_full_conversation_flow_single_user(
mock_session2 = MagicMock()
mock_session2.start_task.return_value = mock_async_gen(
[
{"type": "session_info", "session_id": "sess2"},
{
"type": "assistant",
"message": {"content": [{"type": "text", "text": "Reply 2"}]},
},
{"type": "exit", "code": 0},
{"type": "exit", "code": 0, "stderr": None},
]
)
mock_cli_manager.get_or_create_session.reset_mock()
mock_cli_manager.get_or_create_session.return_value = (
mock_session2,
"sess2",
False,
"pending_2",
True,
)
await handler_integration.handle_message(msg2)
@ -88,7 +93,10 @@ async def test_full_conversation_flow_single_user(
assert tree.get_node("m2").state.value == MessageState.COMPLETED.value
assert tree.get_node("m2").parent_id == "m1"
mock_cli_manager.get_or_create_session.assert_called_with(session_id="sess1")
mock_cli_manager.get_or_create_session.assert_called_with(session_id=None)
mock_session2.start_task.assert_called_with(
"message 2", session_id="sess1", fork_session=True
)
@pytest.mark.asyncio

View file

@ -0,0 +1,103 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from messaging.handler import ClaudeMessageHandler
from messaging.session import SessionStore
from messaging.tree_queue import TreeQueueManager
from messaging.models import IncomingMessage
@pytest.mark.asyncio
async def test_reply_to_old_status_message_after_restore_routes_to_parent(tmp_path, mock_platform, mock_cli_manager):
# Build a persisted tree with a root node A and a bot status message id.
store_path = tmp_path / "sessions.json"
store = SessionStore(storage_path=str(store_path))
handler1 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store)
a_incoming = IncomingMessage(
text="A",
chat_id="chat_1",
user_id="user_1",
message_id="A",
platform="telegram",
)
tree = await handler1.tree_queue.create_tree("A", a_incoming, status_message_id="status_A")
handler1.tree_queue.register_node("status_A", tree.root_id)
store.register_node("status_A", tree.root_id)
store.save_tree(tree.root_id, tree.to_dict())
# "Restart": new store instance loads from disk, and we restore TreeQueueManager.
store2 = SessionStore(storage_path=str(store_path))
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
handler2.tree_queue = TreeQueueManager.from_dict(
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
queue_update_callback=handler2._update_queue_positions,
node_started_callback=handler2._mark_node_processing,
)
# Prevent background task scheduling; we only want to validate routing/tree mutation.
handler2.tree_queue.enqueue = AsyncMock(return_value=False)
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
reply = IncomingMessage(
text="R1",
chat_id="chat_1",
user_id="user_1",
message_id="R1",
platform="telegram",
reply_to_message_id="status_A",
)
await handler2.handle_message(reply)
restored_tree = handler2.tree_queue.get_tree_for_node("A")
assert restored_tree is not None
node_r1 = restored_tree.get_node("R1")
assert node_r1 is not None
assert node_r1.parent_id == "A"
@pytest.mark.asyncio
async def test_reply_to_old_status_message_without_mapping_creates_new_conversation(tmp_path, mock_platform, mock_cli_manager):
store_path = tmp_path / "sessions.json"
store = SessionStore(storage_path=str(store_path))
handler1 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store)
a_incoming = IncomingMessage(
text="A",
chat_id="chat_1",
user_id="user_1",
message_id="A",
platform="telegram",
)
tree = await handler1.tree_queue.create_tree("A", a_incoming, status_message_id="status_A")
# Intentionally do NOT register "status_A" mapping.
store.save_tree(tree.root_id, tree.to_dict())
store2 = SessionStore(storage_path=str(store_path))
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
handler2.tree_queue = TreeQueueManager.from_dict(
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
queue_update_callback=handler2._update_queue_positions,
node_started_callback=handler2._mark_node_processing,
)
handler2.tree_queue.enqueue = AsyncMock(return_value=False)
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
reply = IncomingMessage(
text="R1",
chat_id="chat_1",
user_id="user_1",
message_id="R1",
platform="telegram",
reply_to_message_id="status_A",
)
await handler2.handle_message(reply)
# Since the mapping is missing, this should be treated as a new conversation.
new_tree = handler2.tree_queue.get_tree_for_node("R1")
assert new_tree is not None
assert new_tree.root_id == "R1"