update speed test

This commit is contained in:
qiyuxinlin 2025-04-22 07:38:05 +00:00
parent f5287e908a
commit b17ab8653c
4 changed files with 66 additions and 31 deletions

View file

@ -13,16 +13,10 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
# Define own data structure instead of importing from OpenAI # Define own data structure instead of importing from OpenAI
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
class Choice(BaseModel): class Choice(BaseModel):
index: int index: int
@ -217,6 +211,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
completion_tokens=raw_usage.decode_count, completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count total_tokens=raw_usage.prefill_count + raw_usage.decode_count
) )
if create.return_speed:
chunk.usage.prefill_time = res.prefill_time
chunk.usage.decode_time = res.decode_time
else:
chunk.usage.__dict__.pop('prefill_time', None)
chunk.usage.__dict__.pop('decode_time', None)
yield chunk yield chunk
elif isinstance(res, tuple) and len(res) == 2: elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res token, finish_reason = res
@ -377,8 +377,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
usage = CompletionUsage( usage = CompletionUsage(
prompt_tokens=raw_usage.prefill_count, prompt_tokens=raw_usage.prefill_count,
completion_tokens=raw_usage.decode_count, completion_tokens=raw_usage.decode_count,
total_tokens=raw_usage.prefill_count + raw_usage.decode_count total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
) )
if create.return_speed:
usage.prefill_time = res.prefill_time
usage.decode_time = res.decode_time
else:
usage.__dict__.pop('prefill_time', None)
usage.__dict__.pop('decode_time', None)
elif isinstance(res, tuple) and len(res) == 2: elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res token, finish_reason = res
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)

View file

@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ktransformers.server.config.config import Config
from ktransformers.server.schemas.base import Object from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4 from uuid import uuid4
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
prefill_time: Optional[float] = None
decode_time: Optional[float] = None
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model: str model: str
stream: bool = False stream: bool = False
temperature: Optional[float] = Field(default=0.6) temperature: Optional[float] = Field(default=Config().temperature)
top_p: Optional[float] = Field(default=1.0) top_p: Optional[float] = Field(default=Config().top_p)
tools: Optional[List[Tool]] = None tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None tool_choice: Optional[Union[str, Dict[str, Any]]] = None
stream_options: Optional[Dict[str, Any]] = None stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0 frequency_penalty: float = 0
presence_penalty: float = 0 presence_penalty: float = 0
max_tokens: Optional[int] = Field(default=50) max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
max_completion_tokens: Optional[int] = Field(default=50) max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
return_speed: Optional[bool] = Field(default=False)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]

View file

@ -1,17 +1,17 @@
from typing import List, Optional from typing import List, Optional
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ktransformers.server.config.config import Config
from ..base import Object from ..base import Object
class CompletionCreate(BaseModel): class CompletionCreate(BaseModel):
model: str model: str
prompt: str | List[str] prompt: str | List[str]
stream: bool = False stream: bool = False
temperature: Optional[float] = Field(default=0.6) temperature: Optional[float] = Field(default=Config().temperature)
top_p: Optional[float] = Field(default=1) top_p: Optional[float] = Field(default=Config().top_p)
max_tokens: Optional[int] = Field(default=50) max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
max_completion_tokens: Optional[int] = Field(default=50) max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
if isinstance(self.prompt,List): if isinstance(self.prompt,List):

View file

@ -12,6 +12,8 @@ from time import sleep
decodesz = 128 decodesz = 128
# Server URL (replace with your server URL) # Server URL (replace with your server URL)
decodesz_list = [128] decodesz_list = [128]
prefill_speeds = []
decode_speeds = []
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
@ -43,7 +45,7 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt these people were obviously collecting for something yes, that would be it. The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt these people were obviously collecting for something yes, that would be it.
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
Mr. Dursley always sat with his back to the window in his office on the ninth floor.""" Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
async def fetch_event_stream(session, request_id, prompt): async def fetch_event_stream(session, request_id, prompt, max_tokens):
try: try:
payload = { payload = {
"messages": [ "messages": [
@ -53,7 +55,9 @@ async def fetch_event_stream(session, request_id, prompt):
"model": "DeepSeek-V3", "model": "DeepSeek-V3",
"temperature": 0.3, "temperature": 0.3,
"top_p": 1.0, "top_p": 1.0,
"stream": True "stream": True,
"return_speed": True,
"max_tokens": max_tokens,
} }
headers = { headers = {
@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt):
total_tokens = 0 total_tokens = 0
decode_start_time = None decode_start_time = None
decode_end_time = None decode_end_time = None
usage_info = None
async for line in response.content: async for line in response.content:
try: try:
@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt):
continue continue
response_data = json.loads(decoded_line) response_data = json.loads(decoded_line)
if "usage" in response_data:
usage_info = response_data["usage"]
choices = response_data.get("choices", []) choices = response_data.get("choices", [])
if not choices: if not choices:
continue continue
@ -107,34 +116,45 @@ async def fetch_event_stream(session, request_id, prompt):
except Exception as e: except Exception as e:
print(f"[Request {request_id}] Stream Error: {e}") print(f"[Request {request_id}] Stream Error: {e}")
if buffer.strip(): if buffer.strip():
print(f"[Request {request_id}] {buffer.strip()}") print(f"[Request {request_id}] {buffer.strip()}")
if decode_start_time and decode_end_time and total_tokens > 0: if usage_info:
decode_time = decode_end_time - decode_start_time if "prefill_time" in usage_info:
decode_speed = total_tokens / decode_time if decode_time > 0 else 0 # print(f"[Request {request_id}] Usage:")
print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s") # for key, value in usage_info.items():
# print(f" {key}: {value}")
prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"]
decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"]
prefill_speeds.append(prefill_speed)
decode_speeds.append(decode_speed)
print(f'[Request {request_id}] prefill speed: {prefill_speed}')
print(f'[Request {request_id}] decode speed: {decode_speed}')
except Exception as e: except Exception as e:
print(f"[Request {request_id}] Exception: {e}") print(f"[Request {request_id}] Exception: {e}")
async def main(concurrent_requests , prompt ): async def main(concurrent_requests , prompt, max_tokens):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)] tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if len(prefill_speeds) != 0:
import numpy as np
print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser = argparse.ArgumentParser(description="Event Stream Request Tester")
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
args = parser.parse_args() args = parser.parse_args()
SERVER_URL = args.api_url SERVER_URL = args.api_url
max_tokens = args.max_tokens
if args.prompt_lens == 1024: if args.prompt_lens == 1024:
prompt = ktansformer_prompt1024 prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048: elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2 prompt = ktansformer_prompt1024 * 2
asyncio.run(main(args.concurrent, prompt)) asyncio.run(main(args.concurrent, prompt, max_tokens))