Merge pull request #835 from BITcyman/fix-openai_chat_completion

[fix] support openai chat completion api
This commit is contained in:
wang jiahao 2025-03-07 17:22:00 +08:00 committed by GitHub
commit 96d75d53df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 166 additions and 83 deletions

View file

@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
@ -58,14 +60,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if input.stream:
async def inner():
async for token in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
async for res in interface.inference(input.prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
@ -123,14 +129,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count = 0 # 统计生成的 token 数量
tokens = []
async for token in interface.inference(prompt, id):
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
# 计算性能数据
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒

View file

@ -5,10 +5,16 @@ from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
router = APIRouter()
@router.get('/models', tags=['openai'])
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
chunk.set_token(token)
yield chunk
return chat_stream_response(request,inner())
chunk = ChatCompletionChunk(
id = id,
choices = [],
object = 'chat.completion.chunk',
created = int(time()),
model = Config().model_name,
)
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage = res
chunk.choices = []
chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
yield chunk
else:
token, finish_reason = res
choice = Choice(
index = 0,
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
finish_reason = finish_reason,
logprobs = None,
)
chunk.choices = [choice]
yield chunk
return chat_stream_response(request, inner())
else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token)
return comp
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
content = ""
finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
else:
token, finish_reason = res
content = content + token
finish_reason = finish_reason
choice = Choice(
index = 0,
finish_reason = finish_reason,
message = ChatCompletionMessage(
content=content,
role="assistant"
))
chat_completion = ChatCompletion(
id = id,
choices = [choice],
created = int(time()),
model = Config().model_name,
object = 'chat.completion',
usage = usage
)
return chat_completion

View file

@ -6,6 +6,7 @@ from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter()
@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate):
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream:
async def inner():
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
yield f"data:{json.dumps(d)}\n\n"
return stream_response(request,inner())
else:
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
comp.append_token(token)
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
comp.append_token(token)
return comp

View file

@ -142,7 +142,7 @@ class ThreadContext:
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id):
async for token, finish_reason in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break

View file

@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v
# return this inference raw usage
yield RawUsage(
tokenize_time = self.profiler.get_timer_sec('tokenize'),
prefill_time = self.profiler.get_timer_sec('prefill'),
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)

View file

@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end()
yield self.streamer.end(), "length"
return
logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0)
@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)
yield self.streamer.end()
yield self.append_new_tokens(next_token), None
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
def check_is_new(self, thread_id: str):
if not self.use_static_cache:
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think
yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None:
print(t, end="",flush=True)
yield t
yield t, None
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t in self.generate():
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t
yield t, finish_reason
print("")
self.profiler.pause_timer("decode")
self.report_last_time_performance()

View file

@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0
accelerate >= 0.31.0
sentencepiece >= 0.1.97
openai
setuptools
build
ninja

View file

@ -1,10 +1,15 @@
from typing import List, Optional
from typing_extensions import Literal
from enum import Enum
from pydantic import BaseModel
from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice
class Role(Enum):
system = 'system'
user = 'user'
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class ChatCompletionChunk(BaseModel):
id: str
choices: List[Choice]
created: int
model: str
object: Literal["chat.completion.chunk"]
service_tier: Optional[Literal["scale", "default"]] = None
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
class DeltaChoice(BaseModel):
index: int
delta: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class Usage(BaseModel):
completion_tokens:int
prompt_tokens:int
total_tokens:int
class ChatCompletionBase(Object):
created:int
model:str = 'not implmented'
system_fingerprint:str = 'not implmented'
usage: Optional[Usage] = None
class ChatCompletionObject(ChatCompletionBase):
choices:List[Choice] = []
def append_token(self,token:str):
if len(self.choices) == 0:
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
self.choices[0].message.content += token
class ChatCompletionChunk(ChatCompletionBase):
choices:List[DeltaChoice] = []
def set_token(self,token:str):
self.choices = [
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
]
def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel):
tokenize_time: float
prefill_time: float
decode_time: float
prefill_count: int
decode_count: int