From 668b89927b8e050bfba89a4d22abdfd148b5afb2 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 14 May 2026 09:47:24 +0200 Subject: [PATCH] multi_agent_chat/middleware: real-graph regression test for partial-pause parallel routing --- .../test_parallel_partial_pause_routing.py | 254 ++++++++++++++++++ 1 file changed, 254 insertions(+) create mode 100644 surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py diff --git a/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py new file mode 100644 index 000000000..79210032b --- /dev/null +++ b/surfsense_backend/tests/unit/agents/multi_agent_chat/middleware/checkpointed_subagent_middleware/test_parallel_partial_pause_routing.py @@ -0,0 +1,254 @@ +"""Real-graph contract: one parallel branch completes while a sibling pauses with HITL. + +The two existing parallel-routing tests +(``test_parallel_resume_command_keying`` and +``test_parallel_heterogeneous_decisions``) both pause **all** branches +simultaneously. That's the easy case — every dispatched ``task`` call has a +matching pending interrupt, and the routing helpers see a uniform shape. + +Production rarely matches that uniform shape. The orchestrator typically +delegates "create a Linear ticket and summarize the user's recent activity": +one branch needs HITL, the other returns its result and exits. At the pause +moment:: + + state.values["messages"] += [ToolMessage(from-A)] # A merged in + state.interrupts = [Interrupt(value-from-B)] # B alone is pending + +So ``len(state.interrupts) < num_dispatched_tasks``. The slicer and +``build_lg_resume_map`` must: + +1. **Key off ``state.interrupts``, never off the originally dispatched tcids.** + A flat decisions list of length 1 must route only to B; if anything tries + to look up A in the resume map, langgraph rejects an unknown + ``Interrupt.id``. +2. **Leave A's contributions intact across resume.** A's ToolMessage was + committed at the pause; resuming the paused branch must not re-run A nor + drop its message. +3. **Drain the single pending interrupt.** Final ``state.interrupts`` is + empty regardless of whether sibling branches were paused. + +The langgraph semantics this test relies on were verified empirically in the +exploratory probe before this test was authored. +""" + +from __future__ import annotations + +import json +from typing import Annotated + +import pytest +from langchain.tools import ToolRuntime +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Command, Send, interrupt +from typing_extensions import TypedDict + +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import ( + build_lg_resume_map, + collect_pending_tool_calls, + slice_decisions_by_tool_call, +) +from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import ( + build_task_tool_with_parent_config, +) + + +class _SubState(TypedDict, total=False): + messages: Annotated[list, add_messages] + + +class _DispatchState(TypedDict, total=False): + messages: Annotated[list, add_messages] + tcid: str + desc: str + subtype: str + + +_QUICK_MARKER = "quick-subagent-finished-without-pausing" + + +def _build_quick_subagent(checkpointer: InMemorySaver): + """Subagent that completes synchronously without firing any interrupt.""" + + def quick_node(_state): + return {"messages": [AIMessage(content=_QUICK_MARKER)]} + + g = StateGraph(_SubState) + g.add_node("quick", quick_node) + g.add_edge(START, "quick") + g.add_edge("quick", END) + return g.compile(checkpointer=checkpointer) + + +def _build_pausing_subagent(checkpointer: InMemorySaver): + """Subagent that pauses with a single-action HITL bundle and records its resume payload.""" + + def hitl_node(_state): + decision = interrupt( + { + "action_requests": [ + {"name": "act_0", "args": {"i": 0}, "description": ""} + ], + "review_configs": [ + { + "action_name": "act_0", + "allowed_decisions": ["approve", "reject", "edit"], + } + ], + } + ) + return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]} + + g = StateGraph(_SubState) + g.add_node("hitl", hitl_node) + g.add_edge(START, "hitl") + g.add_edge("hitl", END) + return g.compile(checkpointer=checkpointer) + + +def _parent_with_two_branches(task_tool, *, dispatches, checkpointer): + def fanout(_state) -> list[Send]: + return [Send("call_task", d) for d in dispatches] + + async def call_task(state: _DispatchState, config: RunnableConfig): + rt = ToolRuntime( + state=state, + config=config, + context=None, + stream_writer=None, + tool_call_id=state["tcid"], + store=None, + ) + return await task_tool.coroutine( + description=state["desc"], subagent_type=state["subtype"], runtime=rt + ) + + g = StateGraph(_DispatchState) + g.add_node("call_task", call_task) + g.add_conditional_edges(START, fanout, ["call_task"]) + g.add_edge("call_task", END) + return g.compile(checkpointer=checkpointer) + + +def _quick_marker_count(state) -> int: + """How many messages anywhere in parent state contain the quick subagent's marker.""" + n = 0 + for msg in state.values.get("messages", []) or []: + content = getattr(msg, "content", "") + if isinstance(content, str) and _QUICK_MARKER in content: + n += 1 + return n + + +@pytest.mark.asyncio +async def test_partial_pause_routes_only_to_paused_branch_without_rerunning_completed_one(): + """One branch completes synchronously; the other pauses with HITL — resume routes only to B. + + Setup: + - Sub-A (``quick``): no interrupt, finishes immediately, writes a marker + message to parent state. + - Sub-B (``pausing``): interrupts with a 1-action HITL bundle. + + At pause, parent state has A's marker already merged in and exactly one + pending interrupt (B's). Resume sends a 1-element flat decisions list; + the routing helpers must not look up A in the resume map (would explode + with an unknown ``Interrupt.id``) and must not re-invoke A on resume + (would duplicate the marker). + """ + checkpointer = InMemorySaver() + + quick_sub = _build_quick_subagent(checkpointer) + pausing_sub = _build_pausing_subagent(checkpointer) + + task_tool = build_task_tool_with_parent_config( + [ + {"name": "quick-agent", "description": "instant", "runnable": quick_sub}, + { + "name": "pausing-agent", + "description": "needs review", + "runnable": pausing_sub, + }, + ] + ) + + parent = _parent_with_two_branches( + task_tool, + dispatches=[ + {"tcid": "tcid-A", "subtype": "quick-agent", "desc": "do A fast"}, + { + "tcid": "tcid-B", + "subtype": "pausing-agent", + "desc": "needs approval", + }, + ], + checkpointer=checkpointer, + ) + + config: dict = { + "configurable": {"thread_id": "partial-pause-thread"}, + "recursion_limit": 100, + } + await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config) + + paused = await parent.aget_state(config) + + assert len(paused.interrupts) == 1, ( + f"REGRESSION: expected exactly 1 pending interrupt (sub-B alone), " + f"got {len(paused.interrupts)}" + ) + + pending = collect_pending_tool_calls(paused) + assert pending == [("tcid-B", 1)], ( + f"REGRESSION: pending list contains stale tcids; got {pending!r}" + ) + + pre_resume_marker_count = _quick_marker_count(paused) + assert pre_resume_marker_count == 1, ( + f"REGRESSION: sub-A's contribution missing or duplicated at pause " + f"(found {pre_resume_marker_count}, expected 1)" + ) + + flat_decisions = [{"type": "approve"}] + by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending) + assert by_tool_call_id == {"tcid-B": {"decisions": [{"type": "approve"}]}}, ( + f"REGRESSION: slicer routed to a non-pending tcid: {by_tool_call_id!r}" + ) + + config["configurable"]["surfsense_resume_value"] = by_tool_call_id + lg_resume_map = build_lg_resume_map(paused, by_tool_call_id) + + assert set(lg_resume_map.keys()) == {paused.interrupts[0].id}, ( + f"REGRESSION: resume map keyed by an unknown Interrupt.id " + f"(would crash langgraph): {lg_resume_map!r}" + ) + + await parent.ainvoke(Command(resume=lg_resume_map), config) + + final = await parent.aget_state(config) + assert not final.interrupts, ( + f"REGRESSION: pending interrupts after resume: {final.interrupts!r}" + ) + + post_resume_marker_count = _quick_marker_count(final) + assert post_resume_marker_count == 1, ( + f"REGRESSION: sub-A re-ran on resume (marker count went " + f"{pre_resume_marker_count} → {post_resume_marker_count}); " + f"resume must touch only the paused branch." + ) + + payloads: list[dict] = [] + for msg in final.values.get("messages", []) or []: + content = getattr(msg, "content", None) + if isinstance(content, str): + try: + payloads.append(json.loads(content)) + except json.JSONDecodeError: + pass + + assert {"decisions": [{"type": "approve"}]} in payloads, ( + f"REGRESSION: sub-B did not receive its single approve on resume; " + f"payloads seen: {payloads!r}" + )