refactor and clean identifiers for sd, fix cmake

This commit is contained in:
Concedo 2024-02-29 18:28:45 +08:00
parent 66134bb36e
commit 5a44d4de2b
9 changed files with 69 additions and 134 deletions

View file

@ -22,6 +22,14 @@ logit_bias_max = 16
bias_min_value = -100.0
bias_max_value = 100.0
class logit_bias(ctypes.Structure):
_fields_ = [("token_id", ctypes.c_int32),
("bias", ctypes.c_float)]
class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]
class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int),
("blasthreads", ctypes.c_int),
@ -49,10 +57,6 @@ class load_model_inputs(ctypes.Structure):
("banned_tokens", ctypes.c_char_p * ban_token_max),
("tensor_split", ctypes.c_float * tensor_split_max)]
class logit_bias(ctypes.Structure):
_fields_ = [("token_id", ctypes.c_int32),
("bias", ctypes.c_float)]
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p),
@ -103,12 +107,9 @@ class sd_generation_inputs(ctypes.Structure):
class sd_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data_length", ctypes.c_uint),
("data", ctypes.c_char_p)]
class token_count_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("ids", ctypes.POINTER(ctypes.c_int))]
handle = None
def getdirpath():
@ -273,10 +274,10 @@ def init_library():
handle.abort_generate.restype = ctypes.c_bool
handle.token_count.restype = token_count_outputs
handle.get_pending_output.restype = ctypes.c_char_p
handle.load_model_sd.argtypes = [sd_load_model_inputs]
handle.load_model_sd.restype = ctypes.c_bool
handle.generate_sd.argtypes = [sd_generation_inputs]
handle.generate_sd.restype = sd_generation_outputs
handle.sd_load_model.argtypes = [sd_load_model_inputs]
handle.sd_load_model.restype = ctypes.c_bool
handle.sd_generate.argtypes = [sd_generation_inputs]
handle.sd_generate.restype = sd_generation_outputs
def load_model(model_filename):
global args
@ -469,14 +470,29 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
return outstr
def load_model_sd(model_filename):
def sd_load_model(model_filename):
global args
inputs = sd_load_model_inputs()
inputs.debugmode = args.debugmode
inputs.model_filename = model_filename.encode("UTF-8")
ret = handle.load_model_sd(inputs)
ret = handle.sd_load_model(inputs)
return ret
def sd_generate(prompt, negative_prompt="", cfg_scale=5, sample_steps=20, seed=-1, sample_method="euler a"):
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
inputs = sd_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.negative_prompt = negative_prompt.encode("UTF-8")
inputs.cfg_scale = cfg_scale
inputs.sample_steps = sample_steps
inputs.seed = seed
inputs.sample_method = sample_method.encode("UTF-8")
ret = handle.sd_generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.data.decode("UTF-8","ignore")
return outstr
def utfprint(str):
try:
print(str)
@ -2567,7 +2583,7 @@ def main(launch_args,start_server=True):
time.sleep(3)
sys.exit(2)
imgmodel = os.path.abspath(imgmodel)
loadok = load_model_sd(imgmodel)
loadok = sd_load_model(imgmodel)
print("Load Image Model OK: " + str(loadok))
if not loadok:
exitcounter = 999