Update completions.py

This commit is contained in:
Yuhao Tsui 2025-03-06 11:16:33 +08:00 committed by GitHub
parent 1bcfce8cad
commit d050d8655f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -47,7 +47,10 @@ class OllamaGenerationStreamResponse(BaseModel):
done: bool = Field(...)
class OllamaGenerationResponse(BaseModel):
pass
model: str
created_at: str
response: str
done: bool
@router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
@ -75,8 +78,17 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError
complete_response = ""
async for token in interface.inference(input.prompt, id):
complete_response += token
response = OllamaGenerationResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=complete_response,
done=True
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
role: str
@ -100,10 +112,17 @@ class OllamaChatCompletionStreamResponse(BaseModel):
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel):
pass
model: str
created_at: str
message: dict
done: bool
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
@router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest):
@ -154,8 +173,35 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
start_time = time()
complete_response = ""
eval_count = 0
async for token in interface.inference(prompt, id):
complete_response += token
eval_count += 1
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000)
prompt_eval_count = len(prompt.split())
eval_duration = total_duration
prompt_eval_duration = 0
load_duration = 0
response = OllamaChatCompletionResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": complete_response},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
return response
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel):
name: str
@ -214,4 +260,4 @@ async def show(request: Request, input: OllamaShowRequest):
quantization_level=" "
),
model_info=OllamaModelInfo()
)
)