mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Merge remote-tracking branch 'origin/main' into check-para
This commit is contained in:
commit
f7d939313b
8 changed files with 219 additions and 145 deletions
|
@ -14,15 +14,10 @@ from ktransformers.server.backend.base import BackendInterfaceBase
|
|||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.config.log import logger
|
||||
from fastapi.responses import JSONResponse
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
|
||||
|
||||
# Define own data structure instead of importing from OpenAI
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
||||
completion_tokens_details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
|
@ -267,6 +262,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
completion_tokens=raw_usage.decode_count,
|
||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||
)
|
||||
if create.return_speed:
|
||||
chunk.usage.prefill_time = res.prefill_time
|
||||
chunk.usage.decode_time = res.decode_time
|
||||
else:
|
||||
chunk.usage.__dict__.pop('prefill_time', None)
|
||||
chunk.usage.__dict__.pop('decode_time', None)
|
||||
yield chunk
|
||||
elif isinstance(res, tuple) and len(res) == 2:
|
||||
token, finish_reason = res
|
||||
|
@ -427,8 +428,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
usage = CompletionUsage(
|
||||
prompt_tokens=raw_usage.prefill_count,
|
||||
completion_tokens=raw_usage.decode_count,
|
||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
|
||||
)
|
||||
if create.return_speed:
|
||||
usage.prefill_time = res.prefill_time
|
||||
usage.decode_time = res.decode_time
|
||||
else:
|
||||
usage.__dict__.pop('prefill_time', None)
|
||||
usage.__dict__.pop('decode_time', None)
|
||||
|
||||
elif isinstance(res, tuple) and len(res) == 2:
|
||||
token, finish_reason = res
|
||||
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
||||
|
|
|
@ -46,6 +46,8 @@ import pickle
|
|||
import subprocess
|
||||
import tempfile
|
||||
import atexit
|
||||
import signal
|
||||
|
||||
|
||||
ktransformer_rules_dir = (
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
|
@ -55,6 +57,7 @@ default_optimize_rules = {
|
|||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||
}
|
||||
|
||||
|
||||
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
|
||||
streamer = TextStreamer(tokenizer)
|
||||
while True:
|
||||
|
@ -293,10 +296,6 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
kvcache_event.wait()
|
||||
|
||||
|
||||
def cleanup():
|
||||
if sched_process.poll() is None:
|
||||
sched_process.terminate()
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
pickle.dump(args, temp_file)
|
||||
temp_file_path = temp_file.name
|
||||
|
@ -311,7 +310,27 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
stderr=log
|
||||
)
|
||||
print("sched_rpc started with PID:", sched_process.pid)
|
||||
atexit.register(cleanup)
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print(f"Received signal {signum}, shutting down...")
|
||||
cleanup()
|
||||
os._exit(0)
|
||||
|
||||
def cleanup():
|
||||
print("Cleaning up...")
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
print(f"Terminating subprocess {p.pid}")
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
if sched_process and sched_process.poll() is None:
|
||||
print(f"Terminating sched_process {sched_process.pid}")
|
||||
sched_process.terminate()
|
||||
sched_process.wait()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
start_event.wait()
|
||||
|
||||
|
|
|
@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any
|
|||
from typing_extensions import Literal
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.server.schemas.base import Object
|
||||
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import Choice
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
||||
completion_tokens_details: Optional[Dict[str, Any]] = None
|
||||
prefill_time: Optional[float] = None
|
||||
decode_time: Optional[float] = None
|
||||
|
||||
class Role(Enum):
|
||||
system = 'system'
|
||||
|
@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel):
|
|||
messages: List[Message]
|
||||
model: str
|
||||
stream: bool = False
|
||||
temperature: Optional[float] = Field(default=0.6)
|
||||
top_p: Optional[float] = Field(default=1.0)
|
||||
temperature: Optional[float] = Field(default=Config().temperature)
|
||||
top_p: Optional[float] = Field(default=Config().top_p)
|
||||
tools: Optional[List[Tool]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
stream_options: Optional[Dict[str, Any]] = None
|
||||
frequency_penalty: float = 0
|
||||
presence_penalty: float = 0
|
||||
max_tokens: Optional[int] = Field(default=50)
|
||||
max_completion_tokens: Optional[int] = Field(default=50)
|
||||
|
||||
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||
return_speed: Optional[bool] = Field(default=False)
|
||||
def get_tokenizer_messages(self):
|
||||
return [m.to_tokenizer_message() for m in self.messages]
|
||||
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ..base import Object
|
||||
|
||||
class CompletionCreate(BaseModel):
|
||||
model: str
|
||||
prompt: str | List[str]
|
||||
stream: bool = False
|
||||
temperature: Optional[float] = Field(default=0.6)
|
||||
top_p: Optional[float] = Field(default=1)
|
||||
max_tokens: Optional[int] = Field(default=50)
|
||||
max_completion_tokens: Optional[int] = Field(default=50)
|
||||
temperature: Optional[float] = Field(default=Config().temperature)
|
||||
top_p: Optional[float] = Field(default=Config().top_p)
|
||||
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||
|
||||
def get_tokenizer_messages(self):
|
||||
if isinstance(self.prompt,List):
|
||||
|
|
|
@ -25,19 +25,10 @@ class DataEvaluator:
|
|||
"""
|
||||
# 读取 Parquet 文件
|
||||
# dataset = load_dataset('parquet', data_files=file_path)
|
||||
ds = load_dataset(file_path,"all")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
# print(ds)
|
||||
# # ds_1 = ds['train']
|
||||
# ds_2 = ds['validation']
|
||||
# ds_3 = ds['test']
|
||||
# # 将数据集转换为 Pandas DataFrame
|
||||
# df_test = pd.DataFrame(ds['test'])
|
||||
# df_val = pd.DataFrame(ds['validation'])
|
||||
|
||||
# for _, row in df.iterrows():
|
||||
# self.data.append(row.to_dict())
|
||||
# df = pd.read_parquet(file_path)
|
||||
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||
'dev': 'all/dev-00000-of-00001.parquet',
|
||||
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
|
|
@ -8,12 +8,57 @@ from datasets import load_dataset
|
|||
import os
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import re
|
||||
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
os.environ['https_proxy'] = ''
|
||||
os.environ['http_proxy'] = ''
|
||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
||||
|
||||
|
||||
def extract_final_answer(text):
|
||||
"""
|
||||
提取模型预测的最终选项(如 A/B/C/D)
|
||||
支持自然语言、多行、markdown、高亮、非末尾结论等格式
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# 1. 显式语句匹配(优先)
|
||||
explicit_patterns = [
|
||||
r'Answer:\s*([A-D])\b',
|
||||
r'Correct answer:\s*([A-D])\b',
|
||||
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
|
||||
r'Answer is\s*([A-D])\b',
|
||||
r'Therefore,\s*answer is\s*([A-D])\b',
|
||||
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||
r'Option\s+([A-D])\s+is correct',
|
||||
]
|
||||
for pat in explicit_patterns:
|
||||
match = re.search(pat, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
# 2. markdown 强调 **C**, **C. something**
|
||||
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
|
||||
if markdown_match:
|
||||
return markdown_match[-1].upper()
|
||||
|
||||
# 3. 查找单引号中的 'C' 或 "C"
|
||||
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
|
||||
if quote_match:
|
||||
return quote_match[-1].upper()
|
||||
|
||||
# 4. 倒数几行是否以 "C." 或 "C" 开头
|
||||
lines = text.splitlines()
|
||||
for line in reversed(lines[-5:]):
|
||||
line = line.strip()
|
||||
match = re.match(r'^([A-D])([.\s]|$)', line)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
# 再不行就返回 None
|
||||
return None
|
||||
class DataEvaluator:
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
|
@ -22,8 +67,10 @@ class DataEvaluator:
|
|||
"""
|
||||
从数据文件中加载数据,每条记录对应一个实例
|
||||
"""
|
||||
ds = load_dataset(file_path, "all")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||
'dev': 'all/dev-00000-of-00001.parquet',
|
||||
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
||||
|
@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
|
|||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||
start_total_time = time.time()
|
||||
total_score = 0
|
||||
total_exact_score = 0
|
||||
results = []
|
||||
file_lock = threading.Lock()
|
||||
|
||||
|
@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
|
||||
def worker(index, data_item):
|
||||
nonlocal total_score
|
||||
nonlocal total_exact_score
|
||||
question = data_evaluator.get_prompt(data_item)
|
||||
start_time = time.time()
|
||||
try:
|
||||
|
@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
answer = chr(data_item['answer'] + 65)
|
||||
processed_prediction = data_evaluator.post_processing(prediction)
|
||||
score = data_evaluator.score(processed_prediction, answer)
|
||||
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
|
||||
elapsed_time = time.time() - start_time
|
||||
result_data = {
|
||||
"question_id": index,
|
||||
"answer": answer,
|
||||
"prediction": processed_prediction,
|
||||
"real_prediction": prediction,
|
||||
"full_prediction": prediction,
|
||||
"score": score,
|
||||
"exact_score": exact_score,
|
||||
"time": elapsed_time
|
||||
}
|
||||
# 写入结果时加锁保证线程安全
|
||||
|
@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
if res is not None:
|
||||
results.append(res)
|
||||
total_score += res['score']
|
||||
total_exact_score += res['exact_score']
|
||||
|
||||
total_time = time.time() - start_total_time
|
||||
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
||||
|
@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||
average_score = total_score / len(data_subset) if data_subset else 0
|
||||
log_f.write(f"Average Score: {average_score}\n")
|
||||
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
|
||||
log_f.write(f"Average Exact Score: {average_exact_score}\n")
|
||||
log_f.write('-' * 40 + '\n')
|
||||
|
||||
print(f"Results saved to {result_file}")
|
||||
|
@ -152,4 +206,4 @@ if __name__ == "__main__":
|
|||
data_evaluator = DataEvaluator()
|
||||
data_evaluator.load_data(args.file)
|
||||
|
||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
|
@ -2,23 +2,18 @@ import asyncio
|
|||
import json
|
||||
import sys
|
||||
import aiohttp
|
||||
import random
|
||||
import argparse
|
||||
import yaml
|
||||
import os
|
||||
import time
|
||||
from time import sleep
|
||||
|
||||
decodesz = 128
|
||||
# Server URL (replace with your server URL)
|
||||
SERVER_URL = "http://localhost:10002/v1/chat/completions"
|
||||
bf_list = [1]
|
||||
decodesz_list = [128]
|
||||
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
|
||||
async def fetch_event_stream(session, payload, request_id):
|
||||
prompt_list = [
|
||||
'Please elaborate on modern world history.',
|
||||
'Please introduce Harry Potter.',
|
||||
'I want to learn Python. Please give me some advice.',
|
||||
'Please tell me a joke '
|
||||
]
|
||||
|
||||
|
||||
async def fetch_event_stream(session, payload, request_id, stream):
|
||||
try:
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json'
|
||||
|
@ -31,104 +26,80 @@ async def fetch_event_stream(session, payload, request_id):
|
|||
print(f"Request {request_id}: Error, status {response.status}")
|
||||
return
|
||||
|
||||
output_text = "" # 存储当前 response 的所有 token
|
||||
total_tokens = 0 # 统计总 tokens 数
|
||||
decode_start_time = None # 记录 decode 阶段开始时间
|
||||
decode_end_time = None # 记录 decode 结束时间
|
||||
output_text = ""
|
||||
|
||||
async for line in response.content:
|
||||
try:
|
||||
decoded_line = line.decode("utf-8").strip()
|
||||
if stream:
|
||||
async for line in response.content:
|
||||
try:
|
||||
decoded_line = line.decode("utf-8").strip()
|
||||
if not decoded_line or not decoded_line.startswith("data: "):
|
||||
continue
|
||||
|
||||
# 过滤空行
|
||||
if not decoded_line or not decoded_line.startswith("data: "):
|
||||
continue
|
||||
decoded_line = decoded_line[6:].strip()
|
||||
if not decoded_line:
|
||||
continue
|
||||
|
||||
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
|
||||
response_data = json.loads(decoded_line)
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
# 确保 JSON 数据是合法的
|
||||
if not decoded_line:
|
||||
continue
|
||||
delta = choices[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
|
||||
response_data = json.loads(decoded_line) # 解析 JSON
|
||||
if token:
|
||||
output_text += token
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
# 确保 choices 存在
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
finish_reason = choices[0].get("finish_reason", None)
|
||||
if finish_reason:
|
||||
break
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
|
||||
if token:
|
||||
if decode_start_time is None:
|
||||
decode_start_time = time.time() # 记录 decode 开始时间
|
||||
|
||||
output_text += token # 追加 token
|
||||
sys.stdout.write(token) # 直接输出 token
|
||||
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
|
||||
total_tokens += 1 # 增加 token 计数
|
||||
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
|
||||
|
||||
# 检查是否完成
|
||||
finish_reason = choices[0].get("finish_reason", None)
|
||||
if finish_reason:
|
||||
# print(f"\nRequest {request_id}: Done")
|
||||
break # 结束流式处理
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
||||
except IndexError:
|
||||
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
||||
except Exception as e:
|
||||
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
||||
|
||||
# 计算 decode 速度
|
||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
||||
decode_time = decode_end_time - decode_start_time
|
||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
||||
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
||||
except IndexError:
|
||||
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
||||
except Exception as e:
|
||||
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
||||
else:
|
||||
# 非 stream 模式下,一次性接收完整 json
|
||||
response_data = await response.json()
|
||||
choices = response_data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
print(f"Request {request_id} Output:\n{content}")
|
||||
output_text += content
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nRequest {request_id}: Exception - {e}")
|
||||
|
||||
async def main(prompt_id):
|
||||
async def main(prompt_id, model, stream, max_tokens, temperature, top_p):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"messages": [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": prompt_list[prompt_id]}
|
||||
],
|
||||
"model": "DeepSeek-V3",
|
||||
"stream": True,
|
||||
"max_completion_tokens": 2,
|
||||
# "temperature": 0.3,
|
||||
# "top_p": 1.0,
|
||||
# "max_tokens" : 20,
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["temperature"] = 0.3
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["top_p"] = 1
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["max_tokens"] = 200
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["stream"] = False
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id, stream)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||
parser.add_argument("--question_id", type=int, default=0, required=False)
|
||||
parser.add_argument("--question_id", type=int, default=0)
|
||||
parser.add_argument("--model", type=str, required=True)
|
||||
parser.add_argument("--stream", type=bool, default=True)
|
||||
parser.add_argument("--max_tokens", type=int, default=500)
|
||||
parser.add_argument("--temperature", type=float, default=0.8)
|
||||
parser.add_argument("--top_p", type=float, default=1)
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
output_file = "ktransformer_test_results.txt"
|
||||
asyncio.run(main(args.question_id))
|
||||
SERVER_URL = args.api_url
|
||||
asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))
|
||||
|
|
|
@ -12,6 +12,8 @@ from time import sleep
|
|||
decodesz = 128
|
||||
# Server URL (replace with your server URL)
|
||||
decodesz_list = [128]
|
||||
prefill_speeds = []
|
||||
decode_speeds = []
|
||||
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
|
||||
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
|
||||
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
|
||||
|
@ -43,17 +45,19 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
|
|||
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
||||
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
||||
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
||||
async def fetch_event_stream(session, request_id, prompt):
|
||||
async def fetch_event_stream(session, request_id, prompt, max_tokens, model):
|
||||
try:
|
||||
payload = {
|
||||
"messages": [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"model": "DeepSeek-V3",
|
||||
"model": model,
|
||||
"temperature": 0.3,
|
||||
"top_p": 1.0,
|
||||
"stream": True
|
||||
"stream": True,
|
||||
"return_speed": True,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
headers = {
|
||||
|
@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt):
|
|||
total_tokens = 0
|
||||
decode_start_time = None
|
||||
decode_end_time = None
|
||||
usage_info = None
|
||||
|
||||
async for line in response.content:
|
||||
try:
|
||||
|
@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt):
|
|||
continue
|
||||
|
||||
response_data = json.loads(decoded_line)
|
||||
|
||||
if "usage" in response_data:
|
||||
usage_info = response_data["usage"]
|
||||
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
@ -107,34 +116,48 @@ async def fetch_event_stream(session, request_id, prompt):
|
|||
except Exception as e:
|
||||
print(f"[Request {request_id}] Stream Error: {e}")
|
||||
|
||||
|
||||
if buffer.strip():
|
||||
print(f"[Request {request_id}] {buffer.strip()}")
|
||||
|
||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
||||
decode_time = decode_end_time - decode_start_time
|
||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
||||
print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s")
|
||||
if usage_info:
|
||||
if "prefill_time" in usage_info:
|
||||
# print(f"[Request {request_id}] Usage:")
|
||||
# for key, value in usage_info.items():
|
||||
# print(f" {key}: {value}")
|
||||
prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"]
|
||||
decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"]
|
||||
prefill_speeds.append(prefill_speed)
|
||||
decode_speeds.append(decode_speed)
|
||||
print(f'[Request {request_id}] prefill speed: {prefill_speed}')
|
||||
print(f'[Request {request_id}] decode speed: {decode_speed}')
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Request {request_id}] Exception: {e}")
|
||||
|
||||
async def main(concurrent_requests , prompt ):
|
||||
async def main(concurrent_requests , prompt, max_tokens, model):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)]
|
||||
tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]
|
||||
await asyncio.gather(*tasks)
|
||||
if len(prefill_speeds) != 0:
|
||||
import numpy as np
|
||||
print(f"concurrency: {len(prefill_speeds)}")
|
||||
print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
||||
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
|
||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||
|
||||
args = parser.parse_args()
|
||||
SERVER_URL = args.api_url
|
||||
max_tokens = args.max_tokens
|
||||
model = args.model
|
||||
if args.prompt_lens == 1024:
|
||||
prompt = ktansformer_prompt1024
|
||||
elif args.prompt_lens == 2048:
|
||||
prompt = ktansformer_prompt1024 * 2
|
||||
asyncio.run(main(args.concurrent, prompt))
|
||||
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue