mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 03:59:27 +00:00
multi_agent_chat/middleware: real-graph regression test for partial-pause parallel routing
This commit is contained in:
parent
8e10f38f32
commit
668b89927b
1 changed files with 254 additions and 0 deletions
|
|
@ -0,0 +1,254 @@
|
|||
"""Real-graph contract: one parallel branch completes while a sibling pauses with HITL.
|
||||
|
||||
The two existing parallel-routing tests
|
||||
(``test_parallel_resume_command_keying`` and
|
||||
``test_parallel_heterogeneous_decisions``) both pause **all** branches
|
||||
simultaneously. That's the easy case — every dispatched ``task`` call has a
|
||||
matching pending interrupt, and the routing helpers see a uniform shape.
|
||||
|
||||
Production rarely matches that uniform shape. The orchestrator typically
|
||||
delegates "create a Linear ticket and summarize the user's recent activity":
|
||||
one branch needs HITL, the other returns its result and exits. At the pause
|
||||
moment::
|
||||
|
||||
state.values["messages"] += [ToolMessage(from-A)] # A merged in
|
||||
state.interrupts = [Interrupt(value-from-B)] # B alone is pending
|
||||
|
||||
So ``len(state.interrupts) < num_dispatched_tasks``. The slicer and
|
||||
``build_lg_resume_map`` must:
|
||||
|
||||
1. **Key off ``state.interrupts``, never off the originally dispatched tcids.**
|
||||
A flat decisions list of length 1 must route only to B; if anything tries
|
||||
to look up A in the resume map, langgraph rejects an unknown
|
||||
``Interrupt.id``.
|
||||
2. **Leave A's contributions intact across resume.** A's ToolMessage was
|
||||
committed at the pause; resuming the paused branch must not re-run A nor
|
||||
drop its message.
|
||||
3. **Drain the single pending interrupt.** Final ``state.interrupts`` is
|
||||
empty regardless of whether sibling branches were paused.
|
||||
|
||||
The langgraph semantics this test relies on were verified empirically in the
|
||||
exploratory probe before this test was authored.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Command, Send, interrupt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.resume_routing import (
|
||||
build_lg_resume_map,
|
||||
collect_pending_tool_calls,
|
||||
slice_decisions_by_tool_call,
|
||||
)
|
||||
from app.agents.multi_agent_chat.middleware.main_agent.checkpointed_subagent_middleware.task_tool import (
|
||||
build_task_tool_with_parent_config,
|
||||
)
|
||||
|
||||
|
||||
class _SubState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
|
||||
|
||||
class _DispatchState(TypedDict, total=False):
|
||||
messages: Annotated[list, add_messages]
|
||||
tcid: str
|
||||
desc: str
|
||||
subtype: str
|
||||
|
||||
|
||||
_QUICK_MARKER = "quick-subagent-finished-without-pausing"
|
||||
|
||||
|
||||
def _build_quick_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that completes synchronously without firing any interrupt."""
|
||||
|
||||
def quick_node(_state):
|
||||
return {"messages": [AIMessage(content=_QUICK_MARKER)]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("quick", quick_node)
|
||||
g.add_edge(START, "quick")
|
||||
g.add_edge("quick", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _build_pausing_subagent(checkpointer: InMemorySaver):
|
||||
"""Subagent that pauses with a single-action HITL bundle and records its resume payload."""
|
||||
|
||||
def hitl_node(_state):
|
||||
decision = interrupt(
|
||||
{
|
||||
"action_requests": [
|
||||
{"name": "act_0", "args": {"i": 0}, "description": ""}
|
||||
],
|
||||
"review_configs": [
|
||||
{
|
||||
"action_name": "act_0",
|
||||
"allowed_decisions": ["approve", "reject", "edit"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
return {"messages": [AIMessage(content=json.dumps(decision, sort_keys=True))]}
|
||||
|
||||
g = StateGraph(_SubState)
|
||||
g.add_node("hitl", hitl_node)
|
||||
g.add_edge(START, "hitl")
|
||||
g.add_edge("hitl", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _parent_with_two_branches(task_tool, *, dispatches, checkpointer):
|
||||
def fanout(_state) -> list[Send]:
|
||||
return [Send("call_task", d) for d in dispatches]
|
||||
|
||||
async def call_task(state: _DispatchState, config: RunnableConfig):
|
||||
rt = ToolRuntime(
|
||||
state=state,
|
||||
config=config,
|
||||
context=None,
|
||||
stream_writer=None,
|
||||
tool_call_id=state["tcid"],
|
||||
store=None,
|
||||
)
|
||||
return await task_tool.coroutine(
|
||||
description=state["desc"], subagent_type=state["subtype"], runtime=rt
|
||||
)
|
||||
|
||||
g = StateGraph(_DispatchState)
|
||||
g.add_node("call_task", call_task)
|
||||
g.add_conditional_edges(START, fanout, ["call_task"])
|
||||
g.add_edge("call_task", END)
|
||||
return g.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
def _quick_marker_count(state) -> int:
|
||||
"""How many messages anywhere in parent state contain the quick subagent's marker."""
|
||||
n = 0
|
||||
for msg in state.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str) and _QUICK_MARKER in content:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_pause_routes_only_to_paused_branch_without_rerunning_completed_one():
|
||||
"""One branch completes synchronously; the other pauses with HITL — resume routes only to B.
|
||||
|
||||
Setup:
|
||||
- Sub-A (``quick``): no interrupt, finishes immediately, writes a marker
|
||||
message to parent state.
|
||||
- Sub-B (``pausing``): interrupts with a 1-action HITL bundle.
|
||||
|
||||
At pause, parent state has A's marker already merged in and exactly one
|
||||
pending interrupt (B's). Resume sends a 1-element flat decisions list;
|
||||
the routing helpers must not look up A in the resume map (would explode
|
||||
with an unknown ``Interrupt.id``) and must not re-invoke A on resume
|
||||
(would duplicate the marker).
|
||||
"""
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
quick_sub = _build_quick_subagent(checkpointer)
|
||||
pausing_sub = _build_pausing_subagent(checkpointer)
|
||||
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
[
|
||||
{"name": "quick-agent", "description": "instant", "runnable": quick_sub},
|
||||
{
|
||||
"name": "pausing-agent",
|
||||
"description": "needs review",
|
||||
"runnable": pausing_sub,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
parent = _parent_with_two_branches(
|
||||
task_tool,
|
||||
dispatches=[
|
||||
{"tcid": "tcid-A", "subtype": "quick-agent", "desc": "do A fast"},
|
||||
{
|
||||
"tcid": "tcid-B",
|
||||
"subtype": "pausing-agent",
|
||||
"desc": "needs approval",
|
||||
},
|
||||
],
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
|
||||
config: dict = {
|
||||
"configurable": {"thread_id": "partial-pause-thread"},
|
||||
"recursion_limit": 100,
|
||||
}
|
||||
await parent.ainvoke({"messages": [HumanMessage(content="seed")]}, config)
|
||||
|
||||
paused = await parent.aget_state(config)
|
||||
|
||||
assert len(paused.interrupts) == 1, (
|
||||
f"REGRESSION: expected exactly 1 pending interrupt (sub-B alone), "
|
||||
f"got {len(paused.interrupts)}"
|
||||
)
|
||||
|
||||
pending = collect_pending_tool_calls(paused)
|
||||
assert pending == [("tcid-B", 1)], (
|
||||
f"REGRESSION: pending list contains stale tcids; got {pending!r}"
|
||||
)
|
||||
|
||||
pre_resume_marker_count = _quick_marker_count(paused)
|
||||
assert pre_resume_marker_count == 1, (
|
||||
f"REGRESSION: sub-A's contribution missing or duplicated at pause "
|
||||
f"(found {pre_resume_marker_count}, expected 1)"
|
||||
)
|
||||
|
||||
flat_decisions = [{"type": "approve"}]
|
||||
by_tool_call_id = slice_decisions_by_tool_call(flat_decisions, pending)
|
||||
assert by_tool_call_id == {"tcid-B": {"decisions": [{"type": "approve"}]}}, (
|
||||
f"REGRESSION: slicer routed to a non-pending tcid: {by_tool_call_id!r}"
|
||||
)
|
||||
|
||||
config["configurable"]["surfsense_resume_value"] = by_tool_call_id
|
||||
lg_resume_map = build_lg_resume_map(paused, by_tool_call_id)
|
||||
|
||||
assert set(lg_resume_map.keys()) == {paused.interrupts[0].id}, (
|
||||
f"REGRESSION: resume map keyed by an unknown Interrupt.id "
|
||||
f"(would crash langgraph): {lg_resume_map!r}"
|
||||
)
|
||||
|
||||
await parent.ainvoke(Command(resume=lg_resume_map), config)
|
||||
|
||||
final = await parent.aget_state(config)
|
||||
assert not final.interrupts, (
|
||||
f"REGRESSION: pending interrupts after resume: {final.interrupts!r}"
|
||||
)
|
||||
|
||||
post_resume_marker_count = _quick_marker_count(final)
|
||||
assert post_resume_marker_count == 1, (
|
||||
f"REGRESSION: sub-A re-ran on resume (marker count went "
|
||||
f"{pre_resume_marker_count} → {post_resume_marker_count}); "
|
||||
f"resume must touch only the paused branch."
|
||||
)
|
||||
|
||||
payloads: list[dict] = []
|
||||
for msg in final.values.get("messages", []) or []:
|
||||
content = getattr(msg, "content", None)
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
payloads.append(json.loads(content))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
assert {"decisions": [{"type": "approve"}]} in payloads, (
|
||||
f"REGRESSION: sub-B did not receive its single approve on resume; "
|
||||
f"payloads seen: {payloads!r}"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue