mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-04-28 03:30:23 +00:00
Merge pull request #1428 from 3clyp50/dirtyjson
Dispatch tool calls at first completed JSON object
This commit is contained in:
commit
2da44168da
6 changed files with 265 additions and 39 deletions
26
agent.py
26
agent.py
|
|
@ -388,6 +388,7 @@ class Agent:
|
|||
self.context.streaming_agent = self # mark self as current streamer
|
||||
self.loop_data.iteration += 1
|
||||
self.loop_data.params_temporary = {} # clear temporary params
|
||||
last_response_stream_full = ""
|
||||
|
||||
# call message_loop_start extensions
|
||||
await extension.call_extensions_async(
|
||||
|
|
@ -425,12 +426,32 @@ class Agent:
|
|||
await self.handle_reasoning_stream(stream_data["full"])
|
||||
|
||||
async def stream_callback(chunk: str, full: str):
|
||||
nonlocal last_response_stream_full
|
||||
await self.handle_intervention()
|
||||
# output the agent response stream
|
||||
if chunk == full:
|
||||
printer.print("Response: ") # start of response
|
||||
# Pass chunk and full data to extensions for processing
|
||||
stream_data = {"chunk": chunk, "full": full}
|
||||
stop_response: str | None = None
|
||||
|
||||
snapshot = extract_tools.extract_json_root_string(full)
|
||||
if snapshot:
|
||||
parsed_snapshot = extract_tools.json_parse_dirty(snapshot)
|
||||
if parsed_snapshot is not None:
|
||||
try:
|
||||
await self.validate_tool_request(parsed_snapshot)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
previous_full = last_response_stream_full
|
||||
stream_data["full"] = snapshot
|
||||
if snapshot.startswith(previous_full):
|
||||
stream_data["chunk"] = snapshot[len(previous_full) :]
|
||||
else:
|
||||
stream_data["chunk"] = snapshot
|
||||
stop_response = snapshot
|
||||
|
||||
await extension.call_extensions_async(
|
||||
"response_stream_chunk",
|
||||
self,
|
||||
|
|
@ -442,6 +463,9 @@ class Agent:
|
|||
printer.stream(stream_data["chunk"])
|
||||
# Use the potentially modified full text for downstream processing
|
||||
await self.handle_response_stream(stream_data["full"])
|
||||
last_response_stream_full = stream_data["full"]
|
||||
if stop_response is not None:
|
||||
return stop_response
|
||||
|
||||
# call main LLM
|
||||
agent_response, _reasoning = await self.call_chat_model(
|
||||
|
|
@ -770,7 +794,7 @@ class Agent:
|
|||
async def call_chat_model(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
||||
response_callback: Callable[[str, str], Awaitable[str | None]] | None = None,
|
||||
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
||||
background: bool = False,
|
||||
explicit_caching: bool = True,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,14 @@ class DirtyJson:
|
|||
self.current_char = None
|
||||
self.result = None
|
||||
self.stack = []
|
||||
self.completed = False
|
||||
self._parsing_started = False
|
||||
|
||||
def _pop_stack(self, root_closed: bool = False):
|
||||
"""Pop from the parsing stack and mark completed only on an explicit root close."""
|
||||
self.stack.pop()
|
||||
if root_closed and self._parsing_started and not self.stack:
|
||||
self.completed = True
|
||||
|
||||
@staticmethod
|
||||
def parse_string(json_string):
|
||||
|
|
@ -95,6 +103,8 @@ class DirtyJson:
|
|||
self._advance()
|
||||
|
||||
def _parse(self):
|
||||
if self.completed and not self.stack:
|
||||
return
|
||||
if self.result is None:
|
||||
self.result = self._parse_value()
|
||||
else:
|
||||
|
|
@ -102,6 +112,8 @@ class DirtyJson:
|
|||
|
||||
def _continue_parsing(self):
|
||||
while self.current_char is not None:
|
||||
if self.completed and not self.stack:
|
||||
return
|
||||
if isinstance(self.result, dict):
|
||||
self._parse_object_content()
|
||||
elif isinstance(self.result, list):
|
||||
|
|
@ -114,7 +126,9 @@ class DirtyJson:
|
|||
def _parse_value(self):
|
||||
self._skip_whitespace()
|
||||
if self.current_char == "{":
|
||||
if self._peek(1) == "{": # Handle {{
|
||||
# Only treat doubled braces as a wrapper at the root; nested objects
|
||||
# must keep their closing braces paired correctly.
|
||||
if not self.stack and self._peek(1) == "{": # Handle {{
|
||||
self._advance(2)
|
||||
return self._parse_object()
|
||||
elif self.current_char == "[":
|
||||
|
|
@ -153,6 +167,7 @@ class DirtyJson:
|
|||
obj = {}
|
||||
self._advance() # Skip opening brace
|
||||
self.stack.append(obj)
|
||||
self._parsing_started = True
|
||||
self._parse_object_content()
|
||||
return obj
|
||||
|
||||
|
|
@ -160,14 +175,16 @@ class DirtyJson:
|
|||
while self.current_char is not None:
|
||||
self._skip_whitespace()
|
||||
if self.current_char == "}":
|
||||
if self._peek(1) == "}": # Handle }}
|
||||
# Root-level wrapper outputs may end in "}}"; nested objects must
|
||||
# still close one brace at a time.
|
||||
if len(self.stack) == 1 and self._peek(1) == "}": # Handle }}
|
||||
self._advance(2)
|
||||
else:
|
||||
self._advance()
|
||||
self.stack.pop()
|
||||
self._pop_stack(root_closed=True)
|
||||
return
|
||||
if self.current_char is None:
|
||||
self.stack.pop()
|
||||
self._pop_stack()
|
||||
return # End of input reached while parsing object
|
||||
|
||||
key = self._parse_key()
|
||||
|
|
@ -190,7 +207,7 @@ class DirtyJson:
|
|||
continue
|
||||
elif self.current_char != "}":
|
||||
if self.current_char is None:
|
||||
self.stack.pop()
|
||||
self._pop_stack()
|
||||
return # End of input reached after value
|
||||
continue
|
||||
|
||||
|
|
@ -216,6 +233,7 @@ class DirtyJson:
|
|||
arr = []
|
||||
self._advance() # Skip opening bracket
|
||||
self.stack.append(arr)
|
||||
self._parsing_started = True
|
||||
self._parse_array_content()
|
||||
return arr
|
||||
|
||||
|
|
@ -224,7 +242,7 @@ class DirtyJson:
|
|||
self._skip_whitespace()
|
||||
if self.current_char == "]":
|
||||
self._advance()
|
||||
self.stack.pop()
|
||||
self._pop_stack(root_closed=True)
|
||||
return
|
||||
value = self._parse_value()
|
||||
self.stack[-1].append(value)
|
||||
|
|
@ -236,10 +254,10 @@ class DirtyJson:
|
|||
if self.current_char is None or self.current_char == "]":
|
||||
if self.current_char == "]":
|
||||
self._advance()
|
||||
self.stack.pop()
|
||||
self._pop_stack(root_closed=True)
|
||||
return
|
||||
elif self.current_char != "]":
|
||||
self.stack.pop()
|
||||
self._pop_stack()
|
||||
return
|
||||
|
||||
def _parse_string(self):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,28 @@ def json_parse_dirty(json: str) -> dict[str, Any] | None:
|
|||
return None
|
||||
return None
|
||||
|
||||
def extract_json_root_string(content: str) -> str | None:
|
||||
if not content or not isinstance(content, str):
|
||||
return None
|
||||
|
||||
start = content.find("{")
|
||||
if start == -1:
|
||||
return None
|
||||
first_array = content.find("[")
|
||||
if first_array != -1 and first_array < start:
|
||||
return None
|
||||
|
||||
parser = DirtyJson()
|
||||
try:
|
||||
parser.parse(content[start:])
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not parser.completed:
|
||||
return None
|
||||
|
||||
return content[start : start + parser.index]
|
||||
|
||||
|
||||
def extract_json_object_string(content):
|
||||
start = content.find("{")
|
||||
|
|
|
|||
70
models.py
70
models.py
|
|
@ -475,7 +475,7 @@ class LiteLLMChatWrapper(SimpleChatModel):
|
|||
system_message="",
|
||||
user_message="",
|
||||
messages: List[BaseMessage] | None = None,
|
||||
response_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
||||
response_callback: Callable[[str, str], Awaitable[str | None]] | None = None,
|
||||
reasoning_callback: Callable[[str, str], Awaitable[None]] | None = None,
|
||||
tokens_callback: Callable[[str, int], Awaitable[None]] | None = None,
|
||||
rate_limiter_callback: (
|
||||
|
|
@ -526,36 +526,46 @@ class LiteLLMChatWrapper(SimpleChatModel):
|
|||
|
||||
if stream:
|
||||
# iterate over chunks
|
||||
async for chunk in _completion: # type: ignore
|
||||
got_any_chunk = True
|
||||
# parse chunk
|
||||
parsed = _parse_chunk(chunk)
|
||||
output = result.add_chunk(parsed)
|
||||
stop_response: str | None = None
|
||||
try:
|
||||
async for chunk in _completion: # type: ignore
|
||||
got_any_chunk = True
|
||||
# parse chunk
|
||||
parsed = _parse_chunk(chunk)
|
||||
output = result.add_chunk(parsed)
|
||||
|
||||
# collect reasoning delta and call callbacks
|
||||
if output["reasoning_delta"]:
|
||||
if reasoning_callback:
|
||||
await reasoning_callback(output["reasoning_delta"], result.reasoning)
|
||||
if tokens_callback:
|
||||
await tokens_callback(
|
||||
output["reasoning_delta"],
|
||||
approximate_tokens(output["reasoning_delta"]),
|
||||
)
|
||||
# Add output tokens to rate limiter if configured
|
||||
if limiter:
|
||||
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
|
||||
# collect response delta and call callbacks
|
||||
if output["response_delta"]:
|
||||
if response_callback:
|
||||
await response_callback(output["response_delta"], result.response)
|
||||
if tokens_callback:
|
||||
await tokens_callback(
|
||||
output["response_delta"],
|
||||
approximate_tokens(output["response_delta"]),
|
||||
)
|
||||
# Add output tokens to rate limiter if configured
|
||||
if limiter:
|
||||
limiter.add(output=approximate_tokens(output["response_delta"]))
|
||||
# collect reasoning delta and call callbacks
|
||||
if output["reasoning_delta"]:
|
||||
if reasoning_callback:
|
||||
await reasoning_callback(output["reasoning_delta"], result.reasoning)
|
||||
if tokens_callback:
|
||||
await tokens_callback(
|
||||
output["reasoning_delta"],
|
||||
approximate_tokens(output["reasoning_delta"]),
|
||||
)
|
||||
# Add output tokens to rate limiter if configured
|
||||
if limiter:
|
||||
limiter.add(output=approximate_tokens(output["reasoning_delta"]))
|
||||
# collect response delta and call callbacks
|
||||
if output["response_delta"]:
|
||||
if response_callback:
|
||||
stop_response = await response_callback(
|
||||
output["response_delta"], result.response
|
||||
)
|
||||
if tokens_callback:
|
||||
await tokens_callback(
|
||||
output["response_delta"],
|
||||
approximate_tokens(output["response_delta"]),
|
||||
)
|
||||
# Add output tokens to rate limiter if configured
|
||||
if limiter:
|
||||
limiter.add(output=approximate_tokens(output["response_delta"]))
|
||||
if stop_response is not None:
|
||||
result.response = stop_response
|
||||
break
|
||||
finally:
|
||||
if stop_response is not None and hasattr(_completion, "aclose"):
|
||||
await _completion.aclose() # type: ignore[attr-defined]
|
||||
|
||||
# non-stream response
|
||||
else:
|
||||
|
|
|
|||
56
tests/test_dirty_json.py
Normal file
56
tests/test_dirty_json.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from helpers.dirty_json import DirtyJson
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("payload", "expected"),
|
||||
[
|
||||
(
|
||||
'{"tool_name":"x","tool_args":{}}',
|
||||
{"tool_name": "x", "tool_args": {}},
|
||||
),
|
||||
("[1, 2, 3]", [1, 2, 3]),
|
||||
],
|
||||
)
|
||||
def test_completed_true_when_root_is_explicitly_closed(payload, expected) -> None:
|
||||
parser = DirtyJson()
|
||||
|
||||
assert parser.parse(payload) == expected
|
||||
assert parser.completed is True
|
||||
|
||||
|
||||
def test_completed_false_when_root_hits_eof_before_closing() -> None:
|
||||
parser = DirtyJson()
|
||||
|
||||
assert parser.parse('{"tool_name":"x","tool_args":{}') == {
|
||||
"tool_name": "x",
|
||||
"tool_args": {},
|
||||
}
|
||||
assert parser.completed is False
|
||||
|
||||
|
||||
def test_completed_remains_true_after_trailing_content() -> None:
|
||||
parser = DirtyJson()
|
||||
|
||||
assert parser.feed('{"tool_name":"x","tool_args":{}}') == {
|
||||
"tool_name": "x",
|
||||
"tool_args": {},
|
||||
}
|
||||
assert parser.completed is True
|
||||
|
||||
assert parser.feed(" trailing noise") == {
|
||||
"tool_name": "x",
|
||||
"tool_args": {},
|
||||
}
|
||||
|
||||
assert parser.completed is True
|
||||
96
tests/test_stream_tool_early_stop.py
Normal file
96
tests/test_stream_tool_early_stop.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
import models
|
||||
from helpers import extract_tools
|
||||
|
||||
|
||||
def _chunk(content: str) -> dict:
|
||||
return {"choices": [{"delta": {"content": content}, "message": {}}]}
|
||||
|
||||
|
||||
class _AsyncChunkStream:
|
||||
def __init__(self, chunks: list[dict]):
|
||||
self._chunks = chunks
|
||||
self.index = 0
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.index >= len(self._chunks):
|
||||
raise StopAsyncIteration
|
||||
chunk = self._chunks[self.index]
|
||||
self.index += 1
|
||||
return chunk
|
||||
|
||||
|
||||
def test_extract_json_root_string_returns_canonical_snapshot():
|
||||
text = (
|
||||
'prefix {"tool_name":"response","tool_args":{"text":"brace } inside"}} '
|
||||
"trailing noise"
|
||||
)
|
||||
|
||||
root = extract_tools.extract_json_root_string(text)
|
||||
|
||||
assert root == '{"tool_name":"response","tool_args":{"text":"brace } inside"}}'
|
||||
assert extract_tools.json_parse_dirty(root)["tool_args"]["text"] == "brace } inside"
|
||||
assert extract_tools.extract_json_root_string(
|
||||
'{"tool_name":"response","tool_args":{"text":"missing"'
|
||||
) is None
|
||||
assert extract_tools.extract_json_root_string('[{"tool_name":"response"}]') is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_call_stops_after_canonical_root_snapshot(monkeypatch):
|
||||
stream = _AsyncChunkStream(
|
||||
[
|
||||
_chunk(
|
||||
'{"tool_name":"response","tool_args":{"text":"hello"}} trailing text'
|
||||
),
|
||||
_chunk(" unreachable"),
|
||||
]
|
||||
)
|
||||
|
||||
async def fake_acompletion(*args, **kwargs):
|
||||
assert kwargs["stream"] is True
|
||||
return stream
|
||||
|
||||
async def fake_rate_limiter(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(models, "acompletion", fake_acompletion)
|
||||
monkeypatch.setattr(models, "apply_rate_limiter", fake_rate_limiter)
|
||||
|
||||
wrapper = models.LiteLLMChatWrapper(
|
||||
model="test-model",
|
||||
provider="openai",
|
||||
model_config=None,
|
||||
)
|
||||
|
||||
seen: list[tuple[str, str]] = []
|
||||
|
||||
async def response_callback(chunk: str, full: str):
|
||||
seen.append((chunk, full))
|
||||
snapshot = extract_tools.extract_json_root_string(full)
|
||||
if snapshot:
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
response, reasoning = await wrapper.unified_call(
|
||||
messages=[],
|
||||
response_callback=response_callback,
|
||||
)
|
||||
|
||||
assert response == '{"tool_name":"response","tool_args":{"text":"hello"}}'
|
||||
assert reasoning == ""
|
||||
assert stream.index == 1
|
||||
assert len(seen) == 1
|
||||
assert seen[0][1] == '{"tool_name":"response","tool_args":{"text":"hello"}} trailing text'
|
||||
Loading…
Add table
Add a link
Reference in a new issue