diff --git a/expose.cpp b/expose.cpp index 5e065a16b..17b8d82d1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -372,5 +372,24 @@ extern "C" return output; } - + size_t calc_new_state_kv() // returns how much memory a new savestate will cost + { + return gpttype_calc_new_state_kv(); + } + size_t calc_old_state_kv() //returns how much memory current savestate is using + { + return gpttype_calc_old_state_kv(); + } + bool save_state_kv() //triggers the save kv state of current ctx to memory + { + return gpttype_save_state_kv(); + } + bool load_state_kv() //triggers the load kv state of current ctx to memory + { + return gpttype_load_state_kv(); + } + bool clear_state_kv() + { + return gpttype_clear_state_kv(); + } } diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 31bbf5468..f3167f252 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -142,6 +142,10 @@ static int delayed_generated_tokens_limit = 0; std::deque delayed_generated_tokens; //for use with antislop sampling static std::map> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index +static size_t current_savestate_size = 0; +uint8_t * current_savestate_ptr = nullptr; +static std::vector savestate_context_tokens; //for context clones + inline int kcpp_cpu_has_blas(void) { #if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL) return 1; @@ -4310,3 +4314,87 @@ generation_outputs gpttype_generate(const generation_inputs inputs) generation_finished = true; return output; } + +size_t gpttype_calc_new_state_kv() +{ + if(kcpp_data==nullptr) + { + return 0; + } + if(file_format == FileFormat::GGUF_GENERIC) + { + return llama_state_get_size(llama_ctx_v4); + } + return 0; +} +size_t gpttype_calc_old_state_kv() +{ + return current_savestate_size; +} +bool gpttype_save_state_kv() +{ + if(kcpp_data==nullptr) + { + return false; + } + if(file_format == FileFormat::GGUF_GENERIC) + { + gpttype_clear_state_kv(); //JIT free + size_t newsize = llama_state_get_size(llama_ctx_v4); + current_savestate_ptr = (uint8_t *) malloc(newsize + 512); //add some padding + if(!current_savestate_ptr) + { + return false; + } + auto res = llama_state_get_data(llama_ctx_v4, current_savestate_ptr, newsize); + if (res > 0) { + current_savestate_size = newsize; + savestate_context_tokens = current_context_tokens; + printf("\nKV Save State: Created SaveState of %zu tokens, costing %zu MB.\n",current_context_tokens.size(),current_savestate_size/(1024*1024)); + } + return (res > 0); + } + return false; +} +bool gpttype_load_state_kv() +{ + if(kcpp_data==nullptr) + { + return false; + } + if(file_format == FileFormat::GGUF_GENERIC) + { + if (current_savestate_ptr == nullptr || current_savestate_size == 0) { + return false; + } + auto res = llama_state_set_data(llama_ctx_v4, current_savestate_ptr, current_savestate_size); + if(res > 0) + { + current_context_tokens = savestate_context_tokens; + printf("\nKV Load SaveState: Restored KV with %zu tokens.\n",current_context_tokens.size()); + } + return (res > 0); + } + return false; +} +bool gpttype_clear_state_kv() +{ + if(kcpp_data==nullptr) + { + return false; + } + if(file_format == FileFormat::GGUF_GENERIC) + { + if (current_savestate_ptr != nullptr) { + //JIT free + printf("\nKV Clear SaveState: Freed %zu MB.\n",current_savestate_size/(1024*1024)); + free(current_savestate_ptr); + current_savestate_ptr = nullptr; + savestate_context_tokens.clear(); + current_savestate_size = 0; + return true; + } + return false; + } + return false; +} diff --git a/kcpp_docs.embd b/kcpp_docs.embd index 39d3cc46b..0e3e8464c 100644 --- a/kcpp_docs.embd +++ b/kcpp_docs.embd @@ -1900,6 +1900,132 @@ ] } }, + "/api/admin/check_state": { + "post": { + "description": "Gets the number of bytes taken for existing save state, and predicts the bytes required for a new save state.", + "responses": { + "200": { + "content": { + "application/json": { + "example": { + "success": true, + "old_state": 0, + "new_state": 0 + }, + "schema": { + "properties": { + "success": { + "type": "boolean", + "description": "Whether the operation was successful." + }, + "old_state": { + "type": "number", + "description": "Bytes currently in used for existing save state." + }, + "new_state": { + "type": "number", + "description": "Bytes a new save state is estimated to consume." + } + } + } + } + }, + "description": "Successful request" + } + }, + "summary": "Gets the number of bytes taken for existing save state, and predicts the bytes required for a new save state.", + "tags": [ + "api/admin" + ] + } + }, + "/api/admin/save_state": { + "post": { + "description": "Creates a new KV cache save state in memory. Overwrites any existing saved state.", + "responses": { + "200": { + "content": { + "application/json": { + "example": { + "success": true + }, + "schema": { + "properties": { + "success": { + "type": "boolean", + "description": "Whether the operation was successful." + } + } + } + } + }, + "description": "Successful request" + } + }, + "summary": "Creates a new KV cache save state in memory. Overwrites any existing saved state.", + "tags": [ + "api/admin" + ] + } + }, + "/api/admin/load_state": { + "post": { + "description": "Reloads a previous KV cache save state into context.", + "responses": { + "200": { + "content": { + "application/json": { + "example": { + "success": true + }, + "schema": { + "properties": { + "success": { + "type": "boolean", + "description": "Whether the operation was successful." + } + } + } + } + }, + "description": "Successful request" + } + }, + "summary": "Reloads a previous KV cache save state into context.", + "tags": [ + "api/admin" + ] + } + }, + "/api/admin/clear_state": { + "post": { + "description": "Frees any previous KV cache save state.", + "responses": { + "200": { + "content": { + "application/json": { + "example": { + "success": true + }, + "schema": { + "properties": { + "success": { + "type": "boolean", + "description": "Whether the operation was successful." + } + } + } + } + }, + "description": "Successful request" + } + }, + "summary": "Frees any previous KV cache save state.", + "tags": [ + "api/admin" + ] + } + }, "/api/extra/shutdown": { "post": { "description": "Shuts down the server and exits koboldcpp. Only usable from localhost! Both old and new KoboldCpp Server must have been launched with the --singleinstance flag for this to work.", diff --git a/koboldcpp.py b/koboldcpp.py index 573522951..f26d7b81d 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -521,6 +521,11 @@ def init_library(): handle.token_count.restype = token_count_outputs handle.get_pending_output.restype = ctypes.c_char_p handle.get_chat_template.restype = ctypes.c_char_p + handle.calc_new_state_kv.restype = ctypes.c_size_t + handle.calc_old_state_kv.restype = ctypes.c_size_t + handle.save_state_kv.restype = ctypes.c_bool + handle.load_state_kv.restype = ctypes.c_bool + handle.clear_state_kv.restype = ctypes.c_bool handle.sd_load_model.argtypes = [sd_load_model_inputs] handle.sd_load_model.restype = ctypes.c_bool handle.sd_generate.argtypes = [sd_generation_inputs] @@ -3452,6 +3457,32 @@ Change Mode
resp = {"success": True} response_body = (json.dumps(resp).encode()) + elif self.path.endswith('/api/admin/check_state'): + if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword): + newstate = handle.calc_new_state_kv() + oldstate = handle.calc_old_state_kv() + response_body = (json.dumps({"success": True, "old_state":oldstate, "new_state":newstate}).encode()) + else: + response_body = (json.dumps({"success": False}).encode()) + elif self.path.endswith('/api/admin/load_state'): + if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword): + result = handle.load_state_kv() + response_body = (json.dumps({"success": result}).encode()) + else: + response_body = (json.dumps({"success": False}).encode()) + elif self.path.endswith('/api/admin/save_state'): + if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword): + result = handle.save_state_kv() + response_body = (json.dumps({"success": result}).encode()) + else: + response_body = (json.dumps({"success": False}).encode()) + elif self.path.endswith('/api/admin/clear_state'): + if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword): + result = handle.clear_state_kv() + response_body = (json.dumps({"success": result}).encode()) + else: + response_body = (json.dumps({"success": False}).encode()) + elif self.path.endswith('/set_tts_settings'): #return dummy response response_body = (json.dumps({"message": "Settings successfully applied"}).encode()) diff --git a/model_adapter.h b/model_adapter.h index dc3cde2a1..75a253db8 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -128,3 +128,9 @@ FileFormat check_file_format(const std::string & fname, FileFormatExtraMeta * fi void ContextFastForward(std::vector ¤t_context_tokens, std::vector &embd_inp, int &n_past, std::vector &last_n_tokens, const int nctx, std::vector &smartcontext, const bool useSmartContext, const bool requireFullSubset); + +size_t gpttype_calc_new_state_kv(); +size_t gpttype_calc_old_state_kv(); +bool gpttype_save_state_kv(); +bool gpttype_load_state_kv(); +bool gpttype_clear_state_kv(); \ No newline at end of file