mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
change test
This commit is contained in:
parent
b17ab8653c
commit
3a044e6b14
4 changed files with 134 additions and 115 deletions
|
@ -8,12 +8,57 @@ from datasets import load_dataset
|
|||
import os
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import re
|
||||
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
os.environ['https_proxy'] = ''
|
||||
os.environ['http_proxy'] = ''
|
||||
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
|
||||
|
||||
|
||||
def extract_final_answer(text):
|
||||
"""
|
||||
提取模型预测的最终选项(如 A/B/C/D)
|
||||
支持自然语言、多行、markdown、高亮、非末尾结论等格式
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# 1. 显式语句匹配(优先)
|
||||
explicit_patterns = [
|
||||
r'Answer:\s*([A-D])\b',
|
||||
r'Correct answer:\s*([A-D])\b',
|
||||
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
|
||||
r'Answer is\s*([A-D])\b',
|
||||
r'Therefore,\s*answer is\s*([A-D])\b',
|
||||
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
|
||||
r'Option\s+([A-D])\s+is correct',
|
||||
]
|
||||
for pat in explicit_patterns:
|
||||
match = re.search(pat, text, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
# 2. markdown 强调 **C**, **C. something**
|
||||
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
|
||||
if markdown_match:
|
||||
return markdown_match[-1].upper()
|
||||
|
||||
# 3. 查找单引号中的 'C' 或 "C"
|
||||
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
|
||||
if quote_match:
|
||||
return quote_match[-1].upper()
|
||||
|
||||
# 4. 倒数几行是否以 "C." 或 "C" 开头
|
||||
lines = text.splitlines()
|
||||
for line in reversed(lines[-5:]):
|
||||
line = line.strip()
|
||||
match = re.match(r'^([A-D])([.\s]|$)', line)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
|
||||
# 再不行就返回 None
|
||||
return None
|
||||
class DataEvaluator:
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
|
@ -22,8 +67,10 @@ class DataEvaluator:
|
|||
"""
|
||||
从数据文件中加载数据,每条记录对应一个实例
|
||||
"""
|
||||
ds = load_dataset(file_path, "all")
|
||||
df = pd.DataFrame(ds['test'])
|
||||
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
|
||||
'dev': 'all/dev-00000-of-00001.parquet',
|
||||
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
|
||||
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
|
||||
for _, row in df.iterrows():
|
||||
self.data.append(row.to_dict())
|
||||
|
||||
|
@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
|
|||
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
|
||||
start_total_time = time.time()
|
||||
total_score = 0
|
||||
total_exact_score = 0
|
||||
results = []
|
||||
file_lock = threading.Lock()
|
||||
|
||||
|
@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
|
||||
def worker(index, data_item):
|
||||
nonlocal total_score
|
||||
nonlocal total_exact_score
|
||||
question = data_evaluator.get_prompt(data_item)
|
||||
start_time = time.time()
|
||||
try:
|
||||
|
@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
answer = chr(data_item['answer'] + 65)
|
||||
processed_prediction = data_evaluator.post_processing(prediction)
|
||||
score = data_evaluator.score(processed_prediction, answer)
|
||||
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
|
||||
elapsed_time = time.time() - start_time
|
||||
result_data = {
|
||||
"question_id": index,
|
||||
"answer": answer,
|
||||
"prediction": processed_prediction,
|
||||
"real_prediction": prediction,
|
||||
"full_prediction": prediction,
|
||||
"score": score,
|
||||
"exact_score": exact_score,
|
||||
"time": elapsed_time
|
||||
}
|
||||
# 写入结果时加锁保证线程安全
|
||||
|
@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
if res is not None:
|
||||
results.append(res)
|
||||
total_score += res['score']
|
||||
total_exact_score += res['exact_score']
|
||||
|
||||
total_time = time.time() - start_total_time
|
||||
throughput = len(data_subset) / total_time if total_time > 0 else 0
|
||||
|
@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
|
|||
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
|
||||
average_score = total_score / len(data_subset) if data_subset else 0
|
||||
log_f.write(f"Average Score: {average_score}\n")
|
||||
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
|
||||
log_f.write(f"Average Exact Score: {average_exact_score}\n")
|
||||
log_f.write('-' * 40 + '\n')
|
||||
|
||||
print(f"Results saved to {result_file}")
|
||||
|
@ -152,4 +206,4 @@ if __name__ == "__main__":
|
|||
data_evaluator = DataEvaluator()
|
||||
data_evaluator.load_data(args.file)
|
||||
|
||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
||||
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
|
Loading…
Add table
Add a link
Reference in a new issue