diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 1e427f276..bc0437056 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1572,6 +1572,10 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar std::vector, llama_partial_utf8>> candidates_decoded; std::vector candidates_grammar; + std::vector rejects; + candidates_decoded.reserve(candidates->size); + candidates_grammar.reserve(candidates->size); + rejects.assign(candidates->size, false); for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; @@ -1579,25 +1583,24 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar bool found_eog = std::find(eog_tokens.begin(), eog_tokens.end(), id) != eog_tokens.end(); if (found_eog) { if (!allow_eos) { - candidates->data[i].logit = -INFINITY; + rejects[i] = true; } } else if (piece.empty() || piece[0] == 0) { - candidates->data[i].logit = -INFINITY; + rejects[i] = true; } else { candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); - for (const auto & reject : rejects) { - candidates->data[reject.index].logit = -INFINITY; + for (auto reject: llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar)) { + rejects[reject.index] = true; } - + auto first = candidates->data; auto last = first + candidates->size; last = std::remove_if(first, last, - [&](const llama_token_data & tk){ return tk.logit == -INFINITY; }); + [&](const llama_token_data & tk){ return rejects[&tk - first]; }); // tk.logit == -INFINITY; }); candidates->size = last - first; }