diff --git a/ktransformers/server/api/ollama/completions.py b/ktransformers/server/api/ollama/completions.py index 0ff6183..3c37c54 100644 --- a/ktransformers/server/api/ollama/completions.py +++ b/ktransformers/server/api/ollama/completions.py @@ -13,6 +13,8 @@ from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import check_link_response from ktransformers.server.backend.base import BackendInterfaceBase +from ktransformers.server.schemas.endpoints.chat import RawUsage + router = APIRouter(prefix='/api') # https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion @@ -58,14 +60,18 @@ async def generate(request: Request, input: OllamaGenerateCompletionRequest): if input.stream: async def inner(): - async for token in interface.inference(input.prompt, id): - d = OllamaGenerationStreamResponse( - model=config.model_name, - created_at=str(datetime.now()), - response=token, - done=False - ) - yield d.model_dump_json() + '\n' + async for res in interface.inference(input.prompt, id): + if isinstance(res, RawUsage): + raw_usage = res + else: + token, finish_reason = res + d = OllamaGenerationStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + response=token, + done=False + ) + yield d.model_dump_json() + '\n' d = OllamaGenerationStreamResponse( model=config.model_name, created_at=str(datetime.now()), @@ -123,14 +129,18 @@ async def chat(request: Request, input: OllamaChatCompletionRequest): eval_count = 0 # 统计生成的 token 数量 tokens = [] - async for token in interface.inference(prompt, id): - d = OllamaChatCompletionStreamResponse( - model=config.model_name, - created_at=str(datetime.now()), - message={"role": "assistant", "content": token}, - done=False - ) - yield d.model_dump_json() + '\n' + async for res in interface.inference(prompt, id): + if isinstance(res, RawUsage): + raw_usage = res + else: + token, finish_reason = res + d = OllamaChatCompletionStreamResponse( + model=config.model_name, + created_at=str(datetime.now()), + message={"role": "assistant", "content": token}, + done=False + ) + yield d.model_dump_json() + '\n' # 计算性能数据 end_time = time() total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒 diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index bb155de..c9f7bfc 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -5,10 +5,16 @@ from fastapi import APIRouter from fastapi.requests import Request from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import chat_stream_response -from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage +from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate +from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.config.config import Config +from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk +from openai.types.chat import ChatCompletion +from openai.types.completion_usage import CompletionUsage + + router = APIRouter() @router.get('/models', tags=['openai']) @@ -29,15 +35,76 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): assert request.headers.get('Authorization', '').split()[-1] == Config().api_key if create.stream: + from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta + async def inner(): - chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) - async for token in interface.inference(input_message,id,create.temperature,create.top_p): - chunk.set_token(token) - yield chunk - return chat_stream_response(request,inner()) + chunk = ChatCompletionChunk( + id = id, + choices = [], + object = 'chat.completion.chunk', + created = int(time()), + model = Config().model_name, + ) + + async for res in interface.inference(input_message,id, create.temperature, create.top_p): + if isinstance(res, RawUsage): + # at the end of inference, interface.inference() will return the usage of inference + raw_usage = res + chunk.choices = [] + chunk.usage = CompletionUsage( + prompt_tokens = raw_usage.prefill_count, + completion_tokens = raw_usage.decode_count, + total_tokens = raw_usage.prefill_count + raw_usage.decode_count + ) + + yield chunk + + else: + token, finish_reason = res + choice = Choice( + index = 0, + delta = ChoiceDelta(content=token, role=None, tool_calls=None), + finish_reason = finish_reason, + logprobs = None, + ) + chunk.choices = [choice] + yield chunk + + return chat_stream_response(request, inner()) else: - comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time())) - comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2) - async for token in interface.inference(input_message,id,create.temperature,create.top_p): - comp.append_token(token) - return comp + from openai.types.chat.chat_completion import Choice + from openai.types.chat.chat_completion_message import ChatCompletionMessage + + content = "" + finish_reason = None + async for res in interface.inference(input_message,id,create.temperature,create.top_p): + if isinstance(res, RawUsage): + raw_usage = res + usage = CompletionUsage( + prompt_tokens = raw_usage.prefill_count, + completion_tokens = raw_usage.decode_count, + total_tokens = raw_usage.prefill_count + raw_usage.decode_count + ) + else: + token, finish_reason = res + content = content + token + finish_reason = finish_reason + + choice = Choice( + index = 0, + finish_reason = finish_reason, + message = ChatCompletionMessage( + content=content, + role="assistant" + )) + + chat_completion = ChatCompletion( + id = id, + choices = [choice], + created = int(time()), + model = Config().model_name, + object = 'chat.completion', + usage = usage + ) + + return chat_completion diff --git a/ktransformers/server/api/openai/legacy/completions.py b/ktransformers/server/api/openai/legacy/completions.py index fe250f4..7ce2d2a 100644 --- a/ktransformers/server/api/openai/legacy/completions.py +++ b/ktransformers/server/api/openai/legacy/completions.py @@ -6,6 +6,7 @@ from fastapi.requests import Request from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject +from ktransformers.server.schemas.endpoints.chat import RawUsage router = APIRouter() @@ -17,17 +18,24 @@ async def create_completion(request:Request,create:CompletionCreate): print(f'COMPLETION INPUT:----\n{create.prompt}\n----') - if create.stream: async def inner(): - async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): - d = {'choices':[{'delta':{'content':token}}]} - yield f"data:{json.dumps(d)}\n\n" + async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): + if isinstance(res, RawUsage): + raw_usage = res + else: + token, finish_reason = res + d = {'choices':[{'delta':{'content':token}}]} + yield f"data:{json.dumps(d)}\n\n" d = {'choices':[{'delta':{'content':''},'finish_reason':''}]} yield f"data:{json.dumps(d)}\n\n" return stream_response(request,inner()) else: comp = CompletionObject(id=id,object='text_completion',created=int(time())) - async for token in interface.inference(create.prompt,id,create.temperature,create.top_p): - comp.append_token(token) + async for res in interface.inference(create.prompt,id,create.temperature,create.top_p): + if isinstance(res, RawUsage): + raw_usage = res + else: + token, finish_reason = res + comp.append_token(token) return comp diff --git a/ktransformers/server/backend/base.py b/ktransformers/server/backend/base.py index 4cbcdfa..5148a48 100644 --- a/ktransformers/server/backend/base.py +++ b/ktransformers/server/backend/base.py @@ -142,7 +142,7 @@ class ThreadContext: yield reply_message.stream_response_with_event(MessageObject.Status.in_progress) yield self.run.stream_response_with_event(RunObject.Status.in_progress) - async for token in self.interface.inference(local_messages,self.thread.id): + async for token, finish_reason in self.interface.inference(local_messages,self.thread.id): if self.run.status == RunObject.Status.cancelling: logger.warn(f'Run {self.run.id} cancelling') break diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 2cbbb99..1752a3c 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -16,6 +16,7 @@ from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device from typing import Optional from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton +from ktransformers.server.schemas.endpoints.chat import RawUsage warm_uped = False @@ -231,3 +232,12 @@ class KTransformersInterface(TransformersInterface): async with self._infer_lock: async for v in super().inference(local_messages, thread_id, temperature, top_p): yield v + + # return this inference raw usage + yield RawUsage( + tokenize_time = self.profiler.get_timer_sec('tokenize'), + prefill_time = self.profiler.get_timer_sec('prefill'), + decode_time = self.profiler.get_timer_sec('decode'), + prefill_count = self.profiler.get_counter('prefill'), + decode_count = self.profiler.get_counter('decode'), + ) \ No newline at end of file diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index 7200ef4..7e59804 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase): logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") if(self.max_new_tokens <= 0): logger.warning("max_new_tokens is less than 0") - yield self.streamer.end() + yield self.streamer.end(), "length" return logger.info(f"max_new_tokens: {self.max_new_tokens}") self.profiler.set_counter("decode", 0) @@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase): next_token = self.decode_one_tokens() self.profiler.inc("decode") if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token): + yield self.streamer.end(), None + yield "", "stop" assert self.args.batch_size == 1 break - yield self.append_new_tokens(next_token) - yield self.streamer.end() + yield self.append_new_tokens(next_token), None + + else: # for's else, if output get max new tokens + yield self.streamer.end(), None + yield "", "length" + + def check_is_new(self, thread_id: str): if not self.use_static_cache: @@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase): if Config().user_force_think: think = '\n' print(think, end="",flush=True) - yield think + yield think, None for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): # output think token after prefill done if t is not None: print(t, end="",flush=True) - yield t + yield t, None self.profiler.pause_timer("prefill") self.profiler.create_and_start_timer("decode") - for t in self.generate(): + for t, finish_reason in self.generate(): if t is not None: print(t, end="",flush=True) - yield t + yield t, finish_reason print("") self.profiler.pause_timer("decode") self.report_last_time_performance() diff --git a/ktransformers/server/requirements.txt b/ktransformers/server/requirements.txt index d324cf2..9a4c9c5 100644 --- a/ktransformers/server/requirements.txt +++ b/ktransformers/server/requirements.txt @@ -5,6 +5,7 @@ langchain >= 0.2.0 blessed >= 1.20.0 accelerate >= 0.31.0 sentencepiece >= 0.1.97 +openai setuptools build ninja diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index e5d8f95..eb0081a 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -1,10 +1,15 @@ from typing import List, Optional +from typing_extensions import Literal from enum import Enum from pydantic import BaseModel from ktransformers.server.schemas.base import Object +from openai.types.completion_usage import CompletionUsage +from openai.types.chat.chat_completion_chunk import Choice + + class Role(Enum): system = 'system' user = 'user' @@ -31,50 +36,25 @@ class ChatCompletionCreate(BaseModel): def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] -class FinishReason(Enum): - stop = 'stop' - length = 'length' -class Choice(BaseModel): - index: int - message: Message - logprobs: Optional[str] = None - finish_reason: FinishReason = None +class ChatCompletionChunk(BaseModel): + id: str + choices: List[Choice] + created: int + model: str + object: Literal["chat.completion.chunk"] + service_tier: Optional[Literal["scale", "default"]] = None + system_fingerprint: Optional[str] = None + usage: Optional[CompletionUsage] = None -class DeltaChoice(BaseModel): - index: int - delta: Message - logprobs: Optional[str] = None - finish_reason: FinishReason = None - - -class Usage(BaseModel): - completion_tokens:int - prompt_tokens:int - total_tokens:int - - -class ChatCompletionBase(Object): - created:int - model:str = 'not implmented' - system_fingerprint:str = 'not implmented' - usage: Optional[Usage] = None - -class ChatCompletionObject(ChatCompletionBase): - choices:List[Choice] = [] - - def append_token(self,token:str): - if len(self.choices) == 0: - self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant))) - self.choices[0].message.content += token - -class ChatCompletionChunk(ChatCompletionBase): - choices:List[DeltaChoice] = [] - - def set_token(self,token:str): - self.choices = [ - DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant)) - ] def to_stream_reply(self): return f"data: {self.model_dump_json()}\n\n" + + +class RawUsage(BaseModel): + tokenize_time: float + prefill_time: float + decode_time: float + prefill_count: int + decode_count: int