mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Merge remote-tracking branch 'origin/main' into check-para
This commit is contained in:
commit
f7d939313b
8 changed files with 219 additions and 145 deletions
|
@ -14,15 +14,10 @@ from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.config.log import logger
|
from ktransformers.server.config.log import logger
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
|
||||||
|
|
||||||
# Define own data structure instead of importing from OpenAI
|
# Define own data structure instead of importing from OpenAI
|
||||||
class CompletionUsage(BaseModel):
|
|
||||||
prompt_tokens: int
|
|
||||||
completion_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
|
||||||
completion_tokens_details: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class Choice(BaseModel):
|
class Choice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
|
@ -267,6 +262,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
completion_tokens=raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||||
)
|
)
|
||||||
|
if create.return_speed:
|
||||||
|
chunk.usage.prefill_time = res.prefill_time
|
||||||
|
chunk.usage.decode_time = res.decode_time
|
||||||
|
else:
|
||||||
|
chunk.usage.__dict__.pop('prefill_time', None)
|
||||||
|
chunk.usage.__dict__.pop('decode_time', None)
|
||||||
yield chunk
|
yield chunk
|
||||||
elif isinstance(res, tuple) and len(res) == 2:
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
|
@ -427,8 +428,15 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
usage = CompletionUsage(
|
usage = CompletionUsage(
|
||||||
prompt_tokens=raw_usage.prefill_count,
|
prompt_tokens=raw_usage.prefill_count,
|
||||||
completion_tokens=raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count,
|
||||||
)
|
)
|
||||||
|
if create.return_speed:
|
||||||
|
usage.prefill_time = res.prefill_time
|
||||||
|
usage.decode_time = res.decode_time
|
||||||
|
else:
|
||||||
|
usage.__dict__.pop('prefill_time', None)
|
||||||
|
usage.__dict__.pop('decode_time', None)
|
||||||
|
|
||||||
elif isinstance(res, tuple) and len(res) == 2:
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
||||||
|
|
|
@ -46,6 +46,8 @@ import pickle
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import atexit
|
import atexit
|
||||||
|
import signal
|
||||||
|
|
||||||
|
|
||||||
ktransformer_rules_dir = (
|
ktransformer_rules_dir = (
|
||||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||||
|
@ -55,6 +57,7 @@ default_optimize_rules = {
|
||||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
|
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
|
||||||
streamer = TextStreamer(tokenizer)
|
streamer = TextStreamer(tokenizer)
|
||||||
while True:
|
while True:
|
||||||
|
@ -293,10 +296,6 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
kvcache_event.wait()
|
kvcache_event.wait()
|
||||||
|
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
if sched_process.poll() is None:
|
|
||||||
sched_process.terminate()
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
pickle.dump(args, temp_file)
|
pickle.dump(args, temp_file)
|
||||||
temp_file_path = temp_file.name
|
temp_file_path = temp_file.name
|
||||||
|
@ -311,7 +310,27 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
stderr=log
|
stderr=log
|
||||||
)
|
)
|
||||||
print("sched_rpc started with PID:", sched_process.pid)
|
print("sched_rpc started with PID:", sched_process.pid)
|
||||||
atexit.register(cleanup)
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
print(f"Received signal {signum}, shutting down...")
|
||||||
|
cleanup()
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
print("Cleaning up...")
|
||||||
|
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
print(f"Terminating subprocess {p.pid}")
|
||||||
|
p.terminate()
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
if sched_process and sched_process.poll() is None:
|
||||||
|
print(f"Terminating sched_process {sched_process.pid}")
|
||||||
|
sched_process.terminate()
|
||||||
|
sched_process.wait()
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
start_event.wait()
|
start_event.wait()
|
||||||
|
|
||||||
|
|
|
@ -2,14 +2,22 @@ from typing import List, Optional, Union, Dict, Any
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
from ktransformers.server.schemas.base import Object
|
from ktransformers.server.schemas.base import Object
|
||||||
|
|
||||||
from openai.types.completion_usage import CompletionUsage
|
|
||||||
from openai.types.chat.chat_completion_chunk import Choice
|
from openai.types.chat.chat_completion_chunk import Choice
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
completion_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
prefill_time: Optional[float] = None
|
||||||
|
decode_time: Optional[float] = None
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
system = 'system'
|
system = 'system'
|
||||||
|
@ -58,16 +66,16 @@ class ChatCompletionCreate(BaseModel):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model: str
|
model: str
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=0.6)
|
temperature: Optional[float] = Field(default=Config().temperature)
|
||||||
top_p: Optional[float] = Field(default=1.0)
|
top_p: Optional[float] = Field(default=Config().top_p)
|
||||||
tools: Optional[List[Tool]] = None
|
tools: Optional[List[Tool]] = None
|
||||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
stream_options: Optional[Dict[str, Any]] = None
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
frequency_penalty: float = 0
|
frequency_penalty: float = 0
|
||||||
presence_penalty: float = 0
|
presence_penalty: float = 0
|
||||||
max_tokens: Optional[int] = Field(default=50)
|
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
max_completion_tokens: Optional[int] = Field(default=50)
|
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
|
return_speed: Optional[bool] = Field(default=False)
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from ktransformers.server.config.config import Config
|
||||||
from ..base import Object
|
from ..base import Object
|
||||||
|
|
||||||
class CompletionCreate(BaseModel):
|
class CompletionCreate(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
prompt: str | List[str]
|
prompt: str | List[str]
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=0.6)
|
temperature: Optional[float] = Field(default=Config().temperature)
|
||||||
top_p: Optional[float] = Field(default=1)
|
top_p: Optional[float] = Field(default=Config().top_p)
|
||||||
max_tokens: Optional[int] = Field(default=50)
|
max_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
max_completion_tokens: Optional[int] = Field(default=50)
|
max_completion_tokens: Optional[int] = Field(default=Config().max_new_tokens)
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
if isinstance(self.prompt,List):
|
if isinstance(self.prompt,List):
|
||||||
|
|
|
@ -25,19 +25,10 @@ class DataEvaluator:
|
||||||
"""
|
"""
|
||||||
# 读取 Parquet 文件
|
# 读取 Parquet 文件
|
||||||
# dataset = load_dataset('parquet', data_files=file_path)
|
# dataset = load_dataset('parquet', data_files=file_path)
|
||||||
ds = load_dataset(file_path,"all")
|
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||||
df = pd.DataFrame(ds['test'])
|
'dev': 'all/dev-00000-of-00001.parquet',
|
||||||
# print(ds)
|
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||||
# # ds_1 = ds['train']
|
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||||
# ds_2 = ds['validation']
|
|
||||||
# ds_3 = ds['test']
|
|
||||||
# # 将数据集转换为 Pandas DataFrame
|
|
||||||
# df_test = pd.DataFrame(ds['test'])
|
|
||||||
# df_val = pd.DataFrame(ds['validation'])
|
|
||||||
|
|
||||||
# for _, row in df.iterrows():
|
|
||||||
# self.data.append(row.to_dict())
|
|
||||||
# df = pd.read_parquet(file_path)
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
self.data.append(row.to_dict())
|
self.data.append(row.to_dict())
|
||||||
|
|
|
@ -8,12 +8,57 @@ from datasets import load_dataset
|
||||||
import os
|
import os
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import threading
|
import threading
|
||||||
|
import re
|
||||||
|
|
||||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||||
os.environ['https_proxy'] = ''
|
os.environ['https_proxy'] = ''
|
||||||
os.environ['http_proxy'] = ''
|
os.environ['http_proxy'] = ''
|
||||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
||||||
|
|
||||||
|
|
||||||
|
def extract_final_answer(text):
|
||||||
|
"""
|
||||||
|
提取模型预测的最终选项(如 A/B/C/D)
|
||||||
|
支持自然语言、多行、markdown、高亮、非末尾结论等格式
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# 1. 显式语句匹配(优先)
|
||||||
|
explicit_patterns = [
|
||||||
|
r'Answer:\s*([A-D])\b',
|
||||||
|
r'Correct answer:\s*([A-D])\b',
|
||||||
|
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
|
||||||
|
r'Answer is\s*([A-D])\b',
|
||||||
|
r'Therefore,\s*answer is\s*([A-D])\b',
|
||||||
|
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||||
|
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||||
|
r'Option\s+([A-D])\s+is correct',
|
||||||
|
]
|
||||||
|
for pat in explicit_patterns:
|
||||||
|
match = re.search(pat, text, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).upper()
|
||||||
|
|
||||||
|
# 2. markdown 强调 **C**, **C. something**
|
||||||
|
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
|
||||||
|
if markdown_match:
|
||||||
|
return markdown_match[-1].upper()
|
||||||
|
|
||||||
|
# 3. 查找单引号中的 'C' 或 "C"
|
||||||
|
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
|
||||||
|
if quote_match:
|
||||||
|
return quote_match[-1].upper()
|
||||||
|
|
||||||
|
# 4. 倒数几行是否以 "C." 或 "C" 开头
|
||||||
|
lines = text.splitlines()
|
||||||
|
for line in reversed(lines[-5:]):
|
||||||
|
line = line.strip()
|
||||||
|
match = re.match(r'^([A-D])([.\s]|$)', line)
|
||||||
|
if match:
|
||||||
|
return match.group(1).upper()
|
||||||
|
|
||||||
|
# 再不行就返回 None
|
||||||
|
return None
|
||||||
class DataEvaluator:
|
class DataEvaluator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.data = []
|
self.data = []
|
||||||
|
@ -22,8 +67,10 @@ class DataEvaluator:
|
||||||
"""
|
"""
|
||||||
从数据文件中加载数据,每条记录对应一个实例
|
从数据文件中加载数据,每条记录对应一个实例
|
||||||
"""
|
"""
|
||||||
ds = load_dataset(file_path, "all")
|
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||||
df = pd.DataFrame(ds['test'])
|
'dev': 'all/dev-00000-of-00001.parquet',
|
||||||
|
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||||
|
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
self.data.append(row.to_dict())
|
self.data.append(row.to_dict())
|
||||||
|
|
||||||
|
@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
|
||||||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||||
start_total_time = time.time()
|
start_total_time = time.time()
|
||||||
total_score = 0
|
total_score = 0
|
||||||
|
total_exact_score = 0
|
||||||
results = []
|
results = []
|
||||||
file_lock = threading.Lock()
|
file_lock = threading.Lock()
|
||||||
|
|
||||||
|
@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
|
|
||||||
def worker(index, data_item):
|
def worker(index, data_item):
|
||||||
nonlocal total_score
|
nonlocal total_score
|
||||||
|
nonlocal total_exact_score
|
||||||
question = data_evaluator.get_prompt(data_item)
|
question = data_evaluator.get_prompt(data_item)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
|
@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
answer = chr(data_item['answer'] + 65)
|
answer = chr(data_item['answer'] + 65)
|
||||||
processed_prediction = data_evaluator.post_processing(prediction)
|
processed_prediction = data_evaluator.post_processing(prediction)
|
||||||
score = data_evaluator.score(processed_prediction, answer)
|
score = data_evaluator.score(processed_prediction, answer)
|
||||||
|
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
result_data = {
|
result_data = {
|
||||||
"question_id": index,
|
"question_id": index,
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"prediction": processed_prediction,
|
"prediction": processed_prediction,
|
||||||
"real_prediction": prediction,
|
"full_prediction": prediction,
|
||||||
"score": score,
|
"score": score,
|
||||||
|
"exact_score": exact_score,
|
||||||
"time": elapsed_time
|
"time": elapsed_time
|
||||||
}
|
}
|
||||||
# 写入结果时加锁保证线程安全
|
# 写入结果时加锁保证线程安全
|
||||||
|
@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
if res is not None:
|
if res is not None:
|
||||||
results.append(res)
|
results.append(res)
|
||||||
total_score += res['score']
|
total_score += res['score']
|
||||||
|
total_exact_score += res['exact_score']
|
||||||
|
|
||||||
total_time = time.time() - start_total_time
|
total_time = time.time() - start_total_time
|
||||||
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
||||||
|
@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
||||||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||||
average_score = total_score / len(data_subset) if data_subset else 0
|
average_score = total_score / len(data_subset) if data_subset else 0
|
||||||
log_f.write(f"Average Score: {average_score}\n")
|
log_f.write(f"Average Score: {average_score}\n")
|
||||||
|
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
|
||||||
|
log_f.write(f"Average Exact Score: {average_exact_score}\n")
|
||||||
log_f.write('-' * 40 + '\n')
|
log_f.write('-' * 40 + '\n')
|
||||||
|
|
||||||
print(f"Results saved to {result_file}")
|
print(f"Results saved to {result_file}")
|
||||||
|
|
|
@ -2,23 +2,18 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import random
|
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
decodesz = 128
|
prompt_list = [
|
||||||
# Server URL (replace with your server URL)
|
'Please elaborate on modern world history.',
|
||||||
SERVER_URL = "http://localhost:10002/v1/chat/completions"
|
'Please introduce Harry Potter.',
|
||||||
bf_list = [1]
|
'I want to learn Python. Please give me some advice.',
|
||||||
decodesz_list = [128]
|
'Please tell me a joke '
|
||||||
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
|
]
|
||||||
async def fetch_event_stream(session, payload, request_id):
|
|
||||||
|
|
||||||
|
async def fetch_event_stream(session, payload, request_id, stream):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
|
@ -31,28 +26,20 @@ async def fetch_event_stream(session, payload, request_id):
|
||||||
print(f"Request {request_id}: Error, status {response.status}")
|
print(f"Request {request_id}: Error, status {response.status}")
|
||||||
return
|
return
|
||||||
|
|
||||||
output_text = "" # 存储当前 response 的所有 token
|
output_text = ""
|
||||||
total_tokens = 0 # 统计总 tokens 数
|
|
||||||
decode_start_time = None # 记录 decode 阶段开始时间
|
|
||||||
decode_end_time = None # 记录 decode 结束时间
|
|
||||||
|
|
||||||
|
if stream:
|
||||||
async for line in response.content:
|
async for line in response.content:
|
||||||
try:
|
try:
|
||||||
decoded_line = line.decode("utf-8").strip()
|
decoded_line = line.decode("utf-8").strip()
|
||||||
|
|
||||||
# 过滤空行
|
|
||||||
if not decoded_line or not decoded_line.startswith("data: "):
|
if not decoded_line or not decoded_line.startswith("data: "):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
|
decoded_line = decoded_line[6:].strip()
|
||||||
|
|
||||||
# 确保 JSON 数据是合法的
|
|
||||||
if not decoded_line:
|
if not decoded_line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_data = json.loads(decoded_line) # 解析 JSON
|
response_data = json.loads(decoded_line)
|
||||||
|
|
||||||
# 确保 choices 存在
|
|
||||||
choices = response_data.get("choices", [])
|
choices = response_data.get("choices", [])
|
||||||
if not choices:
|
if not choices:
|
||||||
continue
|
continue
|
||||||
|
@ -61,20 +48,13 @@ async def fetch_event_stream(session, payload, request_id):
|
||||||
token = delta.get("content", "")
|
token = delta.get("content", "")
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
if decode_start_time is None:
|
output_text += token
|
||||||
decode_start_time = time.time() # 记录 decode 开始时间
|
sys.stdout.write(token)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
output_text += token # 追加 token
|
|
||||||
sys.stdout.write(token) # 直接输出 token
|
|
||||||
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
|
|
||||||
total_tokens += 1 # 增加 token 计数
|
|
||||||
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
|
|
||||||
|
|
||||||
# 检查是否完成
|
|
||||||
finish_reason = choices[0].get("finish_reason", None)
|
finish_reason = choices[0].get("finish_reason", None)
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
# print(f"\nRequest {request_id}: Done")
|
break
|
||||||
break # 结束流式处理
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
||||||
|
@ -82,53 +62,44 @@ async def fetch_event_stream(session, payload, request_id):
|
||||||
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
||||||
|
else:
|
||||||
# 计算 decode 速度
|
# 非 stream 模式下,一次性接收完整 json
|
||||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
response_data = await response.json()
|
||||||
decode_time = decode_end_time - decode_start_time
|
choices = response_data.get("choices", [])
|
||||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
if choices:
|
||||||
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
|
content = choices[0].get("message", {}).get("content", "")
|
||||||
|
print(f"Request {request_id} Output:\n{content}")
|
||||||
|
output_text += content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nRequest {request_id}: Exception - {e}")
|
print(f"\nRequest {request_id}: Exception - {e}")
|
||||||
|
|
||||||
async def main(prompt_id):
|
async def main(prompt_id, model, stream, max_tokens, temperature, top_p):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": ""},
|
{"role": "system", "content": ""},
|
||||||
{"role": "user", "content": prompt_list[prompt_id]}
|
{"role": "user", "content": prompt_list[prompt_id]}
|
||||||
],
|
],
|
||||||
"model": "DeepSeek-V3",
|
"model": model,
|
||||||
"stream": True,
|
"stream": stream,
|
||||||
"max_completion_tokens": 2,
|
"max_tokens": max_tokens,
|
||||||
# "temperature": 0.3,
|
"temperature": temperature,
|
||||||
# "top_p": 1.0,
|
"top_p": top_p
|
||||||
# "max_tokens" : 20,
|
|
||||||
}
|
}
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
tasks = [fetch_event_stream(session, payload, prompt_id, stream)]
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["temperature"] = 0.3
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["top_p"] = 1
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["max_tokens"] = 200
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
payload["stream"] = False
|
|
||||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||||
parser.add_argument("--question_id", type=int, default=0, required=False)
|
parser.add_argument("--question_id", type=int, default=0)
|
||||||
|
parser.add_argument("--model", type=str, required=True)
|
||||||
|
parser.add_argument("--stream", type=bool, default=True)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=500)
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.8)
|
||||||
|
parser.add_argument("--top_p", type=float, default=1)
|
||||||
|
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
output_file = "ktransformer_test_results.txt"
|
SERVER_URL = args.api_url
|
||||||
asyncio.run(main(args.question_id))
|
asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))
|
||||||
|
|
|
@ -12,6 +12,8 @@ from time import sleep
|
||||||
decodesz = 128
|
decodesz = 128
|
||||||
# Server URL (replace with your server URL)
|
# Server URL (replace with your server URL)
|
||||||
decodesz_list = [128]
|
decodesz_list = [128]
|
||||||
|
prefill_speeds = []
|
||||||
|
decode_speeds = []
|
||||||
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
|
ktansformer_prompt1024="""Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much.
|
||||||
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
|
They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills.
|
||||||
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
|
He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs.
|
||||||
|
@ -43,17 +45,19 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
|
||||||
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
||||||
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
||||||
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
||||||
async def fetch_event_stream(session, request_id, prompt):
|
async def fetch_event_stream(session, request_id, prompt, max_tokens, model):
|
||||||
try:
|
try:
|
||||||
payload = {
|
payload = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": ""},
|
{"role": "system", "content": ""},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt}
|
||||||
],
|
],
|
||||||
"model": "DeepSeek-V3",
|
"model": model,
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
"stream": True
|
"stream": True,
|
||||||
|
"return_speed": True,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -70,6 +74,7 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
decode_start_time = None
|
decode_start_time = None
|
||||||
decode_end_time = None
|
decode_end_time = None
|
||||||
|
usage_info = None
|
||||||
|
|
||||||
async for line in response.content:
|
async for line in response.content:
|
||||||
try:
|
try:
|
||||||
|
@ -82,6 +87,10 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response_data = json.loads(decoded_line)
|
response_data = json.loads(decoded_line)
|
||||||
|
|
||||||
|
if "usage" in response_data:
|
||||||
|
usage_info = response_data["usage"]
|
||||||
|
|
||||||
choices = response_data.get("choices", [])
|
choices = response_data.get("choices", [])
|
||||||
if not choices:
|
if not choices:
|
||||||
continue
|
continue
|
||||||
|
@ -107,34 +116,48 @@ async def fetch_event_stream(session, request_id, prompt):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Request {request_id}] Stream Error: {e}")
|
print(f"[Request {request_id}] Stream Error: {e}")
|
||||||
|
|
||||||
|
|
||||||
if buffer.strip():
|
if buffer.strip():
|
||||||
print(f"[Request {request_id}] {buffer.strip()}")
|
print(f"[Request {request_id}] {buffer.strip()}")
|
||||||
|
|
||||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
if usage_info:
|
||||||
decode_time = decode_end_time - decode_start_time
|
if "prefill_time" in usage_info:
|
||||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
# print(f"[Request {request_id}] Usage:")
|
||||||
print(f"[Request {request_id}] Speed: {decode_speed:.2f} tokens/s")
|
# for key, value in usage_info.items():
|
||||||
|
# print(f" {key}: {value}")
|
||||||
|
prefill_speed = usage_info["prompt_tokens"] / usage_info["prefill_time"]
|
||||||
|
decode_speed = usage_info["completion_tokens"] / usage_info["decode_time"]
|
||||||
|
prefill_speeds.append(prefill_speed)
|
||||||
|
decode_speeds.append(decode_speed)
|
||||||
|
print(f'[Request {request_id}] prefill speed: {prefill_speed}')
|
||||||
|
print(f'[Request {request_id}] decode speed: {decode_speed}')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[Request {request_id}] Exception: {e}")
|
print(f"[Request {request_id}] Exception: {e}")
|
||||||
|
|
||||||
async def main(concurrent_requests , prompt ):
|
async def main(concurrent_requests , prompt, max_tokens, model):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
tasks = [fetch_event_stream(session, i , prompt) for i in range(concurrent_requests)]
|
tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
if len(prefill_speeds) != 0:
|
||||||
|
import numpy as np
|
||||||
|
print(f"concurrency: {len(prefill_speeds)}")
|
||||||
|
print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||||
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
||||||
|
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
|
||||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SERVER_URL = args.api_url
|
SERVER_URL = args.api_url
|
||||||
|
max_tokens = args.max_tokens
|
||||||
|
model = args.model
|
||||||
if args.prompt_lens == 1024:
|
if args.prompt_lens == 1024:
|
||||||
prompt = ktansformer_prompt1024
|
prompt = ktansformer_prompt1024
|
||||||
elif args.prompt_lens == 2048:
|
elif args.prompt_lens == 2048:
|
||||||
prompt = ktansformer_prompt1024 * 2
|
prompt = ktansformer_prompt1024 * 2
|
||||||
asyncio.run(main(args.concurrent, prompt))
|
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue