fixed savestates with drafting

This commit is contained in:
Concedo 2025-06-27 20:35:38 +08:00
parent df47b51bd1
commit 39b0699c71
2 changed files with 52 additions and 5 deletions

View file

@ -4341,13 +4341,19 @@ size_t gpttype_calc_new_state_kv()
}
if(file_format == FileFormat::GGUF_GENERIC)
{
return llama_state_get_size(llama_ctx_v4);
size_t s1 = llama_state_get_size(llama_ctx_v4);
if(draft_ctx)
{
size_t s2 = llama_state_get_size(draft_ctx);
s1 += s2;
}
return s1;
}
return 0;
}
size_t gpttype_calc_old_state_kv(int slot)
{
return savestates[slot].current_savestate_size;
return savestates[slot].current_savestate_size + savestates[slot].current_draft_savestate_size;
}
size_t gpttype_calc_old_state_tokencount(int slot)
{
@ -4365,30 +4371,54 @@ size_t gpttype_save_state_kv(int slot)
}
if(file_format == FileFormat::GGUF_GENERIC)
{
size_t totalbytes = 0;
if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free
savestates[slot].current_savestate_buffer.clear();
savestates[slot].current_draft_savestate_buffer.clear();
savestates[slot].savestate_context_tokens.clear();
savestates[slot].current_savestate_size = 0;
savestates[slot].current_draft_savestate_size = 0;
}
size_t newsize = llama_state_get_size(llama_ctx_v4);
try {
if (savestates[slot].current_savestate_buffer.capacity() < newsize + 512) {
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512);
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512); // add some padding. May throw std::bad_alloc
} else {
savestates[slot].current_savestate_buffer.resize(newsize + 512);
}
savestates[slot].current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc
} catch (const std::bad_alloc&) {
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512);
return 0;
}
auto res = llama_state_get_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), newsize);
if (res > 0) {
totalbytes += res;
savestates[slot].current_savestate_size = newsize;
savestates[slot].savestate_context_tokens = current_context_tokens;
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
}
return res;
if(draft_ctx)
{
size_t newsize2 = llama_state_get_size(draft_ctx);
try {
if (savestates[slot].current_draft_savestate_buffer.capacity() < newsize2 + 512) {
savestates[slot].current_draft_savestate_buffer = std::vector<uint8_t>(newsize2 + 512);
} else {
savestates[slot].current_draft_savestate_buffer.resize(newsize2 + 512);
}
} catch (const std::bad_alloc&) {
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize2 + 512);
return 0;
}
auto res2 = llama_state_get_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), newsize2);
if (res2 > 0) {
totalbytes += res2;
savestates[slot].current_draft_savestate_size = newsize2;
printf("\nKV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_draft_savestate_size/(1024*1024));
}
}
return totalbytes;
}
return 0;
}
@ -4408,6 +4438,12 @@ bool gpttype_load_state_kv(int slot)
{
current_context_tokens = savestates[slot].savestate_context_tokens;
printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
{
llama_memory_clear(llama_get_memory(draft_ctx),true);
auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size);
printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
}
}
return (res > 0);
}
@ -4432,6 +4468,15 @@ bool gpttype_clear_state_kv(bool shrink)
}
savestates[slot].savestate_context_tokens.clear();
savestates[slot].current_savestate_size = 0;
if(draft_ctx && savestates[slot].current_draft_savestate_size>0)
{
savestates[slot].current_draft_savestate_buffer.clear();
if(shrink)
{
savestates[slot].current_draft_savestate_buffer.shrink_to_fit();
}
savestates[slot].current_draft_savestate_size = 0;
}
}
}
return true;

View file

@ -521,6 +521,8 @@ struct savestate_data
{
size_t current_savestate_size = 0;
std::vector<uint8_t> current_savestate_buffer;
size_t current_draft_savestate_size = 0;
std::vector<uint8_t> current_draft_savestate_buffer;
std::vector<gpt_vocab::id> savestate_context_tokens; //for context clones
};