image generation is fully working over api (+1 squashed commits)

Squashed commits:

[c98ab0b4] single image generation is working now
This commit is contained in:
Concedo 2024-03-01 14:34:03 +08:00
parent e8f4d7b3da
commit 3463688a0e
5 changed files with 190 additions and 76 deletions

View file

@ -95,6 +95,10 @@ class generation_outputs(ctypes.Structure):
class sd_load_model_inputs(ctypes.Structure):
_fields_ = [("model_filename", ctypes.c_char_p),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
("vulkan_info", ctypes.c_char_p),
("threads", ctypes.c_int),
("debugmode", ctypes.c_int)]
class sd_generation_inputs(ctypes.Structure):
@ -107,7 +111,6 @@ class sd_generation_inputs(ctypes.Structure):
class sd_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data_length", ctypes.c_uint),
("data", ctypes.c_char_p)]
handle = None
@ -279,47 +282,12 @@ def init_library():
handle.sd_generate.argtypes = [sd_generation_inputs]
handle.sd_generate.restype = sd_generation_outputs
def load_model(model_filename):
global args
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
inputs.threads = args.threads
inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)
inputs.use_mmq = (True if (args.usecublas and "mmq" in args.usecublas) else False)
inputs.use_rowsplit = (True if (args.usecublas and "rowsplit" in args.usecublas) else False)
inputs.vulkan_info = "0".encode("UTF-8")
inputs.blasthreads = args.blasthreads
inputs.use_mmap = (not args.nommap)
inputs.use_mlock = args.usemlock
inputs.lora_filename = "".encode("UTF-8")
inputs.lora_base = "".encode("UTF-8")
if args.lora:
inputs.lora_filename = args.lora[0].encode("UTF-8")
inputs.use_mmap = False
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.blasbatchsize = args.blasbatchsize
inputs.forceversion = args.forceversion
inputs.gpulayers = args.gpulayers
inputs.rope_freq_scale = args.ropeconfig[0]
if len(args.ropeconfig)>1:
inputs.rope_freq_base = args.ropeconfig[1]
else:
inputs.rope_freq_base = 10000
def set_backend_props(inputs):
clblastids = 0
if args.useclblast:
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
inputs.clblast_info = clblastids
for n in range(tensor_split_max):
if args.tensor_split and n < len(args.tensor_split):
inputs.tensor_split[n] = float(args.tensor_split[n])
else:
inputs.tensor_split[n] = 0
# we must force an explicit tensor split
# otherwise the default will divide equally and multigpu crap will slow it down badly
inputs.cublas_info = 0
@ -356,6 +324,46 @@ def load_model(model_filename):
inputs.vulkan_info = s.encode("UTF-8")
else:
inputs.vulkan_info = "0".encode("UTF-8")
return inputs
def load_model(model_filename):
global args
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
inputs.threads = args.threads
inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)
inputs.use_mmq = (True if (args.usecublas and "mmq" in args.usecublas) else False)
inputs.use_rowsplit = (True if (args.usecublas and "rowsplit" in args.usecublas) else False)
inputs.vulkan_info = "0".encode("UTF-8")
inputs.blasthreads = args.blasthreads
inputs.use_mmap = (not args.nommap)
inputs.use_mlock = args.usemlock
inputs.lora_filename = "".encode("UTF-8")
inputs.lora_base = "".encode("UTF-8")
if args.lora:
inputs.lora_filename = args.lora[0].encode("UTF-8")
inputs.use_mmap = False
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.blasbatchsize = args.blasbatchsize
inputs.forceversion = args.forceversion
inputs.gpulayers = args.gpulayers
inputs.rope_freq_scale = args.ropeconfig[0]
if len(args.ropeconfig)>1:
inputs.rope_freq_base = args.ropeconfig[1]
else:
inputs.rope_freq_base = 10000
for n in range(tensor_split_max):
if args.tensor_split and n < len(args.tensor_split):
inputs.tensor_split[n] = float(args.tensor_split[n])
else:
inputs.tensor_split[n] = 0
inputs = set_backend_props(inputs)
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
inputs.debugmode = args.debugmode
@ -475,6 +483,13 @@ def sd_load_model(model_filename):
inputs = sd_load_model_inputs()
inputs.debugmode = args.debugmode
inputs.model_filename = model_filename.encode("UTF-8")
thds = args.threads
if len(args.sdconfig) > 2:
sdt = int(args.sdconfig[2])
if sdt > 0:
thds = sdt
inputs.threads = thds
inputs = set_backend_props(inputs)
ret = handle.sd_load_model(inputs)
return ret
@ -1372,6 +1387,7 @@ def show_new_gui():
sd_model_var = ctk.StringVar()
sd_quick_var = ctk.IntVar(value=0)
sd_threads_var = ctk.StringVar()
def tabbuttonaction(name):
for t in tabcontent:
@ -1849,6 +1865,8 @@ def show_new_gui():
images_tab = tabcontent["Image Gen"]
makefileentry(images_tab, "Stable Diffusion Model (f16):", "Select Stable Diffusion Model File", sd_model_var, 1, filetypes=[("*.safetensors","*.safetensors")], tooltiptxt="Select a .safetensors Stable Diffusion model file on disk to be loaded.")
makecheckbox(images_tab, "Quick Mode (Low Quality)", sd_quick_var, 4,tooltiptxt="Force optimal generation settings for speed.")
makelabelentry(images_tab, "Image threads:" , sd_threads_var, 6, 50,"How many threads to use during image generation.\nIf left blank, uses same value as threads.")
# launch
def guilaunch():
@ -1936,7 +1954,7 @@ def show_new_gui():
else:
args.hordeconfig = None if usehorde_var.get() == 0 else [horde_name_var.get(), horde_gen_var.get(), horde_context_var.get(), horde_apikey_var.get(), horde_workername_var.get()]
args.sdconfig = None if sd_model_var.get() == "" else [sd_model_var.get(), ("quick" if sd_quick_var.get()==1 else "normal")]
args.sdconfig = None if sd_model_var.get() == "" else [sd_model_var.get(), ("quick" if sd_quick_var.get()==1 else "normal"),(int(threads_var.get()) if sd_threads_var.get()=="" else int(sd_threads_var.get()))]
def import_vars(dict):
if "threads" in dict:
@ -2065,10 +2083,12 @@ def show_new_gui():
horde_workername_var.set(dict["hordeconfig"][4])
usehorde_var.set("1")
if "sdconfig" in dict and dict["sdconfig"] and len(dict["sdconfig"]) > 1:
if "sdconfig" in dict and dict["sdconfig"] and len(dict["sdconfig"]) > 0:
sd_model_var.set(dict["sdconfig"][0])
if len(dict["sdconfig"]) > 2:
if len(dict["sdconfig"]) > 1:
sd_quick_var.set(1 if dict["sdconfig"][1]=="quick" else 0)
if len(dict["sdconfig"]) > 2:
sd_threads_var.set(str(dict["sdconfig"][2]))
def save_config():
file_type = [("KoboldCpp Settings", "*.kcpps")]
@ -2845,6 +2865,6 @@ if __name__ == '__main__':
parser.add_argument("--quiet", help="Enable quiet mode, which hides generation inputs and outputs in the terminal. Quiet mode is automatically enabled when running --hordeconfig.", action='store_true')
parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick]'), nargs='+')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick] [sd_threads]'), nargs='+')
main(parser.parse_args(),start_server=True)