use a static buffer for kv reloads instead. also, added into lite ui

This commit is contained in:
Concedo 2025-06-03 22:32:46 +08:00
parent 4b57108508
commit 53f1511396
6 changed files with 239 additions and 99 deletions

View file

@ -143,7 +143,7 @@ std::deque<std::string> delayed_generated_tokens; //for use with antislop sampli
static std::map<int,std::vector<int>> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index
static size_t current_savestate_size = 0;
uint8_t * current_savestate_ptr = nullptr;
static std::vector<uint8_t> current_savestate_buffer;
static std::vector<gpt_vocab::id> savestate_context_tokens; //for context clones
inline int kcpp_cpu_has_blas(void) {
@ -4331,30 +4331,44 @@ size_t gpttype_calc_old_state_kv()
{
return current_savestate_size;
}
bool gpttype_save_state_kv()
size_t gpttype_calc_old_state_tokencount()
{
return savestate_context_tokens.size();
}
size_t gpttype_calc_new_state_tokencount()
{
return current_context_tokens.size();
}
size_t gpttype_save_state_kv()
{
if(kcpp_data==nullptr)
{
return false;
return 0;
}
if(file_format == FileFormat::GGUF_GENERIC)
{
gpttype_clear_state_kv(); //JIT free
gpttype_clear_state_kv(false); //JIT free
size_t newsize = llama_state_get_size(llama_ctx_v4);
current_savestate_ptr = (uint8_t *) malloc(newsize + 512); //add some padding
if(!current_savestate_ptr)
{
return false;
try {
if (current_savestate_buffer.capacity() < newsize + 512) {
current_savestate_buffer = std::vector<uint8_t>(newsize + 512);
} else {
current_savestate_buffer.resize(newsize + 512);
}
current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc
} catch (const std::bad_alloc&) {
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512);
return 0;
}
auto res = llama_state_get_data(llama_ctx_v4, current_savestate_ptr, newsize);
auto res = llama_state_get_data(llama_ctx_v4, current_savestate_buffer.data(), newsize);
if (res > 0) {
current_savestate_size = newsize;
savestate_context_tokens = current_context_tokens;
printf("\nKV Save State: Created SaveState of %zu tokens, costing %zu MB.\n",current_context_tokens.size(),current_savestate_size/(1024*1024));
}
return (res > 0);
return res;
}
return false;
return 0;
}
bool gpttype_load_state_kv()
{
@ -4364,10 +4378,10 @@ bool gpttype_load_state_kv()
}
if(file_format == FileFormat::GGUF_GENERIC)
{
if (current_savestate_ptr == nullptr || current_savestate_size == 0) {
if (current_savestate_buffer.empty()) {
return false;
}
auto res = llama_state_set_data(llama_ctx_v4, current_savestate_ptr, current_savestate_size);
auto res = llama_state_set_data(llama_ctx_v4, current_savestate_buffer.data(), current_savestate_size);
if(res > 0)
{
current_context_tokens = savestate_context_tokens;
@ -4377,7 +4391,7 @@ bool gpttype_load_state_kv()
}
return false;
}
bool gpttype_clear_state_kv()
bool gpttype_clear_state_kv(bool shrink)
{
if(kcpp_data==nullptr)
{
@ -4385,11 +4399,13 @@ bool gpttype_clear_state_kv()
}
if(file_format == FileFormat::GGUF_GENERIC)
{
if (current_savestate_ptr != nullptr) {
//JIT free
printf("\nKV Clear SaveState: Freed %zu MB.\n",current_savestate_size/(1024*1024));
free(current_savestate_ptr);
current_savestate_ptr = nullptr;
if (!current_savestate_buffer.empty()) {
printf("\nKV Clear SaveState: Freed %zu MB.\n", current_savestate_size / (1024 * 1024));
current_savestate_buffer.clear();
if(shrink)
{
current_savestate_buffer.shrink_to_fit();
}
savestate_context_tokens.clear();
current_savestate_size = 0;
return true;