diff --git a/koboldcpp.py b/koboldcpp.py index a0164a061..f2e6b328f 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -764,6 +764,10 @@ def show_new_gui(): context_var = ctk.IntVar() + customrope_var = ctk.IntVar() + customrope_scale = ctk.StringVar(value="1.0") + customrope_base = ctk.StringVar(value="10000") + model_var = ctk.StringVar() lora_var = ctk.StringVar() lora_base_var = ctk.StringVar() @@ -904,6 +908,19 @@ def show_new_gui(): # context size makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, 4, 20, set=2) + + customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale) + customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base) + def togglerope(a,b,c): + items = [customrope_scale_label, customrope_scale_entry,customrope_base_label, customrope_base_entry] + for idx, item in enumerate(items): + if customrope_var.get() == 1: + item.grid(row=23 + int(idx/2), column=idx%2, padx=8, stick="nw") + else: + item.grid_forget() + makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope) + togglerope(1,1,1) + # Model Tab model_tab = tabcontent["Model"] @@ -996,6 +1013,9 @@ def show_new_gui(): args.mirostat = [int(mirostat_var.get()), float(mirostat_tau.get()), float(mirostat_eta.get())] if usemirostat.get()==1 else None args.contextsize = int(contextsize_text[context_var.get()]) + if customrope_var.get()==1: + args.ropeconfig = [float(customrope_scale.get()),float(customrope_base.get())] + args.model_param = None if model_var.get() == "" else model_var.get() args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) @@ -1046,6 +1066,15 @@ def show_new_gui(): if dict["contextsize"]: context_var.set(contextsize_text.index(str(dict["contextsize"]))) + + if dict["ropeconfig"] and len(dict["ropeconfig"])>1: + if dict["ropeconfig"][0]>0: + customrope_var.set(1) + customrope_scale.set(str(dict["ropeconfig"][0])) + customrope_base.set(str(dict["ropeconfig"][1])) + else: + customrope_var.set(0) + if dict["blasbatchsize"]: blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"]))) if dict["forceversion"]: