From a7b79ed2d74aee4d1806204e5850b32035db9cfe Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:22:44 +0800 Subject: [PATCH] smart buffered stop sequence workaround for SSE streaming mode. --- koboldcpp.py | 229 +++++++++++++++++++++++++++++---------------------- 1 file changed, 132 insertions(+), 97 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 6c632d516..94d401596 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -606,6 +606,15 @@ def bring_terminal_to_foreground(): ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9) ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow()) +def string_contains_sequence_substring(inputstr,sequences): + if inputstr.strip()=="": + return False + for s in sequences: + if s.strip()=="": + continue + if s.strip() in inputstr.strip() or inputstr.strip() in s.strip(): + return True + return False ################################################################# ### A hacky simple HTTP server simulating a kobold api by Concedo @@ -646,6 +655,93 @@ start_time = time.time() last_req_time = time.time() last_non_horde_req_time = time.time() +def transform_genparams(genparams, api_format): + #alias all nonstandard alternative names for rep pen. + rp1 = genparams.get('repeat_penalty', 1.0) + rp2 = genparams.get('repetition_penalty', 1.0) + rp3 = genparams.get('rep_pen', 1.0) + rp_max = max(rp1,rp2,rp3) + genparams["rep_pen"] = rp_max + + if api_format==1: + genparams["prompt"] = genparams.get('text', "") + genparams["top_k"] = int(genparams.get('top_k', 120)) + genparams["max_length"] = genparams.get('max', 100) + + elif api_format==3 or api_format==4: + genparams["max_length"] = genparams.get('max_tokens', 100) + presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0)) + genparams["presence_penalty"] = presence_penalty + # openai allows either a string or a list as a stop sequence + if isinstance(genparams.get('stop',[]), list): + genparams["stop_sequence"] = genparams.get('stop', []) + else: + genparams["stop_sequence"] = [genparams.get('stop')] + + genparams["sampler_seed"] = genparams.get('seed', -1) + genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False) + genparams["mirostat"] = genparams.get('mirostat_mode', 0) + + if api_format==4: + # translate openai chat completion messages format into one big string. + messages_array = genparams.get('messages', []) + default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter + adapter_obj = genparams.get('adapter', default_adapter) + messages_string = "" + system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") + system_message_end = adapter_obj.get("system_end", "") + user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") + user_message_end = adapter_obj.get("user_end", "") + assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") + assistant_message_end = adapter_obj.get("assistant_end", "") + images_added = [] + + for message in messages_array: + if message['role'] == "system": + messages_string += system_message_start + elif message['role'] == "user": + messages_string += user_message_start + elif message['role'] == "assistant": + messages_string += assistant_message_start + + # content can be a string or an array of objects + curr_content = message['content'] + if isinstance(curr_content, str): + messages_string += curr_content + elif isinstance(curr_content, list): #is an array + for item in curr_content: + if item['type']=="text": + messages_string += item['text'] + elif item['type']=="image_url": + if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"): + images_added.append(item['image_url']['url'].split(",", 1)[1]) + + if message['role'] == "system": + messages_string += system_message_end + elif message['role'] == "user": + messages_string += user_message_end + elif message['role'] == "assistant": + messages_string += assistant_message_end + + messages_string += assistant_message_start + genparams["prompt"] = messages_string + if len(images_added)>0: + genparams["images"] = images_added + if len(genparams.get('stop_sequence', []))==0: #only set stop seq if it wont overwrite existing + genparams["stop_sequence"] = [user_message_start.strip(),assistant_message_start.strip()] + else: + genparams["stop_sequence"].append(user_message_start.strip()) + genparams["stop_sequence"].append(assistant_message_start.strip()) + genparams["trim_stop"] = True + + elif api_format==5: + firstimg = genparams.get('image', "") + genparams["images"] = [firstimg] + genparams["max_length"] = 32 + genparams["prompt"] = "### Instruction: In one sentence, write a descriptive caption for this image.\n### Response:" + + return genparams + class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sys_version = "" server_version = "ConcedoLlamaForKoboldServer" @@ -669,92 +765,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): from datetime import datetime global friendlymodelname, chatcompl_adapter is_quiet = args.quiet + def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat - #alias all nonstandard alternative names for rep pen. - rp1 = genparams.get('repeat_penalty', 1.0) - rp2 = genparams.get('repetition_penalty', 1.0) - rp3 = genparams.get('rep_pen', 1.0) - rp_max = max(rp1,rp2,rp3) - genparams["rep_pen"] = rp_max - - if api_format==1: - genparams["prompt"] = genparams.get('text', "") - genparams["top_k"] = int(genparams.get('top_k', 120)) - genparams["max_length"] = genparams.get('max', 100) - - elif api_format==3 or api_format==4: - genparams["max_length"] = genparams.get('max_tokens', 100) - presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0)) - genparams["presence_penalty"] = presence_penalty - # openai allows either a string or a list as a stop sequence - if isinstance(genparams.get('stop',[]), list): - genparams["stop_sequence"] = genparams.get('stop', []) - else: - genparams["stop_sequence"] = [genparams.get('stop')] - - genparams["sampler_seed"] = genparams.get('seed', -1) - genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False) - genparams["mirostat"] = genparams.get('mirostat_mode', 0) - - if api_format==4: - # translate openai chat completion messages format into one big string. - messages_array = genparams.get('messages', []) - default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter - adapter_obj = genparams.get('adapter', default_adapter) - messages_string = "" - system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n") - system_message_end = adapter_obj.get("system_end", "") - user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n") - user_message_end = adapter_obj.get("user_end", "") - assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n") - assistant_message_end = adapter_obj.get("assistant_end", "") - images_added = [] - - for message in messages_array: - if message['role'] == "system": - messages_string += system_message_start - elif message['role'] == "user": - messages_string += user_message_start - elif message['role'] == "assistant": - messages_string += assistant_message_start - - # content can be a string or an array of objects - curr_content = message['content'] - if isinstance(curr_content, str): - messages_string += curr_content - elif isinstance(curr_content, list): #is an array - for item in curr_content: - if item['type']=="text": - messages_string += item['text'] - elif item['type']=="image_url": - if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"): - images_added.append(item['image_url']['url'].split(",", 1)[1]) - - if message['role'] == "system": - messages_string += system_message_end - elif message['role'] == "user": - messages_string += user_message_end - elif message['role'] == "assistant": - messages_string += assistant_message_end - - messages_string += assistant_message_start - genparams["prompt"] = messages_string - if len(images_added)>0: - genparams["images"] = images_added - if len(genparams.get('stop_sequence', []))==0: #only set stop seq if it wont overwrite existing - genparams["stop_sequence"] = [user_message_start.strip(),assistant_message_start.strip()] - else: - genparams["stop_sequence"].append(user_message_start.strip()) - genparams["stop_sequence"].append(assistant_message_start.strip()) - genparams["trim_stop"] = True - - elif api_format==5: - firstimg = genparams.get('image', "") - genparams["images"] = [firstimg] - genparams["max_length"] = 32 - genparams["prompt"] = "### Instruction: In one sentence, write a descriptive caption for this image.\n### Response:" - #flag instance as non-idle for a while washordereq = genparams.get('genkey', '').startswith('HORDEREQ_') if not washordereq: @@ -846,7 +859,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.wfile.write(f'data: {data}\n\n'.encode()) self.wfile.flush() - async def handle_sse_stream(self, api_format): + async def handle_sse_stream(self, genparams, api_format): global friendlymodelname self.send_response(200) self.send_header("cache-control", "no-cache") @@ -855,8 +868,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): current_token = 0 incomplete_token_buffer = bytearray() + async_sleep_short = 0.02 await asyncio.sleep(0.25) #anti race condition, prevent check from overtaking generate try: + tokenReserve = "" #keeps fully formed tokens that we cannot send out yet while True: streamDone = handle.has_finished() #exit next loop on done tokenStr = "" @@ -876,19 +891,37 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): tokenStr += tokenSeg if tokenStr!="": - if api_format == 4: # if oai chat, set format to expected openai streaming response - event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) - await self.send_oai_sse_event(event_str) - elif api_format == 3: # non chat completions - event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]}) - await self.send_oai_sse_event(event_str) + sseq = genparams.get('stop_sequence', []) + trimstop = genparams.get('trim_stop', False) + if trimstop and not streamDone and string_contains_sequence_substring(tokenStr,sseq): + tokenReserve += tokenStr + await asyncio.sleep(async_sleep_short) #if a stop sequence could trigger soon, do not send output else: - event_str = json.dumps({"token": tokenStr}) - await self.send_kai_sse_event(event_str) - tokenStr = "" + tokenStr = tokenReserve + tokenStr + tokenReserve = "" + #apply trimming if needed + if trimstop: + for trim_str in sseq: + sindex = tokenStr.find(trim_str) + if sindex != -1 and trim_str!="": + tokenStr = tokenStr[:sindex] + + if tokenStr!="": + if api_format == 4: # if oai chat, set format to expected openai streaming response + event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) + await self.send_oai_sse_event(event_str) + elif api_format == 3: # non chat completions + event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]}) + await self.send_oai_sse_event(event_str) + else: + event_str = json.dumps({"token": tokenStr}) + await self.send_kai_sse_event(event_str) + tokenStr = "" + else: + await asyncio.sleep(async_sleep_short) else: - await asyncio.sleep(0.02) #this should keep things responsive + await asyncio.sleep(async_sleep_short) #this should keep things responsive if streamDone: if api_format == 4 or api_format == 3: # if oai chat, send last [DONE] message consistent with openai format @@ -907,12 +940,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): await asyncio.sleep(0.05) - async def handle_request(self, genparams, api_format, stream_flag): + async def handle_request(self, raw_genparams, api_format, stream_flag): tasks = [] + genparams = transform_genparams(raw_genparams, api_format) + try: if stream_flag: - tasks.append(self.handle_sse_stream(api_format)) + tasks.append(self.handle_sse_stream(genparams, api_format)) generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag)) tasks.append(generate_task)