Add memoized cache to llama_grammar_reject_candidates_for_stack (#1615)

* Add memoized cache to llama_grammar_reject_candidates_for_stack

* make size cutoff more aggressive and move to outer branch

* update comment

* add cache reset whenever grammar is reloaded

* remove explicit reference types for compiler transportability
This commit is contained in:
Reithan 2025-06-25 04:22:19 -07:00 committed by GitHub
parent b884a7f058
commit 54dde5e565
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 60 additions and 0 deletions

View file

@ -1773,6 +1773,7 @@ static void load_grammar(const std::string & gammarstr)
{
if(grammar!=nullptr) //on demand free when next grammar is loaded
{
llama_grammar_reset_memos();
llama_grammar_free_impl(grammar);
grammar = nullptr;
}

View file

@ -8,10 +8,34 @@
#include <algorithm>
#include <stdexcept>
#include <iostream>
#include <string_view>
#include <unordered_set>
//
// helpers
//
using bytes = std::pair<const char*, size_t>;
using hash_entry_size = std::pair<size_t, size_t>;
template <>
struct std::hash<bytes>
{
std::size_t operator()(const bytes& x) const noexcept
{
return std::hash<std::string_view>{}({x.first, x.second});
}
};
using candidates_memos = std::unordered_map<size_t, llama_grammar_candidates>;
using stack_memos = std::unordered_map<size_t, candidates_memos>;
static stack_memos memo_cache;
static void llama_grammar_reset_memos() {
memo_cache.clear();
}
// NOTE: assumes valid utf8 (but checks for overrun)
static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@ -864,6 +888,38 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
}
return rejects;
}
auto stack_hash_start = reinterpret_cast<const char *>(stack.data());
auto stack_hash_size = sizeof(stack[0]) * stack.size();
auto stack_hash = std::hash<bytes>{}({ stack_hash_start, stack_hash_size });
llama_grammar_candidates * cache_target = nullptr;
// Tests show that >75% of candidate lists are under 1280 and 50% are under 640b.
// Most 'problem' loops are under 24b. However, candidate lists can be over 72k,
// so we need to limit our checks.
// We'll only attempt to memoize candidate lists under 80b
// Doing an over-aggressive size cutoff first befor any other processing 'saves' easy cases
// extra processing but still rescues 'hard' cases from slow down or hangs.
// This leads to a speed up of both easy and hard cases.
const size_t hash_cutoff = 80;
auto candidates_hash_size = sizeof(candidates[0]) * candidates.size();
if (candidates_hash_size < hash_cutoff) {
// Only check stash hash first - these are usually ~24b, and almost always under 64b
if (auto cache_hit = memo_cache.find(stack_hash); cache_hit != memo_cache.end()) {
auto & candidates_memos = cache_hit->second;
auto candidates_hash_start = reinterpret_cast<const char *>(candidates.data());
auto candidates_hash = std::hash<bytes>{}({ candidates_hash_start, candidates_hash_size });
if (auto cache_hit2 = candidates_memos.find(candidates_hash); cache_hit2 != candidates_memos.end()) {
return cache_hit2->second;
} else {
cache_target = &(candidates_memos[candidates_hash]);
}
} else {
memo_cache[stack_hash];
}
}
const llama_grammar_element * stack_pos = stack.back();
@ -900,6 +956,9 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
}
if (cache_target) {
*cache_target = rejects;
}
return rejects;
}