mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-16 19:59:16 +00:00
allow token bans and eos handling in
This commit is contained in:
parent
f273fd35b9
commit
80ce8a50b3
2 changed files with 37 additions and 11 deletions
|
|
@ -3417,6 +3417,7 @@ struct BatchGenerateRequest
|
|||
std::string prompt;
|
||||
std::string prompt_added_memory;
|
||||
std::vector<std::string> stop_sequences;
|
||||
std::vector<llama_logit_bias> logit_biases;
|
||||
int max_context_length = 0;
|
||||
int max_length = 0;
|
||||
int seed = 0;
|
||||
|
|
@ -3566,7 +3567,7 @@ static bool batch_inputs_eligible(const generation_inputs & inputs)
|
|||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f || !inputs.allow_eos_token)
|
||||
if(inputs.images_len > 0 || inputs.audio_len > 0 || inputs.guidance_scale != 1.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -3574,7 +3575,7 @@ static bool batch_inputs_eligible(const generation_inputs & inputs)
|
|||
{
|
||||
return false;
|
||||
}
|
||||
if(inputs.logit_biases_len > 0 || inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
|
||||
if(inputs.banned_tokens_len > 0 || inputs.dry_multiplier > 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
@ -3741,6 +3742,11 @@ static llama_sampler * batch_build_sampler(const BatchGenerateRequest & req)
|
|||
req.rep_pen,
|
||||
req.rep_pen_slope,
|
||||
req.presence_penalty));
|
||||
if(req.logit_biases.size()>0)
|
||||
{
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(llama_get_model(llama_ctx_v4)));
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_logit_bias(n_vocab, req.logit_biases.size(), req.logit_biases.data()));
|
||||
}
|
||||
if(req.top_k > 0)
|
||||
{
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(req.top_k));
|
||||
|
|
@ -3974,7 +3980,7 @@ static void batch_worker_loop()
|
|||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(llama_get_model(llama_ctx_v4));
|
||||
llama_token eos = llama_vocab_eos(vocab);
|
||||
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
|
||||
for(int request_id : decode_ids)
|
||||
{
|
||||
BatchGenerateRequest * req = batch_find_request_locked(request_id);
|
||||
|
|
@ -3983,12 +3989,9 @@ static void batch_worker_loop()
|
|||
continue;
|
||||
}
|
||||
llama_token sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
|
||||
if(!req->allow_eos_token && !req->bypass_eos_token && sampled == eos)
|
||||
{
|
||||
sampled = llama_sampler_sample(req->sampler, llama_ctx_v4, req->i_batch);
|
||||
}
|
||||
req->completion_token_count++;
|
||||
if(sampled == eos && !req->bypass_eos_token)
|
||||
bool is_eog = std::find(eog_tokens.begin(), eog_tokens.end(), sampled) != eog_tokens.end();
|
||||
if(is_eog && !req->bypass_eos_token)
|
||||
{
|
||||
batch_finish_request_locked(*req, stop_reason::EOS_TOKEN_HIT);
|
||||
continue;
|
||||
|
|
@ -4061,6 +4064,24 @@ int gpttype_batch_generate_submit(const generation_inputs inputs)
|
|||
req->allow_eos_token = inputs.allow_eos_token;
|
||||
req->bypass_eos_token = inputs.bypass_eos_token;
|
||||
req->render_special = inputs.render_special;
|
||||
req->logit_biases = {};
|
||||
for(int i = 0; i < inputs.logit_biases_len; ++i)
|
||||
{
|
||||
int32_t t_id = inputs.logit_biases[i].token_id;
|
||||
float bias = inputs.logit_biases[i].bias;
|
||||
if(t_id >= 0 && t_id < n_vocab && bias!=0)
|
||||
{
|
||||
req->logit_biases.push_back({t_id, bias});
|
||||
}
|
||||
}
|
||||
if(!req->allow_eos_token && !req->bypass_eos_token) //eos token bans
|
||||
{
|
||||
const std::vector<llama_token> eog_tokens = GetEogIDs(file_format,n_vocab);
|
||||
for(int x = 0; x < eog_tokens.size(); ++x)
|
||||
{
|
||||
req->logit_biases.push_back({eog_tokens[x], -999.0f});
|
||||
}
|
||||
}
|
||||
for(int i = 0; i < inputs.stop_sequence_len; ++i)
|
||||
{
|
||||
if(inputs.stop_sequence[i])
|
||||
|
|
|
|||
11
koboldcpp.py
11
koboldcpp.py
|
|
@ -2292,20 +2292,25 @@ def continuous_batching_python_eligible(genparams, api_format):
|
|||
if not getattr(args, "noshift", False) or getattr(args, "smartcontext", False) or getattr(args, "draftmodel", "") or getattr(args, "enableguidance", False):
|
||||
return False
|
||||
if genparams.get("negative_prompt") or genparams.get("images") or genparams.get("audio"):
|
||||
utfprint("Batching disabled due to media",0)
|
||||
return False
|
||||
if genparams.get("ban_eos_token", False):
|
||||
return False
|
||||
if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("logit_bias") or genparams.get("banned_tokens") or genparams.get("banned_strings"):
|
||||
if genparams.get("grammar") or genparams.get("grammar_retain_state") or genparams.get("banned_tokens") or genparams.get("banned_strings"):
|
||||
utfprint("Batching disabled due to grammar or bans",0)
|
||||
return False
|
||||
if tryparsefloat(genparams.get("dry_multiplier", 0), 0) or tryparseint(genparams.get("mirostat", 0), 0) or tryparsefloat(genparams.get("xtc_probability", 0), 0) or tryparsefloat(genparams.get("nsigma", 0), 0):
|
||||
utfprint("Batching disabled due to samplers set 1",0)
|
||||
return False
|
||||
if tryparsefloat(genparams.get("smoothing_factor", 0), 0) or tryparsefloat(genparams.get("adaptive_target", -1), -1) > 0 or genparams.get("using_openai_tools", False):
|
||||
utfprint("Batching disabled due to samplers set 2",0)
|
||||
return False
|
||||
if tryparsefloat(genparams.get("top_a", 0), 0) or tryparsefloat(genparams.get("tfs", 1), 1) != 1 or tryparsefloat(genparams.get("dynatemp_range", 0), 0):
|
||||
utfprint("Batching disabled due to samplers set 3",0)
|
||||
return False
|
||||
if genparams.get("sampler_order") and genparams.get("sampler_order") != [6, 0, 1, 3, 4, 2, 5]:
|
||||
utfprint("Batching disabled due to sampler order",0)
|
||||
return False
|
||||
if genparams.get("reasoning_effort"):
|
||||
utfprint("Batching disabled due to reasoning",0)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue