diff --git a/expose.cpp b/expose.cpp index a1b1612fd..d4c1877d7 100644 --- a/expose.cpp +++ b/expose.cpp @@ -219,6 +219,10 @@ extern "C" { return sdtype_generate(inputs); } + sd_info_outputs sd_get_info() + { + return sdtype_get_info(); + } bool whisper_load_model(const whisper_load_model_inputs inputs) { diff --git a/expose.h b/expose.h index c40fad970..a7116aa42 100644 --- a/expose.h +++ b/expose.h @@ -223,6 +223,11 @@ struct sd_generation_outputs const char * data = ""; const char * data_extra = ""; }; +struct sd_info_outputs +{ + int status = -1; + const char * data = ""; +}; struct whisper_load_model_inputs { diff --git a/koboldcpp.py b/koboldcpp.py index 4d6b462f2..733efbd27 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -343,6 +343,10 @@ class sd_generation_outputs(ctypes.Structure): ("data", ctypes.c_char_p), ("data_extra", ctypes.c_char_p)] +class sd_info_outputs(ctypes.Structure): + _fields_ = [("status", ctypes.c_int), + ("data", ctypes.c_char_p)] + class whisper_load_model_inputs(ctypes.Structure): _fields_ = [("model_filename", ctypes.c_char_p), ("executable_path", ctypes.c_char_p), @@ -624,6 +628,8 @@ def init_library(): handle.sd_load_model.restype = ctypes.c_bool handle.sd_generate.argtypes = [sd_generation_inputs] handle.sd_generate.restype = sd_generation_outputs + handle.sd_get_info.argtypes = [] + handle.sd_get_info.restype = sd_info_outputs handle.whisper_load_model.argtypes = [whisper_load_model_inputs] handle.whisper_load_model.restype = ctypes.c_bool handle.whisper_generate.argtypes = [whisper_generation_inputs] @@ -1769,6 +1775,21 @@ def generate(genparams, stream_flag=False): outstr = outstr[:sindex] return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens} +def sd_get_info(): + info = handle.sd_get_info() + if info.status == 0: + try: + return json.loads(info.data) + except Exception: + print("An error occurred while decoding sd metadata info") + else: + print("An error occurred while getting sd metadata info") + return {} + +def sd_get_available_schedulers(): + info = sd_get_info() + return info.get('available_schedulers', []) + sd_convdirect_choices = ['off', 'vaeonly', 'full'] def sd_convdirect_option(value): @@ -3773,7 +3794,7 @@ Change Mode
if friendlysdmodelname=="inactive" or fullsdmodelpath=="": response_body = (json.dumps([]).encode()) else: - response_body = (json.dumps([{"name":name,"label":name} for name in ["default","discrete","karras","exponential","ays","gits","sgm_uniform","simple","smoothstep","kl_optimal","lcm"]]).encode()) + response_body = (json.dumps([{"name":name,"label":name} for name in sd_get_available_schedulers()]).encode()) elif clean_path.endswith('/sdapi/v1/latent-upscale-modes'): response_body = (json.dumps([]).encode()) elif clean_path.endswith('/sdapi/v1/upscalers'): diff --git a/model_adapter.h b/model_adapter.h index e8a547acb..939c42cb6 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -107,6 +107,7 @@ const std::vector gpttype_get_top_picks_data(); bool sdtype_load_model(const sd_load_model_inputs inputs); sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs); +sd_info_outputs sdtype_get_info(); bool whispertype_load_model(const whisper_load_model_inputs inputs); whisper_generation_outputs whispertype_generate(const whisper_generation_inputs inputs); @@ -142,4 +143,4 @@ bool gpttype_load_state_kv(int slot); bool gpttype_clear_state_kv(bool shrink); int get_oldest_slot(int excludeSlotId); void touch_slot(int slot); -int get_identical_existing_slot(); \ No newline at end of file +int get_identical_existing_slot(); diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index 7cba5b045..1013a6862 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -1256,3 +1257,26 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs) total_img_gens += 1; return output; } + +sd_info_outputs sdtype_get_info() +{ + using json = nlohmann::json; + json j; + + auto available_schedulers = json::array(); + available_schedulers.push_back("default"); + for (int i = 0; i < scheduler_t::SCHEDULER_COUNT; i++) { + std::string name = sd_scheduler_name((scheduler_t)i); + if (name != "NONE") { + available_schedulers.push_back(name); + } + } + j["available_schedulers"] = available_schedulers; + + static std::string recent_info = j.dump(); + sd_info_outputs output; + output.status = 0; + output.data = recent_info.c_str(); + return output; +} +