diff --git a/expose.cpp b/expose.cpp index 8b3a1bc7f..1a9b2437e 100644 --- a/expose.cpp +++ b/expose.cpp @@ -380,21 +380,21 @@ extern "C" { return gpttype_calc_new_state_tokencount(); } - size_t calc_old_state_kv() //returns how much memory current savestate is using + size_t calc_old_state_kv(int slot) //returns how much memory current savestate is using { - return gpttype_calc_old_state_kv(); + return gpttype_calc_old_state_kv(slot); } - size_t calc_old_state_tokencount() + size_t calc_old_state_tokencount(int slot) { - return gpttype_calc_old_state_tokencount(); + return gpttype_calc_old_state_tokencount(slot); } - size_t save_state_kv() //triggers the save kv state of current ctx to memory + size_t save_state_kv(int slot) //triggers the save kv state of current ctx to memory { - return gpttype_save_state_kv(); + return gpttype_save_state_kv(slot); } - bool load_state_kv() //triggers the load kv state of current ctx to memory + bool load_state_kv(int slot) //triggers the load kv state of current ctx to memory { - return gpttype_load_state_kv(); + return gpttype_load_state_kv(slot); } bool clear_state_kv() { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 020104940..03e0a08be 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -142,9 +142,8 @@ static int delayed_generated_tokens_limit = 0; std::deque delayed_generated_tokens; //for use with antislop sampling static std::map> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index -static size_t current_savestate_size = 0; -static std::vector current_savestate_buffer; -static std::vector savestate_context_tokens; //for context clones +const int savestate_limit = 3; +static savestate_data savestates[savestate_limit]; inline int kcpp_cpu_has_blas(void) { #if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL) @@ -4327,19 +4326,19 @@ size_t gpttype_calc_new_state_kv() } return 0; } -size_t gpttype_calc_old_state_kv() +size_t gpttype_calc_old_state_kv(int slot) { - return current_savestate_size; + return savestates[slot].current_savestate_size; } -size_t gpttype_calc_old_state_tokencount() +size_t gpttype_calc_old_state_tokencount(int slot) { - return savestate_context_tokens.size(); + return savestates[slot].savestate_context_tokens.size(); } size_t gpttype_calc_new_state_tokencount() { return current_context_tokens.size(); } -size_t gpttype_save_state_kv() +size_t gpttype_save_state_kv(int slot) { if(kcpp_data==nullptr) { @@ -4347,30 +4346,34 @@ size_t gpttype_save_state_kv() } if(file_format == FileFormat::GGUF_GENERIC) { - gpttype_clear_state_kv(false); //JIT free + if (!savestates[slot].current_savestate_buffer.empty()) { //JIT free + savestates[slot].current_savestate_buffer.clear(); + savestates[slot].savestate_context_tokens.clear(); + savestates[slot].current_savestate_size = 0; + } size_t newsize = llama_state_get_size(llama_ctx_v4); try { - if (current_savestate_buffer.capacity() < newsize + 512) { - current_savestate_buffer = std::vector(newsize + 512); + if (savestates[slot].current_savestate_buffer.capacity() < newsize + 512) { + savestates[slot].current_savestate_buffer = std::vector(newsize + 512); } else { - current_savestate_buffer.resize(newsize + 512); + savestates[slot].current_savestate_buffer.resize(newsize + 512); } - current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc + savestates[slot].current_savestate_buffer.resize(newsize + 512); // add some padding. May throw std::bad_alloc } catch (const std::bad_alloc&) { fprintf(stderr, "KV Save State: Failed to allocate %zu bytes.\n", newsize + 512); return 0; } - auto res = llama_state_get_data(llama_ctx_v4, current_savestate_buffer.data(), newsize); + auto res = llama_state_get_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), newsize); if (res > 0) { - current_savestate_size = newsize; - savestate_context_tokens = current_context_tokens; - printf("\nKV Save State: Created SaveState of %zu tokens, costing %zu MB.\n",current_context_tokens.size(),current_savestate_size/(1024*1024)); + savestates[slot].current_savestate_size = newsize; + savestates[slot].savestate_context_tokens = current_context_tokens; + printf("\nKV Save State %d: Created SaveState of %zu tokens, costing %zu MB.\n",slot,current_context_tokens.size(),savestates[slot].current_savestate_size/(1024*1024)); } return res; } return 0; } -bool gpttype_load_state_kv() +bool gpttype_load_state_kv(int slot) { if(kcpp_data==nullptr) { @@ -4378,14 +4381,14 @@ bool gpttype_load_state_kv() } if(file_format == FileFormat::GGUF_GENERIC) { - if (current_savestate_buffer.empty()) { + if (savestates[slot].current_savestate_buffer.empty()) { return false; } - auto res = llama_state_set_data(llama_ctx_v4, current_savestate_buffer.data(), current_savestate_size); + auto res = llama_state_set_data(llama_ctx_v4, savestates[slot].current_savestate_buffer.data(), savestates[slot].current_savestate_size); if(res > 0) { - current_context_tokens = savestate_context_tokens; - printf("\nKV Load SaveState: Restored KV with %zu tokens.\n",current_context_tokens.size()); + current_context_tokens = savestates[slot].savestate_context_tokens; + printf("\nKV Load SaveState %d: Restored KV with %zu tokens.\n", slot,current_context_tokens.size()); } return (res > 0); } @@ -4399,18 +4402,20 @@ bool gpttype_clear_state_kv(bool shrink) } if(file_format == FileFormat::GGUF_GENERIC) { - if (!current_savestate_buffer.empty()) { - printf("\nKV Clear SaveState: Freed %zu MB.\n", current_savestate_size / (1024 * 1024)); - current_savestate_buffer.clear(); - if(shrink) - { - current_savestate_buffer.shrink_to_fit(); + for(int slot=0;slot x.json()) .then(values => { console.log(values); if(values.success) { - document.getElementById("loadstatetxt").innerText = `State Saved (${values.new_tokens} tokens in ${parseInt(values.new_state_size/(1024*1024))} MB)`; + document.getElementById("loadstatetxt").innerText = `State ${slot} Saved (${values.new_tokens} tokens in ${parseInt(values.new_state_size/(1024*1024))} MB)`; }else{ - document.getElementById("loadstatetxt").innerText = `Save State Failed!`; + document.getElementById("loadstatetxt").innerText = `Save State ${slot} Failed!`; } }).catch((error) => { console.log("Error: " + error); - document.getElementById("loadstatetxt").innerText = `Save State Failed!`; + document.getElementById("loadstatetxt").innerText = `Save State ${slot} Failed!`; msgbox(error,"Error"); }); } function trigger_admin_loadstate() { - document.getElementById("loadstatetxt").innerText = "Loading State..."; + let slot = parseInt(document.getElementById("savestate_selection").value); + document.getElementById("loadstatetxt").innerText = `Loading State ${slot}...`; let header = {'Content-Type': 'application/json'}; if(last_admin_key!="") { @@ -11249,20 +11254,23 @@ Current version indicated by LITEVER below. } fetch(custom_kobold_endpoint + koboldcpp_admin_loadstate_endpoint, { method: 'POST', - headers: header + headers: header, + body: JSON.stringify({ + "slot": slot + }) }) .then(x => x.json()) .then(values => { console.log(values); if(values.success) { - document.getElementById("loadstatetxt").innerText = `State Loaded (${values.new_tokens} tokens)`; + document.getElementById("loadstatetxt").innerText = `State ${slot} Loaded (${values.new_tokens} tokens)`; }else{ - document.getElementById("loadstatetxt").innerText = `Load State Failed!`; + document.getElementById("loadstatetxt").innerText = `Load State ${slot} Failed!`; } }).catch((error) => { console.log("Error: " + error); - document.getElementById("loadstatetxt").innerText = `Load State Failed!`; + document.getElementById("loadstatetxt").innerText = `Load State ${slot} Failed!`; msgbox(error,"Error"); }); } @@ -17649,6 +17657,14 @@ Current version indicated by LITEVER below. let pat = new RegExp(localsettings.thinking_pattern, "gmi"); gentxtspeak = gentxtspeak.replace(pat, ''); } + //remove t2i + if (localsettings.img_autogen_type == 2) + { + const pat = /(.*?)<\/t2i>/g; + gentxtspeak = gentxtspeak.replace(pat, ""); + const pat2 = /{{\[IMG_.{1,8}_REF\]}}/g; + gentxtspeak = gentxtspeak.replace(pat2, ""); + } tts_speak(gentxtspeak); } @@ -21185,16 +21201,18 @@ Current version indicated by LITEVER below. let userinput = getInputBoxValue().trim(); try { - if(userinput!="") + if(userinput=="") { - let newjson = JSON.parse(userinput); - pending_wi_obj = pending_wi_obj.filter(item => !currwis.includes(item)); - for (var i = 0; i < newjson.length; ++i) { - newjson[i].wigroup = curr_wi_tab; - pending_wi_obj.push(newjson[i]); - } - update_wi(); + userinput = "[]"; } + let newjson = JSON.parse(userinput); + pending_wi_obj = pending_wi_obj.filter(item => !currwis.includes(item)); + for (var i = 0; i < newjson.length; ++i) { + newjson[i].wigroup = curr_wi_tab; + pending_wi_obj.push(newjson[i]); + } + update_wi(); + } catch (e) { console.log("WI JSON not correctly formatted!"); } @@ -21222,6 +21240,13 @@ Current version indicated by LITEVER below. if(has_tav_wi_check) { wiToAdd = load_tavern_wi(wiToAdd); + if(wiToAdd && wiToAdd.length > 0) + { + for(let i=0;i 0) { @@ -24789,8 +24814,13 @@ Current version indicated by LITEVER below.
Save / Load Context State:
- - + + +
diff --git a/koboldcpp.py b/koboldcpp.py index 0e9d5b3f4..097aa5a89 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -44,6 +44,7 @@ default_draft_amount = 8 default_ttsmaxlen = 4096 default_visionmaxres = 1024 net_save_slots = 10 +savestate_limit = 3 #3 savestate slots # abuse prevention stop_token_max = 256 @@ -522,10 +523,14 @@ def init_library(): handle.get_pending_output.restype = ctypes.c_char_p handle.get_chat_template.restype = ctypes.c_char_p handle.calc_new_state_kv.restype = ctypes.c_size_t - handle.calc_old_state_kv.restype = ctypes.c_size_t handle.calc_new_state_tokencount.restype = ctypes.c_size_t + handle.calc_old_state_kv.argtypes = [ctypes.c_int] + handle.calc_old_state_kv.restype = ctypes.c_size_t + handle.calc_old_state_tokencount.argtypes = [ctypes.c_int] handle.calc_old_state_tokencount.restype = ctypes.c_size_t + handle.save_state_kv.argtypes = [ctypes.c_int] handle.save_state_kv.restype = ctypes.c_size_t + handle.load_state_kv.argtypes = [ctypes.c_int] handle.load_state_kv.restype = ctypes.c_bool handle.clear_state_kv.restype = ctypes.c_bool handle.sd_load_model.argtypes = [sd_load_model_inputs] @@ -3524,23 +3529,42 @@ Change Mode
if self.path.endswith('/api/admin/check_state'): if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword): + cur_states = [] + for sl in range(savestate_limit): #0,1,2 + oldstate = handle.calc_old_state_kv(sl) + oldtokencnt = handle.calc_old_state_tokencount(sl) + cur_states.append({"tokens":oldtokencnt,"size":oldstate}) newstate = handle.calc_new_state_kv() - oldstate = handle.calc_old_state_kv() newtokencnt = handle.calc_new_state_tokencount() - oldtokencnt = handle.calc_old_state_tokencount() - response_body = (json.dumps({"success": True, "old_state_size":oldstate, "old_tokens":oldtokencnt, "new_state_size":newstate, "new_tokens":newtokencnt}).encode()) + response_body = (json.dumps({"success": True, "old_states":cur_states, "new_state_size":newstate, "new_tokens":newtokencnt}).encode()) else: - response_body = (json.dumps({"success": False, "old_state_size":0, "old_tokens":0, "new_state_size":0, "new_tokens":0}).encode()) + response_body = (json.dumps({"success": False, "old_states":[], "new_state_size":0, "new_tokens":0}).encode()) elif self.path.endswith('/api/admin/load_state'): if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword): - result = handle.load_state_kv() + targetslot = 0 + try: + tempbody = json.loads(body) + if isinstance(tempbody, dict): + targetslot = tempbody.get('slot', 0) + except Exception: + pass + targetslot = (targetslot if targetslot0), "new_state_size":result, "new_tokens":tokencnt}).encode()) else: diff --git a/model_adapter.h b/model_adapter.h index ef9fadabf..d5602cd40 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -131,8 +131,8 @@ void ContextFastForward(std::vector ¤t_context_tokens, std::vector current_savestate_buffer; + std::vector savestate_context_tokens; //for context clones +}; + const float default_norm_eps = 1e-5f;