From b17ab8653c20592ed42f253341be7205fcc14c08 Mon Sep 17 00:00:00 2001 From: qiyuxinlin <1668068727@qq.com> Date: Tue, 22 Apr 2025 07:38:05 +0000 Subject: [PATCH 1/3] update speed test --- .../server/api/openai/endpoints/chat.py | 25 +++++++----- .../server/schemas/endpoints/chat.py | 22 ++++++---- .../server/schemas/legacy/completions.py | 10 ++--- ktransformers/tests/test_speed.py | 40 ++++++++++++++----- 4 files changed, 66 insertions(+), 31 deletions(-) diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index a5eb986..ea1e815 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -13,16 +13,10 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.config.config import Config from ktransformers.server.config.log import logger - -from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk +from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage # Define own data structure instead of importing from OpenAI -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_tokens_details: Optional[Dict[str, Any]] = None - completion_tokens_details: Optional[Dict[str, Any]] = None + class Choice(BaseModel): index: int @@ -217,6 +211,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): completion_tokens=raw_usage.decode_count, total_tokens=raw_usage.prefill_count + raw_usage.decode_count ) + if create.return_speed: + chunk.usage.prefill_time = res.prefill_time + chunk.usage.decode_time = res.decode_time + else: + chunk.usage.__dict__.pop('prefill_time', None) + chunk.usage.__dict__.pop('decode_time', None) yield chunk elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res @@ -377,8 +377,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): usage = CompletionUsage( prompt_tokens=raw_usage.prefill_count, completion_tokens=raw_usage.decode_count, - total_tokens=raw_usage.prefill_count + raw_usage.decode_count + total_tokens=raw_usage.prefill_count + raw_usage.decode_count, ) + if create.return_speed: + usage.prefill_time = res.prefill_time + usage.decode_time = res.decode_time + else: + usage.__dict__.pop('prefill_time', None) + usage.__dict__.pop('decode_time', None) + elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) diff --git a/ktransformers/server/schemas/endpoints/chat.py b/ktransformers/server/schemas/endpoints/chat.py index 643c81c..d37e342 100644 --- a/ktransformers/server/schemas/endpoints/chat.py +++ b/ktransformers/server/schemas/endpoints/chat.py @@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any from typing_extensions import Literal from enum import Enum from pydantic import BaseModel, Field - +from ktransformers.server.config.config import Config from ktransformers.server.schemas.base import Object -from openai.types.completion_usage import CompletionUsage + from openai.types.chat.chat_completion_chunk import Choice from uuid import uuid4 +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_tokens_details: Optional[Dict[str, Any]] = None + completion_tokens_details: Optional[Dict[str, Any]] = None + prefill_time: Optional[float] = None + decode_time: Optional[float] = None class Role(Enum): system = 'system' @@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel): messages: List[Message] model: str stream: bool = False - temperature: Optional[float] = Field(default=0.6) - top_p: Optional[float] = Field(default=1.0) + temperature: Optional[float] = Field(default=Config().temperature) + top_p: Optional[float] = Field(default=Config().top_p) tools: Optional[List[Tool]] = None tool_choice: Optional[Union[str, Dict[str, Any]]] = None stream_options: Optional[Dict[str, Any]] = None frequency_penalty: float = 0 presence_penalty: float = 0 - max_tokens: Optional[int] = Field(default=50) - max_completion_tokens: Optional[int] = Field(default=50) - + max_tokens: Optional[int] = Field(default=Config().max_new_tokens) + max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens) + return_speed: Optional[bool] = Field(default=False) def get_tokenizer_messages(self): return [m.to_tokenizer_message() for m in self.messages] diff --git a/ktransformers/server/schemas/legacy/completions.py b/ktransformers/server/schemas/legacy/completions.py index 2d83212..a728cb1 100644 --- a/ktransformers/server/schemas/legacy/completions.py +++ b/ktransformers/server/schemas/legacy/completions.py @@ -1,17 +1,17 @@ from typing import List, Optional from enum import Enum from pydantic import BaseModel, Field - +from ktransformers.server.config.config import Config from ..base import Object class CompletionCreate(BaseModel): model: str prompt: str | List[str] stream: bool = False - temperature: Optional[float] = Field(default=0.6) - top_p: Optional[float] = Field(default=1) - max_tokens: Optional[int] = Field(default=50) - max_completion_tokens: Optional[int] = Field(default=50) + temperature: Optional[float] = Field(default=Config().temperature) + top_p: Optional[float] = Field(default=Config().top_p) + max_tokens: Optional[int] = Field(default=Config().max_new_tokens) + max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens) def get_tokenizer_messages(self): if isinstance(self.prompt,List): diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py index dbdf999..3e7f849 100644 --- a/ktransformers/tests/test_speed.py +++ b/ktransformers/tests/test_speed.py @@ -12,6 +12,8 @@ from time import sleep decodesz = 128 # Server URL (replace with your server URL) decodesz_list = [128] +prefill_speeds = [] +decode_speeds = [] ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. @@ -43,7 +45,7 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. Dursley always sat with his back to the window in his office on the ninth floor.""" -async def fetch_event_stream(session, request_id, prompt): +async def fetch_event_stream(session, request_id, prompt, max_tokens): try: payload = { "messages": [ @@ -53,7 +55,9 @@ async def fetch_event_stream(session, request_id, prompt): "model": "DeepSeek-V3", "temperature": 0.3, "top_p": 1.0, - "stream": True + "stream": True, + "return_speed": True, + "max_tokens": max_tokens, } headers = { @@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt): total_tokens = 0 decode_start_time = None decode_end_time = None + usage_info = None async for line in response.content: try: @@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt): continue response_data = json.loads(decoded_line) + + if "usage" in response_data: + usage_info = response_data["usage"] + choices = response_data.get("choices", []) if not choices: continue @@ -107,34 +116,45 @@ async def fetch_event_stream(session, request_id, prompt): except Exception as e: print(f"[Request {request_id}] Stream Error: {e}") - if buffer.strip(): print(f"[Request {request_id}] {buffer.strip()}") - if decode_start_time and decode_end_time and total_tokens > 0: - decode_time = decode_end_time - decode_start_time - decode_speed = total_tokens / decode_time if decode_time > 0 else 0 - print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s") + if usage_info: + if "prefill_time" in usage_info: + # print(f"[Request {request_id}] Usage:") + # for key, value in usage_info.items(): + # print(f" {key}: {value}") + prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"] + decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"] + prefill_speeds.append(prefill_speed) + decode_speeds.append(decode_speed) + print(f'[Request {request_id}] prefill speed: {prefill_speed}') + print(f'[Request {request_id}] decode speed: {decode_speed}') except Exception as e: print(f"[Request {request_id}] Exception: {e}") -async def main(concurrent_requests , prompt ): +async def main(concurrent_requests , prompt, max_tokens): async with aiohttp.ClientSession() as session: - tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)] + tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)] await asyncio.gather(*tasks) + if len(prefill_speeds) != 0: + import numpy as np + print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") + parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") args = parser.parse_args() SERVER_URL = args.api_url + max_tokens = args.max_tokens if args.prompt_lens == 1024: prompt = ktansformer_prompt1024 elif args.prompt_lens == 2048: prompt = ktansformer_prompt1024 * 2 - asyncio.run(main(args.concurrent, prompt)) + asyncio.run(main(args.concurrent, prompt, max_tokens)) From 4f9950e30c73b3f39f9d533aa71cbd88ce9c8acd Mon Sep 17 00:00:00 2001 From: qiyuxinlin <1668068727@qq.com> Date: Tue, 22 Apr 2025 09:25:44 +0000 Subject: [PATCH 2/3] kill serve lead to kill sched and engine --- .../backend/interfaces/balance_serve.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 74c680d..008431e 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -46,6 +46,8 @@ import pickle import subprocess import tempfile import atexit +import signal + ktransformer_rules_dir = ( os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") @@ -55,6 +57,7 @@ default_optimize_rules = { "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml", } + async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer): streamer = TextStreamer(tokenizer) while True: @@ -293,10 +296,6 @@ class BalanceServeInterface(BackendInterfaceBase): kvcache_event.wait() - def cleanup(): - if sched_process.poll() is None: - sched_process.terminate() - with tempfile.NamedTemporaryFile(delete=False) as temp_file: pickle.dump(args, temp_file) temp_file_path = temp_file.name @@ -311,7 +310,27 @@ class BalanceServeInterface(BackendInterfaceBase): stderr=log ) print("sched_rpc started with PID:", sched_process.pid) - atexit.register(cleanup) + + def signal_handler(signum, frame): + print(f"Received signal {signum}, shutting down...") + cleanup() + os._exit(0) + + def cleanup(): + print("Cleaning up...") + + for p in processes: + if p.is_alive(): + print(f"Terminating subprocess {p.pid}") + p.terminate() + p.join() + + if sched_process and sched_process.poll() is None: + print(f"Terminating sched_process {sched_process.pid}") + sched_process.terminate() + sched_process.wait() + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) start_event.wait() From 3a044e6b1418df00ef8352e2dfe037cda32e7e57 Mon Sep 17 00:00:00 2001 From: qiyuxinlin <1668068727@qq.com> Date: Tue, 22 Apr 2025 12:50:39 +0000 Subject: [PATCH 3/3] change test --- ktransformers/tests/mmlu_test.py | 17 +-- ktransformers/tests/mmlu_test_multi.py | 62 +++++++++- ktransformers/tests/test_client.py | 155 ++++++++++--------------- ktransformers/tests/test_speed.py | 15 ++- 4 files changed, 134 insertions(+), 115 deletions(-) diff --git a/ktransformers/tests/mmlu_test.py b/ktransformers/tests/mmlu_test.py index 452cbbf..36baada 100644 --- a/ktransformers/tests/mmlu_test.py +++ b/ktransformers/tests/mmlu_test.py @@ -25,19 +25,10 @@ class DataEvaluator: """ # 读取 Parquet 文件 # dataset = load_dataset('parquet', data_files=file_path) - ds = load_dataset(file_path,"all") - df = pd.DataFrame(ds['test']) - # print(ds) - # # ds_1 = ds['train'] - # ds_2 = ds['validation'] - # ds_3 = ds['test'] - # # 将数据集转换为 Pandas DataFrame - # df_test = pd.DataFrame(ds['test']) - # df_val = pd.DataFrame(ds['validation']) - - # for _, row in df.iterrows(): - # self.data.append(row.to_dict()) - # df = pd.read_parquet(file_path) + splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet', + 'dev': 'all/dev-00000-of-00001.parquet', + 'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'} + df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"]) for _, row in df.iterrows(): self.data.append(row.to_dict()) diff --git a/ktransformers/tests/mmlu_test_multi.py b/ktransformers/tests/mmlu_test_multi.py index 06c75ab..9033afe 100644 --- a/ktransformers/tests/mmlu_test_multi.py +++ b/ktransformers/tests/mmlu_test_multi.py @@ -8,12 +8,57 @@ from datasets import load_dataset import os import concurrent.futures import threading +import re os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' os.environ['https_proxy'] = '' os.environ['http_proxy'] = '' hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.' + +def extract_final_answer(text): + """ + 提取模型预测的最终选项(如 A/B/C/D) + 支持自然语言、多行、markdown、高亮、非末尾结论等格式 + """ + text = text.strip() + + # 1. 显式语句匹配(优先) + explicit_patterns = [ + r'Answer:\s*([A-D])\b', + r'Correct answer:\s*([A-D])\b', + r'The correct answer is\s*\*?\*?\s*([A-D])\b', + r'Answer is\s*([A-D])\b', + r'Therefore,\s*answer is\s*([A-D])\b', + r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b', + r'The answer should be\s*(?:Option\s*)?([A-D])\b', + r'Option\s+([A-D])\s+is correct', + ] + for pat in explicit_patterns: + match = re.search(pat, text, re.IGNORECASE) + if match: + return match.group(1).upper() + + # 2. markdown 强调 **C**, **C. something** + markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text) + if markdown_match: + return markdown_match[-1].upper() + + # 3. 查找单引号中的 'C' 或 "C" + quote_match = re.findall(r"['\"]([A-D])['\"]", text) + if quote_match: + return quote_match[-1].upper() + + # 4. 倒数几行是否以 "C." 或 "C" 开头 + lines = text.splitlines() + for line in reversed(lines[-5:]): + line = line.strip() + match = re.match(r'^([A-D])([.\s]|$)', line) + if match: + return match.group(1).upper() + + # 再不行就返回 None + return None class DataEvaluator: def __init__(self): self.data = [] @@ -22,8 +67,10 @@ class DataEvaluator: """ 从数据文件中加载数据,每条记录对应一个实例 """ - ds = load_dataset(file_path, "all") - df = pd.DataFrame(ds['test']) + splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet', + 'dev': 'all/dev-00000-of-00001.parquet', + 'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'} + df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"]) for _, row in df.iterrows(): self.data.append(row.to_dict()) @@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False): def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name): start_total_time = time.time() total_score = 0 + total_exact_score = 0 results = [] file_lock = threading.Lock() @@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi def worker(index, data_item): nonlocal total_score + nonlocal total_exact_score question = data_evaluator.get_prompt(data_item) start_time = time.time() try: @@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi answer = chr(data_item['answer'] + 65) processed_prediction = data_evaluator.post_processing(prediction) score = data_evaluator.score(processed_prediction, answer) + exact_score = data_evaluator.score(extract_final_answer(prediction), answer) elapsed_time = time.time() - start_time result_data = { "question_id": index, "answer": answer, "prediction": processed_prediction, - "real_prediction": prediction, + "full_prediction": prediction, "score": score, + "exact_score": exact_score, "time": elapsed_time } # 写入结果时加锁保证线程安全 @@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi if res is not None: results.append(res) total_score += res['score'] + total_exact_score += res['exact_score'] total_time = time.time() - start_total_time throughput = len(data_subset) / total_time if total_time > 0 else 0 @@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi log_f.write(f"Throughput: {throughput:.2f} requests per second\n") average_score = total_score / len(data_subset) if data_subset else 0 log_f.write(f"Average Score: {average_score}\n") + average_exact_score = total_exact_score / len(data_subset) if data_subset else 0 + log_f.write(f"Average Exact Score: {average_exact_score}\n") log_f.write('-' * 40 + '\n') print(f"Results saved to {result_file}") @@ -152,4 +206,4 @@ if __name__ == "__main__": data_evaluator = DataEvaluator() data_evaluator.load_data(args.file) - main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) + main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model) \ No newline at end of file diff --git a/ktransformers/tests/test_client.py b/ktransformers/tests/test_client.py index 1f6b684..4ad560b 100644 --- a/ktransformers/tests/test_client.py +++ b/ktransformers/tests/test_client.py @@ -2,23 +2,18 @@ import asyncio import json import sys import aiohttp -import random import argparse -import yaml -import os -import time -from time import sleep -decodesz = 128 -# Server URL (replace with your server URL) -SERVER_URL = "http://localhost:10002/v1/chat/completions" -bf_list = [1] -decodesz_list = [128] -prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke '] -async def fetch_event_stream(session, payload, request_id): +prompt_list = [ + 'Please elaborate on modern world history.', + 'Please introduce Harry Potter.', + 'I want to learn Python. Please give me some advice.', + 'Please tell me a joke ' +] + + +async def fetch_event_stream(session, payload, request_id, stream): try: - - headers = { 'accept': 'application/json', 'Content-Type': 'application/json' @@ -31,104 +26,80 @@ async def fetch_event_stream(session, payload, request_id): print(f"Request {request_id}: Error, status {response.status}") return - output_text = "" # 存储当前 response 的所有 token - total_tokens = 0 # 统计总 tokens 数 - decode_start_time = None # 记录 decode 阶段开始时间 - decode_end_time = None # 记录 decode 结束时间 + output_text = "" - async for line in response.content: - try: - decoded_line = line.decode("utf-8").strip() + if stream: + async for line in response.content: + try: + decoded_line = line.decode("utf-8").strip() + if not decoded_line or not decoded_line.startswith("data: "): + continue - # 过滤空行 - if not decoded_line or not decoded_line.startswith("data: "): - continue + decoded_line = decoded_line[6:].strip() + if not decoded_line: + continue - decoded_line = decoded_line[6:].strip() # 去掉 `data: ` + response_data = json.loads(decoded_line) + choices = response_data.get("choices", []) + if not choices: + continue - # 确保 JSON 数据是合法的 - if not decoded_line: - continue + delta = choices[0].get("delta", {}) + token = delta.get("content", "") - response_data = json.loads(decoded_line) # 解析 JSON + if token: + output_text += token + sys.stdout.write(token) + sys.stdout.flush() - # 确保 choices 存在 - choices = response_data.get("choices", []) - if not choices: - continue + finish_reason = choices[0].get("finish_reason", None) + if finish_reason: + break - delta = choices[0].get("delta", {}) - token = delta.get("content", "") - - if token: - if decode_start_time is None: - decode_start_time = time.time() # 记录 decode 开始时间 - - output_text += token # 追加 token - sys.stdout.write(token) # 直接输出 token - sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端 - total_tokens += 1 # 增加 token 计数 - decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间 - - # 检查是否完成 - finish_reason = choices[0].get("finish_reason", None) - if finish_reason: - # print(f"\nRequest {request_id}: Done") - break # 结束流式处理 - - except json.JSONDecodeError as e: - print(f"\nRequest {request_id}: JSON Decode Error - {e}") - except IndexError: - print(f"\nRequest {request_id}: List Index Error - choices is empty") - except Exception as e: - print(f"\nRequest {request_id}: Error parsing stream - {e}") - - # 计算 decode 速度 - if decode_start_time and decode_end_time and total_tokens > 0: - decode_time = decode_end_time - decode_start_time - decode_speed = total_tokens / decode_time if decode_time > 0 else 0 - # print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s") + except json.JSONDecodeError as e: + print(f"\nRequest {request_id}: JSON Decode Error - {e}") + except IndexError: + print(f"\nRequest {request_id}: List Index Error - choices is empty") + except Exception as e: + print(f"\nRequest {request_id}: Error parsing stream - {e}") + else: + # 非 stream 模式下,一次性接收完整 json + response_data = await response.json() + choices = response_data.get("choices", []) + if choices: + content = choices[0].get("message", {}).get("content", "") + print(f"Request {request_id} Output:\n{content}") + output_text += content except Exception as e: print(f"\nRequest {request_id}: Exception - {e}") -async def main(prompt_id): +async def main(prompt_id, model, stream, max_tokens, temperature, top_p): async with aiohttp.ClientSession() as session: payload = { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": prompt_list[prompt_id]} ], - "model": "DeepSeek-V3", - "stream": True, - "max_completion_tokens": 2, - # "temperature": 0.3, - # "top_p": 1.0, - # "max_tokens" : 20, + "model": model, + "stream": stream, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p } - tasks = [fetch_event_stream(session, payload, prompt_id)] - await asyncio.gather(*tasks) - - payload["temperature"] = 0.3 - tasks = [fetch_event_stream(session, payload, prompt_id)] - await asyncio.gather(*tasks) - - payload["top_p"] = 1 - tasks = [fetch_event_stream(session, payload, prompt_id)] - await asyncio.gather(*tasks) - - payload["max_tokens"] = 200 - tasks = [fetch_event_stream(session, payload, prompt_id)] - await asyncio.gather(*tasks) - - payload["stream"] = False - tasks = [fetch_event_stream(session, payload, prompt_id)] + tasks = [fetch_event_stream(session, payload, prompt_id, stream)] await asyncio.gather(*tasks) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Event Stream Request Tester") - parser.add_argument("--question_id", type=int, default=0, required=False) + parser.add_argument("--question_id", type=int, default=0) + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--stream", type=bool, default=True) + parser.add_argument("--max_tokens", type=int, default=500) + parser.add_argument("--temperature", type=float, default=0.8) + parser.add_argument("--top_p", type=float, default=1) + parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL") + args = parser.parse_args() - output_file = "ktransformer_test_results.txt" - asyncio.run(main(args.question_id)) + SERVER_URL = args.api_url + asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p)) diff --git a/ktransformers/tests/test_speed.py b/ktransformers/tests/test_speed.py index 3e7f849..8b552b5 100644 --- a/ktransformers/tests/test_speed.py +++ b/ktransformers/tests/test_speed.py @@ -45,14 +45,14 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills. Mr. Dursley always sat with his back to the window in his office on the ninth floor.""" -async def fetch_event_stream(session, request_id, prompt, max_tokens): +async def fetch_event_stream(session, request_id, prompt, max_tokens, model): try: payload = { "messages": [ {"role": "system", "content": ""}, {"role": "user", "content": prompt} ], - "model": "DeepSeek-V3", + "model": model, "temperature": 0.3, "top_p": 1.0, "stream": True, @@ -134,17 +134,19 @@ async def fetch_event_stream(session, request_id, prompt, max_tokens): except Exception as e: print(f"[Request {request_id}] Exception: {e}") -async def main(concurrent_requests , prompt, max_tokens): +async def main(concurrent_requests , prompt, max_tokens, model): async with aiohttp.ClientSession() as session: - tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)] + tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)] await asyncio.gather(*tasks) if len(prefill_speeds) != 0: import numpy as np - print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}") + print(f"concurrency: {len(prefill_speeds)}") + print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Event Stream Request Tester") parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests") + parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True) parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048") parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL") parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens") @@ -152,9 +154,10 @@ if __name__ == "__main__": args = parser.parse_args() SERVER_URL = args.api_url max_tokens = args.max_tokens + model = args.model if args.prompt_lens == 1024: prompt = ktansformer_prompt1024 elif args.prompt_lens == 2048: prompt = ktansformer_prompt1024 * 2 - asyncio.run(main(args.concurrent, prompt, max_tokens)) + asyncio.run(main(args.concurrent, prompt, max_tokens, model))