mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
[update] support openai chat completion api
This commit is contained in:
parent
63b1c8525b
commit
299c4dca64
8 changed files with 166 additions and 83 deletions
|
@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface
|
||||||
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
from ktransformers.server.schemas.assistants.streaming import check_link_response
|
||||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
|
|
||||||
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
|
|
||||||
router = APIRouter(prefix='/api')
|
router = APIRouter(prefix='/api')
|
||||||
|
|
||||||
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
|
||||||
|
@ -58,7 +60,11 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest):
|
||||||
|
|
||||||
if input.stream:
|
if input.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
async for token in interface.inference(input.prompt, id):
|
async for res in interface.inference(input.prompt, id):
|
||||||
|
if isinstance(res, RawUsage):
|
||||||
|
raw_usage = res
|
||||||
|
else:
|
||||||
|
token, finish_reason = res
|
||||||
d = OllamaGenerationStreamResponse(
|
d = OllamaGenerationStreamResponse(
|
||||||
model=config.model_name,
|
model=config.model_name,
|
||||||
created_at=str(datetime.now()),
|
created_at=str(datetime.now()),
|
||||||
|
@ -123,7 +129,11 @@ async def chat(request: Request, input: OllamaChatCompletionRequest):
|
||||||
eval_count = 0 # 统计生成的 token 数量
|
eval_count = 0 # 统计生成的 token 数量
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
||||||
async for token in interface.inference(prompt, id):
|
async for res in interface.inference(prompt, id):
|
||||||
|
if isinstance(res, RawUsage):
|
||||||
|
raw_usage = res
|
||||||
|
else:
|
||||||
|
token, finish_reason = res
|
||||||
d = OllamaChatCompletionStreamResponse(
|
d = OllamaChatCompletionStreamResponse(
|
||||||
model=config.model_name,
|
model=config.model_name,
|
||||||
created_at=str(datetime.now()),
|
created_at=str(datetime.now()),
|
||||||
|
|
|
@ -5,10 +5,16 @@ from fastapi import APIRouter
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from ktransformers.server.utils.create_interface import get_interface
|
from ktransformers.server.utils.create_interface import get_interface
|
||||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
||||||
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
|
||||||
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
||||||
|
from openai.types.chat import ChatCompletion
|
||||||
|
from openai.types.completion_usage import CompletionUsage
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get('/models', tags=['openai'])
|
@router.get('/models', tags=['openai'])
|
||||||
|
@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
|
||||||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
|
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
chunk = ChatCompletionChunk(
|
||||||
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
id = id,
|
||||||
chunk.set_token(token)
|
choices = [],
|
||||||
|
object = 'chat.completion.chunk',
|
||||||
|
created = int(time()),
|
||||||
|
model = Config().model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
|
||||||
|
if isinstance(res, RawUsage):
|
||||||
|
# at the end of inference, interface.inference() will return the usage of inference
|
||||||
|
raw_usage = res
|
||||||
|
chunk.choices = []
|
||||||
|
chunk.usage = CompletionUsage(
|
||||||
|
prompt_tokens = raw_usage.prefill_count,
|
||||||
|
completion_tokens = raw_usage.decode_count,
|
||||||
|
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
||||||
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
return chat_stream_response(request,inner())
|
|
||||||
else:
|
else:
|
||||||
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
|
token, finish_reason = res
|
||||||
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
|
choice = Choice(
|
||||||
async for token in interface.inference(input_message,id,create.temperature,create.top_p):
|
index = 0,
|
||||||
comp.append_token(token)
|
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
|
||||||
return comp
|
finish_reason = finish_reason,
|
||||||
|
logprobs = None,
|
||||||
|
)
|
||||||
|
chunk.choices = [choice]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return chat_stream_response(request, inner())
|
||||||
|
else:
|
||||||
|
from openai.types.chat.chat_completion import Choice
|
||||||
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||||
|
|
||||||
|
content = ""
|
||||||
|
finish_reason = None
|
||||||
|
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
|
||||||
|
if isinstance(res, RawUsage):
|
||||||
|
raw_usage = res
|
||||||
|
usage = CompletionUsage(
|
||||||
|
prompt_tokens = raw_usage.prefill_count,
|
||||||
|
completion_tokens = raw_usage.decode_count,
|
||||||
|
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
token, finish_reason = res
|
||||||
|
content = content + token
|
||||||
|
finish_reason = finish_reason
|
||||||
|
|
||||||
|
choice = Choice(
|
||||||
|
index = 0,
|
||||||
|
finish_reason = finish_reason,
|
||||||
|
message = ChatCompletionMessage(
|
||||||
|
content=content,
|
||||||
|
role="assistant"
|
||||||
|
))
|
||||||
|
|
||||||
|
chat_completion = ChatCompletion(
|
||||||
|
id = id,
|
||||||
|
choices = [choice],
|
||||||
|
created = int(time()),
|
||||||
|
model = Config().model_name,
|
||||||
|
object = 'chat.completion',
|
||||||
|
usage = usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return chat_completion
|
||||||
|
|
|
@ -6,6 +6,7 @@ from fastapi.requests import Request
|
||||||
from ktransformers.server.utils.create_interface import get_interface
|
from ktransformers.server.utils.create_interface import get_interface
|
||||||
from ktransformers.server.schemas.assistants.streaming import stream_response
|
from ktransformers.server.schemas.assistants.streaming import stream_response
|
||||||
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
|
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
|
||||||
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -17,10 +18,13 @@ async def create_completion(request:Request,create:CompletionCreate):
|
||||||
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
|
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
async def inner():
|
async def inner():
|
||||||
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||||
|
if isinstance(res, RawUsage):
|
||||||
|
raw_usage = res
|
||||||
|
else:
|
||||||
|
token, finish_reason = res
|
||||||
d = {'choices':[{'delta':{'content':token}}]}
|
d = {'choices':[{'delta':{'content':token}}]}
|
||||||
yield f"data:{json.dumps(d)}\n\n"
|
yield f"data:{json.dumps(d)}\n\n"
|
||||||
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
||||||
|
@ -28,6 +32,10 @@ async def create_completion(request:Request,create:CompletionCreate):
|
||||||
return stream_response(request,inner())
|
return stream_response(request,inner())
|
||||||
else:
|
else:
|
||||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||||
async for token in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
async for res in interface.inference(create.prompt,id,create.temperature,create.top_p):
|
||||||
|
if isinstance(res, RawUsage):
|
||||||
|
raw_usage = res
|
||||||
|
else:
|
||||||
|
token, finish_reason = res
|
||||||
comp.append_token(token)
|
comp.append_token(token)
|
||||||
return comp
|
return comp
|
||||||
|
|
|
@ -142,7 +142,7 @@ class ThreadContext:
|
||||||
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
|
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
|
||||||
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
|
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
|
||||||
|
|
||||||
async for token in self.interface.inference(local_messages,self.thread.id):
|
async for token, finish_reason in self.interface.inference(local_messages,self.thread.id):
|
||||||
if self.run.status == RunObject.Status.cancelling:
|
if self.run.status == RunObject.Status.cancelling:
|
||||||
logger.warn(f'Run {self.run.id} cancelling')
|
logger.warn(f'Run {self.run.id} cancelling')
|
||||||
break
|
break
|
||||||
|
|
|
@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||||
from ktransformers.util.utils import get_device
|
from ktransformers.util.utils import get_device
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
|
||||||
|
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
||||||
|
|
||||||
warm_uped = False
|
warm_uped = False
|
||||||
|
|
||||||
|
@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface):
|
||||||
async with self._infer_lock:
|
async with self._infer_lock:
|
||||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||||
yield v
|
yield v
|
||||||
|
|
||||||
|
# return this inference raw usage
|
||||||
|
yield RawUsage(
|
||||||
|
tokenize_time = self.profiler.get_timer_sec('tokenize'),
|
||||||
|
prefill_time = self.profiler.get_timer_sec('prefill'),
|
||||||
|
decode_time = self.profiler.get_timer_sec('decode'),
|
||||||
|
prefill_count = self.profiler.get_counter('prefill'),
|
||||||
|
decode_count = self.profiler.get_counter('decode'),
|
||||||
|
)
|
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
||||||
if(self.max_new_tokens <= 0):
|
if(self.max_new_tokens <= 0):
|
||||||
logger.warning("max_new_tokens is less than 0")
|
logger.warning("max_new_tokens is less than 0")
|
||||||
yield self.streamer.end()
|
yield self.streamer.end(), "length"
|
||||||
return
|
return
|
||||||
logger.info(f"max_new_tokens: {self.max_new_tokens}")
|
logger.info(f"max_new_tokens: {self.max_new_tokens}")
|
||||||
self.profiler.set_counter("decode", 0)
|
self.profiler.set_counter("decode", 0)
|
||||||
|
@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
next_token = self.decode_one_tokens()
|
next_token = self.decode_one_tokens()
|
||||||
self.profiler.inc("decode")
|
self.profiler.inc("decode")
|
||||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||||
|
yield self.streamer.end(), None
|
||||||
|
yield "", "stop"
|
||||||
assert self.args.batch_size == 1
|
assert self.args.batch_size == 1
|
||||||
break
|
break
|
||||||
yield self.append_new_tokens(next_token)
|
yield self.append_new_tokens(next_token), None
|
||||||
yield self.streamer.end()
|
|
||||||
|
else: # for's else, if output get max new tokens
|
||||||
|
yield self.streamer.end(), None
|
||||||
|
yield "", "length"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def check_is_new(self, thread_id: str):
|
def check_is_new(self, thread_id: str):
|
||||||
if not self.use_static_cache:
|
if not self.use_static_cache:
|
||||||
|
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
think = '<think>\n'
|
think = '<think>\n'
|
||||||
print(think, end="",flush=True)
|
print(think, end="",flush=True)
|
||||||
yield think
|
yield think, None
|
||||||
|
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||||
# output think token after prefill done
|
# output think token after prefill done
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
yield t
|
yield t, None
|
||||||
self.profiler.pause_timer("prefill")
|
self.profiler.pause_timer("prefill")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("decode")
|
self.profiler.create_and_start_timer("decode")
|
||||||
for t in self.generate():
|
for t, finish_reason in self.generate():
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
yield t
|
yield t, finish_reason
|
||||||
print("")
|
print("")
|
||||||
self.profiler.pause_timer("decode")
|
self.profiler.pause_timer("decode")
|
||||||
self.report_last_time_performance()
|
self.report_last_time_performance()
|
||||||
|
|
|
@ -5,6 +5,7 @@ langchain >= 0.2.0
|
||||||
blessed >= 1.20.0
|
blessed >= 1.20.0
|
||||||
accelerate >= 0.31.0
|
accelerate >= 0.31.0
|
||||||
sentencepiece >= 0.1.97
|
sentencepiece >= 0.1.97
|
||||||
|
openai
|
||||||
setuptools
|
setuptools
|
||||||
build
|
build
|
||||||
ninja
|
ninja
|
||||||
|
|
|
@ -1,10 +1,15 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from typing_extensions import Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ktransformers.server.schemas.base import Object
|
from ktransformers.server.schemas.base import Object
|
||||||
|
|
||||||
|
from openai.types.completion_usage import CompletionUsage
|
||||||
|
from openai.types.chat.chat_completion_chunk import Choice
|
||||||
|
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
system = 'system'
|
system = 'system'
|
||||||
user = 'user'
|
user = 'user'
|
||||||
|
@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel):
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
||||||
class FinishReason(Enum):
|
|
||||||
stop = 'stop'
|
|
||||||
length = 'length'
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
class ChatCompletionChunk(BaseModel):
|
||||||
index: int
|
id: str
|
||||||
message: Message
|
choices: List[Choice]
|
||||||
logprobs: Optional[str] = None
|
created: int
|
||||||
finish_reason: FinishReason = None
|
model: str
|
||||||
|
object: Literal["chat.completion.chunk"]
|
||||||
|
service_tier: Optional[Literal["scale", "default"]] = None
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
|
usage: Optional[CompletionUsage] = None
|
||||||
|
|
||||||
class DeltaChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
delta: Message
|
|
||||||
logprobs: Optional[str] = None
|
|
||||||
finish_reason: FinishReason = None
|
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
|
||||||
completion_tokens:int
|
|
||||||
prompt_tokens:int
|
|
||||||
total_tokens:int
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionBase(Object):
|
|
||||||
created:int
|
|
||||||
model:str = 'not implmented'
|
|
||||||
system_fingerprint:str = 'not implmented'
|
|
||||||
usage: Optional[Usage] = None
|
|
||||||
|
|
||||||
class ChatCompletionObject(ChatCompletionBase):
|
|
||||||
choices:List[Choice] = []
|
|
||||||
|
|
||||||
def append_token(self,token:str):
|
|
||||||
if len(self.choices) == 0:
|
|
||||||
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
|
|
||||||
self.choices[0].message.content += token
|
|
||||||
|
|
||||||
class ChatCompletionChunk(ChatCompletionBase):
|
|
||||||
choices:List[DeltaChoice] = []
|
|
||||||
|
|
||||||
def set_token(self,token:str):
|
|
||||||
self.choices = [
|
|
||||||
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
|
|
||||||
]
|
|
||||||
|
|
||||||
def to_stream_reply(self):
|
def to_stream_reply(self):
|
||||||
return f"data: {self.model_dump_json()}\n\n"
|
return f"data: {self.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class RawUsage(BaseModel):
|
||||||
|
tokenize_time: float
|
||||||
|
prefill_time: float
|
||||||
|
decode_time: float
|
||||||
|
prefill_count: int
|
||||||
|
decode_count: int
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue