From 1bc664a26a1d93a48baf2483a7ff95291ca8bcb8 Mon Sep 17 00:00:00 2001 From: teo Date: Thu, 21 Aug 2025 07:10:08 +0900 Subject: [PATCH] server: fix OpenAI API compatibility for usage statistics in chat streams (#15444) --- tools/server/server.cpp | 11 +++ .../server/tests/unit/test_chat_completion.py | 89 ++++++++++--------- tools/server/tests/utils.py | 87 +++++++++--------- 3 files changed, 105 insertions(+), 82 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index ab88f3d26..35b060674 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -911,6 +911,17 @@ struct server_task_result_cmpl_final : server_task_result { {"model", oaicompat_model}, {"system_fingerprint", build_info}, {"object", "chat.completion.chunk"}, + }); + + // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage + // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, {"usage", json { {"completion_tokens", n_decoded}, {"prompt_tokens", n_prompt_tokens}, diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 6c6f64f5e..509c024b7 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -72,27 +72,29 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte content = "" last_cmpl_id = None for i, data in enumerate(res): - choice = data["choices"][0] - if i == 0: - # Check first role message for stream=True - assert choice["delta"]["content"] is None - assert choice["delta"]["role"] == "assistant" + if data["choices"]: + choice = data["choices"][0] + if i == 0: + # Check first role message for stream=True + assert choice["delta"]["content"] is None + assert choice["delta"]["role"] == "assistant" + else: + assert "role" not in choice["delta"] + assert data["system_fingerprint"].startswith("b") + assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + if last_cmpl_id is None: + last_cmpl_id = data["id"] + assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream + if choice["finish_reason"] in ["stop", "length"]: + assert "content" not in choice["delta"] + assert match_regex(re_content, content) + assert choice["finish_reason"] == finish_reason + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] or '' else: - assert "role" not in choice["delta"] - assert data["system_fingerprint"].startswith("b") - assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future - if last_cmpl_id is None: - last_cmpl_id = data["id"] - assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream - if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted - assert "content" not in choice["delta"] - assert match_regex(re_content, content) - assert choice["finish_reason"] == finish_reason - else: - assert choice["finish_reason"] is None - content += choice["delta"]["content"] or '' def test_chat_completion_with_openai_library(): @@ -278,12 +280,14 @@ def test_chat_completion_with_timings_per_token(): assert data["choices"][0]["delta"]["role"] == "assistant" assert "timings" not in data, f'First event should not have timings: {data}' else: - assert "role" not in data["choices"][0]["delta"] - assert "timings" in data - assert "prompt_per_second" in data["timings"] - assert "predicted_per_second" in data["timings"] - assert "predicted_n" in data["timings"] - assert data["timings"]["predicted_n"] <= 10 + if data["choices"]: + assert "role" not in data["choices"][0]["delta"] + else: + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10 def test_logprobs(): @@ -332,24 +336,25 @@ def test_logprobs_stream(): output_text = '' aggregated_text = '' for i, data in enumerate(res): - choice = data.choices[0] - if i == 0: - # Check first role message for stream=True - assert choice.delta.content is None - assert choice.delta.role == "assistant" - else: - assert choice.delta.role is None - if choice.finish_reason is None: - if choice.delta.content: - output_text += choice.delta.content - assert choice.logprobs is not None - assert choice.logprobs.content is not None - for token in choice.logprobs.content: - aggregated_text += token.token - assert token.logprob <= 0.0 - assert token.bytes is not None - assert token.top_logprobs is not None - assert len(token.top_logprobs) > 0 + if data.choices: + choice = data.choices[0] + if i == 0: + # Check first role message for stream=True + assert choice.delta.content is None + assert choice.delta.role == "assistant" + else: + assert choice.delta.role is None + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 assert aggregated_text == output_text diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index 49277e600..5f42bcae6 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -318,46 +318,53 @@ class ServerProcess: arguments_parts = 0 for chunk in self.make_stream_request(method, path, data, headers): - assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' - choice = chunk['choices'][0] - if choice['delta'].get('content') is not None: - assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' - content.append(choice['delta']['content']) - content_parts += 1 - if choice['delta'].get('reasoning_content') is not None: - assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' - reasoning_content.append(choice['delta']['reasoning_content']) - reasoning_content_parts += 1 - if choice['delta'].get('finish_reason') is not None: - finish_reason = choice['delta']['finish_reason'] - for tc in choice['delta'].get('tool_calls', []): - if 'function' not in tc: - raise ValueError(f"Expected function type, got {tc['type']}") - if tc['index'] >= len(tool_calls): - assert 'id' in tc - assert tc.get('type') == 'function' - assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ - f"Expected function call with name, got {tc.get('function')}" - tool_calls.append(dict( - id="", - type="function", - function=dict( - name="", - arguments="", - ) - )) - tool_call = tool_calls[tc['index']] - if tc.get('id') is not None: - tool_call['id'] = tc['id'] - fct = tc['function'] - assert 'id' not in fct, f"Function call should not have id: {fct}" - if fct.get('name') is not None: - tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] - if fct.get('arguments') is not None: - tool_call['function']['arguments'] += fct['arguments'] - arguments_parts += 1 - tool_call_parts += 1 - + if chunk['choices']: + assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' + choice = chunk['choices'][0] + if choice['delta'].get('content') is not None: + assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' + content.append(choice['delta']['content']) + content_parts += 1 + if choice['delta'].get('reasoning_content') is not None: + assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' + reasoning_content.append(choice['delta']['reasoning_content']) + reasoning_content_parts += 1 + if choice['delta'].get('finish_reason') is not None: + finish_reason = choice['delta']['finish_reason'] + for tc in choice['delta'].get('tool_calls', []): + if 'function' not in tc: + raise ValueError(f"Expected function type, got {tc['type']}") + if tc['index'] >= len(tool_calls): + assert 'id' in tc + assert tc.get('type') == 'function' + assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ + f"Expected function call with name, got {tc.get('function')}" + tool_calls.append(dict( + id="", + type="function", + function=dict( + name="", + arguments="", + ) + )) + tool_call = tool_calls[tc['index']] + if tc.get('id') is not None: + tool_call['id'] = tc['id'] + fct = tc['function'] + assert 'id' not in fct, f"Function call should not have id: {fct}" + if fct.get('name') is not None: + tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] + if fct.get('arguments') is not None: + tool_call['function']['arguments'] += fct['arguments'] + arguments_parts += 1 + tool_call_parts += 1 + else: + # When `include_usage` is True (the default), we expect the last chunk of the stream + # immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array + # and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in + # the last chunk) + assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}" + assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}" print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') result = dict( choices=[