diff --git a/expose.h b/expose.h index 2242f58a0..38ab57740 100644 --- a/expose.h +++ b/expose.h @@ -114,6 +114,10 @@ struct sd_load_model_inputs const char * vulkan_info; const int threads; const int quant = 0; + const bool taesd = false; + const char * vae_filename; + const char * lora_filename; + const float lora_multiplier = 1.0f; const int debugmode = 0; }; struct sd_generation_inputs @@ -128,6 +132,7 @@ struct sd_generation_inputs const int height; const int seed; const char * sample_method; + const int clip_skip = -1; const bool quiet = false; }; struct sd_generation_outputs diff --git a/koboldcpp.py b/koboldcpp.py index c87d7304f..f1a1328b9 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -107,6 +107,10 @@ class sd_load_model_inputs(ctypes.Structure): ("vulkan_info", ctypes.c_char_p), ("threads", ctypes.c_int), ("quant", ctypes.c_int), + ("taesd", ctypes.c_bool), + ("vae_filename", ctypes.c_char_p), + ("lora_filename", ctypes.c_char_p), + ("lora_multiplier", ctypes.c_float), ("debugmode", ctypes.c_int)] class sd_generation_inputs(ctypes.Structure): @@ -120,6 +124,7 @@ class sd_generation_inputs(ctypes.Structure): ("height", ctypes.c_int), ("seed", ctypes.c_int), ("sample_method", ctypes.c_char_p), + ("clip_skip", ctypes.c_int), ("quiet", ctypes.c_bool)] class sd_generation_outputs(ctypes.Structure): @@ -512,7 +517,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512 return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason} -def sd_load_model(model_filename): +def sd_load_model(model_filename,vae_filename,lora_filename): global args inputs = sd_load_model_inputs() inputs.debugmode = args.debugmode @@ -529,6 +534,10 @@ def sd_load_model(model_filename): inputs.threads = thds inputs.quant = quant + inputs.taesd = True if args.sdvaeauto else False + inputs.vae_filename = vae_filename.encode("UTF-8") + inputs.lora_filename = lora_filename.encode("UTF-8") + inputs.lora_multiplier = args.sdloramult inputs = set_backend_props(inputs) ret = handle.sd_load_model(inputs) return ret @@ -547,6 +556,7 @@ def sd_generate(genparams): seed = genparams.get("seed", -1) sample_method = genparams.get("sampler_name", "k_euler_a") is_quiet = True if args.quiet else False + clip_skip = genparams.get("clip_skip", -1) #clean vars width = width - (width%64) @@ -582,6 +592,7 @@ def sd_generate(genparams): inputs.seed = seed inputs.sample_method = sample_method.lower().encode("UTF-8") inputs.quiet = is_quiet + inputs.clip_skip = clip_skip ret = handle.sd_generate(inputs) outstr = "" if ret.status==1: @@ -3154,12 +3165,25 @@ def main(launch_args,start_server=True): time.sleep(3) sys.exit(2) else: + imglora = "" + imgvae = "" + if args.sdlora: + if os.path.exists(args.sdlora): + imglora = os.path.abspath(args.sdlora) + else: + print(f"Missing SD LORA model file...") + if args.sdvae: + if os.path.exists(args.sdvae): + imgvae = os.path.abspath(args.sdvae) + else: + print(f"Missing SD VAE model file...") + imgmodel = os.path.abspath(imgmodel) fullsdmodelpath = imgmodel friendlysdmodelname = os.path.basename(imgmodel) friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0] friendlysdmodelname = sanitize_string(friendlysdmodelname) - loadok = sd_load_model(imgmodel) + loadok = sd_load_model(imgmodel,imgvae,imglora) print("Load Image Model OK: " + str(loadok)) if not loadok: exitcounter = 999 @@ -3414,6 +3438,11 @@ if __name__ == '__main__': sdparsergroup.add_argument("--sdthreads", metavar=('[threads]'), help="Use a different number of threads for image generation if specified. Otherwise, has the same value as --threads.", type=int, default=0) sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true') sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true') + sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group() + sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify a stable diffusion safetensors VAE which replaces the one in the model.", default="") + sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast.", action='store_true') + sdparsergroup.add_argument("--sdlora", metavar=('[filename]'), help="Specify a stable diffusion LORA safetensors model to be applied.", default="") + sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the LORA model to be applied.", type=float, default=1.0) deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!') deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+') diff --git a/otherarch/sdcpp/sdtype_adapter.cpp b/otherarch/sdcpp/sdtype_adapter.cpp index daa4ece48..a2f408689 100644 --- a/otherarch/sdcpp/sdtype_adapter.cpp +++ b/otherarch/sdcpp/sdtype_adapter.cpp @@ -279,6 +279,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs) sd_params->width = inputs.width; sd_params->height = inputs.height; sd_params->strength = inputs.denoising_strength; + sd_params->clip_skip = inputs.clip_skip; sd_params->mode = (img2img_data==""?SDMode::TXT2IMG:SDMode::IMG2IMG); //for img2img