diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index a5eb986..ea1e815 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -13,16 +13,10 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.config.config import Config from ktransformers.server.config.log import logger - -from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk +from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage # Define own data structure instead of importing from OpenAI -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_tokens_details: Optional[Dict[str, Any]] = None - completion_tokens_details: Optional[Dict[str, Any]] = None + class Choice(BaseModel): index: int @@ -217,6 +211,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): completion_tokens=raw_usage.decode_count, total_tokens=raw_usage.prefill_count + raw_usage.decode_count ) + if create.return_speed: + chunk.usage.prefill_time = res.prefill_time + chunk.usage.decode_time = res.decode_time + else: + chunk.usage.__dict__.pop('prefill_time', None) + chunk.usage.__dict__.pop('decode_time', None) yield chunk elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res @@ -377,8 +377,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): usage = CompletionUsage( prompt_tokens=raw_usage.prefill_count, completion_tokens=raw_usage.decode_count, - total_tokens=raw_usage.prefill_count + raw_usage.decode_count + total_tokens=raw_usage.prefill_count + raw_usage.decode_count, ) + if create.return_speed: + usage.prefill_time = res.prefill_time + usage.decode_time = res.decode_time + else: + usage.__dict__.pop('prefill_time', None) + usage.__dict__.pop('decode_time', None) + elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index 643c81c..d37e342 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any from typing_extensions import Literal from enum import Enum from pydantic import BaseModel, Field - +from ktransformers.server.config.config import Config from ktransformers.server.schemas.base import Object -from openai.types.completion_usage import CompletionUsage + from openai.types.chat.chat_completion_chunk import Choice from uuid import uuid4 +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_tokens_details: Optional[Dict[str, Any]] = None + completion_tokens_details: Optional[Dict[str, Any]] = None + prefill_time: Optional[float] = None + decode_time: Optional[float] = None class Role(Enum): system = 'system' @@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel): messages: List[Message] model: str stream: bool = False - temperature: Optional[float] = Field(default=0.6) - top_p: Optional[float] = Field(default=1.0) + temperature: Optional[float] = Field(default=Config().temperature) + top_p: Optional[float] = Field(default=Config().top_p) tools: Optional[List[Tool]] = None tool_choice: Optional[Union[str, Dict[str, Any]]] = None stream_options: Optional[Dict[str, Any]] = None frequency_penalty: float = 0 presence_penalty: float = 0 - max_tokens: Optional[int] = Field(default=50) - max_completion_tokens: Optional[int] = Field(default=50) - + max_tokens: Optional[int] = Field(default=Config().max_new_tokens) + max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens) + return_speed: Optional[bool] = Field(default=False) def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] diff --git a/ktransformers/server/schemas/legacy/completions.py b/ktransformers/server/schemas/legacy/completions.py index 2d83212..a728cb1 100644 --- a/ktransformers/server/schemas/legacy/completions.py +++ b/ktransformers/server/schemas/legacy/completions.py @@ -1,17 +1,17 @@ from typing import List, Optional from enum import Enum from pydantic import BaseModel, Field - +from ktransformers.server.config.config import Config from ..base import Object class CompletionCreate(BaseModel): model: str prompt: str | List[str] stream: bool = False - temperature: Optional[float] = Field(default=0.6) - top_p: Optional[float] = Field(default=1) - max_tokens: Optional[int] = Field(default=50) - max_completion_tokens: Optional[int] = Field(default=50) + temperature: Optional[float] = Field(default=Config().temperature) + top_p: Optional[float] = Field(default=Config().top_p) + max_tokens: Optional[int] = Field(default=Config().max_new_tokens) + max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens) def get_tokenizer_messages(self): if isinstance(self.prompt,List): diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py index dbdf999..3e7f849 100644 --- a/ktransformers/tests/test_speed.py +++ b/ktransformers/tests/test_speed.py @@ -12,6 +12,8 @@ from time import sleep decodesz = 128 # Server URL (replace with your server URL) decodesz_list = [128] +prefill_speeds = [] +decode_speeds = [] ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. @@ -43,7 +45,7 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. Dursley always sat with his back to the window in his office on the ninth floor.""" -async def fetch_event_stream(session, request_id, prompt): +async def fetch_event_stream(session, request_id, prompt, max_tokens): try: payload = { "messages": [ @@ -53,7 +55,9 @@ async def fetch_event_stream(session, request_id, prompt): "model": "DeepSeek-V3", "temperature": 0.3, "top_p": 1.0, - "stream": True + "stream": True, + "return_speed": True, + "max_tokens": max_tokens, } headers = { @@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt): total_tokens = 0 decode_start_time = None decode_end_time = None + usage_info = None async for line in response.content: try: @@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt): continue response_data = json.loads(decoded_line) + + if "usage" in response_data: + usage_info = response_data["usage"] + choices = response_data.get("choices", []) if not choices: continue @@ -107,34 +116,45 @@ async def fetch_event_stream(session, request_id, prompt): except Exception as e: print(f"[Request {request_id}] Stream Error: {e}") - if buffer.strip(): print(f"[Request {request_id}] {buffer.strip()}") - if decode_start_time and decode_end_time and total_tokens > 0: - decode_time = decode_end_time - decode_start_time - decode_speed = total_tokens / decode_time if decode_time > 0 else 0 - print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s") + if usage_info: + if "prefill_time" in usage_info: + # print(f"[Request {request_id}] Usage:") + # for key, value in usage_info.items(): + # print(f" {key}: {value}") + prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"] + decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"] + prefill_speeds.append(prefill_speed) + decode_speeds.append(decode_speed) + print(f'[Request {request_id}] prefill speed: {prefill_speed}') + print(f'[Request {request_id}] decode speed: {decode_speed}') except Exception as e: print(f"[Request {request_id}] Exception: {e}") -async def main(concurrent_requests , prompt ): +async def main(concurrent_requests , prompt, max_tokens): async with aiohttp.ClientSession() as session: - tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)] + tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)] await asyncio.gather(*tasks) + if len(prefill_speeds) != 0: + import numpy as np + print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") + parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") args = parser.parse_args() SERVER_URL = args.api_url + max_tokens = args.max_tokens if args.prompt_lens == 1024: prompt = ktansformer_prompt1024 elif args.prompt_lens == 2048: prompt = ktansformer_prompt1024 * 2 - asyncio.run(main(args.concurrent, prompt)) + asyncio.run(main(args.concurrent, prompt, max_tokens))