diff --git a/koboldcpp.py b/koboldcpp.py index 96e47ba87..fe7aedf87 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -478,15 +478,28 @@ def sd_load_model(model_filename): ret = handle.sd_load_model(inputs) return ret -def sd_generate(prompt, negative_prompt="", cfg_scale=5, sample_steps=20, seed=-1, sample_method="euler a"): +def sd_generate(genparams): global maxctx, args, currentusergenkey, totalgens, pendingabortkey + prompt = genparams.get("prompt", "high quality") + negative_prompt = genparams.get("negative_prompt", "") + cfg_scale = genparams.get("cfg_scale", 5) + sample_steps = genparams.get("steps", 20) + seed = genparams.get("seed", -1) + sample_method = genparams.get("sampler_name", "euler a") + + #quick mode + if args.sdconfig and len(args.sdconfig)>1 and args.sdconfig[1]=="quick": + cfg_scale = 1 + sample_steps = 7 + sample_method = "dpm++ 2m karras" + inputs = sd_generation_inputs() inputs.prompt = prompt.encode("UTF-8") inputs.negative_prompt = negative_prompt.encode("UTF-8") inputs.cfg_scale = cfg_scale inputs.sample_steps = sample_steps inputs.seed = seed - inputs.sample_method = sample_method.encode("UTF-8") + inputs.sample_method = sample_method.lower().encode("UTF-8") ret = handle.sd_generate(inputs) outstr = "" if ret.status==1: @@ -512,7 +525,9 @@ def bring_terminal_to_foreground(): ### A hacky simple HTTP server simulating a kobold api by Concedo ### we are intentionally NOT using flask, because we want MINIMAL dependencies ################################################################# -friendlymodelname = "concedo/koboldcpp" # local kobold api apparently needs a hardcoded known HF model name +friendlymodelname = "inactive" +friendlysdmodelname = "inactive" +fullsdmodelpath = "" #if empty, it's not initialized maxctx = 2048 maxhordectx = 2048 maxhordelen = 256 @@ -860,7 +875,7 @@ Enter Prompt:
self.wfile.write(finalhtml) def do_GET(self): - global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey + global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath self.path = self.path.rstrip('/') response_body = None content_type = 'application/json' @@ -920,6 +935,11 @@ Enter Prompt:
elif self.path.endswith('/v1/models'): response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) + elif self.path.endswith('/sdapi/v1/sd-models'): + response_body = (json.dumps([{"title":friendlysdmodelname,"model_name":friendlysdmodelname,"hash":"8888888888","sha256":"8888888888888888888888888888888888888888888888888888888888888888","filename":fullsdmodelpath,"config": None}]).encode()) + elif self.path.endswith('/sdapi/v1/options'): + response_body = (json.dumps({"samples_format":"png","sd_model_checkpoint":friendlysdmodelname}).encode()) + elif self.path=="/api": content_type = 'text/html' if self.embedded_kcpp_docs is None: @@ -1044,6 +1064,7 @@ Enter Prompt:
sse_stream_flag = False api_format = 0 #1=basic,2=kai,3=oai,4=oai-chat + is_txt2img = False if self.path.endswith('/request'): api_format = 1 @@ -1061,7 +1082,10 @@ Enter Prompt:
if self.path.endswith('/v1/chat/completions'): api_format = 4 - if api_format > 0: + if self.path.endswith('/sdapi/v1/txt2img'): + is_txt2img = True + + if is_txt2img or api_format > 0: genparams = None try: genparams = json.loads(body) @@ -1076,25 +1100,43 @@ Enter Prompt:
if args.foreground: bring_terminal_to_foreground() - # Check if streaming chat completions, if so, set stream mode to true - if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]: - sse_stream_flag = True + if api_format > 0:#text gen + # Check if streaming chat completions, if so, set stream mode to true + if (api_format == 4 or api_format == 3) and "stream" in genparams and genparams["stream"]: + sse_stream_flag = True - gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag)) + gen = asyncio.run(self.handle_request(genparams, api_format, sse_stream_flag)) - try: - # Headers are already sent when streaming - if not sse_stream_flag: + try: + # Headers are already sent when streaming + if not sse_stream_flag: + self.send_response(200) + genresp = (json.dumps(gen).encode()) + self.send_header('content-length', str(len(genresp))) + self.end_headers(content_type='application/json') + self.wfile.write(genresp) + except Exception as ex: + if args.debugmode: + print(ex) + print("Generate: The response could not be sent, maybe connection was terminated?") + handle.abort_generate() + time.sleep(0.2) #short delay + return + + elif is_txt2img: #image gen + try: + gen = sd_generate(genparams) + genresp = (json.dumps({"images":[gen],"parameters":{},"info":""}).encode()) self.send_response(200) - genresp = (json.dumps(gen).encode()) self.send_header('content-length', str(len(genresp))) self.end_headers(content_type='application/json') self.wfile.write(genresp) - except Exception as ex: - print("Generate: The response could not be sent, maybe connection was terminated?") - handle.abort_generate() - time.sleep(0.2) #short delay - return + except Exception as ex: + if args.debugmode: + print(ex) + print("Generate Image: The response could not be sent, maybe connection was terminated?") + time.sleep(0.2) #short delay + return finally: modelbusy.release() @@ -2433,7 +2475,7 @@ def sanitize_string(input_string): return sanitized_string def main(launch_args,start_server=True): - global args, friendlymodelname + global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath args = launch_args embedded_kailite = None embedded_kcpp_docs = None @@ -2583,6 +2625,10 @@ def main(launch_args,start_server=True): time.sleep(3) sys.exit(2) imgmodel = os.path.abspath(imgmodel) + fullsdmodelpath = imgmodel + friendlysdmodelname = os.path.basename(imgmodel) + friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0] + friendlysdmodelname = sanitize_string(friendlysdmodelname) loadok = sd_load_model(imgmodel) print("Load Image Model OK: " + str(loadok)) if not loadok: diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index 69febaeae..a038a61c8 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -127,7 +127,7 @@ static void sd_logger_callback(enum sd_log_level_t level, const char* log, void* bool sdtype_load_model(const sd_load_model_inputs inputs) { - printf("\nSelected Image Model: %s\n",inputs.model_filename); + printf("\nImage Gen - Load Safetensors Image Model: %s\n",inputs.model_filename); sd_params = new SDParams(); sd_params->model_path = inputs.model_filename; @@ -187,10 +187,29 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs) sd_params->sample_steps = inputs.sample_steps; sd_params->seed = inputs.seed; - if(inputs.sample_method=="euler a") //all lowercase + printf("\nGenerating Image (%d steps)\n",inputs.sample_steps); + std::string sampler = inputs.sample_method; + + if(sampler=="euler a") //all lowercase { sd_params->sample_method = sample_method_t::EULER_A; } + else if(sampler=="euler") + { + sd_params->sample_method = sample_method_t::EULER; + } + else if(sampler=="heun") + { + sd_params->sample_method = sample_method_t::HEUN; + } + else if(sampler=="dpm2") + { + sd_params->sample_method = sample_method_t::DPM2; + } + else if(sampler=="dpm++ 2m karras" || sampler=="dpm++ 2m") + { + sd_params->sample_method = sample_method_t::DPMPP2M; + } else { sd_params->sample_method = sample_method_t::EULER_A;