mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
implementation of chat routing for Ollama
This commit is contained in:
parent
9660b2cc1e
commit
68e7df3a25
1 changed files with 69 additions and 40 deletions
|
@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.utils.create_interface import get_interface
|
from ktransformers.server.utils.create_interface import get_interface
|
||||||
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
||||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
router = APIRouter(prefix='/api')
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix='/api')
|
||||||
|
|
||||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||||
class OllamaGenerateCompletionRequest(BaseModel):
|
class OllamaGenerateCompletionRequest(BaseModel):
|
||||||
|
@ -40,61 +40,95 @@ class OllamaGenerateCompletionRequest(BaseModel):
|
||||||
keep_alive: Optional[str] = Field(
|
keep_alive: Optional[str] = Field(
|
||||||
"5m", description="Controls how long the model will stay loaded into memory following the request.")
|
"5m", description="Controls how long the model will stay loaded into memory following the request.")
|
||||||
|
|
||||||
|
|
||||||
class OllamaGenerationStreamResponse(BaseModel):
|
class OllamaGenerationStreamResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
created_at: str
|
created_at: str
|
||||||
response: str
|
response: str
|
||||||
done: bool = Field(...)
|
done: bool = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class OllamaGenerationResponse(BaseModel):
|
class OllamaGenerationResponse(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@router.post("/generate", tags=['ollama'])
|
@router.post("/generate", tags=['ollama'])
|
||||||
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
||||||
id = str(uuid4())
|
id = str(uuid4())
|
||||||
|
|
||||||
interface: BackendInterfaceBase = get_interface()
|
interface: BackendInterfaceBase = get_interface()
|
||||||
print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
|
print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|
||||||
if input.stream:
|
if input.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
async for token in interface.inference(input.prompt,id):
|
async for token in interface.inference(input.prompt, id):
|
||||||
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False)
|
d = OllamaGenerationStreamResponse(
|
||||||
yield d.model_dump_json()+'\n'
|
model=config.model_name,
|
||||||
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
|
created_at=str(datetime.now()),
|
||||||
# yield f"{json.dumps(d)}\n"
|
response=token,
|
||||||
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
|
done=False
|
||||||
# yield f"{json.dumps(d)}\n"
|
)
|
||||||
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True)
|
yield d.model_dump_json() + '\n'
|
||||||
yield d.model_dump_json()+'\n'
|
d = OllamaGenerationStreamResponse(
|
||||||
return check_link_response(request,inner())
|
model=config.model_name,
|
||||||
|
created_at=str(datetime.now()),
|
||||||
|
response='',
|
||||||
|
done=True
|
||||||
|
)
|
||||||
|
yield d.model_dump_json() + '\n'
|
||||||
|
return check_link_response(request, inner())
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
|
||||||
|
class OllamaChatCompletionMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
class OllamaChatCompletionRequest(BaseModel):
|
class OllamaChatCompletionRequest(BaseModel):
|
||||||
pass
|
model: str = Field(..., description="The model name, which is required.")
|
||||||
|
messages: List[OllamaChatCompletionMessage] = Field(
|
||||||
|
..., description="A list of messages to generate a response for.")
|
||||||
|
stream: bool = Field(True, description="If true, the response will be streamed.")
|
||||||
|
|
||||||
class OllamaChatCompletionStreamResponse(BaseModel):
|
class OllamaChatCompletionStreamResponse(BaseModel):
|
||||||
pass
|
model: str
|
||||||
|
created_at: str
|
||||||
|
message: str
|
||||||
|
done: bool = Field(...)
|
||||||
|
|
||||||
class OllamaChatCompletionResponse(BaseModel):
|
class OllamaChatCompletionResponse(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat", tags=['ollama'])
|
@router.post("/chat", tags=['ollama'])
|
||||||
async def chat(request: Request, input: OllamaChatCompletionRequest):
|
async def chat(request: Request, input: OllamaChatCompletionRequest):
|
||||||
raise NotImplementedError
|
id = str(uuid4())
|
||||||
|
interface: BackendInterfaceBase = get_interface()
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
# 将消息转换为提示字符串
|
||||||
|
prompt = ""
|
||||||
|
for msg in input.messages:
|
||||||
|
prompt += f"{msg.role}: {msg.content}\n"
|
||||||
|
prompt += "assistant:"
|
||||||
|
|
||||||
|
if input.stream:
|
||||||
|
async def inner():
|
||||||
|
async for token in interface.inference(prompt, id):
|
||||||
|
d = OllamaChatCompletionStreamResponse(
|
||||||
|
model=config.model_name,
|
||||||
|
created_at=str(datetime.now()),
|
||||||
|
message=token,
|
||||||
|
done=False
|
||||||
|
)
|
||||||
|
yield d.model_dump_json() + '\n'
|
||||||
|
d = OllamaChatCompletionStreamResponse(
|
||||||
|
model=config.model_name,
|
||||||
|
created_at=str(datetime.now()),
|
||||||
|
message='',
|
||||||
|
done=True
|
||||||
|
)
|
||||||
|
yield d.model_dump_json() + '\n'
|
||||||
|
return check_link_response(request, inner())
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming chat is not implemented.")
|
||||||
|
|
||||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
||||||
class OllamaModel(BaseModel):
|
class OllamaModel(BaseModel):
|
||||||
|
@ -103,9 +137,8 @@ class OllamaModel(BaseModel):
|
||||||
size: int
|
size: int
|
||||||
# TODO: fill the rest correctly
|
# TODO: fill the rest correctly
|
||||||
|
|
||||||
|
|
||||||
# mock ollama
|
# mock ollama
|
||||||
@router.get("/tags",tags=['ollama'])
|
@router.get("/tags", tags=['ollama'])
|
||||||
async def tags():
|
async def tags():
|
||||||
config = Config()
|
config = Config()
|
||||||
# TODO: fill this correctly, although it does not effect Tabby
|
# TODO: fill this correctly, although it does not effect Tabby
|
||||||
|
@ -138,25 +171,21 @@ class OllamaShowResponse(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/show", tags=['ollama'])
|
@router.post("/show", tags=['ollama'])
|
||||||
async def show(request: Request, input: OllamaShowRequest):
|
async def show(request: Request, input: OllamaShowRequest):
|
||||||
config = Config()
|
config = Config()
|
||||||
# TODO: Add more info in config to return, although it does not effect Tabby
|
# TODO: Add more info in config to return, although it does not effect Tabby
|
||||||
return OllamaShowResponse(
|
return OllamaShowResponse(
|
||||||
modelfile = "# Modelfile generated by ...",
|
modelfile="# Modelfile generated by ...",
|
||||||
parameters = " ",
|
parameters=" ",
|
||||||
template = " ",
|
template=" ",
|
||||||
details = OllamaShowDetial(
|
details=OllamaShowDetial(
|
||||||
parent_model = " ",
|
parent_model=" ",
|
||||||
format = "gguf",
|
format="gguf",
|
||||||
family = " ",
|
family=" ",
|
||||||
families = [
|
families=[" "],
|
||||||
" "
|
parameter_size=" ",
|
||||||
],
|
quantization_level=" "
|
||||||
parameter_size = " ",
|
|
||||||
quantization_level = " "
|
|
||||||
),
|
),
|
||||||
model_info = OllamaModelInfo()
|
model_info=OllamaModelInfo()
|
||||||
)
|
)
|
Loading…
Add table
Add a link
Reference in a new issue