implementation of chat routing for Ollama

This commit is contained in:
swu-hyk 2025-02-26 17:05:00 +08:00
parent 9660b2cc1e
commit 68e7df3a25

View file

@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase
router = APIRouter(prefix='/api')
router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class OllamaGenerateCompletionRequest(BaseModel):
@ -40,61 +40,95 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive: Optional[str] = Field(
"5m", description="Controls how long the model will stay loaded into memory following the request.")
class OllamaGenerationStreamResponse(BaseModel):
model: str
created_at: str
response: str
done: bool = Field(...)
class OllamaGenerationResponse(BaseModel):
pass
@router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
config = Config()
if input.stream:
async def inner():
async for token in interface.inference(input.prompt,id):
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False)
yield d.model_dump_json()+'\n'
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
# yield f"{json.dumps(d)}\n"
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
# yield f"{json.dumps(d)}\n"
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True)
yield d.model_dump_json()+'\n'
return check_link_response(request,inner())
async for token in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response='',
done=True
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
role: str
content: str
class OllamaChatCompletionRequest(BaseModel):
pass
model: str = Field(..., description="The model name, which is required.")
messages: List[OllamaChatCompletionMessage] = Field(
..., description="A list of messages to generate a response for.")
stream: bool = Field(True, description="If true, the response will be streamed.")
class OllamaChatCompletionStreamResponse(BaseModel):
pass
model: str
created_at: str
message: str
done: bool = Field(...)
class OllamaChatCompletionResponse(BaseModel):
pass
@router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest):
raise NotImplementedError
id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
config = Config()
# 将消息转换为提示字符串
prompt = ""
for msg in input.messages:
prompt += f"{msg.role}: {msg.content}\n"
prompt += "assistant:"
if input.stream:
async def inner():
async for token in interface.inference(prompt, id):
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message=token,
done=False
)
yield d.model_dump_json() + '\n'
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message='',
done=True
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel):
@ -103,9 +137,8 @@ class OllamaModel(BaseModel):
size: int
# TODO: fill the rest correctly
# mock ollama
@router.get("/tags",tags=['ollama'])
@router.get("/tags", tags=['ollama'])
async def tags():
config = Config()
# TODO: fill this correctly, although it does not effect Tabby
@ -138,25 +171,21 @@ class OllamaShowResponse(BaseModel):
class Config:
protected_namespaces = ()
@router.post("/show", tags=['ollama'])
async def show(request: Request, input: OllamaShowRequest):
config = Config()
# TODO: Add more info in config to return, although it does not effect Tabby
return OllamaShowResponse(
modelfile = "# Modelfile generated by ...",
parameters = " ",
template = " ",
details = OllamaShowDetial(
parent_model = " ",
format = "gguf",
family = " ",
families = [
" "
],
parameter_size = " ",
quantization_level = " "
modelfile="# Modelfile generated by ...",
parameters=" ",
template=" ",
details=OllamaShowDetial(
parent_model=" ",
format="gguf",
family=" ",
families=[" "],
parameter_size=" ",
quantization_level=" "
),
model_info = OllamaModelInfo()
model_info=OllamaModelInfo()
)