wip cfg scale

This commit is contained in:
Concedo 2025-05-06 23:06:25 +08:00
parent 13cee48740
commit 38a8778f24
3 changed files with 90 additions and 10 deletions

View file

@ -198,6 +198,8 @@ class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p),
("memory", ctypes.c_char_p),
("negative_prompt", ctypes.c_char_p),
("guidance_scale", ctypes.c_float),
("images", ctypes.c_char_p * images_max),
("max_context_length", ctypes.c_int),
("max_length", ctypes.c_int),
@ -1247,6 +1249,8 @@ def generate(genparams, stream_flag=False):
prompt = genparams.get('prompt', "")
memory = genparams.get('memory', "")
negative_prompt = genparams.get('negative_prompt', "")
guidance_scale = tryparsefloat(genparams.get('guidance_scale', 1.0),1.0)
images = genparams.get('images', [])
max_context_length = tryparseint(genparams.get('max_context_length', maxctx),maxctx)
max_length = tryparseint(genparams.get('max_length', args.defaultgenamt),args.defaultgenamt)
@ -1327,6 +1331,8 @@ def generate(genparams, stream_flag=False):
inputs = generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8")
inputs.negative_prompt = negative_prompt.encode("UTF-8")
inputs.guidance_scale = guidance_scale
for n in range(images_max):
if not images or n >= len(images):
inputs.images[n] = "".encode("UTF-8")