From f07434f4c19eeb80160c86325ec73a4575c93470 Mon Sep 17 00:00:00 2001 From: Reithan Date: Tue, 17 Jun 2025 08:04:59 -0700 Subject: [PATCH] streamline grammar sampler to speed up generation while using heavy grammar (#1606) --- gpttype_adapter.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 1e427f276..bc0437056 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1572,6 +1572,10 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar std::vector, llama_partial_utf8>> candidates_decoded; std::vector candidates_grammar; + std::vector rejects; + candidates_decoded.reserve(candidates->size); + candidates_grammar.reserve(candidates->size); + rejects.assign(candidates->size, false); for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; @@ -1579,25 +1583,24 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end(); if (found_eog) { if (!allow_eos) { - candidates->data[i].logit = -INFINITY; + rejects[i] = true; } } else if (piece.empty() || piece[0] == 0) { - candidates->data[i].logit = -INFINITY; + rejects[i] = true; } else { candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); - for (const auto & reject : rejects) { - candidates->data[reject.index].logit = -INFINITY; + for (auto reject: llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar)) { + rejects[reject.index] = true; } - + auto first = candidates->data; auto last = first + candidates->size; last = std::remove_if(first, last, - [&](const llama_token_data & tk){ return tk.logit == -INFINITY; }); + [&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; }); candidates->size = last - first; }