Merge pull request #1428 from 3clyp50/dirtyjson

Dispatch tool calls at first completed JSON object
This commit is contained in:
Jan Tomášek 2026-04-03 17:05:48 +02:00 committed by GitHub
commit 2da44168da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 265 additions and 39 deletions

View file

@ -388,6 +388,7 @@ class Agent:
self.context.streaming_agent = self # mark self as current streamer
self.loop_data.iteration += 1
self.loop_data.params_temporary = {} # clear temporary params
last_response_stream_full = ""
# call message_loop_start extensions
await extension.call_extensions_async(
@ -425,12 +426,32 @@ class Agent:
await self.handle_reasoning_stream(stream_data["full"])
async def stream_callback(chunk: str, full: str):
nonlocal last_response_stream_full
await self.handle_intervention()
# output the agent response stream
if chunk == full:
printer.print("Response: ") # start of response
# Pass chunk and full data to extensions for processing
stream_data = {"chunk": chunk, "full": full}
stop_response: str | None = None
snapshot = extract_tools.extract_json_root_string(full)
if snapshot:
parsed_snapshot = extract_tools.json_parse_dirty(snapshot)
if parsed_snapshot is not None:
try:
await self.validate_tool_request(parsed_snapshot)
except Exception:
pass
else:
previous_full = last_response_stream_full
stream_data["full"] = snapshot
if snapshot.startswith(previous_full):
stream_data["chunk"] = snapshot[len(previous_full) :]
else:
stream_data["chunk"] = snapshot
stop_response = snapshot
await extension.call_extensions_async(
"response_stream_chunk",
self,
@ -442,6 +463,9 @@ class Agent:
printer.stream(stream_data["chunk"])
# Use the potentially modified full text for downstream processing
await self.handle_response_stream(stream_data["full"])
last_response_stream_full = stream_data["full"]
if stop_response is not None:
return stop_response
# call main LLM
agent_response, _reasoning = await self.call_chat_model(
@ -770,7 +794,7 @@ class Agent:
async def call_chat_model(
self,
messages: list[BaseMessage],
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
response_callback: Callable[[str, str], Awaitable[str | None]] | None = None,
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
background: bool = False,
explicit_caching: bool = True,

View file

@ -25,6 +25,14 @@ class DirtyJson:
self.current_char = None
self.result = None
self.stack = []
self.completed = False
self._parsing_started = False
def _pop_stack(self, root_closed: bool = False):
"""Pop from the parsing stack and mark completed only on an explicit root close."""
self.stack.pop()
if root_closed and self._parsing_started and not self.stack:
self.completed = True
@staticmethod
def parse_string(json_string):
@ -95,6 +103,8 @@ class DirtyJson:
self._advance()
def _parse(self):
if self.completed and not self.stack:
return
if self.result is None:
self.result = self._parse_value()
else:
@ -102,6 +112,8 @@ class DirtyJson:
def _continue_parsing(self):
while self.current_char is not None:
if self.completed and not self.stack:
return
if isinstance(self.result, dict):
self._parse_object_content()
elif isinstance(self.result, list):
@ -114,7 +126,9 @@ class DirtyJson:
def _parse_value(self):
self._skip_whitespace()
if self.current_char == "{":
if self._peek(1) == "{": # Handle {{
# Only treat doubled braces as a wrapper at the root; nested objects
# must keep their closing braces paired correctly.
if not self.stack and self._peek(1) == "{": # Handle {{
self._advance(2)
return self._parse_object()
elif self.current_char == "[":
@ -153,6 +167,7 @@ class DirtyJson:
obj = {}
self._advance() # Skip opening brace
self.stack.append(obj)
self._parsing_started = True
self._parse_object_content()
return obj
@ -160,14 +175,16 @@ class DirtyJson:
while self.current_char is not None:
self._skip_whitespace()
if self.current_char == "}":
if self._peek(1) == "}": # Handle }}
# Root-level wrapper outputs may end in "}}"; nested objects must
# still close one brace at a time.
if len(self.stack) == 1 and self._peek(1) == "}": # Handle }}
self._advance(2)
else:
self._advance()
self.stack.pop()
self._pop_stack(root_closed=True)
return
if self.current_char is None:
self.stack.pop()
self._pop_stack()
return # End of input reached while parsing object
key = self._parse_key()
@ -190,7 +207,7 @@ class DirtyJson:
continue
elif self.current_char != "}":
if self.current_char is None:
self.stack.pop()
self._pop_stack()
return # End of input reached after value
continue
@ -216,6 +233,7 @@ class DirtyJson:
arr = []
self._advance() # Skip opening bracket
self.stack.append(arr)
self._parsing_started = True
self._parse_array_content()
return arr
@ -224,7 +242,7 @@ class DirtyJson:
self._skip_whitespace()
if self.current_char == "]":
self._advance()
self.stack.pop()
self._pop_stack(root_closed=True)
return
value = self._parse_value()
self.stack[-1].append(value)
@ -236,10 +254,10 @@ class DirtyJson:
if self.current_char is None or self.current_char == "]":
if self.current_char == "]":
self._advance()
self.stack.pop()
self._pop_stack(root_closed=True)
return
elif self.current_char != "]":
self.stack.pop()
self._pop_stack()
return
def _parse_string(self):

View file

@ -19,6 +19,28 @@ def json_parse_dirty(json: str) -> dict[str, Any] | None:
return None
return None
def extract_json_root_string(content: str) -> str | None:
if not content or not isinstance(content, str):
return None
start = content.find("{")
if start == -1:
return None
first_array = content.find("[")
if first_array != -1 and first_array < start:
return None
parser = DirtyJson()
try:
parser.parse(content[start:])
except Exception:
return None
if not parser.completed:
return None
return content[start : start + parser.index]
def extract_json_object_string(content):
start = content.find("{")

View file

@ -475,7 +475,7 @@ class LiteLLMChatWrapper(SimpleChatModel):
system_message="",
user_message="",
messages: List[BaseMessage] | None = None,
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
response_callback: Callable[[str, str], Awaitable[str | None]] | None = None,
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
rate_limiter_callback: (
@ -526,36 +526,46 @@ class LiteLLMChatWrapper(SimpleChatModel):
if stream:
# iterate over chunks
async for chunk in _completion: # type: ignore
got_any_chunk = True
# parse chunk
parsed = _parse_chunk(chunk)
output = result.add_chunk(parsed)
stop_response: str | None = None
try:
async for chunk in _completion: # type: ignore
got_any_chunk = True
# parse chunk
parsed = _parse_chunk(chunk)
output = result.add_chunk(parsed)
# collect reasoning delta and call callbacks
if output["reasoning_delta"]:
if reasoning_callback:
await reasoning_callback(output["reasoning_delta"], result.reasoning)
if tokens_callback:
await tokens_callback(
output["reasoning_delta"],
approximate_tokens(output["reasoning_delta"]),
)
# Add output tokens to rate limiter if configured
if limiter:
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
# collect response delta and call callbacks
if output["response_delta"]:
if response_callback:
await response_callback(output["response_delta"], result.response)
if tokens_callback:
await tokens_callback(
output["response_delta"],
approximate_tokens(output["response_delta"]),
)
# Add output tokens to rate limiter if configured
if limiter:
limiter.add(output=approximate_tokens(output["response_delta"]))
# collect reasoning delta and call callbacks
if output["reasoning_delta"]:
if reasoning_callback:
await reasoning_callback(output["reasoning_delta"], result.reasoning)
if tokens_callback:
await tokens_callback(
output["reasoning_delta"],
approximate_tokens(output["reasoning_delta"]),
)
# Add output tokens to rate limiter if configured
if limiter:
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
# collect response delta and call callbacks
if output["response_delta"]:
if response_callback:
stop_response = await response_callback(
output["response_delta"], result.response
)
if tokens_callback:
await tokens_callback(
output["response_delta"],
approximate_tokens(output["response_delta"]),
)
# Add output tokens to rate limiter if configured
if limiter:
limiter.add(output=approximate_tokens(output["response_delta"]))
if stop_response is not None:
result.response = stop_response
break
finally:
if stop_response is not None and hasattr(_completion, "aclose"):
await _completion.aclose() # type: ignore[attr-defined]
# non-stream response
else:

56
tests/test_dirty_json.py Normal file
View file

@ -0,0 +1,56 @@
from __future__ import annotations
import sys
from pathlib import Path
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from helpers.dirty_json import DirtyJson
@pytest.mark.parametrize(
("payload", "expected"),
[
(
'{"tool_name":"x","tool_args":{}}',
{"tool_name": "x", "tool_args": {}},
),
("[1, 2, 3]", [1, 2, 3]),
],
)
def test_completed_true_when_root_is_explicitly_closed(payload, expected) -> None:
parser = DirtyJson()
assert parser.parse(payload) == expected
assert parser.completed is True
def test_completed_false_when_root_hits_eof_before_closing() -> None:
parser = DirtyJson()
assert parser.parse('{"tool_name":"x","tool_args":{}') == {
"tool_name": "x",
"tool_args": {},
}
assert parser.completed is False
def test_completed_remains_true_after_trailing_content() -> None:
parser = DirtyJson()
assert parser.feed('{"tool_name":"x","tool_args":{}}') == {
"tool_name": "x",
"tool_args": {},
}
assert parser.completed is True
assert parser.feed(" trailing noise") == {
"tool_name": "x",
"tool_args": {},
}
assert parser.completed is True

View file

@ -0,0 +1,96 @@
import sys
from pathlib import Path
import pytest
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
import models
from helpers import extract_tools
def _chunk(content: str) -> dict:
return {"choices": [{"delta": {"content": content}, "message": {}}]}
class _AsyncChunkStream:
def __init__(self, chunks: list[dict]):
self._chunks = chunks
self.index = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self._chunks):
raise StopAsyncIteration
chunk = self._chunks[self.index]
self.index += 1
return chunk
def test_extract_json_root_string_returns_canonical_snapshot():
text = (
'prefix {"tool_name":"response","tool_args":{"text":"brace } inside"}} '
"trailing noise"
)
root = extract_tools.extract_json_root_string(text)
assert root == '{"tool_name":"response","tool_args":{"text":"brace } inside"}}'
assert extract_tools.json_parse_dirty(root)["tool_args"]["text"] == "brace } inside"
assert extract_tools.extract_json_root_string(
'{"tool_name":"response","tool_args":{"text":"missing"'
) is None
assert extract_tools.extract_json_root_string('[{"tool_name":"response"}]') is None
@pytest.mark.asyncio
async def test_unified_call_stops_after_canonical_root_snapshot(monkeypatch):
stream = _AsyncChunkStream(
[
_chunk(
'{"tool_name":"response","tool_args":{"text":"hello"}} trailing text'
),
_chunk(" unreachable"),
]
)
async def fake_acompletion(*args, **kwargs):
assert kwargs["stream"] is True
return stream
async def fake_rate_limiter(*args, **kwargs):
return None
monkeypatch.setattr(models, "acompletion", fake_acompletion)
monkeypatch.setattr(models, "apply_rate_limiter", fake_rate_limiter)
wrapper = models.LiteLLMChatWrapper(
model="test-model",
provider="openai",
model_config=None,
)
seen: list[tuple[str, str]] = []
async def response_callback(chunk: str, full: str):
seen.append((chunk, full))
snapshot = extract_tools.extract_json_root_string(full)
if snapshot:
return snapshot
return None
response, reasoning = await wrapper.unified_call(
messages=[],
response_callback=response_callback,
)
assert response == '{"tool_name":"response","tool_args":{"text":"hello"}}'
assert reasoning == ""
assert stream.index == 1
assert len(seen) == 1
assert seen[0][1] == '{"tool_name":"response","tool_args":{"text":"hello"}} trailing text'