server: fix OpenAI API compatibility for usage statistics in chat streams (#15444)

This commit is contained in:
teo 2025-08-21 07:10:08 +09:00 committed by GitHub
parent 13aeb7aef2
commit 1bc664a26a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 105 additions and 82 deletions

View file

@ -911,6 +911,17 @@ struct server_task_result_cmpl_final : server_task_result {
{"model", oaicompat_model}, {"model", oaicompat_model},
{"system_fingerprint", build_info}, {"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}, {"object", "chat.completion.chunk"},
});
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
deltas.push_back({
{"choices", json::array()},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
{"usage", json { {"usage", json {
{"completion_tokens", n_decoded}, {"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens}, {"prompt_tokens", n_prompt_tokens},

View file

@ -72,27 +72,29 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
content = "" content = ""
last_cmpl_id = None last_cmpl_id = None
for i, data in enumerate(res): for i, data in enumerate(res):
choice = data["choices"][0] if data["choices"]:
if i == 0: choice = data["choices"][0]
# Check first role message for stream=True if i == 0:
assert choice["delta"]["content"] is None # Check first role message for stream=True
assert choice["delta"]["role"] == "assistant" assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]:
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
else: else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
def test_chat_completion_with_openai_library(): def test_chat_completion_with_openai_library():
@ -278,12 +280,14 @@ def test_chat_completion_with_timings_per_token():
assert data["choices"][0]["delta"]["role"] == "assistant" assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}' assert "timings" not in data, f'First event should not have timings: {data}'
else: else:
assert "role" not in data["choices"][0]["delta"] if data["choices"]:
assert "timings" in data assert "role" not in data["choices"][0]["delta"]
assert "prompt_per_second" in data["timings"] else:
assert "predicted_per_second" in data["timings"] assert "timings" in data
assert "predicted_n" in data["timings"] assert "prompt_per_second" in data["timings"]
assert data["timings"]["predicted_n"] <= 10 assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
def test_logprobs(): def test_logprobs():
@ -332,24 +336,25 @@ def test_logprobs_stream():
output_text = '' output_text = ''
aggregated_text = '' aggregated_text = ''
for i, data in enumerate(res): for i, data in enumerate(res):
choice = data.choices[0] if data.choices:
if i == 0: choice = data.choices[0]
# Check first role message for stream=True if i == 0:
assert choice.delta.content is None # Check first role message for stream=True
assert choice.delta.role == "assistant" assert choice.delta.content is None
else: assert choice.delta.role == "assistant"
assert choice.delta.role is None else:
if choice.finish_reason is None: assert choice.delta.role is None
if choice.delta.content: if choice.finish_reason is None:
output_text += choice.delta.content if choice.delta.content:
assert choice.logprobs is not None output_text += choice.delta.content
assert choice.logprobs.content is not None assert choice.logprobs is not None
for token in choice.logprobs.content: assert choice.logprobs.content is not None
aggregated_text += token.token for token in choice.logprobs.content:
assert token.logprob <= 0.0 aggregated_text += token.token
assert token.bytes is not None assert token.logprob <= 0.0
assert token.top_logprobs is not None assert token.bytes is not None
assert len(token.top_logprobs) > 0 assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text assert aggregated_text == output_text

View file

@ -318,46 +318,53 @@ class ServerProcess:
arguments_parts = 0 arguments_parts = 0
for chunk in self.make_stream_request(method, path, data, headers): for chunk in self.make_stream_request(method, path, data, headers):
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' if chunk['choices']:
choice = chunk['choices'][0] assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
if choice['delta'].get('content') is not None: choice = chunk['choices'][0]
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' if choice['delta'].get('content') is not None:
content.append(choice['delta']['content']) assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
content_parts += 1 content.append(choice['delta']['content'])
if choice['delta'].get('reasoning_content') is not None: content_parts += 1
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' if choice['delta'].get('reasoning_content') is not None:
reasoning_content.append(choice['delta']['reasoning_content']) assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
reasoning_content_parts += 1 reasoning_content.append(choice['delta']['reasoning_content'])
if choice['delta'].get('finish_reason') is not None: reasoning_content_parts += 1
finish_reason = choice['delta']['finish_reason'] if choice['delta'].get('finish_reason') is not None:
for tc in choice['delta'].get('tool_calls', []): finish_reason = choice['delta']['finish_reason']
if 'function' not in tc: for tc in choice['delta'].get('tool_calls', []):
raise ValueError(f"Expected function type, got {tc['type']}") if 'function' not in tc:
if tc['index'] >= len(tool_calls): raise ValueError(f"Expected function type, got {tc['type']}")
assert 'id' in tc if tc['index'] >= len(tool_calls):
assert tc.get('type') == 'function' assert 'id' in tc
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ assert tc.get('type') == 'function'
f"Expected function call with name, got {tc.get('function')}" assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
tool_calls.append(dict( f"Expected function call with name, got {tc.get('function')}"
id="", tool_calls.append(dict(
type="function", id="",
function=dict( type="function",
name="", function=dict(
arguments="", name="",
) arguments="",
)) )
tool_call = tool_calls[tc['index']] ))
if tc.get('id') is not None: tool_call = tool_calls[tc['index']]
tool_call['id'] = tc['id'] if tc.get('id') is not None:
fct = tc['function'] tool_call['id'] = tc['id']
assert 'id' not in fct, f"Function call should not have id: {fct}" fct = tc['function']
if fct.get('name') is not None: assert 'id' not in fct, f"Function call should not have id: {fct}"
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] if fct.get('name') is not None:
if fct.get('arguments') is not None: tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
tool_call['function']['arguments'] += fct['arguments'] if fct.get('arguments') is not None:
arguments_parts += 1 tool_call['function']['arguments'] += fct['arguments']
tool_call_parts += 1 arguments_parts += 1
tool_call_parts += 1
else:
# When `include_usage` is True (the default), we expect the last chunk of the stream
# immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
# and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
# the last chunk)
assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
result = dict( result = dict(
choices=[ choices=[