Merge pull request #835 from BITcyman/fix-openai_chat_completion

[fix] support openai chat completion api
This commit is contained in:
wang jiahao 2025-03-07 17:22:00 +08:00 committed by GitHub
commit 96d75d53df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 166 additions and 83 deletions

View file

@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter(prefix='/api') router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
@ -58,7 +60,11 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
if input.stream: if input.stream:
async def inner(): async def inner():
async for token in interface.inference(input.prompt, id): async for res in interface.inference(input.prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaGenerationStreamResponse( d = OllamaGenerationStreamResponse(
model=config.model_name, model=config.model_name,
created_at=str(datetime.now()), created_at=str(datetime.now()),
@ -123,7 +129,11 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
eval_count = 0 # 统计生成的 token 数量 eval_count = 0 # 统计生成的 token 数量
tokens = [] tokens = []
async for token in interface.inference(prompt, id): async for res in interface.inference(prompt, id):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = OllamaChatCompletionStreamResponse( d = OllamaChatCompletionStreamResponse(
model=config.model_name, model=config.model_name,
created_at=str(datetime.now()), created_at=str(datetime.now()),

View file

@ -5,10 +5,16 @@ from fastapi import APIRouter
from fastapi.requests import Request from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
router = APIRouter() router = APIRouter()
@router.get('/models', tags=['openai']) @router.get('/models', tags=['openai'])
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(
async for token in interface.inference(input_message,id,create.temperature,create.top_p): id = id,
chunk.set_token(token) choices = [],
object = 'chat.completion.chunk',
created = int(time()),
model = Config().model_name,
)
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference
raw_usage = res
chunk.choices = []
chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
yield chunk yield chunk
else:
token, finish_reason = res
choice = Choice(
index = 0,
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
finish_reason = finish_reason,
logprobs = None,
)
chunk.choices = [choice]
yield chunk
return chat_stream_response(request, inner()) return chat_stream_response(request, inner())
else: else:
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) from openai.types.chat.chat_completion import Choice
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) from openai.types.chat.chat_completion_message import ChatCompletionMessage
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
comp.append_token(token) content = ""
return comp finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
)
else:
token, finish_reason = res
content = content + token
finish_reason = finish_reason
choice = Choice(
index = 0,
finish_reason = finish_reason,
message = ChatCompletionMessage(
content=content,
role="assistant"
))
chat_completion = ChatCompletion(
id = id,
choices = [choice],
created = int(time()),
model = Config().model_name,
object = 'chat.completion',
usage = usage
)
return chat_completion

View file

@ -6,6 +6,7 @@ from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage
router = APIRouter() router = APIRouter()
@ -17,10 +18,13 @@ async def create_completion(request:Request,create:CompletionCreate):
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream: if create.stream:
async def inner(): async def inner():
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
d = {'choices':[{'delta':{'content':token}}]} d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n" yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
@ -28,6 +32,10 @@ async def create_completion(request:Request,create:CompletionCreate):
return stream_response(request,inner()) return stream_response(request,inner())
else: else:
comp = CompletionObject(id=id,object='text_completion',created=int(time())) comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
else:
token, finish_reason = res
comp.append_token(token) comp.append_token(token)
return comp return comp

View file

@ -142,7 +142,7 @@ class ThreadContext:
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id): async for token, finish_reason in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling: if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling') logger.warn(f'Run {self.run.id} cancelling')
break break

View file

@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device from ktransformers.util.utils import get_device
from typing import Optional from typing import Optional
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
from ktransformers.server.schemas.endpoints.chat import RawUsage
warm_uped = False warm_uped = False
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p): async for v in super().inference(local_messages, thread_id, temperature, top_p):
yield v yield v
# return this inference raw usage
yield RawUsage(
tokenize_time = self.profiler.get_timer_sec('tokenize'),
prefill_time = self.profiler.get_timer_sec('prefill'),
decode_time = self.profiler.get_timer_sec('decode'),
prefill_count = self.profiler.get_counter('prefill'),
decode_count = self.profiler.get_counter('decode'),
)

View file

@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0): if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0") logger.warning("max_new_tokens is less than 0")
yield self.streamer.end() yield self.streamer.end(), "length"
return return
logger.info(f"max_new_tokens: {self.max_new_tokens}") logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
next_token = self.decode_one_tokens() next_token = self.decode_one_tokens()
self.profiler.inc("decode") self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
yield self.streamer.end(), None
yield "", "stop"
assert self.args.batch_size == 1 assert self.args.batch_size == 1
break break
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token), None
yield self.streamer.end()
else: # for's else, if output get max new tokens
yield self.streamer.end(), None
yield "", "length"
def check_is_new(self, thread_id: str): def check_is_new(self, thread_id: str):
if not self.use_static_cache: if not self.use_static_cache:
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
if Config().user_force_think: if Config().user_force_think:
think = '<think>\n' think = '<think>\n'
print(think, end="",flush=True) print(think, end="",flush=True)
yield think yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done # output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t in self.generate(): for t, finish_reason in self.generate():
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t, finish_reason
print("") print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()

View file

@ -5,6 +5,7 @@ langchain >= 0.2.0
blessed >= 1.20.0 blessed >= 1.20.0
accelerate >= 0.31.0 accelerate >= 0.31.0
sentencepiece >= 0.1.97 sentencepiece >= 0.1.97
openai
setuptools setuptools
build build
ninja ninja

View file

@ -1,10 +1,15 @@
from typing import List, Optional from typing import List, Optional
from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
user = 'user' user = 'user'
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel): class ChatCompletionChunk(BaseModel):
index: int id: str
message: Message choices: List[Choice]
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class DeltaChoice(BaseModel):
index: int
delta: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class Usage(BaseModel):
completion_tokens:int
prompt_tokens:int
total_tokens:int
class ChatCompletionBase(Object):
created: int created: int
model:str = 'not implmented' model: str
system_fingerprint:str = 'not implmented' object: Literal["chat.completion.chunk"]
usage: Optional[Usage] = None service_tier: Optional[Literal["scale", "default"]] = None
system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None
class ChatCompletionObject(ChatCompletionBase):
choices:List[Choice] = []
def append_token(self,token:str):
if len(self.choices) == 0:
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
self.choices[0].message.content += token
class ChatCompletionChunk(ChatCompletionBase):
choices:List[DeltaChoice] = []
def set_token(self,token:str):
self.choices = [
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
]
def to_stream_reply(self): def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n" return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel):
tokenize_time: float
prefill_time: float
decode_time: float
prefill_count: int
decode_count: int