change test

This commit is contained in:
qiyuxinlin 2025-04-22 12:50:39 +00:00
parent b17ab8653c
commit 3a044e6b14
4 changed files with 134 additions and 115 deletions

View file

@ -8,12 +8,57 @@ from datasets import load_dataset
import os
import concurrent.futures
import threading
import re
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['https_proxy'] = ''
os.environ['http_proxy'] = ''
hint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'
def extract_final_answer(text):
"""
提取模型预测的最终选项 A/B/C/D
支持自然语言多行markdown高亮非末尾结论等格式
"""
text = text.strip()
# 1. 显式语句匹配(优先)
explicit_patterns = [
r'Answer:\s*([A-D])\b',
r'Correct answer:\s*([A-D])\b',
r'The correct answer is\s*\*?\*?\s*([A-D])\b',
r'Answer is\s*([A-D])\b',
r'Therefore,\s*answer is\s*([A-D])\b',
r'Therefore,\s*the answer should be\s*(?:Option\s*)?([A-D])\b',
r'The answer should be\s*(?:Option\s*)?([A-D])\b',
r'Option\s+([A-D])\s+is correct',
]
for pat in explicit_patterns:
match = re.search(pat, text, re.IGNORECASE)
if match:
return match.group(1).upper()
# 2. markdown 强调 **C**, **C. something**
markdown_match = re.findall(r'\*\*\s*([A-D])[\.\s]?', text)
if markdown_match:
return markdown_match[-1].upper()
# 3. 查找单引号中的 'C' 或 "C"
quote_match = re.findall(r"['\"]([A-D])['\"]", text)
if quote_match:
return quote_match[-1].upper()
# 4. 倒数几行是否以 "C." 或 "C" 开头
lines = text.splitlines()
for line in reversed(lines[-5:]):
line = line.strip()
match = re.match(r'^([A-D])([.\s]|$)', line)
if match:
return match.group(1).upper()
# 再不行就返回 None
return None
class DataEvaluator:
def __init__(self):
self.data = []
@ -22,8 +67,10 @@ class DataEvaluator:
"""
从数据文件中加载数据每条记录对应一个实例
"""
ds = load_dataset(file_path, "all")
df = pd.DataFrame(ds['test'])
splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',
'dev': 'all/dev-00000-of-00001.parquet',
'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
df = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])
for _, row in df.iterrows():
self.data.append(row.to_dict())
@ -73,6 +120,7 @@ def generate_text(api_url, question, model_name, stream=False):
def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):
start_total_time = time.time()
total_score = 0
total_exact_score = 0
results = []
file_lock = threading.Lock()
@ -85,6 +133,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
def worker(index, data_item):
nonlocal total_score
nonlocal total_exact_score
question = data_evaluator.get_prompt(data_item)
start_time = time.time()
try:
@ -95,13 +144,15 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
answer = chr(data_item['answer'] + 65)
processed_prediction = data_evaluator.post_processing(prediction)
score = data_evaluator.score(processed_prediction, answer)
exact_score = data_evaluator.score(extract_final_answer(prediction), answer)
elapsed_time = time.time() - start_time
result_data = {
"question_id": index,
"answer": answer,
"prediction": processed_prediction,
"real_prediction": prediction,
"full_prediction": prediction,
"score": score,
"exact_score": exact_score,
"time": elapsed_time
}
# 写入结果时加锁保证线程安全
@ -124,6 +175,7 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
if res is not None:
results.append(res)
total_score += res['score']
total_exact_score += res['exact_score']
total_time = time.time() - start_total_time
throughput = len(data_subset) / total_time if total_time > 0 else 0
@ -133,6 +185,8 @@ def main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_fi
log_f.write(f"Throughput: {throughput:.2f} requests per second\n")
average_score = total_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Score: {average_score}\n")
average_exact_score = total_exact_score / len(data_subset) if data_subset else 0
log_f.write(f"Average Exact Score: {average_exact_score}\n")
log_f.write('-' * 40 + '\n')
print(f"Results saved to {result_file}")
@ -152,4 +206,4 @@ if __name__ == "__main__":
data_evaluator = DataEvaluator()
data_evaluator.load_data(args.file)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)
main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)