server-bench: external OAI servers, sqlite (#15179)

* server-bench: external OAI servers, sqlite

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update scripts/server-bench.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* raise_for_status

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Johannes Gäßler 2025-08-08 23:04:36 +02:00 committed by GitHub
parent cd6983d56d
commit 4850b52aed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,6 +4,7 @@ import argparse
import json import json
import os import os
import random import random
import sqlite3
import subprocess import subprocess
from time import sleep, time from time import sleep, time
from typing import Optional, Union from typing import Optional, Union
@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
def get_server(path_server: str, path_log: Optional[str]) -> dict: def get_server(path_server: str, path_log: Optional[str]) -> dict:
if path_server.startswith("http://") or path_server.startswith("https://"):
return {"process": None, "address": path_server, "fout": None}
if os.environ.get("LLAMA_ARG_HOST") is None: if os.environ.get("LLAMA_ARG_HOST") is None:
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
f"{server_address}/apply-template", f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
) )
if response.status_code != 200: response.raise_for_status()
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
prompt: str = json.loads(response.text)["prompt"] prompt: str = json.loads(response.text)["prompt"]
response = session.post( response = session.post(
f"{server_address}/tokenize", f"{server_address}/tokenize",
json={"content": prompt, "add_special": True} json={"content": prompt, "add_special": True}
) )
if response.status_code != 200: response.raise_for_status()
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
tokens: list[str] = json.loads(response.text)["tokens"] tokens: list[str] = json.loads(response.text)["tokens"]
return len(tokens) return len(tokens)
@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
server_address: str = data["server_address"] server_address: str = data["server_address"]
t_submit = time() t_submit = time()
if data["synthetic_prompt"]: if data["external_server"]:
json_data: dict = {
"prompt": data["prompt"], "ignore_eos": True,
"seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
elif data["synthetic_prompt"]:
json_data: dict = { json_data: dict = {
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False, "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True} "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
f"{server_address}/apply-template", f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
) )
if response.status_code != 200: response.raise_for_status()
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
prompt: str = json.loads(response.text)["prompt"] prompt: str = json.loads(response.text)["prompt"]
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/completion", json=json_data, stream=True) response = session.post(f"{server_address}/completion", json=json_data, stream=True)
response.raise_for_status()
lines = []
token_arrival_times: list[float] = [] token_arrival_times: list[float] = []
for line in response.iter_lines(decode_unicode=False): for line in response.iter_lines(decode_unicode=False):
if not line.startswith(b"data: "): if not line.startswith(b"data: "):
continue continue
lines.append(line)
token_arrival_times.append(time()) token_arrival_times.append(time())
token_arrival_times = token_arrival_times[:-1] token_arrival_times = token_arrival_times[:-1]
if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
if response.status_code != 200: token_arrival_times = token_arrival_times[:-1]
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
return (t_submit, token_arrival_times) return (t_submit, token_arrival_times)
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int): def benchmark(
path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
n_predict: int, n_predict_min: int, seed_offset: int):
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32" os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None: if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999") logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999" os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None: if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'") logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true" os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
else: else:
n_predict_min = n_predict n_predict_min = n_predict
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None: if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048))) context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
context_total: int = context_per_slot * parallel context_total: int = context_per_slot * parallel
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total) os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
try: try:
server = get_server(path_server, path_log) server = get_server(path_server, path_log)
server_address: str = server["address"] server_address: str = server["address"]
assert external_server == (server["process"] is None)
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session() session = requests.Session()
@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
if seed_offset >= 0: if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1) random.seed(3 * (seed_offset + 1000 * i) + 1)
data.append({ data.append({
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts, "session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1}) "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
"seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
if not synthetic_prompts: if not synthetic_prompts:
logger.info("Getting the prompt lengths...") logger.info("Getting the prompt lengths...")
@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
t0 = time() t0 = time()
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1) results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
finally: finally:
if server is not None: if server is not None and server["process"] is not None:
server["process"].terminate() server["process"].terminate()
server["process"].wait() server["process"].wait()
if session is not None: if session is not None:
@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens") logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s") logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
logger.info("")
logger.info( if path_db is not None:
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " con = sqlite3.connect(path_db)
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).") cursor = con.cursor()
cursor.execute(
"CREATE TABLE IF NOT EXISTS server_bench"
"(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
"n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
cursor.execute(
"INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
[name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
con.commit()
plt.figure() plt.figure()
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25) plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
plt.xlim(0, 1.05e0 * np.max(prompt_n)) plt.xlim(0, 1.05e0 * np.max(prompt_n))
plt.ylim(0, 1.05e3 * np.max(prompt_t)) plt.ylim(0, 1.05e3 * np.max(prompt_t))
plt.title(name or "")
plt.xlabel("Prompt length [tokens]") plt.xlabel("Prompt length [tokens]")
plt.ylabel("Time to first token [ms]") plt.ylabel("Time to first token [ms]")
plt.savefig("prompt_time.png", dpi=240) plt.savefig("prompt_time.png", dpi=240)
@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
plt.figure() plt.figure()
plt.hist(token_t, np.arange(0, bin_max)) plt.hist(token_t, np.arange(0, bin_max))
plt.xlim(0, bin_max + 1) plt.xlim(0, bin_max + 1)
plt.title(name or "")
plt.xlabel("Time [s]") plt.xlabel("Time [s]")
plt.ylabel("Num. tokens generated per second") plt.ylabel("Num. tokens generated per second")
plt.savefig("gen_rate.png", dpi=240) plt.savefig("gen_rate.png", dpi=240)
@ -259,9 +281,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
"Results are printed to console and visualized as plots (saved to current working directory). " "Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).") "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark") parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
parser.add_argument( parser.add_argument(
"--prompt_source", type=str, default="rng-1024-2048", "--prompt_source", type=str, default="rng-1024-2048",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or " help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "