koboldcpp/examples/llama-eval/llama-server-simulator.py
Georgi Gerganov d2e179a477
llama-eval : add per-task summary stats (#23151)
* llama-eval : add per-problem summary table to HTML reports

- Add chunk_idx and problem_idx to TaskState and saved case dicts
- Group completed cases by problem_idx in dump_html()
- Render per-problem summary table before individual task table
  - Columns: Problem (zero-padded), Runs, Correct (n/r),
    Tokens (min/avg/max), T/s (min/avg/max), Gen s (min/avg/max)
  - Sorted by problem index, monospace font, right-aligned numbers
  - Colspan headers for grouped stats, auto width
- Simulator: add /v1/models endpoint, timings in response,
  template-aware question matching, --dataset arg (aime/aime2025)

Assisted-by: llama.cpp:local pi

* llama-eval : add tabs for Detailed and Summary tables, apply monospace font globally

- Wrap Detailed and Summary tables in switchable tabs (Detailed active by default)
- Remove summary-section wrapper, use tab labels instead
- Apply monospace font to all tables and the top bar

Assisted-by: llama.cpp:local pi

* llama-eval : redesign top bar as CSS grid label/value pairs

- Replace flat span list with 4-column grid layout (2 pairs per row)
- Labels in muted color (#888), values in dark (#222)
- Bold dataset name and model name
- Removed media query, always uses 4 columns

Assisted-by: llama.cpp:local pi

* llama-eval : use realistic token counts and throughput in simulator

- comp_tokens: [30, 80] → [10000, 60000]
- tps_gen: derived → uniform [90.0, 110.0]
- t_gen_ms: now computed from tokens/tps

Assisted-by: llama.cpp:local pi

* llama-eval : color Answer column green/red based on correctness

Use the same .correct/.incorrect CSS classes on the Answer column
to make correct answers green and incorrect answers red.

Assisted-by: llama.cpp:local pi

* llama-eval : fix pyright errors from max(..., key=len) type inference

Use key=lambda x: len(x) instead of key=len so the type checker
infers the return type as str instead of Sized, fixing:
  - unresolved-attribute: Object of type Sized has no attribute lower
  - not-subscriptable: Cannot subscript object of type Sized

Assisted-by: llama.cpp:local pi
2026-05-19 09:46:05 +03:00

376 lines
13 KiB
Python
Executable file

#!/usr/bin/env python3
import argparse
import json
import random
import re
import time
import sys
import os
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
from typing import Dict, List, Optional
from dataclasses import dataclass
from pathlib import Path
import datasets
# Set cache directory for HuggingFace datasets
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
def dice(s1: str, s2: str) -> float:
"""Calculate Dice coefficient between two strings based on bigram overlap."""
if not s1 and not s2:
return 1.0
def _bigrams(s: str):
return [s[i : i + 2] for i in range(len(s) - 1)]
bigrams1 = _bigrams(s1)
bigrams2 = _bigrams(s2)
if not bigrams1 and not bigrams2:
return 1.0
from collections import Counter
freq1 = Counter(bigrams1)
freq2 = Counter(bigrams2)
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
return dice_coeff
def debug_log(message: str):
"""Log debug messages to both stdout and a file"""
print(message, file=sys.stderr)
with open("/tmp/simulator-debug.log", "a") as f:
f.write(message + "\n")
simulator: Optional["Simulator"] = None
@dataclass
class EvalState:
id: str
tasks: List[str]
task_states: Dict[str, Dict]
sampling_config: Dict
def normalize_number(s: str) -> Optional[int]:
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return int(match.group(0))
class AimeDataset:
def __init__(self, split: str = "train", dataset_type: str = "aime"):
self.split = split
self.dataset_type = dataset_type
self.questions: List[Dict] = []
self._load_dataset()
def _get_question_text(self, question: Dict) -> str:
"""Get question text, handling different dataset field names."""
return question.get("problem", question.get("question", ""))
def _load_dataset(self):
if self.dataset_type == "aime":
print(f"Loading AIME dataset (split: {self.split})...")
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
else:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
elif self.dataset_type == "aime2025":
print(f"Loading AIME2025 dataset...")
ds_list = []
for config_name in ["AIME2025-I", "AIME2025-II"]:
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
else:
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test")
ds_list.extend(ds)
ds = ds_list
else:
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
self.questions = list(ds)
print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions")
def find_question(self, request_text: str) -> Optional[Dict]:
# Strip common template prefixes to get the actual question text
# Templates include things like "Solve the following math problem step by step..."
# The actual question usually follows a blank line or after the template instruction
cleaned = request_text
# Split on double newline and take the part that looks like the problem
parts = cleaned.split('\n\n')
if len(parts) > 1:
# Find the part that's longest (likely the actual problem text)
problem_parts = [p for p in parts if len(p.strip()) > 100]
if problem_parts:
cleaned = max(problem_parts, key=lambda x: len(x))
best_match = None
best_distance = -1
best_index = -1
for i, question in enumerate(self.questions):
question_text = self._get_question_text(question)
request_lower = cleaned.lower()
question_lower = question_text.lower()
# Check if question text is contained in the cleaned request
if question_lower in request_lower or request_lower in question_lower:
debug_log(f"DEBUG: Found substring match at index {i}")
return question
# Exact match
if question_lower == request_lower:
debug_log(f"DEBUG: Found exact match at index {i}")
return question
# Remove LaTeX formatting for more flexible matching
question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
if question_no_latex.lower() == request_lower:
debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
return question
# Calculate Dice coefficient for partial matches
# Only consider if request is at least 50% of question length
if len(request_lower) >= len(question_lower) * 0.5:
distance = dice(question_lower, request_lower)
if distance > best_distance:
best_distance = distance
best_match = question
best_index = i
if best_match and best_distance > 0.3: # Threshold for partial match
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
return best_match
debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...")
return None
def get_answer(self, question: Dict) -> str:
answer = question["answer"]
if isinstance(answer, str):
normalized = normalize_number(answer)
return str(normalized) if normalized is not None else answer
return str(answer)
class Simulator:
def __init__(
self,
port: int = 8033,
host: str = "localhost",
success_rate: float = 0.8,
dataset_split: str = "train",
dataset_type: str = "aime"
):
self.port = port
self.host = host
self.success_rate = success_rate
self.dataset = AimeDataset(dataset_split, dataset_type)
self.eval_state = EvalState(
id=dataset_type,
tasks=[dataset_type],
task_states={},
sampling_config={"temperature": 0, "max_tokens": 2048}
)
def _generate_response(
self,
question: Dict,
should_be_correct: bool
) -> Dict:
expected_answer = self.dataset.get_answer(question)
if should_be_correct:
response_text = expected_answer
else:
response_text = self._generate_wrong_answer(question)
comp_tokens = random.randint(10000, 60000)
tps_gen = random.uniform(90.0, 110.0)
t_gen_ms = comp_tokens / tps_gen * 1000
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "llama",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": comp_tokens,
"total_tokens": 100 + comp_tokens
},
"timings": {
"predicted_ms": t_gen_ms,
"predicted_per_second": tps_gen
}
}
def _generate_wrong_answer(self, question: Dict) -> str:
expected_answer = self.dataset.get_answer(question)
if expected_answer.isdigit():
wrong_answer = str(int(expected_answer) + 1)
else:
wrong_answer = expected_answer + " (wrong)"
return wrong_answer
def _process_request(self, request_data: Dict) -> Dict:
messages = request_data.get("messages", [])
if not messages:
return {"error": "No messages in request"}
request_text = messages[0].get("content", "")
debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
question = self.dataset.find_question(request_text)
if not question:
debug_log(f"DEBUG: find_question returned None")
return {"error": "No matching question found"}
should_be_correct = random.random() < self.success_rate
response = self._generate_response(question, should_be_correct)
task_id = "aime"
self.eval_state.task_states[task_id] = {
"correct": should_be_correct,
"expected": self.dataset.get_answer(question),
"predicted": response["choices"][0]["message"]["content"]
}
return response
class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/v1/models":
self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200)
return
self._send_json({"error": "Not found"}, 404)
def do_POST(self):
if self.path != "/v1/chat/completions":
self._send_json({"error": "Not found"}, 404)
return
try:
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length)
request_data = json.loads(body) if body else None
if not request_data:
self._send_json({"error": "Invalid JSON"}, 400)
return
if simulator is None:
self._send_json({"error": "Simulator not initialized"}, 500)
return
response = simulator._process_request(request_data)
self._send_json(response, 200)
except json.JSONDecodeError:
self._send_json({"error": "Invalid JSON"}, 400)
except Exception as e:
print(f"Error processing request: {e}")
self._send_json({"error": str(e)}, 500)
def _send_json(self, data: dict, status: int = 200):
body = json.dumps(data).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def log_message(self, format, *args):
# Suppress default request logging
pass
def main():
parser = argparse.ArgumentParser(
description="llama-server simulator for testing eval scripts"
)
parser.add_argument(
"--port",
type=int,
default=8033,
help="Server port (default: 8033)"
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server host (default: localhost)"
)
parser.add_argument(
"--success-rate",
type=float,
default=0.8,
help="Success rate 0-1 (default: 0.8)"
)
parser.add_argument(
"--dataset",
type=str,
default="aime",
choices=["aime", "aime2025"],
help="Dataset type (default: aime)"
)
parser.add_argument(
"--dataset-split",
type=str,
default="train",
help="AIME dataset split to use (default: train)"
)
args = parser.parse_args()
global simulator
simulator = Simulator(
port=args.port,
host=args.host,
success_rate=args.success_rate,
dataset_split=args.dataset_split,
dataset_type=args.dataset
)
server = HTTPServer((args.host, args.port), RequestHandler)
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
server_thread.start()
print("\n=== llama-server-simulator ===")
print(f"Server running on http://{args.host}:{args.port}")
print(f"Success rate: {args.success_rate}")
print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions")
print("\nPress Ctrl+C to stop\n")
try:
server_thread.join()
except KeyboardInterrupt:
print("\nShutting down...")
server.shutdown()
if __name__ == "__main__":
main()