mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Merge pull request #835 from BITcyman/fix-openai_chat_completion
[fix] support openai chat completion api
This commit is contained in:
commit
96d75d53df
8 changed files with 166 additions and 83 deletions
|
@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
|
|||
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
router = APIRouter(prefix='/api')
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||
|
@ -58,14 +60,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
|||
|
||||
if input.stream:
|
||||
async def inner():
|
||||
async for token in interface.inference(input.prompt, id):
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
response=token,
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
async for res in interface.inference(input.prompt, id):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
response=token,
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
d = OllamaGenerationStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
|
@ -123,14 +129,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
|
|||
eval_count = 0 # 统计生成的 token 数量
|
||||
tokens = []
|
||||
|
||||
async for token in interface.inference(prompt, id):
|
||||
d = OllamaChatCompletionStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
message={"role": "assistant", "content": token},
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
async for res in interface.inference(prompt, id):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
d = OllamaChatCompletionStreamResponse(
|
||||
model=config.model_name,
|
||||
created_at=str(datetime.now()),
|
||||
message={"role": "assistant", "content": token},
|
||||
done=False
|
||||
)
|
||||
yield d.model_dump_json() + '\n'
|
||||
# 计算性能数据
|
||||
end_time = time()
|
||||
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒
|
||||
|
|
|
@ -5,10 +5,16 @@ from fastapi import APIRouter
|
|||
from fastapi.requests import Request
|
||||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
from ktransformers.server.config.config import Config
|
||||
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get('/models', tags=['openai'])
|
||||
|
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
|||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||
|
||||
if create.stream:
|
||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
||||
|
||||
async def inner():
|
||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
||||
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||
chunk.set_token(token)
|
||||
yield chunk
|
||||
return chat_stream_response(request,inner())
|
||||
chunk = ChatCompletionChunk(
|
||||
id = id,
|
||||
choices = [],
|
||||
object = 'chat.completion.chunk',
|
||||
created = int(time()),
|
||||
model = Config().model_name,
|
||||
)
|
||||
|
||||
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
# at the end of inference, interface.inference() will return the usage of inference
|
||||
raw_usage = res
|
||||
chunk.choices = []
|
||||
chunk.usage = CompletionUsage(
|
||||
prompt_tokens = raw_usage.prefill_count,
|
||||
completion_tokens = raw_usage.decode_count,
|
||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
else:
|
||||
token, finish_reason = res
|
||||
choice = Choice(
|
||||
index = 0,
|
||||
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
|
||||
finish_reason = finish_reason,
|
||||
logprobs = None,
|
||||
)
|
||||
chunk.choices = [choice]
|
||||
yield chunk
|
||||
|
||||
return chat_stream_response(request, inner())
|
||||
else:
|
||||
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
||||
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
||||
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||
comp.append_token(token)
|
||||
return comp
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
content = ""
|
||||
finish_reason = None
|
||||
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
usage = CompletionUsage(
|
||||
prompt_tokens = raw_usage.prefill_count,
|
||||
completion_tokens = raw_usage.decode_count,
|
||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
||||
)
|
||||
else:
|
||||
token, finish_reason = res
|
||||
content = content + token
|
||||
finish_reason = finish_reason
|
||||
|
||||
choice = Choice(
|
||||
index = 0,
|
||||
finish_reason = finish_reason,
|
||||
message = ChatCompletionMessage(
|
||||
content=content,
|
||||
role="assistant"
|
||||
))
|
||||
|
||||
chat_completion = ChatCompletion(
|
||||
id = id,
|
||||
choices = [choice],
|
||||
created = int(time()),
|
||||
model = Config().model_name,
|
||||
object = 'chat.completion',
|
||||
usage = usage
|
||||
)
|
||||
|
||||
return chat_completion
|
||||
|
|
|
@ -6,6 +6,7 @@ from fastapi.requests import Request
|
|||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import stream_response
|
||||
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
|
|||
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
|
||||
|
||||
|
||||
|
||||
if create.stream:
|
||||
async def inner():
|
||||
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
d = {'choices':[{'delta':{'content':token}}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
d = {'choices':[{'delta':{'content':token}}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
return stream_response(request,inner())
|
||||
else:
|
||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
comp.append_token(token)
|
||||
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
else:
|
||||
token, finish_reason = res
|
||||
comp.append_token(token)
|
||||
return comp
|
||||
|
|
|
@ -142,7 +142,7 @@ class ThreadContext:
|
|||
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
|
||||
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
|
||||
|
||||
async for token in self.interface.inference(local_messages,self.thread.id):
|
||||
async for token, finish_reason in self.interface.inference(local_messages,self.thread.id):
|
||||
if self.run.status == RunObject.Status.cancelling:
|
||||
logger.warn(f'Run {self.run.id} cancelling')
|
||||
break
|
||||
|
|
|
@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
|
|||
from ktransformers.util.utils import get_device
|
||||
from typing import Optional
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
|
|||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||
yield v
|
||||
|
||||
# return this inference raw usage
|
||||
yield RawUsage(
|
||||
tokenize_time = self.profiler.get_timer_sec('tokenize'),
|
||||
prefill_time = self.profiler.get_timer_sec('prefill'),
|
||||
decode_time = self.profiler.get_timer_sec('decode'),
|
||||
prefill_count = self.profiler.get_counter('prefill'),
|
||||
decode_count = self.profiler.get_counter('decode'),
|
||||
)
|
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
||||
if(self.max_new_tokens <= 0):
|
||||
logger.warning("max_new_tokens is less than 0")
|
||||
yield self.streamer.end()
|
||||
yield self.streamer.end(), "length"
|
||||
return
|
||||
logger.info(f"max_new_tokens: {self.max_new_tokens}")
|
||||
self.profiler.set_counter("decode", 0)
|
||||
|
@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||
yield self.streamer.end(), None
|
||||
yield "", "stop"
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
yield self.streamer.end()
|
||||
yield self.append_new_tokens(next_token), None
|
||||
|
||||
else: # for's else, if output get max new tokens
|
||||
yield self.streamer.end(), None
|
||||
yield "", "length"
|
||||
|
||||
|
||||
|
||||
def check_is_new(self, thread_id: str):
|
||||
if not self.use_static_cache:
|
||||
|
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t in self.generate():
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
|
|
@ -5,6 +5,7 @@ langchain >= 0.2.0
|
|||
blessed >= 1.20.0
|
||||
accelerate >= 0.31.0
|
||||
sentencepiece >= 0.1.97
|
||||
openai
|
||||
setuptools
|
||||
build
|
||||
ninja
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
from typing import List, Optional
|
||||
from typing_extensions import Literal
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ktransformers.server.schemas.base import Object
|
||||
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.types.chat.chat_completion_chunk import Choice
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
system = 'system'
|
||||
user = 'user'
|
||||
|
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
|
|||
def get_tokenizer_messages(self):
|
||||
return [m.to_tokenizer_message() for m in self.messages]
|
||||
|
||||
class FinishReason(Enum):
|
||||
stop = 'stop'
|
||||
length = 'length'
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
message: Message
|
||||
logprobs: Optional[str] = None
|
||||
finish_reason: FinishReason = None
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: str
|
||||
choices: List[Choice]
|
||||
created: int
|
||||
model: str
|
||||
object: Literal["chat.completion.chunk"]
|
||||
service_tier: Optional[Literal["scale", "default"]] = None
|
||||
system_fingerprint: Optional[str] = None
|
||||
usage: Optional[CompletionUsage] = None
|
||||
|
||||
class DeltaChoice(BaseModel):
|
||||
index: int
|
||||
delta: Message
|
||||
logprobs: Optional[str] = None
|
||||
finish_reason: FinishReason = None
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
completion_tokens:int
|
||||
prompt_tokens:int
|
||||
total_tokens:int
|
||||
|
||||
|
||||
class ChatCompletionBase(Object):
|
||||
created:int
|
||||
model:str = 'not implmented'
|
||||
system_fingerprint:str = 'not implmented'
|
||||
usage: Optional[Usage] = None
|
||||
|
||||
class ChatCompletionObject(ChatCompletionBase):
|
||||
choices:List[Choice] = []
|
||||
|
||||
def append_token(self,token:str):
|
||||
if len(self.choices) == 0:
|
||||
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
|
||||
self.choices[0].message.content += token
|
||||
|
||||
class ChatCompletionChunk(ChatCompletionBase):
|
||||
choices:List[DeltaChoice] = []
|
||||
|
||||
def set_token(self,token:str):
|
||||
self.choices = [
|
||||
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
|
||||
]
|
||||
|
||||
def to_stream_reply(self):
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class RawUsage(BaseModel):
|
||||
tokenize_time: float
|
||||
prefill_time: float
|
||||
decode_time: float
|
||||
prefill_count: int
|
||||
decode_count: int
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue