server : pre-calculate EOG logit biases (#14721)

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-07-16 14:04:12 +03:00 committed by GitHub
parent e4841d24d3
commit 6ffd4e9c44
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 17 additions and 15 deletions

View file

@ -1005,13 +1005,19 @@ struct common_init_result common_init_from_params(common_params & params) {
params.sampling.ignore_eos = false; params.sampling.ignore_eos = false;
} }
if (params.sampling.ignore_eos) { // initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) { if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY}); params.sampling.logit_bias_eog.push_back({i, -INFINITY});
} }
} }
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
} }
if (params.sampling.penalty_last_n == -1) { if (params.sampling.penalty_last_n == -1) {

View file

@ -178,6 +178,7 @@ struct common_params_sampling {
std::set<llama_token> preserved_tokens; std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// print the parameters into a string // print the parameters into a string
std::string print() const; std::string print() const;

View file

@ -473,12 +473,9 @@ struct server_task {
params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
if (params.sampling.ignore_eos) { if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { params.sampling.logit_bias.insert(
if (llama_vocab_is_eog(vocab, i)) { params.sampling.logit_bias.end(),
//SRV_DBG("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(ctx, i).c_str(), -INFINITY); defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
params.sampling.logit_bias.push_back({i, -INFINITY});
}
}
} }
} }
@ -1906,7 +1903,6 @@ struct server_context {
bool clean_kv_cache = true; bool clean_kv_cache = true;
bool add_bos_token = true; bool add_bos_token = true;
bool has_eos_token = false;
int32_t n_ctx; // total context for all clients / slots int32_t n_ctx; // total context for all clients / slots
@ -1965,7 +1961,6 @@ struct server_context {
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_vocab_get_add_bos(vocab); add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());