fixed noscript image gen

This commit is contained in:
Concedo 2025-07-13 11:37:52 +08:00
parent 6f4f1b7389
commit 0938af7c83

View file

@ -2531,32 +2531,33 @@ ws ::= | " " | "\n" [ \t]{0,20}
user_message_end = adapter_obj.get("user_end", "") user_message_end = adapter_obj.get("user_end", "")
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "") assistant_message_end = adapter_obj.get("assistant_end", "")
if "{{[INPUT_END]}}" in prompt or "{{[OUTPUT_END]}}" in prompt: if isinstance(prompt, str): #needed because comfy SD uses same field name
prompt = prompt.replace("{{[INPUT]}}", user_message_start) if "{{[INPUT_END]}}" in prompt or "{{[OUTPUT_END]}}" in prompt:
prompt = prompt.replace("{{[OUTPUT]}}", assistant_message_start) prompt = prompt.replace("{{[INPUT]}}", user_message_start)
prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) prompt = prompt.replace("{{[OUTPUT]}}", assistant_message_start)
prompt = prompt.replace("{{[INPUT_END]}}", user_message_end) prompt = prompt.replace("{{[SYSTEM]}}", system_message_start)
prompt = prompt.replace("{{[OUTPUT_END]}}", assistant_message_end) prompt = prompt.replace("{{[INPUT_END]}}", user_message_end)
prompt = prompt.replace("{{[SYSTEM_END]}}", system_message_end) prompt = prompt.replace("{{[OUTPUT_END]}}", assistant_message_end)
memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) prompt = prompt.replace("{{[SYSTEM_END]}}", system_message_end)
memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
memory = memory.replace("{{[SYSTEM]}}", system_message_start) memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
memory = memory.replace("{{[INPUT_END]}}", user_message_end) memory = memory.replace("{{[SYSTEM]}}", system_message_start)
memory = memory.replace("{{[OUTPUT_END]}}", assistant_message_end) memory = memory.replace("{{[INPUT_END]}}", user_message_end)
memory = memory.replace("{{[SYSTEM_END]}}", system_message_end) memory = memory.replace("{{[OUTPUT_END]}}", assistant_message_end)
else: memory = memory.replace("{{[SYSTEM_END]}}", system_message_end)
prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start) else:
prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
prompt = prompt.replace("{{[INPUT_END]}}", "") prompt = prompt.replace("{{[SYSTEM]}}", system_message_start)
prompt = prompt.replace("{{[OUTPUT_END]}}", "") prompt = prompt.replace("{{[INPUT_END]}}", "")
prompt = prompt.replace("{{[SYSTEM_END]}}", "") prompt = prompt.replace("{{[OUTPUT_END]}}", "")
memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) prompt = prompt.replace("{{[SYSTEM_END]}}", "")
memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
memory = memory.replace("{{[SYSTEM]}}", system_message_start) memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
memory = memory.replace("{{[INPUT_END]}}", "") memory = memory.replace("{{[SYSTEM]}}", system_message_start)
memory = memory.replace("{{[OUTPUT_END]}}", "") memory = memory.replace("{{[INPUT_END]}}", "")
memory = memory.replace("{{[SYSTEM_END]}}", "") memory = memory.replace("{{[OUTPUT_END]}}", "")
memory = memory.replace("{{[SYSTEM_END]}}", "")
for i in range(len(stop_sequence)): for i in range(len(stop_sequence)):
if stop_sequence[i] == "{{[INPUT]}}": if stop_sequence[i] == "{{[INPUT]}}":
stop_sequence[i] = user_message_start stop_sequence[i] = user_message_start
@ -2955,7 +2956,7 @@ class KcppServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if args.host!="": if args.host!="":
epurl = f"{httpsaffix}://{args.host}:{args.port}" epurl = f"{httpsaffix}://{args.host}:{args.port}"
if imgmode and imgprompt: if imgmode and imgprompt:
gen_payload = {"prompt":{"3":{"inputs":{"cfg":cfg,"steps":steps}},"6":{"inputs":{"text":imgprompt}}}} gen_payload = {"prompt":{"3":{"class_type": "KSampler","inputs":{"cfg":cfg,"steps":steps,"latent_image":["5", 0],"positive": ["6", 0]}},"5":{"class_type": "EmptyLatentImage","inputs":{"height":512,"width":512}},"6":{"class_type": "CLIPTextEncode","inputs":{"text":imgprompt}}}}
respjson = make_url_request(f'{epurl}/prompt', gen_payload) respjson = make_url_request(f'{epurl}/prompt', gen_payload)
else: else:
gen_payload = {"prompt": prefix+prompt,"max_length": max_length,"temperature": temperature,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token, "stop_sequence":stops} gen_payload = {"prompt": prefix+prompt,"max_length": max_length,"temperature": temperature,"top_k": top_k,"top_p": top_p,"rep_pen": rep_pen,"ban_eos_token":ban_eos_token, "stop_sequence":stops}