added support for OAI chat completions adapter file, added default stop sequences to prevent chat compl leakage

This commit is contained in:
Concedo 2024-04-07 10:35:20 +08:00
parent 0061299cce
commit 6166fdfde4
2 changed files with 28 additions and 3 deletions

View file

@ -273,7 +273,8 @@ class model_backend(InferenceModel):
unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap, unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap,
usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas, usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas,
useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None, useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None,
onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=None, nocertify=False, sdconfig=None, mmproj=None, password=None) onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=None, nocertify=False, sdconfig=None, mmproj=None,
password=None, chatcompletionsadapter=None)
#koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server #koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server

View file

@ -639,6 +639,7 @@ args = None #global args
gui_layers_untouched = True gui_layers_untouched = True
runmode_untouched = True runmode_untouched = True
preloaded_story = None preloaded_story = None
chatcompl_adapter = None
sslvalid = False sslvalid = False
nocertify = False nocertify = False
start_time = time.time() start_time = time.time()
@ -664,7 +665,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
pass pass
async def generate_text(self, genparams, api_format, stream_flag): async def generate_text(self, genparams, api_format, stream_flag):
global friendlymodelname global friendlymodelname, chatcompl_adapter
is_quiet = args.quiet is_quiet = args.quiet
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
@ -697,7 +698,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if api_format==4: if api_format==4:
# translate openai chat completion messages format into one big string. # translate openai chat completion messages format into one big string.
messages_array = genparams.get('messages', []) messages_array = genparams.get('messages', [])
adapter_obj = genparams.get('adapter', {}) default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
adapter_obj = genparams.get('adapter', default_adapter)
messages_string = "" messages_string = ""
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
system_message_end = adapter_obj.get("system_end", "") system_message_end = adapter_obj.get("system_end", "")
@ -738,6 +740,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
genparams["prompt"] = messages_string genparams["prompt"] = messages_string
if len(images_added)>0: if len(images_added)>0:
genparams["images"] = images_added genparams["images"] = images_added
genparams["stop_sequence"] = [user_message_start.strip(),assistant_message_start.strip()]
genparams["trim_stop"] = True
elif api_format==5: elif api_format==5:
firstimg = genparams.get('image', "") firstimg = genparams.get('image', "")
@ -1533,6 +1537,7 @@ def show_new_gui():
customrope_var = ctk.IntVar() customrope_var = ctk.IntVar()
customrope_scale = ctk.StringVar(value="1.0") customrope_scale = ctk.StringVar(value="1.0")
customrope_base = ctk.StringVar(value="10000") customrope_base = ctk.StringVar(value="10000")
chatcompletionsadapter_var = ctk.StringVar()
model_var = ctk.StringVar() model_var = ctk.StringVar()
lora_var = ctk.StringVar() lora_var = ctk.StringVar()
@ -1984,6 +1989,7 @@ def show_new_gui():
item.grid_forget() item.grid_forget()
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.") makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope,tooltiptxt="Override the default RoPE configuration with custom RoPE scaling.")
togglerope(1,1,1) togglerope(1,1,1)
makefileentry(tokens_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 30,tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
# Model Tab # Model Tab
model_tab = tabcontent["Model Files"] model_tab = tabcontent["Model Files"]
@ -2115,6 +2121,8 @@ def show_new_gui():
if customrope_var.get()==1: if customrope_var.get()==1:
args.ropeconfig = [float(customrope_scale.get()),float(customrope_base.get())] args.ropeconfig = [float(customrope_scale.get()),float(customrope_base.get())]
args.chatcompletionsadapter = None if chatcompletionsadapter_var.get() == "" else chatcompletionsadapter_var.get()
args.model_param = None if model_var.get() == "" else model_var.get() args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()]) args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])
args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get() args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get()
@ -2249,6 +2257,9 @@ def show_new_gui():
if "preloadstory" in dict and dict["preloadstory"]: if "preloadstory" in dict and dict["preloadstory"]:
preloadstory_var.set(dict["preloadstory"]) preloadstory_var.set(dict["preloadstory"])
if "chatcompletionsadapter" in dict and dict["chatcompletionsadapter"]:
chatcompletionsadapter_var.set(dict["chatcompletionsadapter"])
if "port_param" in dict and dict["port_param"]: if "port_param" in dict and dict["port_param"]:
port_var.set(dict["port_param"]) port_var.set(dict["port_param"])
@ -2739,6 +2750,18 @@ def main(launch_args,start_server=True):
else: else:
print(f"Warning: Saved story file {args.preloadstory} invalid or not found. No story will be preloaded into server.") print(f"Warning: Saved story file {args.preloadstory} invalid or not found. No story will be preloaded into server.")
# try to read chat completions adapter
if args.chatcompletionsadapter:
if isinstance(args.chatcompletionsadapter, str) and os.path.exists(args.chatcompletionsadapter):
print(f"Loading Chat Completions Adapter...")
with open(args.chatcompletionsadapter, 'r') as f:
global chatcompl_adapter
chatcompl_adapter = json.load(f)
print(f"Chat Completions Adapter Loaded")
else:
print(f"Warning: Chat Completions Adapter {args.chatcompletionsadapter} invalid or not found.")
# sanitize and replace the default vanity name. remember me.... # sanitize and replace the default vanity name. remember me....
if args.model_param and args.model_param!="": if args.model_param and args.model_param!="":
newmdldisplayname = os.path.basename(args.model_param) newmdldisplayname = os.path.basename(args.model_param)
@ -3104,5 +3127,6 @@ if __name__ == '__main__':
parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="") parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="")
parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None) parser.add_argument("--password", help="Enter a password required to use this instance. This key will be required for all text endpoints. Image endpoints are not secured.", default=None)
parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true') parser.add_argument("--ignoremissing", help="Ignores all missing non-essential files, just skipping them instead.", action='store_true')
parser.add_argument("--chatcompletionsadapter", help="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.", default="")
main(parser.parse_args(),start_server=True) main(parser.parse_args(),start_server=True)