From 39b0699c7167fb86d6c14dee27f9609e2de018ef Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri, 27 Jun 2025 20:35:38 +0800 Subject: [PATCH] fixed savestates with drafting --- gpttype_adapter.cpp | 55 +++++++++++++++++++++++++++++++++++++++---- otherarch/otherarch.h | 2 ++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 26cf04818..1e5ebccc5 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -4341,13 +4341,19 @@ size_t gpttype_calc_new_state_kv() } if(file_format == FileFormat::GGUF_GENERIC) { - return llama_state_get_size(llama_ctx_v4); + size_t s1 = llama_state_get_size(llama_ctx_v4); + if(draft_ctx) + { + size_t s2 = llama_state_get_size(draft_ctx); + s1 += s2; + } + return s1; } return 0; } size_t gpttype_calc_old_state_kv(int slot) { - return savestates[slot].current_savestate_size; + return savestates[slot].current_savestate_size + savestates[slot].current_draft_savestate_size; } size_t gpttype_calc_old_state_tokencount(int slot) { @@ -4365,30 +4371,54 @@ size_t gpttype_save_state_kv(int slot) } if(file_format == FileFormat::GGUF_GENERIC) { + size_t totalbytes = 0; if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free savestates[slot].current_savestate_buffer.clear(); + savestates[slot].current_draft_savestate_buffer.clear(); savestates[slot].savestate_context_tokens.clear(); savestates[slot].current_savestate_size = 0; + savestates[slot].current_draft_savestate_size = 0; } size_t newsize = llama_state_get_size(llama_ctx_v4); try { if (savestates[slot].current_savestate_buffer.capacity() < newsize + 512) { - savestates[slot].current_savestate_buffer = std::vector(newsize + 512); + savestates[slot].current_savestate_buffer = std::vector(newsize + 512); // add some padding. May throw std::bad_alloc } else { savestates[slot].current_savestate_buffer.resize(newsize + 512); } - savestates[slot].current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc } catch (const std::bad_alloc&) { fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512); return 0; } auto res = llama_state_get_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), newsize); if (res > 0) { + totalbytes += res; savestates[slot].current_savestate_size = newsize; savestates[slot].savestate_context_tokens = current_context_tokens; printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024)); } - return res; + + if(draft_ctx) + { + size_t newsize2 = llama_state_get_size(draft_ctx); + try { + if (savestates[slot].current_draft_savestate_buffer.capacity() < newsize2 + 512) { + savestates[slot].current_draft_savestate_buffer = std::vector(newsize2 + 512); + } else { + savestates[slot].current_draft_savestate_buffer.resize(newsize2 + 512); + } + } catch (const std::bad_alloc&) { + fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize2 + 512); + return 0; + } + auto res2 = llama_state_get_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), newsize2); + if (res2 > 0) { + totalbytes += res2; + savestates[slot].current_draft_savestate_size = newsize2; + printf("\nKV Save State %d: Created DraftSaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_draft_savestate_size/(1024*1024)); + } + } + return totalbytes; } return 0; } @@ -4408,6 +4438,12 @@ bool gpttype_load_state_kv(int slot) { current_context_tokens = savestates[slot].savestate_context_tokens; printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size()); + if(draft_ctx && savestates[slot].current_draft_savestate_size>0) + { + llama_memory_clear(llama_get_memory(draft_ctx),true); + auto res2 = llama_state_set_data(draft_ctx, savestates[slot].current_draft_savestate_buffer.data(), savestates[slot].current_draft_savestate_size); + printf("\nKV Load DraftSaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size()); + } } return (res > 0); } @@ -4432,6 +4468,15 @@ bool gpttype_clear_state_kv(bool shrink) } savestates[slot].savestate_context_tokens.clear(); savestates[slot].current_savestate_size = 0; + if(draft_ctx && savestates[slot].current_draft_savestate_size>0) + { + savestates[slot].current_draft_savestate_buffer.clear(); + if(shrink) + { + savestates[slot].current_draft_savestate_buffer.shrink_to_fit(); + } + savestates[slot].current_draft_savestate_size = 0; + } } } return true; diff --git a/otherarch/otherarch.h b/otherarch/otherarch.h index 8f69765b7..be7d04da6 100644 --- a/otherarch/otherarch.h +++ b/otherarch/otherarch.h @@ -521,6 +521,8 @@ struct savestate_data { size_t current_savestate_size = 0; std::vector current_savestate_buffer; + size_t current_draft_savestate_size = 0; + std::vector current_draft_savestate_buffer; std::vector savestate_context_tokens; //for context clones };