mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-16 19:59:16 +00:00
backend sampling: support returning post-sampling probs (#22622)
* server: Never return 0.0 post-sampling probabilities * backend sampling: support returning post-sampling probs
This commit is contained in:
parent
5d5d2e15d2
commit
2e97c5f96f
4 changed files with 80 additions and 16 deletions
|
|
@ -547,6 +547,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
|
|
@ -558,17 +560,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
|
||||
|
||||
// TODO: simplify
|
||||
gsmpl->cur.resize(1);
|
||||
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||
for (size_t i = 0; i < cur_p.size; ++i) {
|
||||
if (cur_p.data[i].id == id) {
|
||||
cur_p.selected = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
// apply reasoning budget first
|
||||
llama_sampler_apply(rbudget, &cur_p);
|
||||
|
||||
|
|
|
|||
|
|
@ -1317,7 +1317,7 @@ private:
|
|||
return false;
|
||||
}
|
||||
|
||||
const bool need_logits = task.params.sampling.n_probs > 0;
|
||||
const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs;
|
||||
|
||||
bool backend_sampling = true;
|
||||
|
||||
|
|
@ -1326,8 +1326,8 @@ private:
|
|||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
|
||||
|
||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_logits;
|
||||
// TODO: getting pre sampling logits is not yet supported with backend sampling
|
||||
backend_sampling &= !need_pre_sample_logits;
|
||||
|
||||
// TODO: tmp until backend sampling is fully implemented
|
||||
if (backend_sampling) {
|
||||
|
|
@ -1504,6 +1504,12 @@ private:
|
|||
// set probability for top n_probs tokens
|
||||
result.probs.reserve(n_probs);
|
||||
for (size_t i = 0; i < n_probs; i++) {
|
||||
// Some samplers do return 0.0 probabilities, others don't.
|
||||
// Filter 0.0 probailities, to ensure the behavior is consistent.
|
||||
if (cur_p->data[i].p == 0.0) {
|
||||
break;
|
||||
}
|
||||
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
common_token_to_piece(ctx, cur_p->data[i].id, special),
|
||||
|
|
|
|||
|
|
@ -491,29 +491,82 @@ def test_n_probs_post_sampling():
|
|||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"prompt": "Today was the day. Today I would finally become a",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"temperature": 1.0,
|
||||
"n_predict": 5,
|
||||
"post_sampling_probs": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
for (i, tok) in enumerate(res.body["completion_probabilities"]):
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_probs"]) == 10
|
||||
assert "top_probs" in tok and type(tok["top_probs"]) == list
|
||||
|
||||
for prob in tok["top_probs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
# 0.0 probability tokens should never be returned by the server
|
||||
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
|
||||
if i == 0:
|
||||
# The prompt is vague enough that we should get at least 10 possibilities
|
||||
# for the first token.
|
||||
assert len(tok["top_probs"]) == 10
|
||||
|
||||
if len(tok["top_probs"]) < 10:
|
||||
# Getting less than the requested number of probabilities should only happen
|
||||
# if the ones we did get already sum to 1.0.
|
||||
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)
|
||||
|
||||
def test_n_probs_post_backend_sampling():
|
||||
"""Verify that the same probabilities are returned with and without backend sampling."""
|
||||
global server
|
||||
server.backend_sampling = True
|
||||
server.start()
|
||||
|
||||
def make_request(backend_sampling):
|
||||
n_predict = 20
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "The countries of Europe, in random order, are:",
|
||||
"n_probs": 10,
|
||||
"n_predict": n_predict,
|
||||
"post_sampling_probs": True,
|
||||
"seed": 4242,
|
||||
"backend_sampling": backend_sampling,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
||||
total_probs = 0
|
||||
completions = res.body["completion_probabilities"]
|
||||
assert len(completions) == n_predict
|
||||
for tok in completions:
|
||||
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
|
||||
# data.
|
||||
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
|
||||
total_probs += len(tok["top_probs"])
|
||||
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
|
||||
assert total_probs >= 2 * n_predict
|
||||
return completions
|
||||
|
||||
def verify_token(a, b):
|
||||
assert a["id"] == b["id"]
|
||||
assert a["token"] == b["token"]
|
||||
assert a["bytes"] == b["bytes"]
|
||||
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)
|
||||
|
||||
for (a, b) in zip(make_request(True), make_request(False)):
|
||||
verify_token(a, b)
|
||||
assert len(a["top_probs"]) == len(b["top_probs"])
|
||||
|
||||
for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
|
||||
verify_token(aa, bb)
|
||||
|
||||
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
|
||||
def test_logit_bias(tokenize, openai_style):
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ class ServerProcess:
|
|||
no_cache_idle_slots: bool = False
|
||||
log_path: str | None = None
|
||||
webui_mcp_proxy: bool = False
|
||||
backend_sampling: bool = False
|
||||
gcp_compat: bool = False
|
||||
|
||||
# session variables
|
||||
|
|
@ -252,6 +253,8 @@ class ServerProcess:
|
|||
server_args.append("--no-cache-idle-slots")
|
||||
if self.webui_mcp_proxy:
|
||||
server_args.append("--webui-mcp-proxy")
|
||||
if self.backend_sampling:
|
||||
server_args.append("--backend_sampling")
|
||||
if self.gcp_compat:
|
||||
env["AIP_MODE"] = "PREDICTION"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue