diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py
index e833070ee..4bdd239c0 100755
--- a/examples/llama-eval/llama-eval.py
+++ b/examples/llama-eval/llama-eval.py
@@ -149,6 +149,8 @@ class TaskState:
t_gen_ms: Optional[float] = None
reasoning_content: Optional[str] = None
server_name: Optional[str] = None
+ chunk_idx: int = 0
+ problem_idx: int = 0
class EvalState:
@@ -233,7 +235,9 @@ class EvalState:
tps_gen: Optional[float] = None,
t_gen_ms: Optional[float] = None,
reasoning_content: Optional[str] = None,
- server_name: Optional[str] = None
+ server_name: Optional[str] = None,
+ chunk_idx: int = 0,
+ problem_idx: int = 0,
):
with self._lock:
if "cases" not in self.task_states:
@@ -252,7 +256,9 @@ class EvalState:
"tps_gen": tps_gen,
"t_gen_ms": t_gen_ms,
"reasoning_content": reasoning_content,
- "server_name": server_name
+ "server_name": server_name,
+ "chunk_idx": chunk_idx,
+ "problem_idx": problem_idx,
}
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
@@ -289,6 +295,9 @@ class EvalState:
all_cases = {}
for i, task_id in tasks_to_save:
question_text, prompt, expected = self.get_case(i)
+ # Extract chunk_idx from task_id for pending cases
+ _parts = task_id.rsplit("_", 2)
+ _chunk_idx = int(_parts[-2]) if len(_parts) >= 3 else 0
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
@@ -306,7 +315,9 @@ class EvalState:
"tps_gen": None,
"t_gen_ms": None,
"reasoning_content": None,
- "server_name": None
+ "server_name": None,
+ "chunk_idx": _chunk_idx,
+ "problem_idx": i,
}
ci_lower, ci_upper = self.accuracy_ci()
@@ -382,11 +393,12 @@ class EvalState:
grader_log_str = self._escape_html(json.dumps(grader_log, indent=2))
escaped_server = self._escape_html(server_name)
+ answer_class = status_class if status == "ok" else ""
rows.append(f"""
{task_id}
{status_text}
{self._escape_html(expected)}
-
{self._escape_html(answer)}
+
{self._escape_html(answer)}
{tokens_str}
{tps_str}
{t_gen_str}
@@ -405,6 +417,53 @@ class EvalState:
rows_html = "\n".join(rows)
+ # ---- per-problem summary table ----
+ problem_groups: Dict[int, List[Dict[str, Any]]] = {}
+ for _tid, _case in cases.items():
+ if _case.get("status") != "ok":
+ continue
+ _pidx = _case.get("problem_idx")
+ if _pidx is None:
+ _p_parts = _tid.rsplit("_", 2)
+ _pidx = int(_p_parts[-1]) if len(_p_parts) >= 3 else 0
+ problem_groups.setdefault(_pidx, []).append(_case)
+
+ summary_rows_html = ""
+ if problem_groups:
+ def _stat(v, fmt=".1f", avg_fmt=None):
+ if not v:
+ return ("–", "–", "–")
+ af = fmt if avg_fmt is None else avg_fmt
+ return (f"{min(v):{fmt}}", f"{sum(v)/len(v):{af}}", f"{max(v):{fmt}}")
+
+ summary_data = []
+ for pidx, g in problem_groups.items():
+ runs = len(g)
+ n_ok = sum(1 for c in g if c.get("correct", False))
+ toks = [c["tokens"] for c in g if c.get("tokens") is not None]
+ tps = [c["tps_gen"] for c in g if c.get("tps_gen") is not None]
+ tg = [c["t_gen_ms"] / 1000 for c in g if c.get("t_gen_ms") is not None]
+ summary_data.append((
+ pidx, runs, n_ok,
+ _stat(toks, "d", ".0f"),
+ _stat(tps),
+ _stat(tg),
+ ))
+
+ summary_data.sort(key=lambda r: r[0]) # sort by problem index ascending
+
+ summary_rows_html = "\n".join(
+ f"""
+
{p:03d}
+
{r}
+
{n}/{r}
+
{tk[0]}
{tk[1]}
{tk[2]}
+
{tp[0]}
{tp[1]}
{tp[2]}
+
{tg[0]}
{tg[1]}
{tg[2]}
+
"""
+ for p, r, n, tk, tp, tg in summary_data
+ )
+
html_content = f"""
@@ -412,10 +471,10 @@ class EvalState:
{self.dataset_type.upper()} Eval