streamline grammar sampler to speed up generation while using heavy grammar (#1606)

This commit is contained in:
Reithan 2025-06-17 08:04:59 -07:00 committed by GitHub
parent ab29be54c4
commit f07434f4c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1572,6 +1572,10 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
std::vector<uint8_t> rejects;
candidates_decoded.reserve(candidates->size);
candidates_grammar.reserve(candidates->size);
rejects.assign(candidates->size, false);
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
@ -1579,25 +1583,24 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
if (found_eog) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
rejects[i] = true;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
rejects[i] = true;
} else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
for (const auto & reject : rejects) {
candidates->data[reject.index].logit = -INFINITY;
for (auto reject: llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar)) {
rejects[reject.index] = true;
}
auto first = candidates->data;
auto last = first + candidates->size;
last = std::remove_if(first, last,
[&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
[&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; });
candidates->size = last - first;
}