diff --git a/koboldcpp.py b/koboldcpp.py index 74ad883e1..fa914e732 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -2298,21 +2298,47 @@ ws ::= | " " | "\n" [ \t]{0,20} memory = genparams.get('memory', "") adapter_obj = {} if chatcompl_adapter is None else chatcompl_adapter system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") + system_message_end = adapter_obj.get("system_end", "") user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") user_message_end = adapter_obj.get("user_end", "") assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") assistant_message_end = adapter_obj.get("assistant_end", "") - prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start) - prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) - prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) - memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) - memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) - memory = memory.replace("{{[SYSTEM]}}", system_message_start) + if "{{[INPUT_END]}}" in prompt or "{{[OUTPUT_END]}}" in prompt: + prompt = prompt.replace("{{[INPUT]}}", user_message_start) + prompt = prompt.replace("{{[OUTPUT]}}", assistant_message_start) + prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) + prompt = prompt.replace("{{[INPUT_END]}}", user_message_end) + prompt = prompt.replace("{{[OUTPUT_END]}}", assistant_message_end) + prompt = prompt.replace("{{[SYSTEM_END]}}", system_message_end) + memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) + memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) + memory = memory.replace("{{[SYSTEM]}}", system_message_start) + memory = memory.replace("{{[INPUT_END]}}", user_message_end) + memory = memory.replace("{{[OUTPUT_END]}}", assistant_message_end) + memory = memory.replace("{{[SYSTEM_END]}}", system_message_end) + else: + prompt = prompt.replace("{{[INPUT]}}", assistant_message_end + user_message_start) + prompt = prompt.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) + prompt = prompt.replace("{{[SYSTEM]}}", system_message_start) + prompt = prompt.replace("{{[INPUT_END]}}", "") + prompt = prompt.replace("{{[OUTPUT_END]}}", "") + prompt = prompt.replace("{{[SYSTEM_END]}}", "") + memory = memory.replace("{{[INPUT]}}", assistant_message_end + user_message_start) + memory = memory.replace("{{[OUTPUT]}}", user_message_end + assistant_message_start) + memory = memory.replace("{{[SYSTEM]}}", system_message_start) + memory = memory.replace("{{[INPUT_END]}}", "") + memory = memory.replace("{{[OUTPUT_END]}}", "") + memory = memory.replace("{{[SYSTEM_END]}}", "") for i in range(len(stop_sequence)): if stop_sequence[i] == "{{[INPUT]}}": stop_sequence[i] = user_message_start elif stop_sequence[i] == "{{[OUTPUT]}}": stop_sequence[i] = assistant_message_start + elif stop_sequence[i] == "{{[INPUT_END]}}": + stop_sequence[i] = (user_message_end if user_message_end.strip()!="" else "") + elif stop_sequence[i] == "{{[OUTPUT_END]}}": + stop_sequence[i] = (assistant_message_end if assistant_message_end.strip()!="" else "") + stop_sequence = list(filter(None, stop_sequence)) genparams["prompt"] = prompt genparams["memory"] = memory genparams["stop_sequence"] = stop_sequence