mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 09:09:42 +00:00
Merge branch 'main' into temperature_top_p_from_request
This commit is contained in:
commit
26f7b4af11
54 changed files with 1573 additions and 159 deletions
|
@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config
|
|||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
router = APIRouter(prefix='/api')
|
||||
|
||||
router = APIRouter(prefix='/api')
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
class OllamaGenerateCompletionRequest(BaseModel):
|
||||
|
@ -40,61 +40,121 @@ class OllamaGenerateCompletionRequest(BaseModel):
|
|||
keep_alive: Optional[str] = Field(
|
||||
"5m", description="Controls how long the model will stay loaded into memory following the request.")
|
||||
|
||||
|
||||
class OllamaGenerationStreamResponse(BaseModel):
|
||||
model: str
|
||||
created_at: str
|
||||
response: str
|
||||
done: bool = Field(...)
|
||||
|
||||
|
||||
class OllamaGenerationResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/generate", tags=['ollama'])
|
||||
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
||||
id = str(uuid4())
|
||||
|
||||
interface: BackendInterfaceBase = get_interface()
|
||||
print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
|
||||
|
||||
config = Config()
|
||||
|
||||
if input.stream:
|
||||
async def inner():
|
||||
async for token in interface.inference(input.prompt,id):
|
||||
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False)
|
||||
yield d.model_dump_json()+'\n'
|
||||
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
|
||||
# yield f"{json.dumps(d)}\n"
|
||||
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
|
||||
# yield f"{json.dumps(d)}\n"
|
||||
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True)
|
||||
yield d.model_dump_json()+'\n'
|
||||
return check_link_response(request,inner())
|
||||
async for token in interface.inference(input.prompt, id):
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
response=token,
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
response='',
|
||||
done=True
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
return check_link_response(request, inner())
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
|
||||
|
||||
class OllamaChatCompletionMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class OllamaChatCompletionRequest(BaseModel):
|
||||
pass
|
||||
|
||||
model: str = Field(..., description="The model name, which is required.")
|
||||
messages: List[OllamaChatCompletionMessage] = Field(
|
||||
..., description="A list of messages to generate a response for.")
|
||||
stream: bool = Field(True, description="If true, the response will be streamed.")
|
||||
|
||||
class OllamaChatCompletionStreamResponse(BaseModel):
|
||||
pass
|
||||
model: str
|
||||
created_at: str
|
||||
message: dict
|
||||
done: bool = Field(...)
|
||||
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
|
||||
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
|
||||
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
|
||||
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
|
||||
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
|
||||
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
|
||||
|
||||
|
||||
|
||||
class OllamaChatCompletionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/chat", tags=['ollama'])
|
||||
async def chat(request: Request, input: OllamaChatCompletionRequest):
|
||||
raise NotImplementedError
|
||||
id = str(uuid4())
|
||||
interface: BackendInterfaceBase = get_interface()
|
||||
config = Config()
|
||||
|
||||
# 将消息转换为提示字符串
|
||||
prompt = ""
|
||||
for msg in input.messages:
|
||||
prompt += f"{msg.role}: {msg.content}\n"
|
||||
prompt += "assistant:"
|
||||
|
||||
if input.stream:
|
||||
async def inner():
|
||||
start_time = time() # 记录开始时间(秒)
|
||||
eval_count = 0 # 统计生成的 token 数量
|
||||
tokens = []
|
||||
|
||||
async for token in interface.inference(prompt, id):
|
||||
d = OllamaChatCompletionStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
message={"role": "assistant", "content": token},
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
# 计算性能数据
|
||||
end_time = time()
|
||||
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒
|
||||
prompt_eval_count = len(prompt.split()) # 简单估算提示词数量
|
||||
eval_duration = total_duration # 假设全部时间用于生成(简化)
|
||||
prompt_eval_duration = 0 # 假设无单独提示评估时间
|
||||
load_duration = 0 # 假设加载时间未知
|
||||
|
||||
d = OllamaChatCompletionStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
message={},
|
||||
done=True,
|
||||
total_duration=total_duration,
|
||||
load_duration=load_duration,
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
prompt_eval_duration=prompt_eval_duration,
|
||||
eval_count=eval_count,
|
||||
eval_duration=eval_duration
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
return check_link_response(request, inner())
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming chat is not implemented.")
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
|
||||
class OllamaModel(BaseModel):
|
||||
|
@ -103,9 +163,8 @@ class OllamaModel(BaseModel):
|
|||
size: int
|
||||
# TODO: fill the rest correctly
|
||||
|
||||
|
||||
# mock ollama
|
||||
@router.get("/tags",tags=['ollama'])
|
||||
@router.get("/tags", tags=['ollama'])
|
||||
async def tags():
|
||||
config = Config()
|
||||
# TODO: fill this correctly, although it does not effect Tabby
|
||||
|
@ -138,25 +197,21 @@ class OllamaShowResponse(BaseModel):
|
|||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
|
||||
@router.post("/show", tags=['ollama'])
|
||||
async def show(request: Request, input: OllamaShowRequest):
|
||||
config = Config()
|
||||
# TODO: Add more info in config to return, although it does not effect Tabby
|
||||
return OllamaShowResponse(
|
||||
modelfile = "# Modelfile generated by ...",
|
||||
parameters = " ",
|
||||
template = " ",
|
||||
details = OllamaShowDetial(
|
||||
parent_model = " ",
|
||||
format = "gguf",
|
||||
family = " ",
|
||||
families = [
|
||||
" "
|
||||
],
|
||||
parameter_size = " ",
|
||||
quantization_level = " "
|
||||
modelfile="# Modelfile generated by ...",
|
||||
parameters=" ",
|
||||
template=" ",
|
||||
details=OllamaShowDetial(
|
||||
parent_model=" ",
|
||||
format="gguf",
|
||||
family=" ",
|
||||
families=[" "],
|
||||
parameter_size=" ",
|
||||
quantization_level=" "
|
||||
),
|
||||
model_info = OllamaModelInfo()
|
||||
model_info=OllamaModelInfo()
|
||||
)
|
|
@ -25,6 +25,9 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
|||
|
||||
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
|
||||
|
||||
if Config().api_key != '':
|
||||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||
|
||||
if create.stream:
|
||||
async def inner():
|
||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue