mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
change test
This commit is contained in:
parent
b17ab8653c
commit
3a044e6b14
4 changed files with 134 additions and 115 deletions
|
@ -25,19 +25,10 @@ class DataEvaluator:
|
|||
"""
|
||||
# 读取 Parquet 文件
|
||||
# dataset = load_dataset('parquet', data_files=file_path)
|
||||
ds = load_dataset(file_path,"all")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
# print(ds)
|
||||
# # ds_1 = ds['train']
|
||||
# ds_2 = ds['validation']
|
||||
# ds_3 = ds['test']
|
||||
# # 将数据集转换为 Pandas DataFrame
|
||||
# df_test = pd.DataFrame(ds['test'])
|
||||
# df_val = pd.DataFrame(ds['validation'])
|
||||
|
||||
# for _, row in df.iterrows():
|
||||
# self.data.append(row.to_dict())
|
||||
# df = pd.read_parquet(file_path)
|
||||
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||
'dev': 'all/dev-00000-of-00001.parquet',
|
||||
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
|
|
@ -8,12 +8,57 @@ from datasets import load_dataset
|
|||
import os
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import re
|
||||
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
os.environ['https_proxy'] = ''
|
||||
os.environ['http_proxy'] = ''
|
||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
||||
|
||||
|
||||
def extract_final_answer(text):
|
||||
"""
|
||||
提取模型预测的最终选项(如 A/B/C/D)
|
||||
支持自然语言、多行、markdown、高亮、非末尾结论等格式
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# 1. 显式语句匹配(优先)
|
||||
explicit_patterns = [
|
||||
r'Answer:\s*([A-D])\b',
|
||||
r'Correct answer:\s*([A-D])\b',
|
||||
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
|
||||
r'Answer is\s*([A-D])\b',
|
||||
r'Therefore,\s*answer is\s*([A-D])\b',
|
||||
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||
r'Option\s+([A-D])\s+is correct',
|
||||
]
|
||||
for pat in explicit_patterns:
|
||||
match = re.search(pat, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
# 2. markdown 强调 **C**, **C. something**
|
||||
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
|
||||
if markdown_match:
|
||||
return markdown_match[-1].upper()
|
||||
|
||||
# 3. 查找单引号中的 'C' 或 "C"
|
||||
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
|
||||
if quote_match:
|
||||
return quote_match[-1].upper()
|
||||
|
||||
# 4. 倒数几行是否以 "C." 或 "C" 开头
|
||||
lines = text.splitlines()
|
||||
for line in reversed(lines[-5:]):
|
||||
line = line.strip()
|
||||
match = re.match(r'^([A-D])([.\s]|$)', line)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
# 再不行就返回 None
|
||||
return None
|
||||
class DataEvaluator:
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
|
@ -22,8 +67,10 @@ class DataEvaluator:
|
|||
"""
|
||||
从数据文件中加载数据,每条记录对应一个实例
|
||||
"""
|
||||
ds = load_dataset(file_path, "all")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||
'dev': 'all/dev-00000-of-00001.parquet',
|
||||
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
||||
|
@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
|
|||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||
start_total_time = time.time()
|
||||
total_score = 0
|
||||
total_exact_score = 0
|
||||
results = []
|
||||
file_lock = threading.Lock()
|
||||
|
||||
|
@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
|
||||
def worker(index, data_item):
|
||||
nonlocal total_score
|
||||
nonlocal total_exact_score
|
||||
question = data_evaluator.get_prompt(data_item)
|
||||
start_time = time.time()
|
||||
try:
|
||||
|
@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
answer = chr(data_item['answer'] + 65)
|
||||
processed_prediction = data_evaluator.post_processing(prediction)
|
||||
score = data_evaluator.score(processed_prediction, answer)
|
||||
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
|
||||
elapsed_time = time.time() - start_time
|
||||
result_data = {
|
||||
"question_id": index,
|
||||
"answer": answer,
|
||||
"prediction": processed_prediction,
|
||||
"real_prediction": prediction,
|
||||
"full_prediction": prediction,
|
||||
"score": score,
|
||||
"exact_score": exact_score,
|
||||
"time": elapsed_time
|
||||
}
|
||||
# 写入结果时加锁保证线程安全
|
||||
|
@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
if res is not None:
|
||||
results.append(res)
|
||||
total_score += res['score']
|
||||
total_exact_score += res['exact_score']
|
||||
|
||||
total_time = time.time() - start_total_time
|
||||
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
||||
|
@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||
average_score = total_score / len(data_subset) if data_subset else 0
|
||||
log_f.write(f"Average Score: {average_score}\n")
|
||||
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
|
||||
log_f.write(f"Average Exact Score: {average_exact_score}\n")
|
||||
log_f.write('-' * 40 + '\n')
|
||||
|
||||
print(f"Results saved to {result_file}")
|
||||
|
|
|
@ -2,23 +2,18 @@ import asyncio
|
|||
import json
|
||||
import sys
|
||||
import aiohttp
|
||||
import random
|
||||
import argparse
|
||||
import yaml
|
||||
import os
|
||||
import time
|
||||
from time import sleep
|
||||
|
||||
decodesz = 128
|
||||
# Server URL (replace with your server URL)
|
||||
SERVER_URL = "http://localhost:10002/v1/chat/completions"
|
||||
bf_list = [1]
|
||||
decodesz_list = [128]
|
||||
prompt_list = ['Please elaborate on modern world history.', 'Please introduce Harry Potter.', 'I want to learn Python. Please give me some advice.', 'Please tell me a joke ']
|
||||
async def fetch_event_stream(session, payload, request_id):
|
||||
prompt_list = [
|
||||
'Please elaborate on modern world history.',
|
||||
'Please introduce Harry Potter.',
|
||||
'I want to learn Python. Please give me some advice.',
|
||||
'Please tell me a joke '
|
||||
]
|
||||
|
||||
|
||||
async def fetch_event_stream(session, payload, request_id, stream):
|
||||
try:
|
||||
|
||||
|
||||
headers = {
|
||||
'accept': 'application/json',
|
||||
'Content-Type': 'application/json'
|
||||
|
@ -31,28 +26,20 @@ async def fetch_event_stream(session, payload, request_id):
|
|||
print(f"Request {request_id}: Error, status {response.status}")
|
||||
return
|
||||
|
||||
output_text = "" # 存储当前 response 的所有 token
|
||||
total_tokens = 0 # 统计总 tokens 数
|
||||
decode_start_time = None # 记录 decode 阶段开始时间
|
||||
decode_end_time = None # 记录 decode 结束时间
|
||||
output_text = ""
|
||||
|
||||
if stream:
|
||||
async for line in response.content:
|
||||
try:
|
||||
decoded_line = line.decode("utf-8").strip()
|
||||
|
||||
# 过滤空行
|
||||
if not decoded_line or not decoded_line.startswith("data: "):
|
||||
continue
|
||||
|
||||
decoded_line = decoded_line[6:].strip() # 去掉 `data: `
|
||||
|
||||
# 确保 JSON 数据是合法的
|
||||
decoded_line = decoded_line[6:].strip()
|
||||
if not decoded_line:
|
||||
continue
|
||||
|
||||
response_data = json.loads(decoded_line) # 解析 JSON
|
||||
|
||||
# 确保 choices 存在
|
||||
response_data = json.loads(decoded_line)
|
||||
choices = response_data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
@ -61,20 +48,13 @@ async def fetch_event_stream(session, payload, request_id):
|
|||
token = delta.get("content", "")
|
||||
|
||||
if token:
|
||||
if decode_start_time is None:
|
||||
decode_start_time = time.time() # 记录 decode 开始时间
|
||||
output_text += token
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
output_text += token # 追加 token
|
||||
sys.stdout.write(token) # 直接输出 token
|
||||
sys.stdout.flush() # 立即刷新,确保 token 立刻出现在终端
|
||||
total_tokens += 1 # 增加 token 计数
|
||||
decode_end_time = time.time() # 每次收到 token,更新 decode 结束时间
|
||||
|
||||
# 检查是否完成
|
||||
finish_reason = choices[0].get("finish_reason", None)
|
||||
if finish_reason:
|
||||
# print(f"\nRequest {request_id}: Done")
|
||||
break # 结束流式处理
|
||||
break
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"\nRequest {request_id}: JSON Decode Error - {e}")
|
||||
|
@ -82,53 +62,44 @@ async def fetch_event_stream(session, payload, request_id):
|
|||
print(f"\nRequest {request_id}: List Index Error - choices is empty")
|
||||
except Exception as e:
|
||||
print(f"\nRequest {request_id}: Error parsing stream - {e}")
|
||||
|
||||
# 计算 decode 速度
|
||||
if decode_start_time and decode_end_time and total_tokens > 0:
|
||||
decode_time = decode_end_time - decode_start_time
|
||||
decode_speed = total_tokens / decode_time if decode_time > 0 else 0
|
||||
# print(f"Request {request_id}: Decode Speed = {decode_speed:.2f} tokens/s")
|
||||
else:
|
||||
# 非 stream 模式下,一次性接收完整 json
|
||||
response_data = await response.json()
|
||||
choices = response_data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
print(f"Request {request_id} Output:\n{content}")
|
||||
output_text += content
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nRequest {request_id}: Exception - {e}")
|
||||
|
||||
async def main(prompt_id):
|
||||
async def main(prompt_id, model, stream, max_tokens, temperature, top_p):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"messages": [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": prompt_list[prompt_id]}
|
||||
],
|
||||
"model": "DeepSeek-V3",
|
||||
"stream": True,
|
||||
"max_completion_tokens": 2,
|
||||
# "temperature": 0.3,
|
||||
# "top_p": 1.0,
|
||||
# "max_tokens" : 20,
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p
|
||||
}
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["temperature"] = 0.3
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["top_p"] = 1
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["max_tokens"] = 200
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
payload["stream"] = False
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id)]
|
||||
tasks = [fetch_event_stream(session, payload, prompt_id, stream)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||
parser.add_argument("--question_id", type=int, default=0, required=False)
|
||||
parser.add_argument("--question_id", type=int, default=0)
|
||||
parser.add_argument("--model", type=str, required=True)
|
||||
parser.add_argument("--stream", type=bool, default=True)
|
||||
parser.add_argument("--max_tokens", type=int, default=500)
|
||||
parser.add_argument("--temperature", type=float, default=0.8)
|
||||
parser.add_argument("--top_p", type=float, default=1)
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10006/v1/chat/completions", help="API URL")
|
||||
|
||||
args = parser.parse_args()
|
||||
output_file = "ktransformer_test_results.txt"
|
||||
asyncio.run(main(args.question_id))
|
||||
SERVER_URL = args.api_url
|
||||
asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))
|
||||
|
|
|
@ -45,14 +45,14 @@ They were whispering excitedly together. Mr. Dursley was enraged to see that a c
|
|||
The nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it.
|
||||
The traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.
|
||||
Mr. Dursley always sat with his back to the window in his office on the ninth floor."""
|
||||
async def fetch_event_stream(session, request_id, prompt, max_tokens):
|
||||
async def fetch_event_stream(session, request_id, prompt, max_tokens, model):
|
||||
try:
|
||||
payload = {
|
||||
"messages": [
|
||||
{"role": "system", "content": ""},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"model": "DeepSeek-V3",
|
||||
"model": model,
|
||||
"temperature": 0.3,
|
||||
"top_p": 1.0,
|
||||
"stream": True,
|
||||
|
@ -134,17 +134,19 @@ async def fetch_event_stream(session, request_id, prompt, max_tokens):
|
|||
except Exception as e:
|
||||
print(f"[Request {request_id}] Exception: {e}")
|
||||
|
||||
async def main(concurrent_requests , prompt, max_tokens):
|
||||
async def main(concurrent_requests , prompt, max_tokens, model):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = [fetch_event_stream(session, i , prompt, max_tokens) for i in range(concurrent_requests)]
|
||||
tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]
|
||||
await asyncio.gather(*tasks)
|
||||
if len(prefill_speeds) != 0:
|
||||
import numpy as np
|
||||
print(f"average prefill speed: {np.average(prefill_speeds)}\naverage decode speed: {np.average(decode_speeds)}")
|
||||
print(f"concurrency: {len(prefill_speeds)}")
|
||||
print(f"total prefill speed: {np.sum(prefill_speeds)}\n total decode speed: {np.sum(decode_speeds)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Event Stream Request Tester")
|
||||
parser.add_argument("--concurrent", type=int, default=1, help="Number of concurrent requests")
|
||||
parser.add_argument("--model", type=str, default="DeepSeek-V3", help="Model name", required=True)
|
||||
parser.add_argument("--prompt_lens", type=int, default=1024, help="prefill prompt lens, 1024 or 2048")
|
||||
parser.add_argument("--api_url", type=str, default="http://localhost:10002/v1/chat/completions", help="API URL")
|
||||
parser.add_argument("--max_tokens", type=int, default=50, help="max decode tokens")
|
||||
|
@ -152,9 +154,10 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
SERVER_URL = args.api_url
|
||||
max_tokens = args.max_tokens
|
||||
model = args.model
|
||||
if args.prompt_lens == 1024:
|
||||
prompt = ktansformer_prompt1024
|
||||
elif args.prompt_lens == 2048:
|
||||
prompt = ktansformer_prompt1024 * 2
|
||||
asyncio.run(main(args.concurrent, prompt, max_tokens))
|
||||
asyncio.run(main(args.concurrent, prompt, max_tokens, model))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue