mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-22 19:47:49 +00:00
# Conflicts: # .devops/cann.Dockerfile # .devops/cpu.Dockerfile # .devops/cuda.Dockerfile # .devops/intel.Dockerfile # .devops/llama-cli-cann.Dockerfile # .devops/musa.Dockerfile # .devops/openvino.Dockerfile # .devops/rocm.Dockerfile # .devops/s390x.Dockerfile # .devops/vulkan.Dockerfile # .github/ISSUE_TEMPLATE/011-bug-results.yml # .github/ISSUE_TEMPLATE/019-bug-misc.yml # .github/workflows/build-and-test-snapdragon.yml # .github/workflows/docker.yml # .github/workflows/server-self-hosted.yml # .github/workflows/ui-ci.yml # .pi/gg/SYSTEM.md # README.md # common/arg.cpp # docs/backend/SYCL.md # docs/backend/snapdragon/CMakeUserPresets.json # docs/backend/snapdragon/README.md # docs/speculative.md # examples/save-load-state/save-load-state.cpp # ggml/src/ggml-hexagon/ggml-hexagon.cpp # ggml/src/ggml-hexagon/htp/CMakeLists.txt # ggml/src/ggml-hexagon/htp/htp-ctx.h # ggml/src/ggml-hexagon/htp/htp-ops.h # ggml/src/ggml-hexagon/htp/main.c # ggml/src/ggml-hexagon/htp/rope-ops.c # ggml/src/ggml-hexagon/htp/unary-ops.c # ggml/src/ggml-opencl/CMakeLists.txt # ggml/src/ggml-opencl/ggml-opencl.cpp # ggml/src/ggml-opencl/kernels/cvt.cl # ggml/src/ggml-sycl/ggml-sycl.cpp # ggml/src/ggml-webgpu/ggml-webgpu.cpp # ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl # tools/cli/README.md # tools/server/README.md
376 lines
13 KiB
Python
376 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
import re
|
|
import time
|
|
import sys
|
|
import os
|
|
import threading
|
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
from typing import Dict, List, Optional
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import datasets
|
|
|
|
# Set cache directory for HuggingFace datasets
|
|
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
|
|
|
|
def dice(s1: str, s2: str) -> float:
|
|
"""Calculate Dice coefficient between two strings based on bigram overlap."""
|
|
if not s1 and not s2:
|
|
return 1.0
|
|
|
|
def _bigrams(s: str):
|
|
return [s[i : i + 2] for i in range(len(s) - 1)]
|
|
|
|
bigrams1 = _bigrams(s1)
|
|
bigrams2 = _bigrams(s2)
|
|
|
|
if not bigrams1 and not bigrams2:
|
|
return 1.0
|
|
|
|
from collections import Counter
|
|
|
|
freq1 = Counter(bigrams1)
|
|
freq2 = Counter(bigrams2)
|
|
|
|
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
|
|
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
|
|
return dice_coeff
|
|
|
|
def debug_log(message: str):
|
|
"""Log debug messages to both stdout and a file"""
|
|
print(message, file=sys.stderr)
|
|
with open("/tmp/simulator-debug.log", "a") as f:
|
|
f.write(message + "\n")
|
|
|
|
simulator: Optional["Simulator"] = None
|
|
|
|
@dataclass
|
|
class EvalState:
|
|
id: str
|
|
tasks: List[str]
|
|
task_states: Dict[str, Dict]
|
|
sampling_config: Dict
|
|
|
|
def normalize_number(s: str) -> Optional[int]:
|
|
match = re.match(r"\d+", s) # match digits from the start
|
|
if not match:
|
|
return None
|
|
return int(match.group(0))
|
|
|
|
class AimeDataset:
|
|
def __init__(self, split: str = "train", dataset_type: str = "aime"):
|
|
self.split = split
|
|
self.dataset_type = dataset_type
|
|
self.questions: List[Dict] = []
|
|
self._load_dataset()
|
|
|
|
def _get_question_text(self, question: Dict) -> str:
|
|
"""Get question text, handling different dataset field names."""
|
|
return question.get("problem", question.get("question", ""))
|
|
|
|
def _load_dataset(self):
|
|
if self.dataset_type == "aime":
|
|
print(f"Loading AIME dataset (split: {self.split})...")
|
|
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
|
if cache_path.exists():
|
|
print(f"Using cached dataset from {cache_path}")
|
|
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
|
else:
|
|
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
|
elif self.dataset_type == "aime2025":
|
|
print(f"Loading AIME2025 dataset...")
|
|
ds_list = []
|
|
for config_name in ["AIME2025-I", "AIME2025-II"]:
|
|
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "opencompass___AIME2025" / "default" / "0.0.0"
|
|
if cache_path.exists():
|
|
print(f"Using cached dataset from {cache_path}")
|
|
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test", cache_dir=str(cache_path))
|
|
else:
|
|
ds = datasets.load_dataset("opencompass/AIME2025", config_name, split="test")
|
|
ds_list.extend(ds)
|
|
ds = ds_list
|
|
else:
|
|
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
|
|
|
self.questions = list(ds)
|
|
print(f"{self.dataset_type} dataset loaded: {len(self.questions)} questions")
|
|
|
|
def find_question(self, request_text: str) -> Optional[Dict]:
|
|
# Strip common template prefixes to get the actual question text
|
|
# Templates include things like "Solve the following math problem step by step..."
|
|
# The actual question usually follows a blank line or after the template instruction
|
|
cleaned = request_text
|
|
# Split on double newline and take the part that looks like the problem
|
|
parts = cleaned.split('\n\n')
|
|
if len(parts) > 1:
|
|
# Find the part that's longest (likely the actual problem text)
|
|
problem_parts = [p for p in parts if len(p.strip()) > 100]
|
|
if problem_parts:
|
|
cleaned = max(problem_parts, key=lambda x: len(x))
|
|
|
|
best_match = None
|
|
best_distance = -1
|
|
best_index = -1
|
|
|
|
for i, question in enumerate(self.questions):
|
|
question_text = self._get_question_text(question)
|
|
request_lower = cleaned.lower()
|
|
question_lower = question_text.lower()
|
|
|
|
# Check if question text is contained in the cleaned request
|
|
if question_lower in request_lower or request_lower in question_lower:
|
|
debug_log(f"DEBUG: Found substring match at index {i}")
|
|
return question
|
|
|
|
# Exact match
|
|
if question_lower == request_lower:
|
|
debug_log(f"DEBUG: Found exact match at index {i}")
|
|
return question
|
|
|
|
# Remove LaTeX formatting for more flexible matching
|
|
question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
|
|
if question_no_latex.lower() == request_lower:
|
|
debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
|
|
return question
|
|
|
|
# Calculate Dice coefficient for partial matches
|
|
# Only consider if request is at least 50% of question length
|
|
if len(request_lower) >= len(question_lower) * 0.5:
|
|
distance = dice(question_lower, request_lower)
|
|
|
|
if distance > best_distance:
|
|
best_distance = distance
|
|
best_match = question
|
|
best_index = i
|
|
|
|
if best_match and best_distance > 0.3: # Threshold for partial match
|
|
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
|
|
return best_match
|
|
|
|
debug_log(f"DEBUG: No matching question found for cleaned: {cleaned[:100]}...")
|
|
return None
|
|
|
|
def get_answer(self, question: Dict) -> str:
|
|
answer = question["answer"]
|
|
if isinstance(answer, str):
|
|
normalized = normalize_number(answer)
|
|
return str(normalized) if normalized is not None else answer
|
|
return str(answer)
|
|
|
|
class Simulator:
|
|
def __init__(
|
|
self,
|
|
port: int = 8033,
|
|
host: str = "localhost",
|
|
success_rate: float = 0.8,
|
|
dataset_split: str = "train",
|
|
dataset_type: str = "aime"
|
|
):
|
|
self.port = port
|
|
self.host = host
|
|
self.success_rate = success_rate
|
|
self.dataset = AimeDataset(dataset_split, dataset_type)
|
|
self.eval_state = EvalState(
|
|
id=dataset_type,
|
|
tasks=[dataset_type],
|
|
task_states={},
|
|
sampling_config={"temperature": 0, "max_tokens": 2048}
|
|
)
|
|
|
|
def _generate_response(
|
|
self,
|
|
question: Dict,
|
|
should_be_correct: bool
|
|
) -> Dict:
|
|
expected_answer = self.dataset.get_answer(question)
|
|
|
|
if should_be_correct:
|
|
response_text = expected_answer
|
|
else:
|
|
response_text = self._generate_wrong_answer(question)
|
|
|
|
comp_tokens = random.randint(10000, 60000)
|
|
tps_gen = random.uniform(90.0, 110.0)
|
|
t_gen_ms = comp_tokens / tps_gen * 1000
|
|
|
|
return {
|
|
"id": f"chatcmpl-{int(time.time())}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": "llama",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": response_text
|
|
},
|
|
"finish_reason": "stop"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 100,
|
|
"completion_tokens": comp_tokens,
|
|
"total_tokens": 100 + comp_tokens
|
|
},
|
|
"timings": {
|
|
"predicted_ms": t_gen_ms,
|
|
"predicted_per_second": tps_gen
|
|
}
|
|
}
|
|
|
|
def _generate_wrong_answer(self, question: Dict) -> str:
|
|
expected_answer = self.dataset.get_answer(question)
|
|
|
|
if expected_answer.isdigit():
|
|
wrong_answer = str(int(expected_answer) + 1)
|
|
else:
|
|
wrong_answer = expected_answer + " (wrong)"
|
|
|
|
return wrong_answer
|
|
|
|
def _process_request(self, request_data: Dict) -> Dict:
|
|
messages = request_data.get("messages", [])
|
|
if not messages:
|
|
return {"error": "No messages in request"}
|
|
|
|
request_text = messages[0].get("content", "")
|
|
debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
|
|
|
|
question = self.dataset.find_question(request_text)
|
|
if not question:
|
|
debug_log(f"DEBUG: find_question returned None")
|
|
return {"error": "No matching question found"}
|
|
|
|
should_be_correct = random.random() < self.success_rate
|
|
|
|
response = self._generate_response(question, should_be_correct)
|
|
|
|
task_id = "aime"
|
|
self.eval_state.task_states[task_id] = {
|
|
"correct": should_be_correct,
|
|
"expected": self.dataset.get_answer(question),
|
|
"predicted": response["choices"][0]["message"]["content"]
|
|
}
|
|
|
|
return response
|
|
|
|
class RequestHandler(BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
if self.path == "/v1/models":
|
|
self._send_json({"data": [{"id": "llama", "object": "model"}]}, 200)
|
|
return
|
|
self._send_json({"error": "Not found"}, 404)
|
|
|
|
def do_POST(self):
|
|
if self.path != "/v1/chat/completions":
|
|
self._send_json({"error": "Not found"}, 404)
|
|
return
|
|
|
|
try:
|
|
content_length = int(self.headers.get("Content-Length", 0))
|
|
body = self.rfile.read(content_length)
|
|
request_data = json.loads(body) if body else None
|
|
|
|
if not request_data:
|
|
self._send_json({"error": "Invalid JSON"}, 400)
|
|
return
|
|
|
|
if simulator is None:
|
|
self._send_json({"error": "Simulator not initialized"}, 500)
|
|
return
|
|
|
|
response = simulator._process_request(request_data)
|
|
self._send_json(response, 200)
|
|
|
|
except json.JSONDecodeError:
|
|
self._send_json({"error": "Invalid JSON"}, 400)
|
|
except Exception as e:
|
|
print(f"Error processing request: {e}")
|
|
self._send_json({"error": str(e)}, 500)
|
|
|
|
def _send_json(self, data: dict, status: int = 200):
|
|
body = json.dumps(data).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
def log_message(self, format, *args):
|
|
# Suppress default request logging
|
|
pass
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="llama-server simulator for testing eval scripts"
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=8033,
|
|
help="Server port (default: 8033)"
|
|
)
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default="localhost",
|
|
help="Server host (default: localhost)"
|
|
)
|
|
parser.add_argument(
|
|
"--success-rate",
|
|
type=float,
|
|
default=0.8,
|
|
help="Success rate 0-1 (default: 0.8)"
|
|
)
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
default="aime",
|
|
choices=["aime", "aime2025"],
|
|
help="Dataset type (default: aime)"
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-split",
|
|
type=str,
|
|
default="train",
|
|
help="AIME dataset split to use (default: train)"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
global simulator
|
|
simulator = Simulator(
|
|
port=args.port,
|
|
host=args.host,
|
|
success_rate=args.success_rate,
|
|
dataset_split=args.dataset_split,
|
|
dataset_type=args.dataset
|
|
)
|
|
|
|
server = HTTPServer((args.host, args.port), RequestHandler)
|
|
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
|
|
server_thread.start()
|
|
|
|
print("\n=== llama-server-simulator ===")
|
|
print(f"Server running on http://{args.host}:{args.port}")
|
|
print(f"Success rate: {args.success_rate}")
|
|
print(f"{args.dataset} dataset loaded: {len(simulator.dataset.questions)} questions")
|
|
print("\nPress Ctrl+C to stop\n")
|
|
|
|
try:
|
|
server_thread.join()
|
|
except KeyboardInterrupt:
|
|
print("\nShutting down...")
|
|
server.shutdown()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|