mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-04-28 03:30:20 +00:00
added support for RNN models in smartcache
This commit is contained in:
parent
cde4791e36
commit
e9ae0cb2dd
2 changed files with 158 additions and 20 deletions
|
|
@ -1887,11 +1887,25 @@ float ComputePrefixMatchPercent(std::vector<int> ¤t_context_tokens, std::v
|
|||
}
|
||||
// Handle case where both sequences are empty to avoid division by zero
|
||||
if (min_length == 0) {
|
||||
return 0.0f; // Both empty sequences are considered 100% matched
|
||||
return 0.0f; // Both empty sequences are considered not matched
|
||||
}
|
||||
return static_cast<float>(match_count) / static_cast<float>(min_length);
|
||||
}
|
||||
|
||||
//returns true if and only if sequence 1 is fully contained within the starting of sequence 2
|
||||
bool FullyContainedPrefix(std::vector<int> &sequence1, std::vector<int> &sequence2)
|
||||
{
|
||||
if (sequence1.size() > sequence2.size() || sequence1.size()==0 || sequence2.size()==0) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < sequence1.size(); ++i) {
|
||||
if (sequence1[i] != sequence2[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//given an old GGUF context and a new context that has some middle portion removed,
|
||||
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
|
||||
//returns true if contextshift is doable, executes it if dryrun is false
|
||||
|
|
@ -3921,13 +3935,75 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
{
|
||||
shiftable = false;
|
||||
}
|
||||
const float similarity_threshold = 0.7f;
|
||||
//If CanBeShifted is true, do nothing. Allow shift as normal.
|
||||
if(!(shiftable && CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx)))
|
||||
|
||||
//we handle recurrent models differently since they require a full subset match
|
||||
if(is_recurrent)
|
||||
{
|
||||
bool curr_usable = FullyContainedPrefix(current_context_tokens,embd_inp);
|
||||
if(!curr_usable)
|
||||
{
|
||||
//see if we have any other usable contexts out there
|
||||
int bestslot = -1;
|
||||
int bestlen = 0;
|
||||
int identical_slot = get_identical_existing_slot(); //see if the slot already exists
|
||||
for(int i=0;i<savestate_limit;++i)
|
||||
{
|
||||
bool target_usable = FullyContainedPrefix(savestates[i].savestate_context_tokens,embd_inp);
|
||||
int target_len = savestates[i].savestate_context_tokens.size();
|
||||
if(target_usable && target_len>bestlen)
|
||||
{
|
||||
bestlen = target_len;
|
||||
bestslot = i;
|
||||
}
|
||||
}
|
||||
if(bestslot!=-1) //found a good slot to load
|
||||
{
|
||||
int oldest_slot = get_oldest_slot(bestslot);
|
||||
if(oldest_slot!=bestslot)
|
||||
{
|
||||
if(current_context_tokens.size() > 32) //do not save tiny contexts
|
||||
{
|
||||
if(identical_slot==-1)
|
||||
{
|
||||
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Saving into slot %d and switching...]\n",bestlen,bestslot,oldest_slot);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
} else {
|
||||
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Already saved in slot %d, switching...]\n",bestlen,bestslot,identical_slot);
|
||||
touch_slot(identical_slot);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\n[SmartCache RNN Match of %d tokens in slot %d. Switching...]\n",bestlen,bestslot);
|
||||
}
|
||||
gpttype_load_state_kv(bestslot);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(current_context_tokens.size() > 32) //do not save tiny contexts
|
||||
{
|
||||
if(identical_slot==-1)
|
||||
{
|
||||
int oldest_slot = get_oldest_slot(-1);
|
||||
printf("\n[SmartCache RNN No Match, Saving into slot %d...]\n",oldest_slot);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\n[SmartCache RNN No Match, Already saved in slot %d]\n",identical_slot);
|
||||
touch_slot(identical_slot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(!(shiftable && CanContextShift(current_context_tokens, embd_inp, inputs.max_length, nctx))) //If CanBeShifted is true, do nothing. Allow shift as normal.
|
||||
{
|
||||
// If CanBeShifted is false, calculate prefix similarity with current_context_tokens of current context
|
||||
// If similarity > similarity_threshold, do nothing. Allow fast forward as normal.
|
||||
float similarity = ComputePrefixMatchPercent(current_context_tokens,embd_inp);
|
||||
const float similarity_threshold = 0.7f;
|
||||
if(similarity < similarity_threshold)
|
||||
{
|
||||
// Otherwise, for each of the currently used kv state slots, calculate ComputePrefixMatch and CanBeShifted
|
||||
|
|
@ -3935,6 +4011,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
// Whenever loading or saving current slot, simply tag the slot with a timestamp. When running out of slots after all 3 are used, delete the oldest timestamped slot.
|
||||
// Slot loading and saving completely reuses gpttype_load_state_kv and gpttype_save_state_kv, nothing else is needed.
|
||||
bool foundswap = false;
|
||||
int identical_slot = get_identical_existing_slot(); //see if a slot already exists with identical data to current
|
||||
for(int i=0;i<savestate_limit;++i)
|
||||
{
|
||||
float similaritybeat = ComputePrefixMatchPercent(savestates[i].savestate_context_tokens,embd_inp);
|
||||
|
|
@ -3944,15 +4021,20 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
int oldest_slot = get_oldest_slot(i);
|
||||
if(oldest_slot!=i)
|
||||
{
|
||||
if(current_context_tokens.size()>32) //do not save tiny contexts
|
||||
if(current_context_tokens.size() > 32) //do not save tiny contexts
|
||||
{
|
||||
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]",similaritybeat,i,oldest_slot);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
if(identical_slot==-1)
|
||||
{
|
||||
printf("\n[SmartCache Match of %.2f in slot %d. Saving into slot %d and switching...]\n",similaritybeat,i,oldest_slot);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
} else {
|
||||
printf("\n[SmartCache Match of %.2f in slot %d. Already saved in slot %d, switching...]\n",similaritybeat,i,identical_slot);
|
||||
touch_slot(identical_slot);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]",similaritybeat,i);
|
||||
|
||||
printf("\n[SmartCache Match of %.2f in slot %d. Switching...]\n",similaritybeat,i);
|
||||
}
|
||||
gpttype_load_state_kv(i);
|
||||
foundswap = true;
|
||||
|
|
@ -3962,11 +4044,19 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
}
|
||||
if(!foundswap) //could not match anything, just save kv and continue
|
||||
{
|
||||
if(current_context_tokens.size()>32) //do not save tiny contexts
|
||||
if(current_context_tokens.size() > 32) //do not save tiny contexts
|
||||
{
|
||||
int oldest_slot = get_oldest_slot(-1);
|
||||
printf("\n[SmartCache No Match, Saving into slot %d...]",oldest_slot);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
if(identical_slot==-1)
|
||||
{
|
||||
int oldest_slot = get_oldest_slot(-1);
|
||||
printf("\n[SmartCache No Match, Saving into slot %d...]\n",oldest_slot);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\n[SmartCache No Match, Already saved in slot %d]\n",identical_slot);
|
||||
touch_slot(identical_slot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4790,6 +4880,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
delayed_generated_tokens.pop_front();
|
||||
}
|
||||
|
||||
//if running rnn model in smartcache mode, save progress after each gen
|
||||
if(kcpp_data->smartcache && is_recurrent && file_format==FileFormat::GGUF_GENERIC && current_context_tokens.size() > 32)
|
||||
{
|
||||
int identical_slot = get_identical_existing_slot();
|
||||
if(identical_slot==-1)
|
||||
{
|
||||
int oldest_slot = get_oldest_slot(-1);
|
||||
gpttype_save_state_kv(oldest_slot);
|
||||
}
|
||||
else
|
||||
{
|
||||
touch_slot(identical_slot);
|
||||
}
|
||||
}
|
||||
|
||||
if(debugmode==1 && !is_quiet && file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
printf("\n");
|
||||
|
|
@ -4907,9 +5012,7 @@ size_t gpttype_save_state_kv(int slot)
|
|||
totalbytes += res;
|
||||
savestates[slot].current_savestate_size = newsize;
|
||||
savestates[slot].savestate_context_tokens = current_context_tokens;
|
||||
auto timenow = std::chrono::system_clock::now();
|
||||
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
|
||||
savestates[slot].last_used = timestamp;
|
||||
touch_slot(slot);
|
||||
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
|
||||
}
|
||||
|
||||
|
|
@ -4959,9 +5062,7 @@ bool gpttype_load_state_kv(int slot)
|
|||
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
|
||||
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
|
||||
}
|
||||
auto timenow = std::chrono::system_clock::now();
|
||||
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
|
||||
savestates[slot].last_used = timestamp;
|
||||
touch_slot(slot);
|
||||
}
|
||||
return (res > 0);
|
||||
}
|
||||
|
|
@ -5002,6 +5103,41 @@ bool gpttype_clear_state_kv(bool shrink)
|
|||
}
|
||||
return false;
|
||||
}
|
||||
void touch_slot(int slot) //update the slot's last used time and nothing else
|
||||
{
|
||||
auto timenow = std::chrono::system_clock::now();
|
||||
auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(timenow.time_since_epoch()).count();
|
||||
savestates[slot].last_used = timestamp;
|
||||
}
|
||||
int get_identical_existing_slot() //returns slot number of slot containing exactly the same data, or -1 if nothing
|
||||
{
|
||||
int64_t slotage = INT64_MAX; // Initialize with maximum possible value
|
||||
int slotid = -1;
|
||||
int currctxsize = current_context_tokens.size();
|
||||
for(int i=0;i<savestate_limit;++i)
|
||||
{
|
||||
if(savestates[i].savestate_context_tokens.size() == currctxsize)
|
||||
{
|
||||
bool is_identical = true;
|
||||
const auto& slot_tokens = savestates[i].savestate_context_tokens;
|
||||
for (size_t j = 0; j < currctxsize; ++j)
|
||||
{
|
||||
if (slot_tokens[j] != current_context_tokens[j])
|
||||
{
|
||||
is_identical = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_identical)
|
||||
{
|
||||
slotid = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return slotid;
|
||||
}
|
||||
|
||||
int get_oldest_slot(int excludeSlotId)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -140,4 +140,6 @@ size_t gpttype_calc_old_state_tokencount(int slot);
|
|||
size_t gpttype_save_state_kv(int slot);
|
||||
bool gpttype_load_state_kv(int slot);
|
||||
bool gpttype_clear_state_kv(bool shrink);
|
||||
int get_oldest_slot(int excludeSlotId);
|
||||
int get_oldest_slot(int excludeSlotId);
|
||||
void touch_slot(int slot);
|
||||
int get_identical_existing_slot();
|
||||
Loading…
Add table
Add a link
Reference in a new issue