allow token bans and eos handling in

This commit is contained in:
Concedo 2026-05-16 15:20:46 +08:00
parent f273fd35b9
commit 80ce8a50b3
2 changed files with 37 additions and 11 deletions

View file

@ -3417,6 +3417,7 @@ struct BatchGenerateRequest
std::string prompt;
std::string prompt_added_memory;
std::vector<std::string> stop_sequences;
std::vector<llama_logit_bias> logit_biases;
int max_context_length = 0;
int max_length = 0;
int seed = 0;
@ -3566,7 +3567,7 @@ static bool batch_inputs_eligible(const generation_inputs & inputs)
{
return false;
}
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token)
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f)
{
return false;
}
@ -3574,7 +3575,7 @@ static bool batch_inputs_eligible(const generation_inputs & inputs)
{
return false;
}
if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
if(inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
{
return false;
}
@ -3741,6 +3742,11 @@ static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req)
req.rep_pen,
req.rep_pen_slope,
req.presence_penalty));
if(req.logit_biases.size()>0)
{
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(llama_ctx_v4)));
llama_sampler_chain_add(chain, llama_sampler_init_logit_bias(n_vocab, req.logit_biases.size(), req.logit_biases.data()));
}
if(req.top_k > 0)
{
llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k));
@ -3974,7 +3980,7 @@ static void batch_worker_loop()
}
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
llama_token eos = llama_vocab_eos(vocab);
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
for(int request_id : decode_ids)
{
BatchGenerateRequest * req = batch_find_request_locked(request_id);
@ -3983,12 +3989,9 @@ static void batch_worker_loop()
continue;
}
llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos)
{
sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
}
req->completion_token_count++;
if(sampled == eos && !req->bypass_eos_token)
bool is_eog = std::find(eog_tokens.begin(), eog_tokens.end(), sampled) != eog_tokens.end();
if(is_eog && !req->bypass_eos_token)
{
batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT);
continue;
@ -4061,6 +4064,24 @@ int gpttype_batch_generate_submit(const generation_inputs inputs)
req->allow_eos_token = inputs.allow_eos_token;
req->bypass_eos_token = inputs.bypass_eos_token;
req->render_special = inputs.render_special;
req->logit_biases = {};
for(int i = 0; i < inputs.logit_biases_len; ++i)
{
int32_t t_id = inputs.logit_biases[i].token_id;
float bias = inputs.logit_biases[i].bias;
if(t_id >= 0 && t_id < n_vocab && bias!=0)
{
req->logit_biases.push_back({t_id, bias});
}
}
if(!req->allow_eos_token && !req->bypass_eos_token) //eos token bans
{
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
for(int x = 0; x < eog_tokens.size(); ++x)
{
req->logit_biases.push_back({eog_tokens[x], -999.0f});
}
}
for(int i = 0; i < inputs.stop_sequence_len; ++i)
{
if(inputs.stop_sequence[i])

View file

@ -2292,20 +2292,25 @@ def continuous_batching_python_eligible(genparams, api_format):
if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "enableguidance", False):
return False
if genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"):
utfprint("Batching disabled due to media",0)
return False
if genparams.get("ban_eos_token", False):
return False
if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("logit_bias") or genparams.get("banned_tokens") or genparams.get("banned_strings"):
if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("banned_tokens") or genparams.get("banned_strings"):
utfprint("Batching disabled due to grammar or bans",0)
return False
if tryparsefloat(genparams.get("dry_multiplier", 0), 0) or tryparseint(genparams.get("mirostat", 0), 0) or tryparsefloat(genparams.get("xtc_probability", 0), 0) or tryparsefloat(genparams.get("nsigma", 0), 0):
utfprint("Batching disabled due to samplers set 1",0)
return False
if tryparsefloat(genparams.get("smoothing_factor", 0), 0) or tryparsefloat(genparams.get("adaptive_target", -1), -1) > 0 or genparams.get("using_openai_tools", False):
utfprint("Batching disabled due to samplers set 2",0)
return False
if tryparsefloat(genparams.get("top_a", 0), 0) or tryparsefloat(genparams.get("tfs", 1), 1) != 1 or tryparsefloat(genparams.get("dynatemp_range", 0), 0):
utfprint("Batching disabled due to samplers set 3",0)
return False
if genparams.get("sampler_order") and genparams.get("sampler_order") != [6, 0, 1, 3, 4, 2, 5]:
utfprint("Batching disabled due to sampler order",0)
return False
if genparams.get("reasoning_effort"):
utfprint("Batching disabled due to reasoning",0)
return False
return True