mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 09:09:42 +00:00
Update completions.py
This commit is contained in:
parent
1bcfce8cad
commit
d050d8655f
1 changed files with 54 additions and 8 deletions
|
@ -47,7 +47,10 @@ class OllamaGenerationStreamResponse(BaseModel):
|
||||||
done: bool = Field(...)
|
done: bool = Field(...)
|
||||||
|
|
||||||
class OllamaGenerationResponse(BaseModel):
|
class OllamaGenerationResponse(BaseModel):
|
||||||
pass
|
model: str
|
||||||
|
created_at: str
|
||||||
|
response: str
|
||||||
|
done: bool
|
||||||
|
|
||||||
@router.post("/generate", tags=['ollama'])
|
@router.post("/generate", tags=['ollama'])
|
||||||
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
||||||
|
@ -75,8 +78,17 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
||||||
yield d.model_dump_json() + '\n'
|
yield d.model_dump_json() + '\n'
|
||||||
return check_link_response(request, inner())
|
return check_link_response(request, inner())
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
complete_response = ""
|
||||||
|
async for token in interface.inference(input.prompt, id):
|
||||||
|
complete_response += token
|
||||||
|
response = OllamaGenerationResponse(
|
||||||
|
model=config.model_name,
|
||||||
|
created_at=str(datetime.now()),
|
||||||
|
response=complete_response,
|
||||||
|
done=True
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
|
||||||
class OllamaChatCompletionMessage(BaseModel):
|
class OllamaChatCompletionMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
|
@ -100,10 +112,17 @@ class OllamaChatCompletionStreamResponse(BaseModel):
|
||||||
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
|
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
|
||||||
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
|
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaChatCompletionResponse(BaseModel):
|
class OllamaChatCompletionResponse(BaseModel):
|
||||||
pass
|
model: str
|
||||||
|
created_at: str
|
||||||
|
message: dict
|
||||||
|
done: bool
|
||||||
|
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
|
||||||
|
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
|
||||||
|
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
|
||||||
|
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
|
||||||
|
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
|
||||||
|
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
|
||||||
|
|
||||||
@router.post("/chat", tags=['ollama'])
|
@router.post("/chat", tags=['ollama'])
|
||||||
async def chat(request: Request, input: OllamaChatCompletionRequest):
|
async def chat(request: Request, input: OllamaChatCompletionRequest):
|
||||||
|
@ -154,8 +173,35 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
|
||||||
yield d.model_dump_json() + '\n'
|
yield d.model_dump_json() + '\n'
|
||||||
return check_link_response(request, inner())
|
return check_link_response(request, inner())
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Non-streaming chat is not implemented.")
|
start_time = time()
|
||||||
|
complete_response = ""
|
||||||
|
eval_count = 0
|
||||||
|
|
||||||
|
async for token in interface.inference(prompt, id):
|
||||||
|
complete_response += token
|
||||||
|
eval_count += 1
|
||||||
|
|
||||||
|
end_time = time()
|
||||||
|
total_duration = int((end_time - start_time) * 1_000_000_000)
|
||||||
|
prompt_eval_count = len(prompt.split())
|
||||||
|
eval_duration = total_duration
|
||||||
|
prompt_eval_duration = 0
|
||||||
|
load_duration = 0
|
||||||
|
|
||||||
|
response = OllamaChatCompletionResponse(
|
||||||
|
model=config.model_name,
|
||||||
|
created_at=str(datetime.now()),
|
||||||
|
message={"role": "assistant", "content": complete_response},
|
||||||
|
done=True,
|
||||||
|
total_duration=total_duration,
|
||||||
|
load_duration=load_duration,
|
||||||
|
prompt_eval_count=prompt_eval_count,
|
||||||
|
prompt_eval_duration=prompt_eval_duration,
|
||||||
|
eval_count=eval_count,
|
||||||
|
eval_duration=eval_duration
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
||||||
class OllamaModel(BaseModel):
|
class OllamaModel(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@ -214,4 +260,4 @@ async def show(request: Request, input: OllamaShowRequest):
|
||||||
quantization_level=" "
|
quantization_level=" "
|
||||||
),
|
),
|
||||||
model_info=OllamaModelInfo()
|
model_info=OllamaModelInfo()
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue