mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 09:34:37 +00:00
server: fix OpenAI API compatibility for usage statistics in chat streams (#15444)
This commit is contained in:
parent
13aeb7aef2
commit
1bc664a26a
3 changed files with 105 additions and 82 deletions
|
@ -911,6 +911,17 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
{"model", oaicompat_model},
|
{"model", oaicompat_model},
|
||||||
{"system_fingerprint", build_info},
|
{"system_fingerprint", build_info},
|
||||||
{"object", "chat.completion.chunk"},
|
{"object", "chat.completion.chunk"},
|
||||||
|
});
|
||||||
|
|
||||||
|
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
|
||||||
|
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
|
||||||
|
deltas.push_back({
|
||||||
|
{"choices", json::array()},
|
||||||
|
{"created", t},
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"system_fingerprint", build_info},
|
||||||
|
{"object", "chat.completion.chunk"},
|
||||||
{"usage", json {
|
{"usage", json {
|
||||||
{"completion_tokens", n_decoded},
|
{"completion_tokens", n_decoded},
|
||||||
{"prompt_tokens", n_prompt_tokens},
|
{"prompt_tokens", n_prompt_tokens},
|
||||||
|
|
|
@ -72,6 +72,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||||
content = ""
|
content = ""
|
||||||
last_cmpl_id = None
|
last_cmpl_id = None
|
||||||
for i, data in enumerate(res):
|
for i, data in enumerate(res):
|
||||||
|
if data["choices"]:
|
||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Check first role message for stream=True
|
# Check first role message for stream=True
|
||||||
|
@ -85,14 +86,15 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||||
last_cmpl_id = data["id"]
|
last_cmpl_id = data["id"]
|
||||||
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
|
||||||
if choice["finish_reason"] in ["stop", "length"]:
|
if choice["finish_reason"] in ["stop", "length"]:
|
||||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
|
||||||
assert data["usage"]["completion_tokens"] == n_predicted
|
|
||||||
assert "content" not in choice["delta"]
|
assert "content" not in choice["delta"]
|
||||||
assert match_regex(re_content, content)
|
assert match_regex(re_content, content)
|
||||||
assert choice["finish_reason"] == finish_reason
|
assert choice["finish_reason"] == finish_reason
|
||||||
else:
|
else:
|
||||||
assert choice["finish_reason"] is None
|
assert choice["finish_reason"] is None
|
||||||
content += choice["delta"]["content"] or ''
|
content += choice["delta"]["content"] or ''
|
||||||
|
else:
|
||||||
|
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||||
|
assert data["usage"]["completion_tokens"] == n_predicted
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_with_openai_library():
|
def test_chat_completion_with_openai_library():
|
||||||
|
@ -278,7 +280,9 @@ def test_chat_completion_with_timings_per_token():
|
||||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||||
assert "timings" not in data, f'First event should not have timings: {data}'
|
assert "timings" not in data, f'First event should not have timings: {data}'
|
||||||
else:
|
else:
|
||||||
|
if data["choices"]:
|
||||||
assert "role" not in data["choices"][0]["delta"]
|
assert "role" not in data["choices"][0]["delta"]
|
||||||
|
else:
|
||||||
assert "timings" in data
|
assert "timings" in data
|
||||||
assert "prompt_per_second" in data["timings"]
|
assert "prompt_per_second" in data["timings"]
|
||||||
assert "predicted_per_second" in data["timings"]
|
assert "predicted_per_second" in data["timings"]
|
||||||
|
@ -332,6 +336,7 @@ def test_logprobs_stream():
|
||||||
output_text = ''
|
output_text = ''
|
||||||
aggregated_text = ''
|
aggregated_text = ''
|
||||||
for i, data in enumerate(res):
|
for i, data in enumerate(res):
|
||||||
|
if data.choices:
|
||||||
choice = data.choices[0]
|
choice = data.choices[0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Check first role message for stream=True
|
# Check first role message for stream=True
|
||||||
|
|
|
@ -318,6 +318,7 @@ class ServerProcess:
|
||||||
arguments_parts = 0
|
arguments_parts = 0
|
||||||
|
|
||||||
for chunk in self.make_stream_request(method, path, data, headers):
|
for chunk in self.make_stream_request(method, path, data, headers):
|
||||||
|
if chunk['choices']:
|
||||||
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
||||||
choice = chunk['choices'][0]
|
choice = chunk['choices'][0]
|
||||||
if choice['delta'].get('content') is not None:
|
if choice['delta'].get('content') is not None:
|
||||||
|
@ -357,7 +358,13 @@ class ServerProcess:
|
||||||
tool_call['function']['arguments'] += fct['arguments']
|
tool_call['function']['arguments'] += fct['arguments']
|
||||||
arguments_parts += 1
|
arguments_parts += 1
|
||||||
tool_call_parts += 1
|
tool_call_parts += 1
|
||||||
|
else:
|
||||||
|
# When `include_usage` is True (the default), we expect the last chunk of the stream
|
||||||
|
# immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
|
||||||
|
# and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
|
||||||
|
# the last chunk)
|
||||||
|
assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
|
||||||
|
assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
|
||||||
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||||
result = dict(
|
result = dict(
|
||||||
choices=[
|
choices=[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue