mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
fixed savestates with drafting
This commit is contained in:
parent
df47b51bd1
commit
39b0699c71
2 changed files with 52 additions and 5 deletions
|
@ -4341,13 +4341,19 @@ size_t gpttype_calc_new_state_kv()
|
|||
}
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
return llama_state_get_size(llama_ctx_v4);
|
||||
size_t s1 = llama_state_get_size(llama_ctx_v4);
|
||||
if(draft_ctx)
|
||||
{
|
||||
size_t s2 = llama_state_get_size(draft_ctx);
|
||||
s1 += s2;
|
||||
}
|
||||
return s1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
size_t gpttype_calc_old_state_kv(int slot)
|
||||
{
|
||||
return savestates[slot].current_savestate_size;
|
||||
return savestates[slot].current_savestate_size + savestates[slot].current_draft_savestate_size;
|
||||
}
|
||||
size_t gpttype_calc_old_state_tokencount(int slot)
|
||||
{
|
||||
|
@ -4365,30 +4371,54 @@ size_t gpttype_save_state_kv(int slot)
|
|||
}
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
size_t totalbytes = 0;
|
||||
if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free
|
||||
savestates[slot].current_savestate_buffer.clear();
|
||||
savestates[slot].current_draft_savestate_buffer.clear();
|
||||
savestates[slot].savestate_context_tokens.clear();
|
||||
savestates[slot].current_savestate_size = 0;
|
||||
savestates[slot].current_draft_savestate_size = 0;
|
||||
}
|
||||
size_t newsize = llama_state_get_size(llama_ctx_v4);
|
||||
try {
|
||||
if (savestates[slot].current_savestate_buffer.capacity() < newsize + 512) {
|
||||
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512);
|
||||
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512); // add some padding. May throw std::bad_alloc
|
||||
} else {
|
||||
savestates[slot].current_savestate_buffer.resize(newsize + 512);
|
||||
}
|
||||
savestates[slot].current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc
|
||||
} catch (const std::bad_alloc&) {
|
||||
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512);
|
||||
return 0;
|
||||
}
|
||||
auto res = llama_state_get_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), newsize);
|
||||
if (res > 0) {
|
||||
totalbytes += res;
|
||||
savestates[slot].current_savestate_size = newsize;
|
||||
savestates[slot].savestate_context_tokens = current_context_tokens;
|
||||
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
|
||||
}
|
||||
return res;
|
||||
|
||||
if(draft_ctx)
|
||||
{
|
||||
size_t newsize2 = llama_state_get_size(draft_ctx);
|
||||
try {
|
||||
if (savestates[slot].current_draft_savestate_buffer.capacity() < newsize2 + 512) {
|
||||
savestates[slot].current_draft_savestate_buffer = std::vector<uint8_t>(newsize2 + 512);
|
||||
} else {
|
||||
savestates[slot].current_draft_savestate_buffer.resize(newsize2 + 512);
|
||||
}
|
||||
} catch (const std::bad_alloc&) {
|
||||
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize2 + 512);
|
||||
return 0;
|
||||
}
|
||||
auto res2 = llama_state_get_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), newsize2);
|
||||
if (res2 > 0) {
|
||||
totalbytes += res2;
|
||||
savestates[slot].current_draft_savestate_size = newsize2;
|
||||
printf("\nKV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_draft_savestate_size/(1024*1024));
|
||||
}
|
||||
}
|
||||
return totalbytes;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -4408,6 +4438,12 @@ bool gpttype_load_state_kv(int slot)
|
|||
{
|
||||
current_context_tokens = savestates[slot].savestate_context_tokens;
|
||||
printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
|
||||
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
|
||||
{
|
||||
llama_memory_clear(llama_get_memory(draft_ctx),true);
|
||||
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
|
||||
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
|
||||
}
|
||||
}
|
||||
return (res > 0);
|
||||
}
|
||||
|
@ -4432,6 +4468,15 @@ bool gpttype_clear_state_kv(bool shrink)
|
|||
}
|
||||
savestates[slot].savestate_context_tokens.clear();
|
||||
savestates[slot].current_savestate_size = 0;
|
||||
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
|
||||
{
|
||||
savestates[slot].current_draft_savestate_buffer.clear();
|
||||
if(shrink)
|
||||
{
|
||||
savestates[slot].current_draft_savestate_buffer.shrink_to_fit();
|
||||
}
|
||||
savestates[slot].current_draft_savestate_size = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue