server: fix OpenAI API compatibility for usage statistics in chat streams (#15444)

This commit is contained in:
teo 2025-08-21 07:10:08 +09:00 committed by GitHub
parent 13aeb7aef2
commit 1bc664a26a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 105 additions and 82 deletions

View file

@ -911,6 +911,17 @@ struct server_task_result_cmpl_final : server_task_result {
{"model", oaicompat_model}, {"model", oaicompat_model},
{"system_fingerprint", build_info}, {"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}, {"object", "chat.completion.chunk"},
});
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
deltas.push_back({
{"choices", json::array()},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
{"usage", json { {"usage", json {
{"completion_tokens", n_decoded}, {"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens}, {"prompt_tokens", n_prompt_tokens},

View file

@ -72,6 +72,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
content = "" content = ""
last_cmpl_id = None last_cmpl_id = None
for i, data in enumerate(res): for i, data in enumerate(res):
if data["choices"]:
choice = data["choices"][0] choice = data["choices"][0]
if i == 0: if i == 0:
# Check first role message for stream=True # Check first role message for stream=True
@ -85,14 +86,15 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
last_cmpl_id = data["id"] last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]: if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"] assert "content" not in choice["delta"]
assert match_regex(re_content, content) assert match_regex(re_content, content)
assert choice["finish_reason"] == finish_reason assert choice["finish_reason"] == finish_reason
else: else:
assert choice["finish_reason"] is None assert choice["finish_reason"] is None
content += choice["delta"]["content"] or '' content += choice["delta"]["content"] or ''
else:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
def test_chat_completion_with_openai_library(): def test_chat_completion_with_openai_library():
@ -278,7 +280,9 @@ def test_chat_completion_with_timings_per_token():
assert data["choices"][0]["delta"]["role"] == "assistant" assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}' assert "timings" not in data, f'First event should not have timings: {data}'
else: else:
if data["choices"]:
assert "role" not in data["choices"][0]["delta"] assert "role" not in data["choices"][0]["delta"]
else:
assert "timings" in data assert "timings" in data
assert "prompt_per_second" in data["timings"] assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"] assert "predicted_per_second" in data["timings"]
@ -332,6 +336,7 @@ def test_logprobs_stream():
output_text = '' output_text = ''
aggregated_text = '' aggregated_text = ''
for i, data in enumerate(res): for i, data in enumerate(res):
if data.choices:
choice = data.choices[0] choice = data.choices[0]
if i == 0: if i == 0:
# Check first role message for stream=True # Check first role message for stream=True

View file

@ -318,6 +318,7 @@ class ServerProcess:
arguments_parts = 0 arguments_parts = 0
for chunk in self.make_stream_request(method, path, data, headers): for chunk in self.make_stream_request(method, path, data, headers):
if chunk['choices']:
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
choice = chunk['choices'][0] choice = chunk['choices'][0]
if choice['delta'].get('content') is not None: if choice['delta'].get('content') is not None:
@ -357,7 +358,13 @@ class ServerProcess:
tool_call['function']['arguments'] += fct['arguments'] tool_call['function']['arguments'] += fct['arguments']
arguments_parts += 1 arguments_parts += 1
tool_call_parts += 1 tool_call_parts += 1
else:
# When `include_usage` is True (the default), we expect the last chunk of the stream
# immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
# and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
# the last chunk)
assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
result = dict( result = dict(
choices=[ choices=[