mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-19 08:00:25 +00:00
common/autoparser: fixes for newline handling / forced tool calls (#22654)
* chat/autoparser: the fixes * Move optspace() to chat-peg-parser, comment out server tests invalidated due to content now allowed with forced tool calls. * Trim whitespace on apply instead
This commit is contained in:
parent
994118a183
commit
a4701c98f7
10 changed files with 392 additions and 97 deletions
|
|
@ -79,7 +79,7 @@ def print_info(msg):
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def chat_completion(url, messages, tools=None, stream=False):
|
||||
def chat_completion(url, messages, tools=None, stream=False, force_tools=False):
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
|
|
@ -87,7 +87,10 @@ def chat_completion(url, messages, tools=None, stream=False):
|
|||
}
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
if force_tools:
|
||||
payload["tool_choice"] = "required"
|
||||
else:
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, stream=stream)
|
||||
|
|
@ -160,7 +163,13 @@ def chat_completion(url, messages, tools=None, stream=False):
|
|||
return result
|
||||
|
||||
|
||||
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6):
|
||||
def all_tools_called(tools, all_tool_calls):
|
||||
all_tool_names = set([tc["function"]["name"] for tc in tools])
|
||||
all_called_tool_names = set([tc["function"]["name"] for tc in all_tool_calls])
|
||||
return all_tool_names == all_called_tool_names
|
||||
|
||||
|
||||
def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turns=6, force_tools=False):
|
||||
"""
|
||||
Drive the multi-turn tool-call loop:
|
||||
1. Send messages to model.
|
||||
|
|
@ -172,8 +181,8 @@ def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turn
|
|||
msgs = list(messages)
|
||||
all_tool_calls: list[dict] = []
|
||||
|
||||
for _ in range(max_turns):
|
||||
result = chat_completion(url, msgs, tools=tools, stream=stream)
|
||||
for t in range(max_turns):
|
||||
result = chat_completion(url, msgs, tools=tools, stream=stream, force_tools=(force_tools and not all_tools_called(tools, all_tool_calls)))
|
||||
if result is None:
|
||||
return all_tool_calls, None
|
||||
|
||||
|
|
@ -235,10 +244,10 @@ def run_agentic_loop(url, messages, tools, mock_tool_responses, stream, max_turn
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_test(url, test_case, stream):
|
||||
def run_test(url, test_case, stream, force_tools):
|
||||
name = test_case["name"]
|
||||
mode = f"{'stream' if stream else 'non-stream'}"
|
||||
print_header(f"{name} [{mode}]")
|
||||
print_header(f"{name} [{mode}, force_tools={force_tools}] ")
|
||||
|
||||
all_tool_calls, final_content = run_agentic_loop(
|
||||
url,
|
||||
|
|
@ -246,6 +255,7 @@ def run_test(url, test_case, stream):
|
|||
tools=test_case["tools"],
|
||||
mock_tool_responses=test_case["mock_tool_responses"],
|
||||
stream=stream,
|
||||
force_tools=force_tools
|
||||
)
|
||||
|
||||
if final_content is None and not all_tool_calls:
|
||||
|
|
@ -1093,6 +1103,9 @@ def main():
|
|||
parser.add_argument(
|
||||
"--stream-only", action="store_true", help="Only run streaming mode tests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-tools", action="store_true", help="Change tool mode to forced instead of auto"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
help="Run only the test whose name contains this substring (case-insensitive)",
|
||||
|
|
@ -1103,10 +1116,13 @@ def main():
|
|||
print_info(f"Testing server at {url}")
|
||||
|
||||
modes = []
|
||||
force_tools = False
|
||||
if not args.stream_only:
|
||||
modes.append(False)
|
||||
if not args.no_stream:
|
||||
modes.append(True)
|
||||
if args.force_tools:
|
||||
force_tools = True
|
||||
|
||||
cases: list[dict] = ALL_TEST_CASES
|
||||
if args.test:
|
||||
|
|
@ -1121,7 +1137,7 @@ def main():
|
|||
for stream in modes:
|
||||
for case in cases:
|
||||
total += 1
|
||||
if run_test(url, case, stream=stream):
|
||||
if run_test(url, case, stream=stream, force_tools=force_tools):
|
||||
passed += 1
|
||||
|
||||
color = GREEN if passed == total else RED
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue