Improve GNBF performance by attempting culled grammar search first (#1597)

* cull tokens with top_3k first before running grammar, fallback to unculled if none found

* fix errors

* fix improvement and test against concedo's GBNF

* revert non-culling changes
This commit is contained in:
Reithan 2025-06-13 00:57:27 -07:00 committed by GitHub
parent 1cbe716e45
commit 5af9138ebe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1594,6 +1594,11 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
candidates->data[reject.index].logit = -INFINITY;
}
auto first = candidates->data;
auto last = first + candidates->size;
last = std::remove_if(first, last,
[&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
candidates->size = last - first;
}
void sample_guidance(struct llama_context * ctx, struct llama_context * guidance_ctx, int n_vocab, float scale)
@ -1643,16 +1648,31 @@ const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dyna
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
if (grammar != nullptr) {
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
}
//dry always first as logits cannot be resorted
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);
//prefilter to top 3k tokens for improved speed
bool use_grammar = grammar != nullptr;
size_t n_pre_cull = candidates_p.size;
sample_top_k(&candidates_p, 3000);
if (use_grammar) {
(debugmode == 1 && printf("\nGrammar sampling %zu candidates.\n", candidates_p.size));
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
(debugmode == 1 && printf("\nGrammar returned %zu candidates.\n", candidates_p.size));
// if top_k 3000 doesn't contain a valid candidate for this grammar, try again pre-cull
if (candidates_p.size <= 0) {
candidates_p.size = n_pre_cull;
(debugmode == 1 && printf("\nRe-sampling grammar with %zu pre-cull tokens.\n", candidates_p.size));
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
(debugmode == 1 && printf("\nGrammar returned %zu candidates.\n", candidates_p.size));
sample_top_k(&candidates_p, 3000);
}
}
if (mirostat == 1 || mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
@ -1745,7 +1765,6 @@ static void grammar_accept_token(FileFormat file_format, int32_t n_vocab, struct
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar, *it);
}
grammar->partial_utf8 = decoded.second;
@ -3941,6 +3960,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
if (grammar != nullptr) {
(debugmode == 1 && printf("\nGrammar attempting to accept token...\n"));
grammar_accept_token(file_format, n_vocab, grammar, id);
}