diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index e833070ee..4bdd239c0 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -149,6 +149,8 @@ class TaskState: t_gen_ms: Optional[float] = None reasoning_content: Optional[str] = None server_name: Optional[str] = None + chunk_idx: int = 0 + problem_idx: int = 0 class EvalState: @@ -233,7 +235,9 @@ class EvalState: tps_gen: Optional[float] = None, t_gen_ms: Optional[float] = None, reasoning_content: Optional[str] = None, - server_name: Optional[str] = None + server_name: Optional[str] = None, + chunk_idx: int = 0, + problem_idx: int = 0, ): with self._lock: if "cases" not in self.task_states: @@ -252,7 +256,9 @@ class EvalState: "tps_gen": tps_gen, "t_gen_ms": t_gen_ms, "reasoning_content": reasoning_content, - "server_name": server_name + "server_name": server_name, + "chunk_idx": chunk_idx, + "problem_idx": problem_idx, } self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) @@ -289,6 +295,9 @@ class EvalState: all_cases = {} for i, task_id in tasks_to_save: question_text, prompt, expected = self.get_case(i) + # Extract chunk_idx from task_id for pending cases + _parts = task_id.rsplit("_", 2) + _chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0 if task_id in self.task_states.get("cases", {}): all_cases[task_id] = self.task_states["cases"][task_id] else: @@ -306,7 +315,9 @@ class EvalState: "tps_gen": None, "t_gen_ms": None, "reasoning_content": None, - "server_name": None + "server_name": None, + "chunk_idx": _chunk_idx, + "problem_idx": i, } ci_lower, ci_upper = self.accuracy_ci() @@ -382,11 +393,12 @@ class EvalState: grader_log_str = self._escape_html(json.dumps(grader_log, indent=2)) escaped_server = self._escape_html(server_name) + answer_class = status_class if status == "ok" else "" rows.append(f""" {task_id} {status_text} {self._escape_html(expected)} - {self._escape_html(answer)} + {self._escape_html(answer)} {tokens_str} {tps_str} {t_gen_str} @@ -405,6 +417,53 @@ class EvalState: rows_html = "\n".join(rows) + # ---- per-problem summary table ---- + problem_groups: Dict[int, List[Dict[str, Any]]] = {} + for _tid, _case in cases.items(): + if _case.get("status") != "ok": + continue + _pidx = _case.get("problem_idx") + if _pidx is None: + _p_parts = _tid.rsplit("_", 2) + _pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0 + problem_groups.setdefault(_pidx, []).append(_case) + + summary_rows_html = "" + if problem_groups: + def _stat(v, fmt=".1f", avg_fmt=None): + if not v: + return ("–", "–", "–") + af = fmt if avg_fmt is None else avg_fmt + return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}") + + summary_data = [] + for pidx, g in problem_groups.items(): + runs = len(g) + n_ok = sum(1 for c in g if c.get("correct", False)) + toks = [c["tokens"] for c in g if c.get("tokens") is not None] + tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None] + tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None] + summary_data.append(( + pidx, runs, n_ok, + _stat(toks, "d", ".0f"), + _stat(tps), + _stat(tg), + )) + + summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending + + summary_rows_html = "\n".join( + f""" + {p:03d} + {r} + {n}/{r} + {tk[0]}{tk[1]}{tk[2]} + {tp[0]}{tp[1]}{tp[2]} + {tg[0]}{tg[1]}{tg[2]} + """ + for p, r, n, tk, tp, tg in summary_data + ) + html_content = f""" @@ -412,10 +471,10 @@ class EvalState: {self.dataset_type.upper()} Eval
- {self.dataset_type.upper()} - Model: {self.model_name or 'N/A'} - Accuracy: {accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%] - Correct: {n_correct} / {len(completed)} - Pending: {n_pending} - Time: {self.total_time:.1f}s - Sampling: {sampling_str} +
Dataset
{self.dataset_type.upper()}
+
Model
{self.model_name or 'N/A'}
+
Accuracy
{accuracy:.1f}% [{ci_lower*100:.1f}%, {ci_upper*100:.1f}%]
+
Correct
{n_correct} / {len(completed)}
+
Pending
{n_pending}
+
Time
{self.total_time:.1f}s
+
Sampling
{sampling_str}
+
+
+ + +
+
+ + + + + + + + + + + + + + + {rows_html} + +
IDGoldAnswerTokensT/sGen sServer
+
+
+ + + + + + + + + + + + + + + + + + + + + {summary_rows_html} + +
ProblemRunsCorrectTokensT/sGen s
minavgmaxminavgmaxminavgmax
- - - - - - - - - - - - - - - {rows_html} - -
IDGoldAnswerTokensT/sGen sServer
""" @@ -1062,12 +1172,19 @@ class Processor: ) -> TaskState: question_text, prompt, expected = eval_state.get_case(i) + # Extract chunk_idx from task_id: "{dataset_type}_{chunk_idx:03d}_{index:03d}" + _parts = task_id.rsplit("_", 2) + chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0 + problem_idx = i + task_state = TaskState( task_id=task_id, prompt=prompt, expected=expected, question_text=question_text, - server_name=server_config.name + server_name=server_config.name, + chunk_idx=chunk_idx, + problem_idx=problem_idx, ) try: @@ -1085,7 +1202,8 @@ class Processor: eval_state.add_result( task_id, prompt, expected, result, None, {"finish_reason": finish_reason}, False, task_state.status, - tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name, + chunk_idx, problem_idx, ) eval_state.dump() return task_state @@ -1108,7 +1226,8 @@ class Processor: eval_state.add_result( task_id, prompt, expected, result, answer, grader_log, is_correct, "ok", - tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name + tokens, tps_gen, t_gen_ms, reasoning_content, server_config.name, + chunk_idx, problem_idx, ) eval_state.dump() diff --git a/examples/llama-eval/llama-server-simulator.py b/examples/llama-eval/llama-server-simulator.py index 2f9cdc545..e64ba8933 100755 --- a/examples/llama-eval/llama-server-simulator.py +++ b/examples/llama-eval/llama-server-simulator.py @@ -65,34 +65,70 @@ def normalize_number(s: str) -> Optional[int]: return int(match.group(0)) class AimeDataset: - def __init__(self, split: str = "train"): + def __init__(self, split: str = "train", dataset_type: str = "aime"): self.split = split + self.dataset_type = dataset_type self.questions: List[Dict] = [] self._load_dataset() - def _load_dataset(self): - print(f"Loading AIME dataset (split: {self.split})...") + def _get_question_text(self, question: Dict) -> str: + """Get question text, handling different dataset field names.""" + return question.get("problem", question.get("question", "")) - cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" - if cache_path.exists(): - print(f"Using cached dataset from {cache_path}") - ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + def _load_dataset(self): + if self.dataset_type == "aime": + print(f"Loading AIME dataset (split: {self.split})...") + cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path)) + else: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + elif self.dataset_type == "aime2025": + print(f"Loading AIME2025 dataset...") + ds_list = [] + for config_name in ["AIME2025-I", "AIME2025-II"]: + cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path)) + else: + ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test") + ds_list.extend(ds) + ds = ds_list else: - ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + raise ValueError(f"Unknown dataset type: {self.dataset_type}") self.questions = list(ds) - print(f"AIME dataset loaded: {len(self.questions)} questions") + print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions") def find_question(self, request_text: str) -> Optional[Dict]: + # Strip common template prefixes to get the actual question text + # Templates include things like "Solve the following math problem step by step..." + # The actual question usually follows a blank line or after the template instruction + cleaned = request_text + # Split on double newline and take the part that looks like the problem + parts = cleaned.split('\n\n') + if len(parts) > 1: + # Find the part that's longest (likely the actual problem text) + problem_parts = [p for p in parts if len(p.strip()) > 100] + if problem_parts: + cleaned = max(problem_parts, key=lambda x: len(x)) + best_match = None best_distance = -1 best_index = -1 for i, question in enumerate(self.questions): - question_text = question["problem"] - request_lower = request_text.lower() + question_text = self._get_question_text(question) + request_lower = cleaned.lower() question_lower = question_text.lower() + # Check if question text is contained in the cleaned request + if question_lower in request_lower or request_lower in question_lower: + debug_log(f"DEBUG: Found substring match at index {i}") + return question + # Exact match if question_lower == request_lower: debug_log(f"DEBUG: Found exact match at index {i}") @@ -118,7 +154,7 @@ class AimeDataset: debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}") return best_match - debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...") + debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...") return None def get_answer(self, question: Dict) -> str: @@ -134,15 +170,16 @@ class Simulator: port: int = 8033, host: str = "localhost", success_rate: float = 0.8, - dataset_split: str = "train" + dataset_split: str = "train", + dataset_type: str = "aime" ): self.port = port self.host = host self.success_rate = success_rate - self.dataset = AimeDataset(dataset_split) + self.dataset = AimeDataset(dataset_split, dataset_type) self.eval_state = EvalState( - id="aime-2025", - tasks=["aime"], + id=dataset_type, + tasks=[dataset_type], task_states={}, sampling_config={"temperature": 0, "max_tokens": 2048} ) @@ -159,6 +196,10 @@ class Simulator: else: response_text = self._generate_wrong_answer(question) + comp_tokens = random.randint(10000, 60000) + tps_gen = random.uniform(90.0, 110.0) + t_gen_ms = comp_tokens / tps_gen * 1000 + return { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", @@ -176,8 +217,12 @@ class Simulator: ], "usage": { "prompt_tokens": 100, - "completion_tokens": 50, - "total_tokens": 150 + "completion_tokens": comp_tokens, + "total_tokens": 100 + comp_tokens + }, + "timings": { + "predicted_ms": t_gen_ms, + "predicted_per_second": tps_gen } } @@ -218,6 +263,12 @@ class Simulator: return response class RequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == "/v1/models": + self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200) + return + self._send_json({"error": "Not found"}, 404) + def do_POST(self): if self.path != "/v1/chat/completions": self._send_json({"error": "Not found"}, 404) @@ -280,6 +331,13 @@ def main(): default=0.8, help="Success rate 0-1 (default: 0.8)" ) + parser.add_argument( + "--dataset", + type=str, + default="aime", + choices=["aime", "aime2025"], + help="Dataset type (default: aime)" + ) parser.add_argument( "--dataset-split", type=str, @@ -294,7 +352,8 @@ def main(): port=args.port, host=args.host, success_rate=args.success_rate, - dataset_split=args.dataset_split + dataset_split=args.dataset_split, + dataset_type=args.dataset ) server = HTTPServer((args.host, args.port), RequestHandler) @@ -304,7 +363,7 @@ def main(): print("\n=== llama-server-simulator ===") print(f"Server running on http://{args.host}:{args.port}") print(f"Success rate: {args.success_rate}") - print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions") + print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions") print("\nPress Ctrl+C to stop\n") try: