Merge pull request #978 from cyhasuka/main
Some checks failed
Deploy / deploy (ubuntu-latest) (push) Failing after 3s
Book-CI / test (push) Has been cancelled
Deploy / deploy (macos-latest) (push) Has been cancelled
Deploy / deploy (windows-latest) (push) Has been cancelled

Feat: Support Non-streaming chat in Ollama backend
This commit is contained in:
wang jiahao 2025-04-17 14:34:35 +08:00 committed by GitHub
commit 6e4da83d4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -49,7 +49,10 @@ class OllamaGenerationStreamResponse(BaseModel):
done: bool = Field(...)
class OllamaGenerationResponse(BaseModel):
pass
model: str
created_at: str
response: str
done: bool
@router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
@ -81,8 +84,21 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError
complete_response = ""
async for res in interface.inference(input.prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
complete_response += token
response = OllamaGenerationResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=complete_response,
done=True
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
role: str
@ -106,10 +122,17 @@ class OllamaChatCompletionStreamResponse(BaseModel):
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel):
pass
model: str
created_at: str
message: dict
done: bool
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
@router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest):
@ -164,8 +187,39 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
start_time = time()
complete_response = ""
eval_count = 0
async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
complete_response += token
eval_count += 1
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000)
prompt_eval_count = len(prompt.split())
eval_duration = total_duration
prompt_eval_duration = 0
load_duration = 0
response = OllamaChatCompletionResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": complete_response},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel):
name: str
@ -224,4 +278,4 @@ async def show(request: Request, input: OllamaShowRequest):
quantization_level=" "
),
model_info=OllamaModelInfo()
)
)