backend sampling: support returning post-sampling probs (#22622)

* server: Never return 0.0 post-sampling probabilities

* backend sampling: support returning post-sampling probs
This commit is contained in:
Tim Neumann 2026-05-10 19:12:02 +02:00 committed by GitHub
parent 5d5d2e15d2
commit 2e97c5f96f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 80 additions and 16 deletions

View file

@ -547,6 +547,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
gsmpl->set_logits(ctx, idx);
// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
@ -558,17 +560,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
// TODO: simplify
gsmpl->cur.resize(1);
gsmpl->cur[0] = { id, 0.0f, 1.0f };
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
for (size_t i = 0; i < cur_p.size; ++i) {
if (cur_p.data[i].id == id) {
cur_p.selected = i;
break;
}
}
return id;
}
}
gsmpl->set_logits(ctx, idx);
// apply reasoning budget first
llama_sampler_apply(rbudget, &cur_p);

View file

@ -1317,7 +1317,7 @@ private:
return false;
}
const bool need_logits = task.params.sampling.n_probs > 0;
const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs;
bool backend_sampling = true;
@ -1326,8 +1326,8 @@ private:
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
// TODO: getting pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_pre_sample_logits;
// TODO: tmp until backend sampling is fully implemented
if (backend_sampling) {
@ -1504,6 +1504,12 @@ private:
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < n_probs; i++) {
// Some samplers do return 0.0 probabilities, others don't.
// Filter 0.0 probailities, to ensure the behavior is consistent.
if (cur_p->data[i].p == 0.0) {
break;
}
result.probs.push_back({
cur_p->data[i].id,
common_token_to_piece(ctx, cur_p->data[i].id, special),

View file

@ -491,29 +491,82 @@ def test_n_probs_post_sampling():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"prompt": "Today was the day. Today I would finally become a",
"n_probs": 10,
"temperature": 0.0,
"temperature": 1.0,
"n_predict": 5,
"post_sampling_probs": True,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
for (i, tok) in enumerate(res.body["completion_probabilities"]):
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_probs"]) == 10
assert "top_probs" in tok and type(tok["top_probs"]) == list
for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
# 0.0 probability tokens should never be returned by the server
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
if i == 0:
# The prompt is vague enough that we should get at least 10 possibilities
# for the first token.
assert len(tok["top_probs"]) == 10
if len(tok["top_probs"]) < 10:
# Getting less than the requested number of probabilities should only happen
# if the ones we did get already sum to 1.0.
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)
def test_n_probs_post_backend_sampling():
"""Verify that the same probabilities are returned with and without backend sampling."""
global server
server.backend_sampling = True
server.start()
def make_request(backend_sampling):
n_predict = 20
res = server.make_request("POST", "/completion", data={
"prompt": "The countries of Europe, in random order, are:",
"n_probs": 10,
"n_predict": n_predict,
"post_sampling_probs": True,
"seed": 4242,
"backend_sampling": backend_sampling,
})
assert res.status_code == 200
total_probs = 0
completions = res.body["completion_probabilities"]
assert len(completions) == n_predict
for tok in completions:
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
# data.
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
total_probs += len(tok["top_probs"])
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
assert total_probs >= 2 * n_predict
return completions
def verify_token(a, b):
assert a["id"] == b["id"]
assert a["token"] == b["token"]
assert a["bytes"] == b["bytes"]
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)
for (a, b) in zip(make_request(True), make_request(False)):
verify_token(a, b)
assert len(a["top_probs"]) == len(b["top_probs"])
for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
verify_token(aa, bb)
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
def test_logit_bias(tokenize, openai_style):

View file

@ -108,6 +108,7 @@ class ServerProcess:
no_cache_idle_slots: bool = False
log_path: str | None = None
webui_mcp_proxy: bool = False
backend_sampling: bool = False
gcp_compat: bool = False
# session variables
@ -252,6 +253,8 @@ class ServerProcess:
server_args.append("--no-cache-idle-slots")
if self.webui_mcp_proxy:
server_args.append("--webui-mcp-proxy")
if self.backend_sampling:
server_args.append("--backend_sampling")
if self.gcp_compat:
env["AIP_MODE"] = "PREDICTION"