mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
save and load state upgraded to 3 available states
This commit is contained in:
parent
06d2bc3404
commit
736030bb9f
7 changed files with 206 additions and 82 deletions
|
@ -142,9 +142,8 @@ static int delayed_generated_tokens_limit = 0;
|
|||
std::deque<std::string> delayed_generated_tokens; //for use with antislop sampling
|
||||
static std::map<int,std::vector<int>> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index
|
||||
|
||||
static size_t current_savestate_size = 0;
|
||||
static std::vector<uint8_t> current_savestate_buffer;
|
||||
static std::vector<gpt_vocab::id> savestate_context_tokens; //for context clones
|
||||
const int savestate_limit = 3;
|
||||
static savestate_data savestates[savestate_limit];
|
||||
|
||||
inline int kcpp_cpu_has_blas(void) {
|
||||
#if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL)
|
||||
|
@ -4327,19 +4326,19 @@ size_t gpttype_calc_new_state_kv()
|
|||
}
|
||||
return 0;
|
||||
}
|
||||
size_t gpttype_calc_old_state_kv()
|
||||
size_t gpttype_calc_old_state_kv(int slot)
|
||||
{
|
||||
return current_savestate_size;
|
||||
return savestates[slot].current_savestate_size;
|
||||
}
|
||||
size_t gpttype_calc_old_state_tokencount()
|
||||
size_t gpttype_calc_old_state_tokencount(int slot)
|
||||
{
|
||||
return savestate_context_tokens.size();
|
||||
return savestates[slot].savestate_context_tokens.size();
|
||||
}
|
||||
size_t gpttype_calc_new_state_tokencount()
|
||||
{
|
||||
return current_context_tokens.size();
|
||||
}
|
||||
size_t gpttype_save_state_kv()
|
||||
size_t gpttype_save_state_kv(int slot)
|
||||
{
|
||||
if(kcpp_data==nullptr)
|
||||
{
|
||||
|
@ -4347,30 +4346,34 @@ size_t gpttype_save_state_kv()
|
|||
}
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
gpttype_clear_state_kv(false); //JIT free
|
||||
if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free
|
||||
savestates[slot].current_savestate_buffer.clear();
|
||||
savestates[slot].savestate_context_tokens.clear();
|
||||
savestates[slot].current_savestate_size = 0;
|
||||
}
|
||||
size_t newsize = llama_state_get_size(llama_ctx_v4);
|
||||
try {
|
||||
if (current_savestate_buffer.capacity() < newsize + 512) {
|
||||
current_savestate_buffer = std::vector<uint8_t>(newsize + 512);
|
||||
if (savestates[slot].current_savestate_buffer.capacity() < newsize + 512) {
|
||||
savestates[slot].current_savestate_buffer = std::vector<uint8_t>(newsize + 512);
|
||||
} else {
|
||||
current_savestate_buffer.resize(newsize + 512);
|
||||
savestates[slot].current_savestate_buffer.resize(newsize + 512);
|
||||
}
|
||||
current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc
|
||||
savestates[slot].current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc
|
||||
} catch (const std::bad_alloc&) {
|
||||
fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512);
|
||||
return 0;
|
||||
}
|
||||
auto res = llama_state_get_data(llama_ctx_v4, current_savestate_buffer.data(), newsize);
|
||||
auto res = llama_state_get_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), newsize);
|
||||
if (res > 0) {
|
||||
current_savestate_size = newsize;
|
||||
savestate_context_tokens = current_context_tokens;
|
||||
printf("\nKV Save State: Created SaveState of %zu tokens, costing %zu MB.\n",current_context_tokens.size(),current_savestate_size/(1024*1024));
|
||||
savestates[slot].current_savestate_size = newsize;
|
||||
savestates[slot].savestate_context_tokens = current_context_tokens;
|
||||
printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
bool gpttype_load_state_kv()
|
||||
bool gpttype_load_state_kv(int slot)
|
||||
{
|
||||
if(kcpp_data==nullptr)
|
||||
{
|
||||
|
@ -4378,14 +4381,14 @@ bool gpttype_load_state_kv()
|
|||
}
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if (current_savestate_buffer.empty()) {
|
||||
if (savestates[slot].current_savestate_buffer.empty()) {
|
||||
return false;
|
||||
}
|
||||
auto res = llama_state_set_data(llama_ctx_v4, current_savestate_buffer.data(), current_savestate_size);
|
||||
auto res = llama_state_set_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), savestates[slot].current_savestate_size);
|
||||
if(res > 0)
|
||||
{
|
||||
current_context_tokens = savestate_context_tokens;
|
||||
printf("\nKV Load SaveState: Restored KV with %zu tokens.\n",current_context_tokens.size());
|
||||
current_context_tokens = savestates[slot].savestate_context_tokens;
|
||||
printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size());
|
||||
}
|
||||
return (res > 0);
|
||||
}
|
||||
|
@ -4399,18 +4402,20 @@ bool gpttype_clear_state_kv(bool shrink)
|
|||
}
|
||||
if(file_format == FileFormat::GGUF_GENERIC)
|
||||
{
|
||||
if (!current_savestate_buffer.empty()) {
|
||||
printf("\nKV Clear SaveState: Freed %zu MB.\n", current_savestate_size / (1024 * 1024));
|
||||
current_savestate_buffer.clear();
|
||||
if(shrink)
|
||||
{
|
||||
current_savestate_buffer.shrink_to_fit();
|
||||
for(int slot=0;slot<savestate_limit;++slot)
|
||||
{
|
||||
if (!savestates[slot].current_savestate_buffer.empty()) {
|
||||
printf("\nKV Clear SaveState %d: Freed %zu MB.\n",slot, savestates[slot].current_savestate_size / (1024 * 1024));
|
||||
savestates[slot].current_savestate_buffer.clear();
|
||||
if(shrink)
|
||||
{
|
||||
savestates[slot].current_savestate_buffer.shrink_to_fit();
|
||||
}
|
||||
savestates[slot].savestate_context_tokens.clear();
|
||||
savestates[slot].current_savestate_size = 0;
|
||||
}
|
||||
savestate_context_tokens.clear();
|
||||
current_savestate_size = 0;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue