From 4850b52aedceeb70bb4fe49f2d7cd1df6ee98682 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 8 Aug 2025 23:04:36 +0200 Subject: [PATCH] server-bench: external OAI servers, sqlite (#15179) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * server-bench: external OAI servers, sqlite * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret * raise_for_status --------- Co-authored-by: Sigbjørn Skjæret --- scripts/server-bench.py | 70 ++++++++++++++++++++++++++++------------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/scripts/server-bench.py b/scripts/server-bench.py index 9326be8d5..a71602017 100755 --- a/scripts/server-bench.py +++ b/scripts/server-bench.py @@ -4,6 +4,7 @@ import argparse import json import os import random +import sqlite3 import subprocess from time import sleep, time from typing import Optional, Union @@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]: def get_server(path_server: str, path_log: Optional[str]) -> dict: + if path_server.startswith("http://") or path_server.startswith("https://"): + return {"process": None, "address": path_server, "fout": None} if os.environ.get("LLAMA_ARG_HOST") is None: logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" @@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int: f"{server_address}/apply-template", json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} ) - if response.status_code != 200: - raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + response.raise_for_status() prompt: str = json.loads(response.text)["prompt"] response = session.post( f"{server_address}/tokenize", json={"content": prompt, "add_special": True} ) - if response.status_code != 200: - raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + response.raise_for_status() tokens: list[str] = json.loads(response.text)["tokens"] return len(tokens) @@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]: server_address: str = data["server_address"] t_submit = time() - if data["synthetic_prompt"]: + if data["external_server"]: + json_data: dict = { + "prompt": data["prompt"], "ignore_eos": True, + "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True} + response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True) + elif data["synthetic_prompt"]: json_data: dict = { "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} @@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]: f"{server_address}/apply-template", json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} ) - if response.status_code != 200: - raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + response.raise_for_status() prompt: str = json.loads(response.text)["prompt"] json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} response = session.post(f"{server_address}/completion", json=json_data, stream=True) + response.raise_for_status() + lines = [] token_arrival_times: list[float] = [] for line in response.iter_lines(decode_unicode=False): if not line.startswith(b"data: "): continue + lines.append(line) token_arrival_times.append(time()) token_arrival_times = token_arrival_times[:-1] - - if response.status_code != 200: - raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]): + token_arrival_times = token_arrival_times[:-1] return (t_submit, token_arrival_times) -def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int): +def benchmark( + path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int, + n_predict: int, n_predict_min: int, seed_offset: int): + external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") os.environ["LLAMA_ARG_N_PARALLEL"] = "32" - if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None: + if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None: logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999") os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999" - if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None: + if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None: logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'") os.environ["LLAMA_ARG_FLASH_ATTN"] = "true" @@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p else: n_predict_min = n_predict - if os.environ.get("LLAMA_ARG_CTX_SIZE") is None: + if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None: context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048))) context_total: int = context_per_slot * parallel os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total) @@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p try: server = get_server(path_server, path_log) server_address: str = server["address"] + assert external_server == (server["process"] is None) adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore session = requests.Session() @@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p if seed_offset >= 0: random.seed(3 * (seed_offset + 1000 * i) + 1) data.append({ - "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts, - "n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1}) + "session": session, "server_address": server_address, "external_server": external_server, "prompt": p, + "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict), + "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1}) if not synthetic_prompts: logger.info("Getting the prompt lengths...") @@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p t0 = time() results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1) finally: - if server is not None: + if server is not None and server["process"] is not None: server["process"].terminate() server["process"].wait() if session is not None: @@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens") logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s") logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") - logger.info("") - logger.info( - "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " - "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).") + + if path_db is not None: + con = sqlite3.connect(path_db) + cursor = con.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS server_bench" + "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, " + "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);") + cursor.execute( + "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);", + [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last]) + con.commit() plt.figure() plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25) plt.xlim(0, 1.05e0 * np.max(prompt_n)) plt.ylim(0, 1.05e3 * np.max(prompt_t)) + plt.title(name or "") plt.xlabel("Prompt length [tokens]") plt.ylabel("Time to first token [ms]") plt.savefig("prompt_time.png", dpi=240) @@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p plt.figure() plt.hist(token_t, np.arange(0, bin_max)) plt.xlim(0, bin_max + 1) + plt.title(name or "") plt.xlabel("Time [s]") plt.ylabel("Num. tokens generated per second") plt.savefig("gen_rate.png", dpi=240) @@ -259,9 +281,13 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " "Results are printed to console and visualized as plots (saved to current working directory). " - "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).") + "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). " + "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " + "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).") parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark") + parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in") + parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with") parser.add_argument( "--prompt_source", type=str, default="rng-1024-2048", help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "