streamline grammar sampler to speed up generation while using heavy grammar (#1606)

This commit is contained in:
Reithan 2025-06-17 08:04:59 -07:00 committed by GitHub
parent ab29be54c4
commit f07434f4c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1572,6 +1572,10 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
std::vector<uint8_t> rejects;
candidates_decoded.reserve(candidates->size);
candidates_grammar.reserve(candidates->size);
rejects.assign(candidates->size, false);
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
@ -1579,25 +1583,24 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end(); bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
if (found_eog) { if (found_eog) {
if (!allow_eos) { if (!allow_eos) {
candidates->data[i].logit = -INFINITY; rejects[i] = true;
} }
} else if (piece.empty() || piece[0] == 0) { } else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY; rejects[i] = true;
} else { } else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
} }
} }
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); for (auto reject: llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar)) {
for (const auto & reject : rejects) { rejects[reject.index] = true;
candidates->data[reject.index].logit = -INFINITY;
} }
auto first = candidates->data; auto first = candidates->data;
auto last = first + candidates->size; auto last = first + candidates->size;
last = std::remove_if(first, last, last = std::remove_if(first, last,
[&](const llama_token_data & tk){ return tk.logit == -INFINITY; }); [&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; });
candidates->size = last - first; candidates->size = last - first;
} }