mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
wip sd
This commit is contained in:
parent
bd95ee7d9a
commit
4807b66907
3 changed files with 37 additions and 2 deletions
5
expose.h
5
expose.h
|
@ -114,6 +114,10 @@ struct sd_load_model_inputs
|
||||||
const char * vulkan_info;
|
const char * vulkan_info;
|
||||||
const int threads;
|
const int threads;
|
||||||
const int quant = 0;
|
const int quant = 0;
|
||||||
|
const bool taesd = false;
|
||||||
|
const char * vae_filename;
|
||||||
|
const char * lora_filename;
|
||||||
|
const float lora_multiplier = 1.0f;
|
||||||
const int debugmode = 0;
|
const int debugmode = 0;
|
||||||
};
|
};
|
||||||
struct sd_generation_inputs
|
struct sd_generation_inputs
|
||||||
|
@ -128,6 +132,7 @@ struct sd_generation_inputs
|
||||||
const int height;
|
const int height;
|
||||||
const int seed;
|
const int seed;
|
||||||
const char * sample_method;
|
const char * sample_method;
|
||||||
|
const int clip_skip = -1;
|
||||||
const bool quiet = false;
|
const bool quiet = false;
|
||||||
};
|
};
|
||||||
struct sd_generation_outputs
|
struct sd_generation_outputs
|
||||||
|
|
33
koboldcpp.py
33
koboldcpp.py
|
@ -107,6 +107,10 @@ class sd_load_model_inputs(ctypes.Structure):
|
||||||
("vulkan_info", ctypes.c_char_p),
|
("vulkan_info", ctypes.c_char_p),
|
||||||
("threads", ctypes.c_int),
|
("threads", ctypes.c_int),
|
||||||
("quant", ctypes.c_int),
|
("quant", ctypes.c_int),
|
||||||
|
("taesd", ctypes.c_bool),
|
||||||
|
("vae_filename", ctypes.c_char_p),
|
||||||
|
("lora_filename", ctypes.c_char_p),
|
||||||
|
("lora_multiplier", ctypes.c_float),
|
||||||
("debugmode", ctypes.c_int)]
|
("debugmode", ctypes.c_int)]
|
||||||
|
|
||||||
class sd_generation_inputs(ctypes.Structure):
|
class sd_generation_inputs(ctypes.Structure):
|
||||||
|
@ -120,6 +124,7 @@ class sd_generation_inputs(ctypes.Structure):
|
||||||
("height", ctypes.c_int),
|
("height", ctypes.c_int),
|
||||||
("seed", ctypes.c_int),
|
("seed", ctypes.c_int),
|
||||||
("sample_method", ctypes.c_char_p),
|
("sample_method", ctypes.c_char_p),
|
||||||
|
("clip_skip", ctypes.c_int),
|
||||||
("quiet", ctypes.c_bool)]
|
("quiet", ctypes.c_bool)]
|
||||||
|
|
||||||
class sd_generation_outputs(ctypes.Structure):
|
class sd_generation_outputs(ctypes.Structure):
|
||||||
|
@ -512,7 +517,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
|
||||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
|
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
|
||||||
|
|
||||||
|
|
||||||
def sd_load_model(model_filename):
|
def sd_load_model(model_filename,vae_filename,lora_filename):
|
||||||
global args
|
global args
|
||||||
inputs = sd_load_model_inputs()
|
inputs = sd_load_model_inputs()
|
||||||
inputs.debugmode = args.debugmode
|
inputs.debugmode = args.debugmode
|
||||||
|
@ -529,6 +534,10 @@ def sd_load_model(model_filename):
|
||||||
|
|
||||||
inputs.threads = thds
|
inputs.threads = thds
|
||||||
inputs.quant = quant
|
inputs.quant = quant
|
||||||
|
inputs.taesd = True if args.sdvaeauto else False
|
||||||
|
inputs.vae_filename = vae_filename.encode("UTF-8")
|
||||||
|
inputs.lora_filename = lora_filename.encode("UTF-8")
|
||||||
|
inputs.lora_multiplier = args.sdloramult
|
||||||
inputs = set_backend_props(inputs)
|
inputs = set_backend_props(inputs)
|
||||||
ret = handle.sd_load_model(inputs)
|
ret = handle.sd_load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
@ -547,6 +556,7 @@ def sd_generate(genparams):
|
||||||
seed = genparams.get("seed", -1)
|
seed = genparams.get("seed", -1)
|
||||||
sample_method = genparams.get("sampler_name", "k_euler_a")
|
sample_method = genparams.get("sampler_name", "k_euler_a")
|
||||||
is_quiet = True if args.quiet else False
|
is_quiet = True if args.quiet else False
|
||||||
|
clip_skip = genparams.get("clip_skip", -1)
|
||||||
|
|
||||||
#clean vars
|
#clean vars
|
||||||
width = width - (width%64)
|
width = width - (width%64)
|
||||||
|
@ -582,6 +592,7 @@ def sd_generate(genparams):
|
||||||
inputs.seed = seed
|
inputs.seed = seed
|
||||||
inputs.sample_method = sample_method.lower().encode("UTF-8")
|
inputs.sample_method = sample_method.lower().encode("UTF-8")
|
||||||
inputs.quiet = is_quiet
|
inputs.quiet = is_quiet
|
||||||
|
inputs.clip_skip = clip_skip
|
||||||
ret = handle.sd_generate(inputs)
|
ret = handle.sd_generate(inputs)
|
||||||
outstr = ""
|
outstr = ""
|
||||||
if ret.status==1:
|
if ret.status==1:
|
||||||
|
@ -3154,12 +3165,25 @@ def main(launch_args,start_server=True):
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
else:
|
else:
|
||||||
|
imglora = ""
|
||||||
|
imgvae = ""
|
||||||
|
if args.sdlora:
|
||||||
|
if os.path.exists(args.sdlora):
|
||||||
|
imglora = os.path.abspath(args.sdlora)
|
||||||
|
else:
|
||||||
|
print(f"Missing SD LORA model file...")
|
||||||
|
if args.sdvae:
|
||||||
|
if os.path.exists(args.sdvae):
|
||||||
|
imgvae = os.path.abspath(args.sdvae)
|
||||||
|
else:
|
||||||
|
print(f"Missing SD VAE model file...")
|
||||||
|
|
||||||
imgmodel = os.path.abspath(imgmodel)
|
imgmodel = os.path.abspath(imgmodel)
|
||||||
fullsdmodelpath = imgmodel
|
fullsdmodelpath = imgmodel
|
||||||
friendlysdmodelname = os.path.basename(imgmodel)
|
friendlysdmodelname = os.path.basename(imgmodel)
|
||||||
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
|
friendlysdmodelname = os.path.splitext(friendlysdmodelname)[0]
|
||||||
friendlysdmodelname = sanitize_string(friendlysdmodelname)
|
friendlysdmodelname = sanitize_string(friendlysdmodelname)
|
||||||
loadok = sd_load_model(imgmodel)
|
loadok = sd_load_model(imgmodel,imgvae,imglora)
|
||||||
print("Load Image Model OK: " + str(loadok))
|
print("Load Image Model OK: " + str(loadok))
|
||||||
if not loadok:
|
if not loadok:
|
||||||
exitcounter = 999
|
exitcounter = 999
|
||||||
|
@ -3414,6 +3438,11 @@ if __name__ == '__main__':
|
||||||
sdparsergroup.add_argument("--sdthreads", metavar=('[threads]'), help="Use a different number of threads for image generation if specified. Otherwise, has the same value as --threads.", type=int, default=0)
|
sdparsergroup.add_argument("--sdthreads", metavar=('[threads]'), help="Use a different number of threads for image generation if specified. Otherwise, has the same value as --threads.", type=int, default=0)
|
||||||
sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
|
sdparsergroup.add_argument("--sdquant", help="If specified, loads the model quantized to save memory.", action='store_true')
|
||||||
sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true')
|
sdparsergroup.add_argument("--sdclamped", help="If specified, limit generation steps and resolution settings for shared use.", action='store_true')
|
||||||
|
sdparsergroupvae = sdparsergroup.add_mutually_exclusive_group()
|
||||||
|
sdparsergroupvae.add_argument("--sdvae", metavar=('[filename]'), help="Specify a stable diffusion safetensors VAE which replaces the one in the model.", default="")
|
||||||
|
sdparsergroupvae.add_argument("--sdvaeauto", help="Uses a built-in VAE via TAE SD, which is very fast.", action='store_true')
|
||||||
|
sdparsergroup.add_argument("--sdlora", metavar=('[filename]'), help="Specify a stable diffusion LORA safetensors model to be applied.", default="")
|
||||||
|
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the LORA model to be applied.", type=float, default=1.0)
|
||||||
|
|
||||||
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
|
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
|
||||||
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
|
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
|
||||||
|
|
|
@ -279,6 +279,7 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
|
||||||
sd_params->width = inputs.width;
|
sd_params->width = inputs.width;
|
||||||
sd_params->height = inputs.height;
|
sd_params->height = inputs.height;
|
||||||
sd_params->strength = inputs.denoising_strength;
|
sd_params->strength = inputs.denoising_strength;
|
||||||
|
sd_params->clip_skip = inputs.clip_skip;
|
||||||
sd_params->mode = (img2img_data==""?SDMode::TXT2IMG:SDMode::IMG2IMG);
|
sd_params->mode = (img2img_data==""?SDMode::TXT2IMG:SDMode::IMG2IMG);
|
||||||
|
|
||||||
//for img2img
|
//for img2img
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue