koboldcpp/examples/llama-eval/llama-server-simulator.py
Concedo 7d987af23a Merge branch 'upstream' into concedo_experimental
# Conflicts:
#	.devops/cann.Dockerfile
#	.devops/cpu.Dockerfile
#	.devops/cuda.Dockerfile
#	.devops/intel.Dockerfile
#	.devops/llama-cli-cann.Dockerfile
#	.devops/musa.Dockerfile
#	.devops/openvino.Dockerfile
#	.devops/rocm.Dockerfile
#	.devops/s390x.Dockerfile
#	.devops/vulkan.Dockerfile
#	.github/ISSUE_TEMPLATE/011-bug-results.yml
#	.github/ISSUE_TEMPLATE/019-bug-misc.yml
#	.github/workflows/build-and-test-snapdragon.yml
#	.github/workflows/docker.yml
#	.github/workflows/server-self-hosted.yml
#	.github/workflows/ui-ci.yml
#	.pi/gg/SYSTEM.md
#	README.md
#	common/arg.cpp
#	docs/backend/SYCL.md
#	docs/backend/snapdragon/CMakeUserPresets.json
#	docs/backend/snapdragon/README.md
#	docs/speculative.md
#	examples/save-load-state/save-load-state.cpp
#	ggml/src/ggml-hexagon/ggml-hexagon.cpp
#	ggml/src/ggml-hexagon/htp/CMakeLists.txt
#	ggml/src/ggml-hexagon/htp/htp-ctx.h
#	ggml/src/ggml-hexagon/htp/htp-ops.h
#	ggml/src/ggml-hexagon/htp/main.c
#	ggml/src/ggml-hexagon/htp/rope-ops.c
#	ggml/src/ggml-hexagon/htp/unary-ops.c
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-opencl/kernels/cvt.cl
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	ggml/src/ggml-webgpu/ggml-webgpu.cpp
#	ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl
#	tools/cli/README.md
#	tools/server/README.md
2026-05-20 18:48:34 +08:00

376 lines
13 KiB
Python

#!/usr/bin/env python3
import argparse
import json
import random
import re
import time
import sys
import os
import threading
from http.server import HTTPServer, BaseHTTPRequestHandler
from typing import Dict, List, Optional
from dataclasses import dataclass
from pathlib import Path
import datasets
# Set cache directory for HuggingFace datasets
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
def dice(s1: str, s2: str) -> float:
"""Calculate Dice coefficient between two strings based on bigram overlap."""
if not s1 and not s2:
return 1.0
def _bigrams(s: str):
return [s[i : i + 2] for i in range(len(s) - 1)]
bigrams1 = _bigrams(s1)
bigrams2 = _bigrams(s2)
if not bigrams1 and not bigrams2:
return 1.0
from collections import Counter
freq1 = Counter(bigrams1)
freq2 = Counter(bigrams2)
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
return dice_coeff
def debug_log(message: str):
"""Log debug messages to both stdout and a file"""
print(message, file=sys.stderr)
with open("/tmp/simulator-debug.log", "a") as f:
f.write(message + "\n")
simulator: Optional["Simulator"] = None
@dataclass
class EvalState:
id: str
tasks: List[str]
task_states: Dict[str, Dict]
sampling_config: Dict
def normalize_number(s: str) -> Optional[int]:
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return int(match.group(0))
class AimeDataset:
def __init__(self, split: str = "train", dataset_type: str = "aime"):
self.split = split
self.dataset_type = dataset_type
self.questions: List[Dict] = []
self._load_dataset()
def _get_question_text(self, question: Dict) -> str:
"""Get question text, handling different dataset field names."""
return question.get("problem", question.get("question", ""))
def _load_dataset(self):
if self.dataset_type == "aime":
print(f"Loading AIME dataset (split: {self.split})...")
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
else:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
elif self.dataset_type == "aime2025":
print(f"Loading AIME2025 dataset...")
ds_list = []
for config_name in ["AIME2025-I", "AIME2025-II"]:
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0"
if cache_path.exists():
print(f"Using cached dataset from {cache_path}")
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
else:
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test")
ds_list.extend(ds)
ds = ds_list
else:
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
self.questions = list(ds)
print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions")
def find_question(self, request_text: str) -> Optional[Dict]:
# Strip common template prefixes to get the actual question text
# Templates include things like "Solve the following math problem step by step..."
# The actual question usually follows a blank line or after the template instruction
cleaned = request_text
# Split on double newline and take the part that looks like the problem
parts = cleaned.split('\n\n')
if len(parts) > 1:
# Find the part that's longest (likely the actual problem text)
problem_parts = [p for p in parts if len(p.strip()) > 100]
if problem_parts:
cleaned = max(problem_parts, key=lambda x: len(x))
best_match = None
best_distance = -1
best_index = -1
for i, question in enumerate(self.questions):
question_text = self._get_question_text(question)
request_lower = cleaned.lower()
question_lower = question_text.lower()
# Check if question text is contained in the cleaned request
if question_lower in request_lower or request_lower in question_lower:
debug_log(f"DEBUG: Found substring match at index {i}")
return question
# Exact match
if question_lower == request_lower:
debug_log(f"DEBUG: Found exact match at index {i}")
return question
# Remove LaTeX formatting for more flexible matching
question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
if question_no_latex.lower() == request_lower:
debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
return question
# Calculate Dice coefficient for partial matches
# Only consider if request is at least 50% of question length
if len(request_lower) >= len(question_lower) * 0.5:
distance = dice(question_lower, request_lower)
if distance > best_distance:
best_distance = distance
best_match = question
best_index = i
if best_match and best_distance > 0.3: # Threshold for partial match
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
return best_match
debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...")
return None
def get_answer(self, question: Dict) -> str:
answer = question["answer"]
if isinstance(answer, str):
normalized = normalize_number(answer)
return str(normalized) if normalized is not None else answer
return str(answer)
class Simulator:
def __init__(
self,
port: int = 8033,
host: str = "localhost",
success_rate: float = 0.8,
dataset_split: str = "train",
dataset_type: str = "aime"
):
self.port = port
self.host = host
self.success_rate = success_rate
self.dataset = AimeDataset(dataset_split, dataset_type)
self.eval_state = EvalState(
id=dataset_type,
tasks=[dataset_type],
task_states={},
sampling_config={"temperature": 0, "max_tokens": 2048}
)
def _generate_response(
self,
question: Dict,
should_be_correct: bool
) -> Dict:
expected_answer = self.dataset.get_answer(question)
if should_be_correct:
response_text = expected_answer
else:
response_text = self._generate_wrong_answer(question)
comp_tokens = random.randint(10000, 60000)
tps_gen = random.uniform(90.0, 110.0)
t_gen_ms = comp_tokens / tps_gen * 1000
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": "llama",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": comp_tokens,
"total_tokens": 100 + comp_tokens
},
"timings": {
"predicted_ms": t_gen_ms,
"predicted_per_second": tps_gen
}
}
def _generate_wrong_answer(self, question: Dict) -> str:
expected_answer = self.dataset.get_answer(question)
if expected_answer.isdigit():
wrong_answer = str(int(expected_answer) + 1)
else:
wrong_answer = expected_answer + " (wrong)"
return wrong_answer
def _process_request(self, request_data: Dict) -> Dict:
messages = request_data.get("messages", [])
if not messages:
return {"error": "No messages in request"}
request_text = messages[0].get("content", "")
debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
question = self.dataset.find_question(request_text)
if not question:
debug_log(f"DEBUG: find_question returned None")
return {"error": "No matching question found"}
should_be_correct = random.random() < self.success_rate
response = self._generate_response(question, should_be_correct)
task_id = "aime"
self.eval_state.task_states[task_id] = {
"correct": should_be_correct,
"expected": self.dataset.get_answer(question),
"predicted": response["choices"][0]["message"]["content"]
}
return response
class RequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/v1/models":
self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200)
return
self._send_json({"error": "Not found"}, 404)
def do_POST(self):
if self.path != "/v1/chat/completions":
self._send_json({"error": "Not found"}, 404)
return
try:
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length)
request_data = json.loads(body) if body else None
if not request_data:
self._send_json({"error": "Invalid JSON"}, 400)
return
if simulator is None:
self._send_json({"error": "Simulator not initialized"}, 500)
return
response = simulator._process_request(request_data)
self._send_json(response, 200)
except json.JSONDecodeError:
self._send_json({"error": "Invalid JSON"}, 400)
except Exception as e:
print(f"Error processing request: {e}")
self._send_json({"error": str(e)}, 500)
def _send_json(self, data: dict, status: int = 200):
body = json.dumps(data).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def log_message(self, format, *args):
# Suppress default request logging
pass
def main():
parser = argparse.ArgumentParser(
description="llama-server simulator for testing eval scripts"
)
parser.add_argument(
"--port",
type=int,
default=8033,
help="Server port (default: 8033)"
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Server host (default: localhost)"
)
parser.add_argument(
"--success-rate",
type=float,
default=0.8,
help="Success rate 0-1 (default: 0.8)"
)
parser.add_argument(
"--dataset",
type=str,
default="aime",
choices=["aime", "aime2025"],
help="Dataset type (default: aime)"
)
parser.add_argument(
"--dataset-split",
type=str,
default="train",
help="AIME dataset split to use (default: train)"
)
args = parser.parse_args()
global simulator
simulator = Simulator(
port=args.port,
host=args.host,
success_rate=args.success_rate,
dataset_split=args.dataset_split,
dataset_type=args.dataset
)
server = HTTPServer((args.host, args.port), RequestHandler)
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
server_thread.start()
print("\n=== llama-server-simulator ===")
print(f"Server running on http://{args.host}:{args.port}")
print(f"Success rate: {args.success_rate}")
print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions")
print("\nPress Ctrl+C to stop\n")
try:
server_thread.join()
except KeyboardInterrupt:
print("\nShutting down...")
server.shutdown()
if __name__ == "__main__":
main()