diff --git a/class.py b/class.py index c1cd4a81a..0478bef5b 100644 --- a/class.py +++ b/class.py @@ -273,7 +273,8 @@ class model_backend(InferenceModel): unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap, usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas, useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None, - onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=None, nocertify=False, sdconfig=None, mmproj=None, password=None) + onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=None, nocertify=False, sdconfig=None, mmproj=None, + password=None, chatcompletionsadapter=None) #koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server diff --git a/koboldcpp.py b/koboldcpp.py index b3907502b..5771abc6e 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -639,6 +639,7 @@ args = None #global args gui_layers_untouched = True runmode_untouched = True preloaded_story = None +chatcompl_adapter = None sslvalid = False nocertify = False start_time = time.time() @@ -664,7 +665,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pass async def generate_text(self, genparams, api_format, stream_flag): - global friendlymodelname + global friendlymodelname, chatcompl_adapter is_quiet = args.quiet def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat @@ -697,7 +698,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if api_format==4: # translate openai chat completion messages format into one big string. messages_array = genparams.get('messages', []) - adapter_obj = genparams.get('adapter', {}) + default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter + adapter_obj = genparams.get('adapter', default_adapter) messages_string = "" system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") system_message_end = adapter_obj.get("system_end", "") @@ -738,6 +740,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): genparams["prompt"] = messages_string if len(images_added)>0: genparams["images"] = images_added + genparams["stop_sequence"] = [user_message_start.strip(),assistant_message_start.strip()] + genparams["trim_stop"] = True elif api_format==5: firstimg = genparams.get('image', "") @@ -1533,6 +1537,7 @@ def show_new_gui(): customrope_var = ctk.IntVar() customrope_scale = ctk.StringVar(value="1.0") customrope_base = ctk.StringVar(value="10000") + chatcompletionsadapter_var = ctk.StringVar() model_var = ctk.StringVar() lora_var = ctk.StringVar() @@ -1984,6 +1989,7 @@ def show_new_gui(): item.grid_forget() makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.") togglerope(1,1,1) + makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.") # Model Tab model_tab = tabcontent["Model Files"] @@ -2115,6 +2121,8 @@ def show_new_gui(): if customrope_var.get()==1: args.ropeconfig = [float(customrope_scale.get()),float(customrope_base.get())] + args.chatcompletionsadapter = None if chatcompletionsadapter_var.get() == "" else chatcompletionsadapter_var.get() + args.model_param = None if model_var.get() == "" else model_var.get() args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get() @@ -2249,6 +2257,9 @@ def show_new_gui(): if "preloadstory" in dict and dict["preloadstory"]: preloadstory_var.set(dict["preloadstory"]) + if "chatcompletionsadapter" in dict and dict["chatcompletionsadapter"]: + chatcompletionsadapter_var.set(dict["chatcompletionsadapter"]) + if "port_param" in dict and dict["port_param"]: port_var.set(dict["port_param"]) @@ -2739,6 +2750,18 @@ def main(launch_args,start_server=True): else: print(f"Warning: Saved story file {args.preloadstory} invalid or not found. No story will be preloaded into server.") + # try to read chat completions adapter + if args.chatcompletionsadapter: + if isinstance(args.chatcompletionsadapter, str) and os.path.exists(args.chatcompletionsadapter): + print(f"Loading Chat Completions Adapter...") + with open(args.chatcompletionsadapter, 'r') as f: + global chatcompl_adapter + chatcompl_adapter = json.load(f) + print(f"Chat Completions Adapter Loaded") + else: + print(f"Warning: Chat Completions Adapter {args.chatcompletionsadapter} invalid or not found.") + + # sanitize and replace the default vanity name. remember me.... if args.model_param and args.model_param!="": newmdldisplayname = os.path.basename(args.model_param) @@ -3104,5 +3127,6 @@ if __name__ == '__main__': parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="") parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None) parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true') + parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="") main(parser.parse_args(),start_server=True)