mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-09 16:44:35 +00:00
Add memoized cache to llama_grammar_reject_candidates_for_stack
(#1615)
* Add memoized cache to llama_grammar_reject_candidates_for_stack * make size cutoff more aggressive and move to outer branch * update comment * add cache reset whenever grammar is reloaded * remove explicit reference types for compiler transportability
This commit is contained in:
parent
b884a7f058
commit
54dde5e565
2 changed files with 60 additions and 0 deletions
|
@ -1773,6 +1773,7 @@ static void load_grammar(const std::string & gammarstr)
|
|||
{
|
||||
if(grammar!=nullptr) //on demand free when next grammar is loaded
|
||||
{
|
||||
llama_grammar_reset_memos();
|
||||
llama_grammar_free_impl(grammar);
|
||||
grammar = nullptr;
|
||||
}
|
||||
|
|
|
@ -8,10 +8,34 @@
|
|||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
#include <iostream>
|
||||
#include <string_view>
|
||||
#include <unordered_set>
|
||||
|
||||
//
|
||||
// helpers
|
||||
//
|
||||
|
||||
using bytes = std::pair<const char*, size_t>;
|
||||
using hash_entry_size = std::pair<size_t, size_t>;
|
||||
|
||||
template <>
|
||||
struct std::hash<bytes>
|
||||
{
|
||||
std::size_t operator()(const bytes& x) const noexcept
|
||||
{
|
||||
return std::hash<std::string_view>{}({x.first, x.second});
|
||||
}
|
||||
};
|
||||
|
||||
using candidates_memos = std::unordered_map<size_t, llama_grammar_candidates>;
|
||||
using stack_memos = std::unordered_map<size_t, candidates_memos>;
|
||||
static stack_memos memo_cache;
|
||||
|
||||
static void llama_grammar_reset_memos() {
|
||||
memo_cache.clear();
|
||||
}
|
||||
|
||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||
|
@ -865,6 +889,38 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||
return rejects;
|
||||
}
|
||||
|
||||
auto stack_hash_start = reinterpret_cast<const char *>(stack.data());
|
||||
auto stack_hash_size = sizeof(stack[0]) * stack.size();
|
||||
auto stack_hash = std::hash<bytes>{}({ stack_hash_start, stack_hash_size });
|
||||
|
||||
llama_grammar_candidates * cache_target = nullptr;
|
||||
|
||||
// Tests show that >75% of candidate lists are under 1280 and 50% are under 640b.
|
||||
// Most 'problem' loops are under 24b. However, candidate lists can be over 72k,
|
||||
// so we need to limit our checks.
|
||||
|
||||
// We'll only attempt to memoize candidate lists under 80b
|
||||
// Doing an over-aggressive size cutoff first befor any other processing 'saves' easy cases
|
||||
// extra processing but still rescues 'hard' cases from slow down or hangs.
|
||||
// This leads to a speed up of both easy and hard cases.
|
||||
const size_t hash_cutoff = 80;
|
||||
auto candidates_hash_size = sizeof(candidates[0]) * candidates.size();
|
||||
if (candidates_hash_size < hash_cutoff) {
|
||||
// Only check stash hash first - these are usually ~24b, and almost always under 64b
|
||||
if (auto cache_hit = memo_cache.find(stack_hash); cache_hit != memo_cache.end()) {
|
||||
auto & candidates_memos = cache_hit->second;
|
||||
auto candidates_hash_start = reinterpret_cast<const char *>(candidates.data());
|
||||
auto candidates_hash = std::hash<bytes>{}({ candidates_hash_start, candidates_hash_size });
|
||||
if (auto cache_hit2 = candidates_memos.find(candidates_hash); cache_hit2 != candidates_memos.end()) {
|
||||
return cache_hit2->second;
|
||||
} else {
|
||||
cache_target = &(candidates_memos[candidates_hash]);
|
||||
}
|
||||
} else {
|
||||
memo_cache[stack_hash];
|
||||
}
|
||||
}
|
||||
|
||||
const llama_grammar_element * stack_pos = stack.back();
|
||||
|
||||
llama_grammar_candidates next_candidates;
|
||||
|
@ -900,6 +956,9 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
|||
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
||||
}
|
||||
|
||||
if (cache_target) {
|
||||
*cache_target = rejects;
|
||||
}
|
||||
return rejects;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue