mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-09 16:44:35 +00:00
streamline grammar sampler to speed up generation while using heavy grammar (#1606)
This commit is contained in:
parent
ab29be54c4
commit
f07434f4c1
1 changed files with 10 additions and 7 deletions
|
@ -1572,6 +1572,10 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
|||
|
||||
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||
std::vector<llama_grammar_candidate> candidates_grammar;
|
||||
std::vector<uint8_t> rejects;
|
||||
candidates_decoded.reserve(candidates->size);
|
||||
candidates_grammar.reserve(candidates->size);
|
||||
rejects.assign(candidates->size, false);
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const llama_token id = candidates->data[i].id;
|
||||
|
@ -1579,25 +1583,24 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
|
|||
bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end();
|
||||
if (found_eog) {
|
||||
if (!allow_eos) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
rejects[i] = true;
|
||||
}
|
||||
} else if (piece.empty() || piece[0] == 0) {
|
||||
candidates->data[i].logit = -INFINITY;
|
||||
rejects[i] = true;
|
||||
} else {
|
||||
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
|
||||
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||
}
|
||||
}
|
||||
|
||||
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
||||
for (const auto & reject : rejects) {
|
||||
candidates->data[reject.index].logit = -INFINITY;
|
||||
for (auto reject: llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar)) {
|
||||
rejects[reject.index] = true;
|
||||
}
|
||||
|
||||
|
||||
auto first = candidates->data;
|
||||
auto last = first + candidates->size;
|
||||
last = std::remove_if(first, last,
|
||||
[&](const llama_token_data & tk){ return tk.logit == -INFINITY; });
|
||||
[&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; });
|
||||
candidates->size = last - first;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue