mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
refactor and clean identifiers for sd, fix cmake
This commit is contained in:
parent
66134bb36e
commit
5a44d4de2b
9 changed files with 69 additions and 134 deletions
46
koboldcpp.py
46
koboldcpp.py
|
@ -22,6 +22,14 @@ logit_bias_max = 16
|
|||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
|
||||
class logit_bias(ctypes.Structure):
|
||||
_fields_ = [("token_id", ctypes.c_int32),
|
||||
("bias", ctypes.c_float)]
|
||||
|
||||
class token_count_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||
|
||||
class load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("threads", ctypes.c_int),
|
||||
("blasthreads", ctypes.c_int),
|
||||
|
@ -49,10 +57,6 @@ class load_model_inputs(ctypes.Structure):
|
|||
("banned_tokens", ctypes.c_char_p * ban_token_max),
|
||||
("tensor_split", ctypes.c_float * tensor_split_max)]
|
||||
|
||||
class logit_bias(ctypes.Structure):
|
||||
_fields_ = [("token_id", ctypes.c_int32),
|
||||
("bias", ctypes.c_float)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
("prompt", ctypes.c_char_p),
|
||||
|
@ -103,12 +107,9 @@ class sd_generation_inputs(ctypes.Structure):
|
|||
|
||||
class sd_generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("data_length", ctypes.c_uint),
|
||||
("data", ctypes.c_char_p)]
|
||||
|
||||
class token_count_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||
|
||||
handle = None
|
||||
|
||||
def getdirpath():
|
||||
|
@ -273,10 +274,10 @@ def init_library():
|
|||
handle.abort_generate.restype = ctypes.c_bool
|
||||
handle.token_count.restype = token_count_outputs
|
||||
handle.get_pending_output.restype = ctypes.c_char_p
|
||||
handle.load_model_sd.argtypes = [sd_load_model_inputs]
|
||||
handle.load_model_sd.restype = ctypes.c_bool
|
||||
handle.generate_sd.argtypes = [sd_generation_inputs]
|
||||
handle.generate_sd.restype = sd_generation_outputs
|
||||
handle.sd_load_model.argtypes = [sd_load_model_inputs]
|
||||
handle.sd_load_model.restype = ctypes.c_bool
|
||||
handle.sd_generate.argtypes = [sd_generation_inputs]
|
||||
handle.sd_generate.restype = sd_generation_outputs
|
||||
|
||||
def load_model(model_filename):
|
||||
global args
|
||||
|
@ -469,14 +470,29 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
|
|||
return outstr
|
||||
|
||||
|
||||
def load_model_sd(model_filename):
|
||||
def sd_load_model(model_filename):
|
||||
global args
|
||||
inputs = sd_load_model_inputs()
|
||||
inputs.debugmode = args.debugmode
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
ret = handle.load_model_sd(inputs)
|
||||
ret = handle.sd_load_model(inputs)
|
||||
return ret
|
||||
|
||||
def sd_generate(prompt, negative_prompt="", cfg_scale=5, sample_steps=20, seed=-1, sample_method="euler a"):
|
||||
global maxctx, args, currentusergenkey, totalgens, pendingabortkey
|
||||
inputs = sd_generation_inputs()
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
inputs.negative_prompt = negative_prompt.encode("UTF-8")
|
||||
inputs.cfg_scale = cfg_scale
|
||||
inputs.sample_steps = sample_steps
|
||||
inputs.seed = seed
|
||||
inputs.sample_method = sample_method.encode("UTF-8")
|
||||
ret = handle.sd_generate(inputs)
|
||||
outstr = ""
|
||||
if ret.status==1:
|
||||
outstr = ret.data.decode("UTF-8","ignore")
|
||||
return outstr
|
||||
|
||||
def utfprint(str):
|
||||
try:
|
||||
print(str)
|
||||
|
@ -2567,7 +2583,7 @@ def main(launch_args,start_server=True):
|
|||
time.sleep(3)
|
||||
sys.exit(2)
|
||||
imgmodel = os.path.abspath(imgmodel)
|
||||
loadok = load_model_sd(imgmodel)
|
||||
loadok = sd_load_model(imgmodel)
|
||||
print("Load Image Model OK: " + str(loadok))
|
||||
if not loadok:
|
||||
exitcounter = 999
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue