diff --git a/klite.embd b/klite.embd index 9bd54e664..50fead6b4 100644 --- a/klite.embd +++ b/klite.embd @@ -16398,7 +16398,7 @@ Current version indicated by LITEVER below. gentxt = trim_extra_stop_seqs(gentxt,false); //fix alpaca leakage - if(localsettings.fix_alpaca_leak && (localsettings.opmode == 2 || localsettings.opmode == 3 || localsettings.opmode == 4) && get_instruct_starttag(true).toLowerCase().includes("### instruction")) + if(localsettings.fix_alpaca_leak && (localsettings.opmode == 2 || localsettings.opmode == 3 || localsettings.opmode == 4) && (get_instruct_starttag(true)=="{{[INPUT]}}" || get_instruct_starttag(true).toLowerCase().includes("### instruction"))) { let matches = gentxt.match(/\n### (instruction|response)\n|\n### ([^\s]+?):\n/gi); for(let m in matches) @@ -16540,6 +16540,11 @@ Current version indicated by LITEVER below. { let st = get_instruct_starttag(true); let et = get_instruct_endtag(true); + if(st=="{{[INPUT]}}" || et=="{{[OUTPUT]}}") + { + st = "### Instruction:"; + et = "### Response:"; + } let stet_et = ""; if(localsettings.separate_end_tags && get_instruct_endtag_end(true)) { diff --git a/koboldcpp.py b/koboldcpp.py index d789d5846..adb81db51 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -1282,8 +1282,7 @@ def generate(genparams, stream_flag=False): xtc_probability = tryparsefloat(genparams.get('xtc_probability', 0),0) sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5]) seed = tryparseint(genparams.get('sampler_seed', -1),-1) - stop_sequence = (genparams.get('stop_sequence', []) if genparams.get('stop_sequence', []) is not None else []) - stop_sequence = stop_sequence[:stop_token_max] + stop_sequence = genparams.get('stop_sequence', []) ban_eos_token = genparams.get('ban_eos_token', False) stream_sse = stream_flag grammar = genparams.get('grammar', '') @@ -1307,25 +1306,6 @@ def generate(genparams, stream_flag=False): banned_tokens = genparams.get('banned_tokens', banned_strings) bypass_eos_token = genparams.get('bypass_eos', False) custom_token_bans = genparams.get('custom_token_bans', '') - replace_instruct_placeholders = genparams.get('replace_instruct_placeholders', False) - if replace_instruct_placeholders: - adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter - system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") - user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") - user_message_end = adapter_obj.get("user_end", "") - assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") - assistant_message_end = adapter_obj.get("assistant_end", "") - prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start) - prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) - prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) - memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) - memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) - memory = memory.replace("{{[SYSTEM]}}", system_message_start) - for i in range(len(stop_sequence)): - if stop_sequence[i] == "{{[INPUT]}}": - stop_sequence[i] = user_message_start - elif stop_sequence[i] == "{{[OUTPUT]}}": - stop_sequence[i] = assistant_message_start for tok in custom_token_bans.split(','): tok = tok.strip() # Remove leading/trailing whitespace @@ -2298,6 +2278,34 @@ ws ::= | " " | "\n" [ \t]{0,20} genparams["ollamasysprompt"] = ollamasysprompt genparams["ollamabodyprompt"] = ollamabodyprompt genparams["prompt"] = ollamasysprompt + ollamabodyprompt + + #final transformations (universal template replace) + replace_instruct_placeholders = genparams.get('replace_instruct_placeholders', False) + stop_sequence = (genparams.get('stop_sequence', []) if genparams.get('stop_sequence', []) is not None else []) + stop_sequence = stop_sequence[:stop_token_max] + if replace_instruct_placeholders: + prompt = genparams.get('prompt', "") + memory = genparams.get('memory', "") + adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter + system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") + user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") + user_message_end = adapter_obj.get("user_end", "") + assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") + assistant_message_end = adapter_obj.get("assistant_end", "") + prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start) + prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) + prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) + memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) + memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) + memory = memory.replace("{{[SYSTEM]}}", system_message_start) + for i in range(len(stop_sequence)): + if stop_sequence[i] == "{{[INPUT]}}": + stop_sequence[i] = user_message_start + elif stop_sequence[i] == "{{[OUTPUT]}}": + stop_sequence[i] = assistant_message_start + genparams["prompt"] = prompt + genparams["memory"] = memory + genparams["stop_sequence"] = stop_sequence return genparams def LaunchWebbrowser(target_url, failedmsg):