free-claude-code/tests/messaging/test_handler_context_isolation.py
2026-02-19 20:38:11 -08:00

158 lines
5.1 KiB
Python

from unittest.mock import AsyncMock, MagicMock
import pytest
from messaging.handler import ClaudeMessageHandler
from messaging.trees.data import MessageState
async def _gen_session(events):
for e in events:
yield e
@pytest.fixture
def handler(mock_platform, mock_cli_manager, mock_session_store):
return ClaudeMessageHandler(mock_platform, mock_cli_manager, mock_session_store)
@pytest.mark.asyncio
async def test_sibling_replies_fork_from_parent_session_id(
handler, mock_cli_manager, incoming_message_factory
):
# Root node A with a known session_id.
root_incoming = incoming_message_factory(text="A", message_id="A")
tree = await handler.tree_queue.create_tree(
node_id="A", incoming=root_incoming, status_message_id="status_A"
)
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
# Add two sibling replies R1 and R2 under A.
r1_incoming = incoming_message_factory(
text="R1", message_id="R1", reply_to_message_id="A"
)
r2_incoming = incoming_message_factory(
text="R2", message_id="R2", reply_to_message_id="A"
)
_, r1_node = await handler.tree_queue.add_to_tree(
"A", "R1", r1_incoming, "status_R1"
)
_, r2_node = await handler.tree_queue.add_to_tree(
"A", "R2", r2_incoming, "status_R2"
)
# Mock a fresh cli_session per node.
calls = []
async def _get_or_create_session(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
child_sid = f"sess_{prompt}"
async for ev in _gen_session(
[
{"type": "session_info", "session_id": child_sid},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, f"pending_{len(calls) + 1}", True
mock_cli_manager.get_or_create_session = AsyncMock(
side_effect=_get_or_create_session
)
await handler._process_node("R1", r1_node)
await handler._process_node("R2", r2_node)
# Both siblings must resume from the same parent session and fork.
assert calls[0][0] == "R1"
assert calls[0][1] == "sess_A"
assert calls[0][2] is True
assert calls[1][0] == "R2"
assert calls[1][1] == "sess_A"
assert calls[1][2] is True
@pytest.mark.asyncio
async def test_grandchild_reply_forks_from_branch_session(
handler, mock_cli_manager, incoming_message_factory
):
root_incoming = incoming_message_factory(text="A", message_id="A")
tree = await handler.tree_queue.create_tree(
node_id="A", incoming=root_incoming, status_message_id="status_A"
)
await tree.update_state("A", MessageState.COMPLETED, session_id="sess_A")
r1_incoming = incoming_message_factory(
text="R1", message_id="R1", reply_to_message_id="A"
)
_, r1_node = await handler.tree_queue.add_to_tree(
"A", "R1", r1_incoming, "status_R1"
)
calls = []
async def _get_or_create_session(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
# R1 gets its own forked session id.
child_sid = "sess_R1"
async for ev in _gen_session(
[
{"type": "session_info", "session_id": child_sid},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, "pending_R1", True
mock_cli_manager.get_or_create_session = AsyncMock(
side_effect=_get_or_create_session
)
await handler._process_node("R1", r1_node)
assert r1_node.session_id == "sess_R1"
# Grandchild C1 replies to R1 and must fork from sess_R1, not sess_A.
c1_incoming = incoming_message_factory(
text="C1", message_id="C1", reply_to_message_id="R1"
)
_, c1_node = await handler.tree_queue.add_to_tree(
"R1", "C1", c1_incoming, "status_C1"
)
async def _get_or_create_session_c1(session_id=None):
cli_session = MagicMock()
async def _start_task(prompt, session_id=None, fork_session=False):
calls.append((prompt, session_id, fork_session))
async for ev in _gen_session(
[
{"type": "session_info", "session_id": "sess_C1"},
{"type": "exit", "code": 0, "stderr": None},
]
):
yield ev
cli_session.start_task = _start_task
return cli_session, "pending_C1", True
mock_cli_manager.get_or_create_session = AsyncMock(
side_effect=_get_or_create_session_c1
)
await handler._process_node("C1", c1_node)
# The last call should be for C1 and must resume from sess_R1.
assert calls[-1][0] == "C1"
assert calls[-1][1] == "sess_R1"
assert calls[-1][2] is True