mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-04-28 03:20:01 +00:00
introduced per reply /stop command to stop just that task and fixed context leakage in reply session forking
This commit is contained in:
parent
de7677ba6c
commit
459ce6e8fd
10 changed files with 513 additions and 40 deletions
|
|
@ -32,7 +32,7 @@ class CLISession:
|
|||
return self._is_busy
|
||||
|
||||
async def start_task(
|
||||
self, prompt: str, session_id: Optional[str] = None
|
||||
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Start a new task or continue an existing session.
|
||||
|
|
@ -66,6 +66,10 @@ class CLISession:
|
|||
"claude",
|
||||
"--resume",
|
||||
session_id,
|
||||
]
|
||||
if fork_session:
|
||||
cmd.append("--fork-session")
|
||||
cmd += [
|
||||
"-p",
|
||||
prompt,
|
||||
"--output-format",
|
||||
|
|
@ -106,11 +110,36 @@ class CLISession:
|
|||
session_id_extracted = False
|
||||
buffer = bytearray()
|
||||
|
||||
while True:
|
||||
chunk = await self.process.stdout.read(65536)
|
||||
if not chunk:
|
||||
if buffer:
|
||||
line_str = buffer.decode("utf-8", errors="replace").strip()
|
||||
try:
|
||||
while True:
|
||||
chunk = await self.process.stdout.read(65536)
|
||||
if not chunk:
|
||||
if buffer:
|
||||
line_str = buffer.decode(
|
||||
"utf-8", errors="replace"
|
||||
).strip()
|
||||
if line_str:
|
||||
async for event in self._handle_line_gen(
|
||||
line_str, session_id_extracted
|
||||
):
|
||||
if event.get("type") == "session_info":
|
||||
session_id_extracted = True
|
||||
yield event
|
||||
break
|
||||
|
||||
buffer.extend(chunk)
|
||||
|
||||
while True:
|
||||
newline_pos = buffer.find(b"\n")
|
||||
if newline_pos == -1:
|
||||
break
|
||||
|
||||
line = buffer[:newline_pos]
|
||||
buffer = buffer[newline_pos + 1 :]
|
||||
|
||||
line_str = line.decode(
|
||||
"utf-8", errors="replace"
|
||||
).strip()
|
||||
if line_str:
|
||||
async for event in self._handle_line_gen(
|
||||
line_str, session_id_extracted
|
||||
|
|
@ -118,26 +147,13 @@ class CLISession:
|
|||
if event.get("type") == "session_info":
|
||||
session_id_extracted = True
|
||||
yield event
|
||||
break
|
||||
|
||||
buffer.extend(chunk)
|
||||
|
||||
while True:
|
||||
newline_pos = buffer.find(b"\n")
|
||||
if newline_pos == -1:
|
||||
break
|
||||
|
||||
line = buffer[:newline_pos]
|
||||
buffer = buffer[newline_pos + 1 :]
|
||||
|
||||
line_str = line.decode("utf-8", errors="replace").strip()
|
||||
if line_str:
|
||||
async for event in self._handle_line_gen(
|
||||
line_str, session_id_extracted
|
||||
):
|
||||
if event.get("type") == "session_info":
|
||||
session_id_extracted = True
|
||||
yield event
|
||||
except asyncio.CancelledError:
|
||||
# Cancelling the handler task should not leave a Claude CLI
|
||||
# subprocess running in the background.
|
||||
try:
|
||||
await asyncio.shield(self.stop())
|
||||
finally:
|
||||
raise
|
||||
|
||||
stderr_text = None
|
||||
if self.process.stderr:
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ class CLISession(Protocol):
|
|||
"""Protocol for CLI session - avoid circular import from cli package."""
|
||||
|
||||
def start_task(
|
||||
self, prompt: str, session_id: Optional[str] = None
|
||||
self, prompt: str, session_id: Optional[str] = None, fork_session: bool = False
|
||||
) -> AsyncGenerator[Dict, Any]:
|
||||
"""Start a task in the CLI session."""
|
||||
...
|
||||
|
|
@ -60,6 +60,10 @@ class SessionManagerInterface(Protocol):
|
|||
"""Stop all sessions."""
|
||||
...
|
||||
|
||||
async def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session from the manager."""
|
||||
...
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get session statistics."""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -464,7 +464,7 @@ class ClaudeMessageHandler:
|
|||
session_or_temp_id,
|
||||
is_new,
|
||||
) = await self.cli_manager.get_or_create_session(
|
||||
session_id=parent_session_id # Fork from parent if available
|
||||
session_id=None # Always create a fresh session per node
|
||||
)
|
||||
if is_new:
|
||||
temp_session_id = session_or_temp_id
|
||||
|
|
@ -485,7 +485,9 @@ class ClaudeMessageHandler:
|
|||
logger.info(f"HANDLER: Starting CLI task processing for node {node_id}")
|
||||
event_count = 0
|
||||
async for event_data in cli_session.start_task(
|
||||
incoming.text, session_id=captured_session_id
|
||||
incoming.text,
|
||||
session_id=parent_session_id,
|
||||
fork_session=bool(parent_session_id),
|
||||
):
|
||||
if not isinstance(event_data, dict):
|
||||
logger.warning(
|
||||
|
|
@ -505,6 +507,15 @@ class ClaudeMessageHandler:
|
|||
)
|
||||
captured_session_id = real_session_id
|
||||
temp_session_id = None
|
||||
# Persist session_id early so replies can fork even if a task
|
||||
# is stopped before completion.
|
||||
if tree and captured_session_id:
|
||||
await tree.update_state(
|
||||
node_id,
|
||||
MessageState.IN_PROGRESS,
|
||||
session_id=captured_session_id,
|
||||
)
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
continue
|
||||
|
||||
parsed_list = parse_cli_event(event_data)
|
||||
|
|
@ -560,11 +571,21 @@ class ClaudeMessageHandler:
|
|||
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"HANDLER: Task cancelled for node {node_id}")
|
||||
components["errors"].append("Task was cancelled")
|
||||
await update_ui(format_status("❌", "Cancelled"), force=True)
|
||||
cancel_reason = None
|
||||
if isinstance(node.context, dict):
|
||||
cancel_reason = node.context.get("cancel_reason")
|
||||
|
||||
if cancel_reason == "stop":
|
||||
await update_ui(format_status("⏹", "Stopped."), force=True)
|
||||
else:
|
||||
components["errors"].append("Task was cancelled")
|
||||
await update_ui(format_status("❌", "Cancelled"), force=True)
|
||||
|
||||
# Do not propagate cancellation to children; a reply-scoped "/stop"
|
||||
# should only stop the targeted task.
|
||||
if tree:
|
||||
await self._propagate_error_to_children(
|
||||
node_id, "Cancelled by user", "Parent task was stopped"
|
||||
await tree.update_state(
|
||||
node_id, MessageState.ERROR, error_message="Cancelled by user"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
|
@ -581,6 +602,16 @@ class ClaudeMessageHandler:
|
|||
logger.info(
|
||||
f"HANDLER: _process_node completed for node {node_id}, errors={len(components['errors'])}"
|
||||
)
|
||||
# Free the session-manager slot. Session IDs are persisted in the tree and
|
||||
# can be resumed later by ID; we don't need to keep a CLISession instance
|
||||
# around after this node completes.
|
||||
try:
|
||||
if captured_session_id:
|
||||
await self.cli_manager.remove_session(captured_session_id)
|
||||
elif temp_session_id:
|
||||
await self.cli_manager.remove_session(temp_session_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to remove session for node {node_id}: {e}")
|
||||
|
||||
async def _propagate_error_to_children(
|
||||
self,
|
||||
|
|
@ -745,8 +776,61 @@ class ClaudeMessageHandler:
|
|||
|
||||
return len(cancelled_nodes)
|
||||
|
||||
async def stop_task(self, node_id: str) -> int:
|
||||
"""
|
||||
Stop a single queued or in-progress task node.
|
||||
|
||||
Used when the user replies "/stop" to a specific status/user message.
|
||||
"""
|
||||
tree = self.tree_queue.get_tree_for_node(node_id)
|
||||
if tree:
|
||||
node = tree.get_node(node_id)
|
||||
if node and node.state not in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
# Used by _process_node cancellation path to render "Stopped."
|
||||
node.context = {"cancel_reason": "stop"}
|
||||
|
||||
cancelled_nodes = await self.tree_queue.cancel_node(node_id)
|
||||
|
||||
for node in cancelled_nodes:
|
||||
self.platform.fire_and_forget(
|
||||
self.platform.queue_edit_message(
|
||||
node.incoming.chat_id,
|
||||
node.status_message_id,
|
||||
format_status("⏹", "Stopped."),
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
)
|
||||
|
||||
tree = self.tree_queue.get_tree_for_node(node.node_id)
|
||||
if tree:
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
|
||||
return len(cancelled_nodes)
|
||||
|
||||
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
|
||||
"""Handle /stop command from messaging platform."""
|
||||
# Reply-scoped stop: reply "/stop" to stop only that task.
|
||||
if incoming.is_reply() and incoming.reply_to_message_id:
|
||||
reply_id = incoming.reply_to_message_id
|
||||
tree = self.tree_queue.get_tree_for_node(reply_id)
|
||||
node_id = self.tree_queue.resolve_parent_node_id(reply_id) if tree else None
|
||||
|
||||
if not node_id:
|
||||
await self.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
format_status("⏹", "Stopped.", "Nothing to stop for that message."),
|
||||
)
|
||||
return
|
||||
|
||||
count = await self.stop_task(node_id)
|
||||
noun = "request" if count == 1 else "requests"
|
||||
await self.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
format_status("⏹", "Stopped.", f"Cancelled {count} {noun}."),
|
||||
)
|
||||
return
|
||||
|
||||
# Global stop: legacy behavior (stop everything)
|
||||
count = await self.stop_all_tasks()
|
||||
await self.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ Uses TreeRepository for data, TreeQueueProcessor for async logic.
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Awaitable, List, Optional
|
||||
|
||||
from .models import IncomingMessage
|
||||
|
|
@ -295,6 +297,49 @@ class TreeQueueManager:
|
|||
|
||||
return cancelled_nodes
|
||||
|
||||
async def cancel_node(self, node_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel a single node (queued or in-progress) without affecting other nodes.
|
||||
|
||||
- If the node is currently running, cancels the current asyncio task.
|
||||
- If the node is queued, removes it from the queue.
|
||||
- Marks the node as ERROR with "Cancelled by user".
|
||||
|
||||
Returns:
|
||||
List containing the cancelled node if it was cancellable, else empty list.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(node_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
async with tree._lock:
|
||||
node = tree.get_node(node_id)
|
||||
if not node:
|
||||
return []
|
||||
|
||||
if node.state in (MessageState.COMPLETED, MessageState.ERROR):
|
||||
return []
|
||||
|
||||
# Cancel running task if this is the current node.
|
||||
if tree._current_node_id == node_id:
|
||||
self._processor.cancel_current(tree)
|
||||
|
||||
# Remove from queue if present (asyncio.Queue exposes its internal deque).
|
||||
try:
|
||||
q = tree._queue._queue # type: ignore[attr-defined]
|
||||
if q and node_id in q:
|
||||
tree._queue._queue = deque(x for x in q if x != node_id) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
# Best-effort: if we can't mutate the queue internals, the node will
|
||||
# still be dequeued later and skipped due to state=ERROR.
|
||||
logger.debug("Failed to remove node from queue; will rely on state=ERROR")
|
||||
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
return [node]
|
||||
|
||||
async def cancel_all(self) -> List[MessageNode]:
|
||||
"""Cancel all messages in all trees (async wrapper)."""
|
||||
async with self._lock:
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def mock_cli_manager():
|
|||
manager.get_or_create_session = AsyncMock()
|
||||
manager.register_real_session_id = AsyncMock(return_value=True)
|
||||
manager.stop_all = AsyncMock()
|
||||
manager.remove_session = AsyncMock(return_value=True)
|
||||
manager.get_stats = MagicMock(
|
||||
return_value={"active_sessions": 0, "max_sessions": 5}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -259,6 +259,34 @@ class TestCLISession:
|
|||
args = mock_exec.call_args[0]
|
||||
assert "--resume" in args
|
||||
assert "sess_abc" in args
|
||||
assert "--fork-session" not in args
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_task_with_session_resume_and_fork(self):
|
||||
"""Test resuming an existing session and forking."""
|
||||
from cli.session import CLISession
|
||||
|
||||
session = CLISession("/tmp", "http://localhost:8082/v1")
|
||||
|
||||
mock_process = AsyncMock()
|
||||
mock_process.stdout.read.side_effect = [b""] # Immediate EOF
|
||||
mock_process.stderr.read.return_value = b""
|
||||
mock_process.wait.return_value = 0
|
||||
|
||||
with patch(
|
||||
"asyncio.create_subprocess_exec", new_callable=AsyncMock
|
||||
) as mock_exec:
|
||||
mock_exec.return_value = mock_process
|
||||
|
||||
async for _ in session.start_task(
|
||||
"Hello", session_id="sess_abc", fork_session=True
|
||||
):
|
||||
pass
|
||||
|
||||
args = mock_exec.call_args[0]
|
||||
assert "--resume" in args
|
||||
assert "sess_abc" in args
|
||||
assert "--fork-session" in args
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_task_process_failure_with_stderr(self):
|
||||
|
|
|
|||
|
|
@ -29,6 +29,63 @@ async def test_handle_message_stop_command(
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stop_command_reply_stops_only_target_node(
|
||||
handler, mock_platform, mock_cli_manager, incoming_message_factory
|
||||
):
|
||||
# Create a tree with a root node and register its status message ID mapping.
|
||||
root_incoming = incoming_message_factory(
|
||||
text="do something", message_id="root_msg", reply_to_message_id=None
|
||||
)
|
||||
tree = await handler.tree_queue.create_tree(
|
||||
node_id="root_msg",
|
||||
incoming=root_incoming,
|
||||
status_message_id="status_root",
|
||||
)
|
||||
handler.tree_queue.register_node("status_root", tree.root_id)
|
||||
|
||||
# Reply "/stop" to the status message; should stop only that node.
|
||||
incoming = incoming_message_factory(
|
||||
text="/stop",
|
||||
message_id="stop_msg",
|
||||
reply_to_message_id="status_root",
|
||||
)
|
||||
|
||||
handler.stop_all_tasks = AsyncMock(return_value=999)
|
||||
|
||||
await handler.handle_message(incoming)
|
||||
|
||||
handler.stop_all_tasks.assert_not_called()
|
||||
mock_cli_manager.stop_all.assert_not_called()
|
||||
assert tree.get_node("root_msg").state == MessageState.ERROR
|
||||
mock_platform.queue_send_message.assert_called_once_with(
|
||||
incoming.chat_id,
|
||||
"⏹ *Stopped\\.* Cancelled 1 request\\.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stop_command_reply_unknown_does_not_stop_all(
|
||||
handler, mock_platform, mock_cli_manager, incoming_message_factory
|
||||
):
|
||||
incoming = incoming_message_factory(
|
||||
text="/stop",
|
||||
message_id="stop_msg",
|
||||
reply_to_message_id="unknown_msg",
|
||||
)
|
||||
|
||||
handler.stop_all_tasks = AsyncMock(return_value=5)
|
||||
|
||||
await handler.handle_message(incoming)
|
||||
|
||||
handler.stop_all_tasks.assert_not_called()
|
||||
mock_cli_manager.stop_all.assert_not_called()
|
||||
mock_platform.queue_send_message.assert_called_once_with(
|
||||
incoming.chat_id,
|
||||
"⏹ *Stopped\\.* Nothing to stop for that message\\.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stats_command(
|
||||
handler, mock_platform, mock_cli_manager, incoming_message_factory
|
||||
|
|
|
|||
127
tests/test_handler_context_isolation.py
Normal file
127
tests/test_handler_context_isolation.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from messaging.handler import ClaudeMessageHandler
|
||||
from messaging.tree_data import MessageState
|
||||
|
||||
|
||||
async def _gen_session(events):
|
||||
for e in events:
|
||||
yield e
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler(mock_platform, mock_cli_manager, mock_session_store):
|
||||
return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sibling_replies_fork_from_parent_session_id(handler, mock_cli_manager, incoming_message_factory):
|
||||
# Root node A with a known session_id.
|
||||
root_incoming = incoming_message_factory(text="A", message_id="A")
|
||||
tree = await handler.tree_queue.create_tree(node_id="A", incoming=root_incoming, status_message_id="status_A")
|
||||
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
|
||||
|
||||
# Add two sibling replies R1 and R2 under A.
|
||||
r1_incoming = incoming_message_factory(text="R1", message_id="R1", reply_to_message_id="A")
|
||||
r2_incoming = incoming_message_factory(text="R2", message_id="R2", reply_to_message_id="A")
|
||||
_, r1_node = await handler.tree_queue.add_to_tree("A", "R1", r1_incoming, "status_R1")
|
||||
_, r2_node = await handler.tree_queue.add_to_tree("A", "R2", r2_incoming, "status_R2")
|
||||
|
||||
# Mock a fresh cli_session per node.
|
||||
calls = []
|
||||
|
||||
async def _get_or_create_session(session_id=None):
|
||||
cli_session = MagicMock()
|
||||
|
||||
async def _start_task(prompt, session_id=None, fork_session=False):
|
||||
calls.append((prompt, session_id, fork_session))
|
||||
child_sid = f"sess_{prompt}"
|
||||
async for ev in _gen_session(
|
||||
[
|
||||
{"type": "session_info", "session_id": child_sid},
|
||||
{"type": "exit", "code": 0, "stderr": None},
|
||||
]
|
||||
):
|
||||
yield ev
|
||||
|
||||
cli_session.start_task = _start_task
|
||||
return cli_session, f"pending_{len(calls)+1}", True
|
||||
|
||||
mock_cli_manager.get_or_create_session = AsyncMock(side_effect=_get_or_create_session)
|
||||
|
||||
await handler._process_node("R1", r1_node)
|
||||
await handler._process_node("R2", r2_node)
|
||||
|
||||
# Both siblings must resume from the same parent session and fork.
|
||||
assert calls[0][0] == "R1"
|
||||
assert calls[0][1] == "sess_A"
|
||||
assert calls[0][2] is True
|
||||
|
||||
assert calls[1][0] == "R2"
|
||||
assert calls[1][1] == "sess_A"
|
||||
assert calls[1][2] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grandchild_reply_forks_from_branch_session(handler, mock_cli_manager, incoming_message_factory):
|
||||
root_incoming = incoming_message_factory(text="A", message_id="A")
|
||||
tree = await handler.tree_queue.create_tree(node_id="A", incoming=root_incoming, status_message_id="status_A")
|
||||
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
|
||||
|
||||
r1_incoming = incoming_message_factory(text="R1", message_id="R1", reply_to_message_id="A")
|
||||
_, r1_node = await handler.tree_queue.add_to_tree("A", "R1", r1_incoming, "status_R1")
|
||||
|
||||
calls = []
|
||||
|
||||
async def _get_or_create_session(session_id=None):
|
||||
cli_session = MagicMock()
|
||||
|
||||
async def _start_task(prompt, session_id=None, fork_session=False):
|
||||
calls.append((prompt, session_id, fork_session))
|
||||
# R1 gets its own forked session id.
|
||||
child_sid = "sess_R1"
|
||||
async for ev in _gen_session(
|
||||
[
|
||||
{"type": "session_info", "session_id": child_sid},
|
||||
{"type": "exit", "code": 0, "stderr": None},
|
||||
]
|
||||
):
|
||||
yield ev
|
||||
|
||||
cli_session.start_task = _start_task
|
||||
return cli_session, "pending_R1", True
|
||||
|
||||
mock_cli_manager.get_or_create_session = AsyncMock(side_effect=_get_or_create_session)
|
||||
|
||||
await handler._process_node("R1", r1_node)
|
||||
assert r1_node.session_id == "sess_R1"
|
||||
|
||||
# Grandchild C1 replies to R1 and must fork from sess_R1, not sess_A.
|
||||
c1_incoming = incoming_message_factory(text="C1", message_id="C1", reply_to_message_id="R1")
|
||||
_, c1_node = await handler.tree_queue.add_to_tree("R1", "C1", c1_incoming, "status_C1")
|
||||
|
||||
async def _get_or_create_session_c1(session_id=None):
|
||||
cli_session = MagicMock()
|
||||
|
||||
async def _start_task(prompt, session_id=None, fork_session=False):
|
||||
calls.append((prompt, session_id, fork_session))
|
||||
async for ev in _gen_session(
|
||||
[
|
||||
{"type": "session_info", "session_id": "sess_C1"},
|
||||
{"type": "exit", "code": 0, "stderr": None},
|
||||
]
|
||||
):
|
||||
yield ev
|
||||
|
||||
cli_session.start_task = _start_task
|
||||
return cli_session, "pending_C1", True
|
||||
|
||||
mock_cli_manager.get_or_create_session = AsyncMock(side_effect=_get_or_create_session_c1)
|
||||
|
||||
await handler._process_node("C1", c1_node)
|
||||
|
||||
# The last call should be for C1 and must resume from sess_R1.
|
||||
assert calls[-1][0] == "C1"
|
||||
assert calls[-1][1] == "sess_R1"
|
||||
assert calls[-1][2] is True
|
||||
|
|
@ -29,17 +29,18 @@ async def test_full_conversation_flow_single_user(
|
|||
mock_session1 = MagicMock()
|
||||
mock_session1.start_task.return_value = mock_async_gen(
|
||||
[
|
||||
{"type": "session_info", "session_id": "sess1"},
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {"content": [{"type": "text", "text": "Reply 1"}]},
|
||||
},
|
||||
{"type": "exit", "code": 0},
|
||||
{"type": "exit", "code": 0, "stderr": None},
|
||||
]
|
||||
)
|
||||
mock_cli_manager.get_or_create_session.return_value = (
|
||||
mock_session1,
|
||||
"sess1",
|
||||
False,
|
||||
"pending_1",
|
||||
True,
|
||||
)
|
||||
|
||||
await handler_integration.handle_message(msg1)
|
||||
|
|
@ -53,6 +54,9 @@ async def test_full_conversation_flow_single_user(
|
|||
|
||||
assert tree.get_node("m1").state.value == MessageState.COMPLETED.value
|
||||
assert tree.get_node("m1").session_id == "sess1"
|
||||
mock_session1.start_task.assert_called_with(
|
||||
"message 1", session_id=None, fork_session=False
|
||||
)
|
||||
|
||||
# 2. Reply to m1
|
||||
msg2 = incoming_message_factory(
|
||||
|
|
@ -64,18 +68,19 @@ async def test_full_conversation_flow_single_user(
|
|||
mock_session2 = MagicMock()
|
||||
mock_session2.start_task.return_value = mock_async_gen(
|
||||
[
|
||||
{"type": "session_info", "session_id": "sess2"},
|
||||
{
|
||||
"type": "assistant",
|
||||
"message": {"content": [{"type": "text", "text": "Reply 2"}]},
|
||||
},
|
||||
{"type": "exit", "code": 0},
|
||||
{"type": "exit", "code": 0, "stderr": None},
|
||||
]
|
||||
)
|
||||
mock_cli_manager.get_or_create_session.reset_mock()
|
||||
mock_cli_manager.get_or_create_session.return_value = (
|
||||
mock_session2,
|
||||
"sess2",
|
||||
False,
|
||||
"pending_2",
|
||||
True,
|
||||
)
|
||||
|
||||
await handler_integration.handle_message(msg2)
|
||||
|
|
@ -88,7 +93,10 @@ async def test_full_conversation_flow_single_user(
|
|||
|
||||
assert tree.get_node("m2").state.value == MessageState.COMPLETED.value
|
||||
assert tree.get_node("m2").parent_id == "m1"
|
||||
mock_cli_manager.get_or_create_session.assert_called_with(session_id="sess1")
|
||||
mock_cli_manager.get_or_create_session.assert_called_with(session_id=None)
|
||||
mock_session2.start_task.assert_called_with(
|
||||
"message 2", session_id="sess1", fork_session=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
103
tests/test_restart_reply_restore.py
Normal file
103
tests/test_restart_reply_restore.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from messaging.handler import ClaudeMessageHandler
|
||||
from messaging.session import SessionStore
|
||||
from messaging.tree_queue import TreeQueueManager
|
||||
from messaging.models import IncomingMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_to_old_status_message_after_restore_routes_to_parent(tmp_path, mock_platform, mock_cli_manager):
|
||||
# Build a persisted tree with a root node A and a bot status message id.
|
||||
store_path = tmp_path / "sessions.json"
|
||||
store = SessionStore(storage_path=str(store_path))
|
||||
|
||||
handler1 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store)
|
||||
a_incoming = IncomingMessage(
|
||||
text="A",
|
||||
chat_id="chat_1",
|
||||
user_id="user_1",
|
||||
message_id="A",
|
||||
platform="telegram",
|
||||
)
|
||||
tree = await handler1.tree_queue.create_tree("A", a_incoming, status_message_id="status_A")
|
||||
handler1.tree_queue.register_node("status_A", tree.root_id)
|
||||
store.register_node("status_A", tree.root_id)
|
||||
store.save_tree(tree.root_id, tree.to_dict())
|
||||
|
||||
# "Restart": new store instance loads from disk, and we restore TreeQueueManager.
|
||||
store2 = SessionStore(storage_path=str(store_path))
|
||||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||
handler2.tree_queue = TreeQueueManager.from_dict(
|
||||
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
||||
queue_update_callback=handler2._update_queue_positions,
|
||||
node_started_callback=handler2._mark_node_processing,
|
||||
)
|
||||
|
||||
# Prevent background task scheduling; we only want to validate routing/tree mutation.
|
||||
handler2.tree_queue.enqueue = AsyncMock(return_value=False)
|
||||
|
||||
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
|
||||
|
||||
reply = IncomingMessage(
|
||||
text="R1",
|
||||
chat_id="chat_1",
|
||||
user_id="user_1",
|
||||
message_id="R1",
|
||||
platform="telegram",
|
||||
reply_to_message_id="status_A",
|
||||
)
|
||||
|
||||
await handler2.handle_message(reply)
|
||||
|
||||
restored_tree = handler2.tree_queue.get_tree_for_node("A")
|
||||
assert restored_tree is not None
|
||||
node_r1 = restored_tree.get_node("R1")
|
||||
assert node_r1 is not None
|
||||
assert node_r1.parent_id == "A"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_to_old_status_message_without_mapping_creates_new_conversation(tmp_path, mock_platform, mock_cli_manager):
|
||||
store_path = tmp_path / "sessions.json"
|
||||
store = SessionStore(storage_path=str(store_path))
|
||||
|
||||
handler1 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store)
|
||||
a_incoming = IncomingMessage(
|
||||
text="A",
|
||||
chat_id="chat_1",
|
||||
user_id="user_1",
|
||||
message_id="A",
|
||||
platform="telegram",
|
||||
)
|
||||
tree = await handler1.tree_queue.create_tree("A", a_incoming, status_message_id="status_A")
|
||||
# Intentionally do NOT register "status_A" mapping.
|
||||
store.save_tree(tree.root_id, tree.to_dict())
|
||||
|
||||
store2 = SessionStore(storage_path=str(store_path))
|
||||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||
handler2.tree_queue = TreeQueueManager.from_dict(
|
||||
{"trees": store2.get_all_trees(), "node_to_tree": store2.get_node_mapping()},
|
||||
queue_update_callback=handler2._update_queue_positions,
|
||||
node_started_callback=handler2._mark_node_processing,
|
||||
)
|
||||
handler2.tree_queue.enqueue = AsyncMock(return_value=False)
|
||||
mock_platform.queue_send_message = AsyncMock(return_value="status_reply")
|
||||
|
||||
reply = IncomingMessage(
|
||||
text="R1",
|
||||
chat_id="chat_1",
|
||||
user_id="user_1",
|
||||
message_id="R1",
|
||||
platform="telegram",
|
||||
reply_to_message_id="status_A",
|
||||
)
|
||||
|
||||
await handler2.handle_message(reply)
|
||||
|
||||
# Since the mapping is missing, this should be treated as a new conversation.
|
||||
new_tree = handler2.tree_queue.get_tree_for_node("R1")
|
||||
assert new_tree is not None
|
||||
assert new_tree.root_id == "R1"
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue