mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 03:18:31 +00:00
fix(models): preserve reasoning_content in patched openai payloads
This commit is contained in:
parent
b107444878
commit
d59fa191c0
2 changed files with 102 additions and 22 deletions
|
|
@ -1,15 +1,19 @@
|
|||
"""Patched ChatOpenAI that preserves thought_signature for Gemini thinking models.
|
||||
"""Patched ChatOpenAI that preserves non-standard thinking fields.
|
||||
|
||||
When using Gemini with thinking enabled via an OpenAI-compatible gateway (e.g.
|
||||
Vertex AI, Google AI Studio, or any proxy), the API requires that the
|
||||
``thought_signature`` field on tool-call objects is echoed back verbatim in
|
||||
every subsequent request.
|
||||
|
||||
The OpenAI-compatible gateway stores the raw tool-call dicts (including
|
||||
``thought_signature``) in ``additional_kwargs["tool_calls"]``, but standard
|
||||
``langchain_openai.ChatOpenAI`` only serialises the standard fields (``id``,
|
||||
``type``, ``function``) into the outgoing payload, silently dropping the
|
||||
signature. That causes an HTTP 400 ``INVALID_ARGUMENT`` error:
|
||||
OpenAI-compatible gateways often return assistant-only metadata that must be
|
||||
echoed back on later turns:
|
||||
|
||||
- Gemini-style ``thought_signature`` on tool calls
|
||||
- DeepSeek/Kimi-style ``reasoning_content`` on assistant messages
|
||||
|
||||
Standard ``langchain_openai.ChatOpenAI`` serializes only the standard fields,
|
||||
silently dropping those gateway-specific thinking fields. That causes request
|
||||
validation failures in multi-turn tool-call flows.
|
||||
|
||||
Unable to submit request because function call `<tool>` in the N. content
|
||||
block is missing a `thought_signature`.
|
||||
|
|
@ -29,13 +33,13 @@ from langchain_openai import ChatOpenAI
|
|||
|
||||
|
||||
class PatchedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
|
||||
"""ChatOpenAI with gateway-specific thinking field preservation.
|
||||
|
||||
When using Gemini with thinking enabled via an OpenAI-compatible gateway,
|
||||
the API expects ``thought_signature`` to be present on tool-call objects in
|
||||
multi-turn conversations. This patched version restores those signatures
|
||||
from ``AIMessage.additional_kwargs["tool_calls"]`` into the serialised
|
||||
request payload before it is sent to the API.
|
||||
When using thinking-enabled models via an OpenAI-compatible gateway, the
|
||||
API may expect prior assistant metadata to be echoed back verbatim in
|
||||
subsequent requests. This patched version restores those fields from the
|
||||
original ``AIMessage`` objects into the serialized request payload before
|
||||
it is sent to the API.
|
||||
|
||||
Usage in ``config.yaml``::
|
||||
|
||||
|
|
@ -80,17 +84,30 @@ class PatchedChatOpenAI(ChatOpenAI):
|
|||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_tool_call_signatures(payload_msg, orig_msg)
|
||||
_restore_assistant_gateway_fields(payload_msg, orig_msg)
|
||||
else:
|
||||
# Fallback: match assistant-role entries positionally against AIMessages.
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_tool_call_signatures(payload_msg, ai_msg)
|
||||
_restore_assistant_gateway_fields(payload_msg, ai_msg)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _restore_assistant_gateway_fields(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
"""Re-inject non-standard assistant fields required by OpenAI-compatible gateways."""
|
||||
_restore_reasoning_content(payload_msg, orig_msg)
|
||||
_restore_tool_call_signatures(payload_msg, orig_msg)
|
||||
|
||||
|
||||
def _restore_reasoning_content(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
"""Re-inject ``reasoning_content`` onto outgoing assistant messages."""
|
||||
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_msg["reasoning_content"] = reasoning_content
|
||||
|
||||
|
||||
def _restore_tool_call_signatures(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
"""Re-inject ``thought_signature`` onto tool-call objects in *payload_msg*.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,25 @@
|
|||
"""Tests for deerflow.models.patched_openai.PatchedChatOpenAI.
|
||||
|
||||
These tests verify that _restore_tool_call_signatures correctly re-injects
|
||||
``thought_signature`` onto tool-call objects stored in
|
||||
``additional_kwargs["tool_calls"]``, covering id-based matching, positional
|
||||
fallback, camelCase keys, and several edge-cases.
|
||||
These tests verify that the patched provider correctly re-injects gateway-
|
||||
specific assistant fields, covering:
|
||||
|
||||
- ``reasoning_content`` restoration onto assistant messages
|
||||
- ``thought_signature`` restoration onto tool calls
|
||||
- id-based matching, positional fallback, camelCase keys, and edge-cases
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.models.patched_openai import _restore_tool_call_signatures
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.models.patched_openai import (
|
||||
PatchedChatOpenAI,
|
||||
_restore_reasoning_content,
|
||||
_restore_tool_call_signatures,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
|
@ -46,11 +55,40 @@ def _ai_msg_with_raw_tool_calls(raw_tool_calls: list[dict]) -> AIMessage:
|
|||
return AIMessage(content="", additional_kwargs={"tool_calls": raw_tool_calls})
|
||||
|
||||
|
||||
def _make_model(**kwargs):
|
||||
return PatchedChatOpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="test-key",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core: signed tool-call restoration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reasoning_content_restored_on_assistant_message():
|
||||
payload_msg = {"role": "assistant", "content": "Answer"}
|
||||
orig = AIMessage(
|
||||
content="Answer",
|
||||
additional_kwargs={"reasoning_content": "Reason first, answer second."},
|
||||
)
|
||||
|
||||
_restore_reasoning_content(payload_msg, orig)
|
||||
|
||||
assert payload_msg["reasoning_content"] == "Reason first, answer second."
|
||||
|
||||
|
||||
def test_reasoning_content_noop_when_absent():
|
||||
payload_msg = {"role": "assistant", "content": "Answer"}
|
||||
orig = AIMessage(content="Answer", additional_kwargs={})
|
||||
|
||||
_restore_reasoning_content(payload_msg, orig)
|
||||
|
||||
assert "reasoning_content" not in payload_msg
|
||||
|
||||
|
||||
def test_tool_call_signature_restored_by_id():
|
||||
"""thought_signature is copied to the payload tool-call matched by id."""
|
||||
payload_msg = {"role": "assistant", "content": None, "tool_calls": [PAYLOAD_TC_1.copy()]}
|
||||
|
|
@ -172,5 +210,30 @@ def test_tool_call_multiple_sequential_signatures():
|
|||
assert payload_tc_b["thought_signature"] == "SIG_STEP2=="
|
||||
|
||||
|
||||
# Integration behavior for PatchedChatOpenAI is validated indirectly via
|
||||
# _restore_tool_call_signatures unit coverage above.
|
||||
def test_get_request_payload_restores_reasoning_content_and_tool_signatures():
|
||||
model = _make_model()
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [PAYLOAD_TC_1.copy()],
|
||||
}
|
||||
]
|
||||
}
|
||||
orig = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "Plan the next tool call carefully.",
|
||||
"tool_calls": [RAW_TC_SIGNED],
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(ChatOpenAI, "_get_request_payload", return_value=payload):
|
||||
with patch.object(model, "_convert_input") as mock_convert:
|
||||
mock_convert.return_value = MagicMock(to_messages=lambda: [orig])
|
||||
result = model._get_request_payload([orig])
|
||||
|
||||
assistant_msg = result["messages"][0]
|
||||
assert assistant_msg["reasoning_content"] == "Plan the next tool call carefully."
|
||||
assert assistant_msg["tool_calls"][0]["thought_signature"] == "SIG_A=="
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue