better passthrough handling

This commit is contained in:
Concedo 2025-05-10 19:11:09 +08:00
parent a62e1dfea1
commit 50e1064ffe
2 changed files with 35 additions and 22 deletions

View file

@ -1282,8 +1282,7 @@ def generate(genparams, stream_flag=False):
xtc_probability = tryparsefloat(genparams.get('xtc_probability', 0),0)
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
seed = tryparseint(genparams.get('sampler_seed', -1),-1)
stop_sequence = (genparams.get('stop_sequence', []) if genparams.get('stop_sequence', []) is not None else [])
stop_sequence = stop_sequence[:stop_token_max]
stop_sequence = genparams.get('stop_sequence', [])
ban_eos_token = genparams.get('ban_eos_token', False)
stream_sse = stream_flag
grammar = genparams.get('grammar', '')
@ -1307,25 +1306,6 @@ def generate(genparams, stream_flag=False):
banned_tokens = genparams.get('banned_tokens', banned_strings)
bypass_eos_token = genparams.get('bypass_eos', False)
custom_token_bans = genparams.get('custom_token_bans', '')
replace_instruct_placeholders = genparams.get('replace_instruct_placeholders', False)
if replace_instruct_placeholders:
adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
user_message_end = adapter_obj.get("user_end", "")
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "")
prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
prompt = prompt.replace("{{[SYSTEM]}}", system_message_start)
memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
memory = memory.replace("{{[SYSTEM]}}", system_message_start)
for i in range(len(stop_sequence)):
if stop_sequence[i] == "{{[INPUT]}}":
stop_sequence[i] = user_message_start
elif stop_sequence[i] == "{{[OUTPUT]}}":
stop_sequence[i] = assistant_message_start
for tok in custom_token_bans.split(','):
tok = tok.strip() # Remove leading/trailing whitespace
@ -2298,6 +2278,34 @@ ws ::= | " " | "\n" [ \t]{0,20}
genparams["ollamasysprompt"] = ollamasysprompt
genparams["ollamabodyprompt"] = ollamabodyprompt
genparams["prompt"] = ollamasysprompt + ollamabodyprompt
#final transformations (universal template replace)
replace_instruct_placeholders = genparams.get('replace_instruct_placeholders', False)
stop_sequence = (genparams.get('stop_sequence', []) if genparams.get('stop_sequence', []) is not None else [])
stop_sequence = stop_sequence[:stop_token_max]
if replace_instruct_placeholders:
prompt = genparams.get('prompt', "")
memory = genparams.get('memory', "")
adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
user_message_end = adapter_obj.get("user_end", "")
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
assistant_message_end = adapter_obj.get("assistant_end", "")
prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
prompt = prompt.replace("{{[SYSTEM]}}", system_message_start)
memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start)
memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start)
memory = memory.replace("{{[SYSTEM]}}", system_message_start)
for i in range(len(stop_sequence)):
if stop_sequence[i] == "{{[INPUT]}}":
stop_sequence[i] = user_message_start
elif stop_sequence[i] == "{{[OUTPUT]}}":
stop_sequence[i] = assistant_message_start
genparams["prompt"] = prompt
genparams["memory"] = memory
genparams["stop_sequence"] = stop_sequence
return genparams
def LaunchWebbrowser(target_url, failedmsg):