#!/usr/bin/env python3
import argparse
import json
import os
import re
import subprocess
import sys
import threading
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict, field
from pathlib import Path
from queue import Queue
from typing import Dict, List, Optional, Any, Tuple
import requests
from tqdm import tqdm
import random
from math import sqrt
@dataclass
class ServerConfig:
url: str
threads: int
name: str = ""
def wilson_interval(correct: int, total: int, z: float = 1.96) -> Tuple[float, float]:
"""Wilson score confidence interval for a proportion."""
if total == 0:
return (0.0, 1.0)
p = correct / total
z2 = z * z / total
center = (p + z2 / 2) / (1 + z2)
margin = z * sqrt((p * (1 - p) + z2 / 4) / total) / (1 + z2)
return (center - margin, center + margin)
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
GRADER_PATTERNS = {
"aime": r'\boxed{(\d+)}|\b(\d+)\b',
"aime2025": r'\boxed{(\d+)}|\b(\d+)\b',
"aime2026": r'\boxed{(\d+)}|\b(\d+)\b',
"gsm8k": r'\b(\d+)\b',
}
SAMPLE_ANSWERS = {
"aime": [
"42",
"-123",
"999"
],
"aime2025": [
"42",
"-123",
"999"
],
"aime2026": [
"42",
"-123",
"999"
],
"gsm8k": [
"42",
"-123",
"999"
],
"gpqa": [
"A",
"D",
"C"
],
}
TEMPLATE_REGISTRY = {
"aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"aime2026": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"gsm8k": """{question}
Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
""",
"gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""",
}
class BaseDataset(ABC):
questions: List[Dict]
@abstractmethod
def get_question(self, index: int) -> Dict:
pass
@abstractmethod
def get_question_text(self, question: Dict) -> str:
pass
@abstractmethod
def get_answer(self, question: Dict) -> str:
pass
@abstractmethod
def get_prompt(self, question: Dict) -> str:
pass
def __len__(self) -> int:
return len(self.questions)
@dataclass
class TaskState:
task_id: str
prompt: str
expected: str
question_text: str = ""
response: Optional[str] = None
answer: Optional[str] = None
grader_log: Dict[str, Any] = field(default_factory=dict)
correct: bool = False
status: str = "pending"
tokens: Optional[int] = None
tps_gen: Optional[float] = None
t_gen_ms: Optional[float] = None
reasoning_content: Optional[str] = None
server_name: Optional[str] = None
chunk_idx: int = 0
problem_idx: int = 0
class EvalState:
def __init__(
self,
dataset_type: str,
sampling_config: Dict[str, Any],
output_file: Path = Path("llama-eval-state.json"),
model_name: Optional[str] = None
):
self.dataset_type = dataset_type
self.sampling_config = sampling_config
self.output_file = output_file
self.model_name = model_name
self.dataset: Optional[BaseDataset] = None
self.tasks: List[Tuple[int, str]] = []
self.all_tasks: List[Tuple[int, str]] = []
self.task_states: Dict[str, Any] = {}
self.total = 0
self.correct = 0
self.processed = 0
self.total_time: float = 0.0
self._lock = threading.Lock()
def load_dataset(self, seed: int = 1234):
if self.dataset_type == "aime":
self.dataset = AimeDataset()
elif self.dataset_type == "aime2025":
self.dataset = Aime2025Dataset()
elif self.dataset_type == "aime2026":
self.dataset = Aime2026Dataset()
elif self.dataset_type == "gsm8k":
self.dataset = Gsm8kDataset()
elif self.dataset_type == "gpqa":
self.dataset = GpqaDataset(variant="diamond", seed=seed)
else:
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234):
if self.dataset is None:
raise ValueError("Dataset not loaded. Call load_dataset() first.")
if n_cases is None:
n_cases = len(self.dataset)
dataset_size = len(self.dataset)
rng = random.Random(seed)
self.tasks = []
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
indices = list(range(dataset_size))
rng.shuffle(indices)
chunk_indices = indices[:chunk_size]
for i in chunk_indices:
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
self.tasks.append((i, task_id))
self.all_tasks = list(self.tasks)
def get_case(self, index: int) -> Tuple[str, str, str]:
if self.dataset is None:
raise ValueError("Dataset not loaded.")
question = self.dataset.get_question(index)
question_text = self.dataset.get_question_text(question)
prompt = self.dataset.get_prompt(question)
expected = self.dataset.get_answer(question)
return question_text, prompt, expected
def add_result(
self,
task_id: str,
prompt: str,
expected: str,
response: Optional[str],
answer: Optional[str],
grader_log: Dict[str, Any],
correct: bool,
status: str,
tokens: Optional[int] = None,
tps_gen: Optional[float] = None,
t_gen_ms: Optional[float] = None,
reasoning_content: Optional[str] = None,
server_name: Optional[str] = None,
chunk_idx: int = 0,
problem_idx: int = 0,
):
with self._lock:
if "cases" not in self.task_states:
self.task_states["cases"] = {}
self.task_states["cases"][task_id] = {
"task_id": task_id,
"prompt": prompt,
"expected": expected,
"response": response,
"answer": answer,
"grader_log": grader_log,
"correct": correct,
"status": status,
"tokens": tokens,
"tps_gen": tps_gen,
"t_gen_ms": t_gen_ms,
"reasoning_content": reasoning_content,
"server_name": server_name,
"chunk_idx": chunk_idx,
"problem_idx": problem_idx,
}
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
def print_progress(self, task_state: TaskState, total_tasks: int, n_correct: int = 0):
display_answer = task_state.answer if task_state.answer else "N/A"
display_tokens = str(task_state.tokens) if task_state.tokens is not None else "N/A"
display_tps = f"{task_state.tps_gen:.1f}" if task_state.tps_gen is not None else "N/A"
display_t_gen = f"{task_state.t_gen_ms/1000:.1f}" if task_state.t_gen_ms is not None else "N/A"
display_server = task_state.server_name if task_state.server_name else "N/A"
success_ratio = n_correct / self.processed if self.processed > 0 else 0.0
first_line = task_state.question_text.split('\n')[0]
truncated_question = first_line[:43]
if len(first_line) > 43:
truncated_question += "..."
else:
truncated_question = truncated_question.ljust(43) + "..."
print(f"{self.processed:3}/{total_tasks:3} {task_state.task_id:<20} {self.dataset_type.upper()} {truncated_question:<40} {task_state.expected:<10} {display_answer:<10} {display_tokens:<6} {display_tps:<6} {display_t_gen:<8} {'✓' if task_state.correct else '✗'} [{n_correct:3}/{self.processed:3}, {success_ratio:.3f}] {display_server}")
def print_summary(self):
if self.total == 0:
print(f"\n{'='*60}")
print(f"Results: 0/0 correct (0.0%)")
print(f"{'='*60}")
else:
ci_lower, ci_upper = self.accuracy_ci()
print(f"\n{'='*60}")
print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%) [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]")
print(f"{'='*60}")
def dump(self):
with self._lock:
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
all_cases = {}
for i, task_id in tasks_to_save:
question_text, prompt, expected = self.get_case(i)
# Extract chunk_idx from task_id for pending cases
_parts = task_id.rsplit("_", 2)
_chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
all_cases[task_id] = {
"task_id": task_id,
"prompt": prompt,
"expected": expected,
"question_text": question_text,
"response": None,
"answer": None,
"grader_log": {},
"correct": False,
"status": "pending",
"tokens": None,
"tps_gen": None,
"t_gen_ms": None,
"reasoning_content": None,
"server_name": None,
"chunk_idx": _chunk_idx,
"problem_idx": i,
}
ci_lower, ci_upper = self.accuracy_ci()
data = {
"id": self.dataset_type,
"model_name": self.model_name,
"tasks": [tid for _, tid in tasks_to_save],
"task_states": {
"total": self.total,
"correct": self.correct,
"total_time": self.total_time,
"ci_lower": ci_lower,
"ci_upper": ci_upper,
"cases": all_cases,
},
"sampling_config": self.sampling_config
}
with open(self.output_file, "w") as f:
json.dump(data, f, indent=2)
self.dump_html(tasks_to_save, all_cases)
def dump_html(self, tasks_to_save: List[Tuple[int, str]], all_cases: Dict[str, Any]):
html_file = Path(str(self.output_file) + ".html")
cases = all_cases
completed = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
n_correct = sum(1 for c in completed.values() if c.get("correct", False))
n_incorrect = len(completed) - n_correct
n_pending = len(tasks_to_save) - len(completed)
accuracy = n_correct / len(completed) * 100 if completed else 0.0
ci_lower, ci_upper = wilson_interval(n_correct, len(completed)) if completed else (0.0, 1.0)
sampling_parts = []
for k, v in self.sampling_config.items():
if v is not None:
sampling_parts.append(f"{k}={v}")
sampling_str = ", ".join(sampling_parts) if sampling_parts else "default"
rows = []
for i, task_id in tasks_to_save:
case = cases.get(task_id, {})
status = case.get("status", "pending")
expected = case.get("expected", "")
answer = case.get("answer", "") if status == "ok" else ""
is_correct = case.get("correct", False) if status == "ok" else False
response = case.get("response", "") or ""
prompt = case.get("prompt", "") or ""
grader_log = case.get("grader_log", {})
if status == "ok":
status_class = "correct" if is_correct else "incorrect"
status_text = "✓" if is_correct else "✗"
elif status == "pending":
status_class = "pending"
status_text = "–"
else:
status_class = "error"
status_text = "!"
tokens = case.get("tokens")
tokens_str = str(tokens) if tokens is not None else ""
tps_gen = case.get("tps_gen")
tps_str = f"{tps_gen:.1f}" if tps_gen is not None else ""
t_gen_ms = case.get("t_gen_ms")
t_gen_str = f"{t_gen_ms/1000:.1f}" if t_gen_ms is not None else ""
reasoning_content = case.get("reasoning_content", "") or ""
server_name = case.get("server_name", "") or ""
escaped_response = self._escape_html(response)
escaped_prompt = self._escape_html(prompt)
escaped_reasoning = self._escape_html(reasoning_content)
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
escaped_server = self._escape_html(server_name)
answer_class = status_class if status == "ok" else ""
rows.append(f"""
{task_id}
{status_text}
{self._escape_html(expected)}
{self._escape_html(answer)}
{tokens_str}
{tps_str}
{t_gen_str}
{escaped_server}
Prompt
{escaped_prompt}
Response
{escaped_response}
{f'Reasoning
{escaped_reasoning}
' if escaped_reasoning else ''}
Grader
{grader_log_str}
""")
rows_html = "\n".join(rows)
# ---- per-problem summary table ----
problem_groups: Dict[int, List[Dict[str, Any]]] = {}
for _tid, _case in cases.items():
if _case.get("status") != "ok":
continue
_pidx = _case.get("problem_idx")
if _pidx is None:
_p_parts = _tid.rsplit("_", 2)
_pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0
problem_groups.setdefault(_pidx, []).append(_case)
summary_rows_html = ""
if problem_groups:
def _stat(v, fmt=".1f", avg_fmt=None):
if not v:
return ("–", "–", "–")
af = fmt if avg_fmt is None else avg_fmt
return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}")
summary_data = []
for pidx, g in problem_groups.items():
runs = len(g)
n_ok = sum(1 for c in g if c.get("correct", False))
toks = [c["tokens"] for c in g if c.get("tokens") is not None]
tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None]
tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None]
summary_data.append((
pidx, runs, n_ok,
_stat(toks, "d", ".0f"),
_stat(tps),
_stat(tg),
))
summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending
summary_rows_html = "\n".join(
f"""
{p:03d}
{r}
{n}/{r}
{tk[0]}
{tk[1]}
{tk[2]}
{tp[0]}
{tp[1]}
{tp[2]}
{tg[0]}
{tg[1]}
{tg[2]}
"""
for p, r, n, tk, tp, tg in summary_data
)
html_content = f"""
{self.dataset_type.upper()} Eval