mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
update speed test
This commit is contained in:
parent
f5287e908a
commit
b17ab8653c
4 changed files with 66 additions and 31 deletions
|
@ -13,16 +13,10 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
|
||||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
|
||||||
|
|
||||||
# Define own data structure instead of importing from OpenAI
|
# Define own data structure instead of importing from OpenAI
|
||||||
class CompletionUsage(BaseModel):
|
|
||||||
prompt_tokens: int
|
|
||||||
completion_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
|
||||||
completion_tokens_details: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
class Choice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
|
@ -217,6 +211,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
completion_tokens=raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||||
)
|
)
|
||||||
|
if create.return_speed:
|
||||||
|
chunk.usage.prefill_time = res.prefill_time
|
||||||
|
chunk.usage.decode_time = res.decode_time
|
||||||
|
else:
|
||||||
|
chunk.usage.__dict__.pop('prefill_time', None)
|
||||||
|
chunk.usage.__dict__.pop('decode_time', None)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(res, tuple) and len(res) == 2:
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
|
@ -377,8 +377,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
usage = CompletionUsage(
|
usage = CompletionUsage(
|
||||||
prompt_tokens=raw_usage.prefill_count,
|
prompt_tokens=raw_usage.prefill_count,
|
||||||
completion_tokens=raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
|
||||||
)
|
)
|
||||||
|
if create.return_speed:
|
||||||
|
usage.prefill_time = res.prefill_time
|
||||||
|
usage.decode_time = res.decode_time
|
||||||
|
else:
|
||||||
|
usage.__dict__.pop('prefill_time', None)
|
||||||
|
usage.__dict__.pop('decode_time', None)
|
||||||
|
|
||||||
elif isinstance(res, tuple) and len(res) == 2:
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
||||||
|
|
|
@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.schemas.base import Object
|
from ktransformers.server.schemas.base import Object
|
||||||
|
|
||||||
from openai.types.completion_usage import CompletionUsage
|
|
||||||
from openai.types.chat.chat_completion_chunk import Choice
|
from openai.types.chat.chat_completion_chunk import Choice
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
completion_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
prefill_time: Optional[float] = None
|
||||||
|
decode_time: Optional[float] = None
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
system = 'system'
|
system = 'system'
|
||||||
|
@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model: str
|
model: str
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=0.6)
|
temperature: Optional[float] = Field(default=Config().temperature)
|
||||||
top_p: Optional[float] = Field(default=1.0)
|
top_p: Optional[float] = Field(default=Config().top_p)
|
||||||
tools: Optional[List[Tool]] = None
|
tools: Optional[List[Tool]] = None
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
stream_options: Optional[Dict[str, Any]] = None
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
frequency_penalty: float = 0
|
frequency_penalty: float = 0
|
||||||
presence_penalty: float = 0
|
presence_penalty: float = 0
|
||||||
max_tokens: Optional[int] = Field(default=50)
|
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
max_completion_tokens: Optional[int] = Field(default=50)
|
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
|
return_speed: Optional[bool] = Field(default=False)
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
from ..base import Object
|
from ..base import Object
|
||||||
|
|
||||||
class CompletionCreate(BaseModel):
|
class CompletionCreate(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
prompt: str | List[str]
|
prompt: str | List[str]
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=0.6)
|
temperature: Optional[float] = Field(default=Config().temperature)
|
||||||
top_p: Optional[float] = Field(default=1)
|
top_p: Optional[float] = Field(default=Config().top_p)
|
||||||
max_tokens: Optional[int] = Field(default=50)
|
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
max_completion_tokens: Optional[int] = Field(default=50)
|
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
if isinstance(self.prompt,List):
|
if isinstance(self.prompt,List):
|
||||||
|
|
|
@ -12,6 +12,8 @@ from time import sleep
|
||||||
decodesz = 128
|
decodesz = 128
|
||||||
# Server URL (replace with your server URL)
|
# Server URL (replace with your server URL)
|
||||||
decodesz_list = [128]
|
decodesz_list = [128]
|
||||||
|
prefill_speeds = []
|
||||||
|
decode_speeds = []
|
||||||
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
|
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
|
||||||
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
|
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
|
||||||
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
|
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
|
||||||
|
@ -43,7 +45,7 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
|
||||||
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
||||||
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
||||||
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
||||||
async def fetch_event_stream(session, request_id, prompt):
|
async def fetch_event_stream(session, request_id, prompt, max_tokens):
|
||||||
try:
|
try:
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
|
@ -53,7 +55,9 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
"model": "DeepSeek-V3",
|
"model": "DeepSeek-V3",
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
"stream": True
|
"stream": True,
|
||||||
|
"return_speed": True,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
decode_start_time = None
|
decode_start_time = None
|
||||||
decode_end_time = None
|
decode_end_time = None
|
||||||
|
usage_info = None
|
||||||
|
|
||||||
async for line in response.content:
|
async for line in response.content:
|
||||||
try:
|
try:
|
||||||
|
@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_data = json.loads(decoded_line)
|
response_data = json.loads(decoded_line)
|
||||||
|
|
||||||
|
if "usage" in response_data:
|
||||||
|
usage_info = response_data["usage"]
|
||||||
|
|
||||||
choices = response_data.get("choices", [])
|
choices = response_data.get("choices", [])
|
||||||
if not choices:
|
if not choices:
|
||||||
continue
|
continue
|
||||||
|
@ -107,34 +116,45 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Request {request_id}] Stream Error: {e}")
|
print(f"[Request {request_id}] Stream Error: {e}")
|
||||||
|
|
||||||
|
|
||||||
if buffer.strip():
|
if buffer.strip():
|
||||||
print(f"[Request {request_id}] {buffer.strip()}")
|
print(f"[Request {request_id}] {buffer.strip()}")
|
||||||
|
|
||||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
if usage_info:
|
||||||
decode_time = decode_end_time - decode_start_time
|
if "prefill_time" in usage_info:
|
||||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
# print(f"[Request {request_id}] Usage:")
|
||||||
print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s")
|
# for key, value in usage_info.items():
|
||||||
|
# print(f" {key}: {value}")
|
||||||
|
prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"]
|
||||||
|
decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"]
|
||||||
|
prefill_speeds.append(prefill_speed)
|
||||||
|
decode_speeds.append(decode_speed)
|
||||||
|
print(f'[Request {request_id}] prefill speed: {prefill_speed}')
|
||||||
|
print(f'[Request {request_id}] decode speed: {decode_speed}')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Request {request_id}] Exception: {e}")
|
print(f"[Request {request_id}] Exception: {e}")
|
||||||
|
|
||||||
async def main(concurrent_requests , prompt ):
|
async def main(concurrent_requests , prompt, max_tokens):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)]
|
tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
if len(prefill_speeds) != 0:
|
||||||
|
import numpy as np
|
||||||
|
print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||||
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
||||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SERVER_URL = args.api_url
|
SERVER_URL = args.api_url
|
||||||
|
max_tokens = args.max_tokens
|
||||||
if args.prompt_lens == 1024:
|
if args.prompt_lens == 1024:
|
||||||
prompt = ktansformer_prompt1024
|
prompt = ktansformer_prompt1024
|
||||||
elif args.prompt_lens == 2048:
|
elif args.prompt_lens == 2048:
|
||||||
prompt = ktansformer_prompt1024 * 2
|
prompt = ktansformer_prompt1024 * 2
|
||||||
asyncio.run(main(args.concurrent, prompt))
|
asyncio.run(main(args.concurrent, prompt, max_tokens))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue