added support for RNN models in smartcache

This commit is contained in:
Concedo 2025-12-19 16:36:25 +08:00
parent cde4791e36
commit e9ae0cb2dd
2 changed files with 158 additions and 20 deletions

View file

@ -1887,11 +1887,25 @@ float ComputePrefixMatchPercent(std::vector<int> &current_context_tokens, std::v
}
// Handle case where both sequences are empty to avoid division by zero
if (min_length == 0) {
return 0.0f; // Both empty sequences are considered 100% matched
return 0.0f; // Both empty sequences are considered not matched
}
return static_cast<float>(match_count) / static_cast<float>(min_length);
}
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
{
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
return false;
}
for (size_t i = 0; i < sequence1.size(); ++i) {
if (sequence1[i] != sequence2[i]) {
return false;
}
}
return true;
}
//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
//returns true if contextshift is doable, executes it if dryrun is false
@ -3921,13 +3935,75 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
{
shiftable = false;
}
const float similarity_threshold = 0.7f;
//If CanBeShifted is true, do nothing. Allow shift as normal.
if(!(shiftable && CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx)))
//we handle recurrent models differently since they require a full subset match
if(is_recurrent)
{
bool curr_usable = FullyContainedPrefix(current_context_tokens,embd_inp);
if(!curr_usable)
{
//see if we have any other usable contexts out there
int bestslot = -1;
int bestlen = 0;
int identical_slot = get_identical_existing_slot(); //see if the slot already exists
for(int i=0;i<savestate_limit;++i)
{
bool target_usable = FullyContainedPrefix(savestates[i].savestate_context_tokens,embd_inp);
int target_len = savestates[i].savestate_context_tokens.size();
if(target_usable && target_len>bestlen)
{
bestlen = target_len;
bestslot = i;
}
}
if(bestslot!=-1) //found a good slot to load
{
int oldest_slot = get_oldest_slot(bestslot);
if(oldest_slot!=bestslot)
{
if(current_context_tokens.size() > 32) //do not save tiny contexts
{
if(identical_slot==-1)
{
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Saving into slot %d and switching...]\n",bestlen,bestslot,oldest_slot);
gpttype_save_state_kv(oldest_slot);
} else {
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Already saved in slot %d, switching...]\n",bestlen,bestslot,identical_slot);
touch_slot(identical_slot);
}
}
else
{
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Switching...]\n",bestlen,bestslot);
}
gpttype_load_state_kv(bestslot);
}
}
else
{
if(current_context_tokens.size() > 32) //do not save tiny contexts
{
if(identical_slot==-1)
{
int oldest_slot = get_oldest_slot(-1);
printf("\n[SmartCache RNN No Match, Saving into slot %d...]\n",oldest_slot);
gpttype_save_state_kv(oldest_slot);
}
else
{
printf("\n[SmartCache RNN No Match, Already saved in slot %d]\n",identical_slot);
touch_slot(identical_slot);
}
}
}
}
}
else if(!(shiftable && CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx))) //If CanBeShifted is true, do nothing. Allow shift as normal.
{
// If CanBeShifted is false, calculate prefix similarity with current_context_tokens of current context
// If similarity > similarity_threshold, do nothing. Allow fast forward as normal.
float similarity = ComputePrefixMatchPercent(current_context_tokens,embd_inp);
const float similarity_threshold = 0.7f;
if(similarity < similarity_threshold)
{
// Otherwise, for each of the currently used kv state slots, calculate ComputePrefixMatch and CanBeShifted
@ -3935,6 +4011,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
// Whenever loading or saving current slot, simply tag the slot with a timestamp. When running out of slots after all 3 are used, delete the oldest timestamped slot.
// Slot loading and saving completely reuses gpttype_load_state_kv and gpttype_save_state_kv, nothing else is needed.
bool foundswap = false;
int identical_slot = get_identical_existing_slot(); //see if a slot already exists with identical data to current
for(int i=0;i<savestate_limit;++i)
{
float similaritybeat = ComputePrefixMatchPercent(savestates[i].savestate_context_tokens,embd_inp);
@ -3944,15 +4021,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
int oldest_slot = get_oldest_slot(i);
if(oldest_slot!=i)
{
if(current_context_tokens.size()>32) //do not save tiny contexts
if(current_context_tokens.size() > 32) //do not save tiny contexts
{
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]",similaritybeat,i,oldest_slot);
gpttype_save_state_kv(oldest_slot);
if(identical_slot==-1)
{
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]\n",similaritybeat,i,oldest_slot);
gpttype_save_state_kv(oldest_slot);
} else {
printf("\n[SmartCache Match of %.2f in slot %d. Already saved in slot %d, switching...]\n",similaritybeat,i,identical_slot);
touch_slot(identical_slot);
}
}
else
{
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]",similaritybeat,i);
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]\n",similaritybeat,i);
}
gpttype_load_state_kv(i);
foundswap = true;
@ -3962,11 +4044,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
if(!foundswap) //could not match anything, just save kv and continue
{
if(current_context_tokens.size()>32) //do not save tiny contexts
if(current_context_tokens.size() > 32) //do not save tiny contexts
{
int oldest_slot = get_oldest_slot(-1);
printf("\n[SmartCache No Match, Saving into slot %d...]",oldest_slot);
gpttype_save_state_kv(oldest_slot);
if(identical_slot==-1)
{
int oldest_slot = get_oldest_slot(-1);
printf("\n[SmartCache No Match, Saving into slot %d...]\n",oldest_slot);
gpttype_save_state_kv(oldest_slot);
}
else
{
printf("\n[SmartCache No Match, Already saved in slot %d]\n",identical_slot);
touch_slot(identical_slot);
}
}
}
}
@ -4790,6 +4880,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
delayed_generated_tokens.pop_front();
}
//if running rnn model in smartcache mode, save progress after each gen
if(kcpp_data->smartcache && is_recurrent && file_format==FileFormat::GGUF_GENERIC && current_context_tokens.size() > 32)
{
int identical_slot = get_identical_existing_slot();
if(identical_slot==-1)
{
int oldest_slot = get_oldest_slot(-1);
gpttype_save_state_kv(oldest_slot);
}
else
{
touch_slot(identical_slot);
}
}
if(debugmode==1 && !is_quiet && file_format == FileFormat::GGUF_GENERIC)
{
printf("\n");
@ -4907,9 +5012,7 @@ size_t gpttype_save_state_kv(int slot)
totalbytes += res;
savestates[slot].current_savestate_size = newsize;
savestates[slot].savestate_context_tokens = current_context_tokens;
auto timenow = std::chrono::system_clock::now();
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
savestates[slot].last_used = timestamp;
touch_slot(slot);
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
}
@ -4959,9 +5062,7 @@ bool gpttype_load_state_kv(int slot)
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
}
auto timenow = std::chrono::system_clock::now();
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
savestates[slot].last_used = timestamp;
touch_slot(slot);
}
return (res > 0);
}
@ -5002,6 +5103,41 @@ bool gpttype_clear_state_kv(bool shrink)
}
return false;
}
void touch_slot(int slot) //update the slot's last used time and nothing else
{
auto timenow = std::chrono::system_clock::now();
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
savestates[slot].last_used = timestamp;
}
int get_identical_existing_slot() //returns slot number of slot containing exactly the same data, or -1 if nothing
{
int64_t slotage = INT64_MAX; // Initialize with maximum possible value
int slotid = -1;
int currctxsize = current_context_tokens.size();
for(int i=0;i<savestate_limit;++i)
{
if(savestates[i].savestate_context_tokens.size() == currctxsize)
{
bool is_identical = true;
const auto& slot_tokens = savestates[i].savestate_context_tokens;
for (size_t j = 0; j < currctxsize; ++j)
{
if (slot_tokens[j] != current_context_tokens[j])
{
is_identical = false;
break;
}
}
if (is_identical)
{
slotid = i;
break;
}
}
}
return slotid;
}
int get_oldest_slot(int excludeSlotId)
{

View file

@ -140,4 +140,6 @@ size_t gpttype_calc_old_state_tokencount(int slot);
size_t gpttype_save_state_kv(int slot);
bool gpttype_load_state_kv(int slot);
bool gpttype_clear_state_kv(bool shrink);
int get_oldest_slot(int excludeSlotId);
int get_oldest_slot(int excludeSlotId);
void touch_slot(int slot);
int get_identical_existing_slot();