From 54dde5e565b16ce311b5077a124f07c0f28c81db Mon Sep 17 00:00:00 2001 From: Reithan Date: Wed, 25 Jun 2025 04:22:19 -0700 Subject: [PATCH] Add memoized cache to `llama_grammar_reject_candidates_for_stack` (#1615) * Add memoized cache to llama_grammar_reject_candidates_for_stack * make size cutoff more aggressive and move to outer branch * update comment * add cache reset whenever grammar is reloaded * remove explicit reference types for compiler transportability --- gpttype_adapter.cpp | 1 + src/llama-grammar.cpp | 59 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index ad9d7f922..26cf04818 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1773,6 +1773,7 @@ static void load_grammar(const std::string & gammarstr) { if(grammar!=nullptr) //on demand free when next grammar is loaded { + llama_grammar_reset_memos(); llama_grammar_free_impl(grammar); grammar = nullptr; } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bed706bb2..ead5d3f3a 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -8,10 +8,34 @@ #include #include +#include +#include +#include + // // helpers // +using bytes = std::pair; +using hash_entry_size = std::pair; + +template <> +struct std::hash +{ + std::size_t operator()(const bytes& x) const noexcept + { + return std::hash{}({x.first, x.second}); + } +}; + +using candidates_memos = std::unordered_map; +using stack_memos = std::unordered_map; +static stack_memos memo_cache; + +static void llama_grammar_reset_memos() { + memo_cache.clear(); +} + // NOTE: assumes valid utf8 (but checks for overrun) static std::pair decode_utf8(const char * src) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; @@ -864,6 +888,38 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( } return rejects; } + + auto stack_hash_start = reinterpret_cast(stack.data()); + auto stack_hash_size = sizeof(stack[0]) * stack.size(); + auto stack_hash = std::hash{}({ stack_hash_start, stack_hash_size }); + + llama_grammar_candidates * cache_target = nullptr; + + // Tests show that >75% of candidate lists are under 1280 and 50% are under 640b. + // Most 'problem' loops are under 24b. However, candidate lists can be over 72k, + // so we need to limit our checks. + + // We'll only attempt to memoize candidate lists under 80b + // Doing an over-aggressive size cutoff first befor any other processing 'saves' easy cases + // extra processing but still rescues 'hard' cases from slow down or hangs. + // This leads to a speed up of both easy and hard cases. + const size_t hash_cutoff = 80; + auto candidates_hash_size = sizeof(candidates[0]) * candidates.size(); + if (candidates_hash_size < hash_cutoff) { + // Only check stash hash first - these are usually ~24b, and almost always under 64b + if (auto cache_hit = memo_cache.find(stack_hash); cache_hit != memo_cache.end()) { + auto & candidates_memos = cache_hit->second; + auto candidates_hash_start = reinterpret_cast(candidates.data()); + auto candidates_hash = std::hash{}({ candidates_hash_start, candidates_hash_size }); + if (auto cache_hit2 = candidates_memos.find(candidates_hash); cache_hit2 != candidates_memos.end()) { + return cache_hit2->second; + } else { + cache_target = &(candidates_memos[candidates_hash]); + } + } else { + memo_cache[stack_hash]; + } + } const llama_grammar_element * stack_pos = stack.back(); @@ -900,6 +956,9 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); } + if (cache_target) { + *cache_target = rejects; + } return rejects; }