mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
smart buffered stop sequence workaround for SSE streaming mode.
This commit is contained in:
parent
7f54f5580b
commit
a7b79ed2d7
1 changed files with 132 additions and 97 deletions
229
koboldcpp.py
229
koboldcpp.py
|
@ -606,6 +606,15 @@ def bring_terminal_to_foreground():
|
|||
ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9)
|
||||
ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow())
|
||||
|
||||
def string_contains_sequence_substring(inputstr,sequences):
|
||||
if inputstr.strip()=="":
|
||||
return False
|
||||
for s in sequences:
|
||||
if s.strip()=="":
|
||||
continue
|
||||
if s.strip() in inputstr.strip() or inputstr.strip() in s.strip():
|
||||
return True
|
||||
return False
|
||||
|
||||
#################################################################
|
||||
### A hacky simple HTTP server simulating a kobold api by Concedo
|
||||
|
@ -646,6 +655,93 @@ start_time = time.time()
|
|||
last_req_time = time.time()
|
||||
last_non_horde_req_time = time.time()
|
||||
|
||||
def transform_genparams(genparams, api_format):
|
||||
#alias all nonstandard alternative names for rep pen.
|
||||
rp1 = genparams.get('repeat_penalty', 1.0)
|
||||
rp2 = genparams.get('repetition_penalty', 1.0)
|
||||
rp3 = genparams.get('rep_pen', 1.0)
|
||||
rp_max = max(rp1,rp2,rp3)
|
||||
genparams["rep_pen"] = rp_max
|
||||
|
||||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 100)
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
genparams["max_length"] = genparams.get('max_tokens', 100)
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
genparams["presence_penalty"] = presence_penalty
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
if isinstance(genparams.get('stop',[]), list):
|
||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||
else:
|
||||
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||
|
||||
genparams["sampler_seed"] = genparams.get('seed', -1)
|
||||
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
|
||||
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
|
||||
|
||||
if api_format==4:
|
||||
# translate openai chat completion messages format into one big string.
|
||||
messages_array = genparams.get('messages', [])
|
||||
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
|
||||
adapter_obj = genparams.get('adapter', default_adapter)
|
||||
messages_string = ""
|
||||
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
|
||||
system_message_end = adapter_obj.get("system_end", "")
|
||||
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
|
||||
user_message_end = adapter_obj.get("user_end", "")
|
||||
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
|
||||
assistant_message_end = adapter_obj.get("assistant_end", "")
|
||||
images_added = []
|
||||
|
||||
for message in messages_array:
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_start
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_start
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_start
|
||||
|
||||
# content can be a string or an array of objects
|
||||
curr_content = message['content']
|
||||
if isinstance(curr_content, str):
|
||||
messages_string += curr_content
|
||||
elif isinstance(curr_content, list): #is an array
|
||||
for item in curr_content:
|
||||
if item['type']=="text":
|
||||
messages_string += item['text']
|
||||
elif item['type']=="image_url":
|
||||
if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"):
|
||||
images_added.append(item['image_url']['url'].split(",", 1)[1])
|
||||
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_end
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_end
|
||||
|
||||
messages_string += assistant_message_start
|
||||
genparams["prompt"] = messages_string
|
||||
if len(images_added)>0:
|
||||
genparams["images"] = images_added
|
||||
if len(genparams.get('stop_sequence', []))==0: #only set stop seq if it wont overwrite existing
|
||||
genparams["stop_sequence"] = [user_message_start.strip(),assistant_message_start.strip()]
|
||||
else:
|
||||
genparams["stop_sequence"].append(user_message_start.strip())
|
||||
genparams["stop_sequence"].append(assistant_message_start.strip())
|
||||
genparams["trim_stop"] = True
|
||||
|
||||
elif api_format==5:
|
||||
firstimg = genparams.get('image', "")
|
||||
genparams["images"] = [firstimg]
|
||||
genparams["max_length"] = 32
|
||||
genparams["prompt"] = "### Instruction: In one sentence, write a descriptive caption for this image.\n### Response:"
|
||||
|
||||
return genparams
|
||||
|
||||
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
sys_version = ""
|
||||
server_version = "ConcedoLlamaForKoboldServer"
|
||||
|
@ -669,92 +765,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
from datetime import datetime
|
||||
global friendlymodelname, chatcompl_adapter
|
||||
is_quiet = args.quiet
|
||||
|
||||
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||
|
||||
#alias all nonstandard alternative names for rep pen.
|
||||
rp1 = genparams.get('repeat_penalty', 1.0)
|
||||
rp2 = genparams.get('repetition_penalty', 1.0)
|
||||
rp3 = genparams.get('rep_pen', 1.0)
|
||||
rp_max = max(rp1,rp2,rp3)
|
||||
genparams["rep_pen"] = rp_max
|
||||
|
||||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 100)
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
genparams["max_length"] = genparams.get('max_tokens', 100)
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
genparams["presence_penalty"] = presence_penalty
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
if isinstance(genparams.get('stop',[]), list):
|
||||
genparams["stop_sequence"] = genparams.get('stop', [])
|
||||
else:
|
||||
genparams["stop_sequence"] = [genparams.get('stop')]
|
||||
|
||||
genparams["sampler_seed"] = genparams.get('seed', -1)
|
||||
genparams["use_default_badwordsids"] = genparams.get('ignore_eos', False)
|
||||
genparams["mirostat"] = genparams.get('mirostat_mode', 0)
|
||||
|
||||
if api_format==4:
|
||||
# translate openai chat completion messages format into one big string.
|
||||
messages_array = genparams.get('messages', [])
|
||||
default_adapter = {} if chatcompl_adapter is None else chatcompl_adapter
|
||||
adapter_obj = genparams.get('adapter', default_adapter)
|
||||
messages_string = ""
|
||||
system_message_start = adapter_obj.get("system_start", "\n### Instruction:\n")
|
||||
system_message_end = adapter_obj.get("system_end", "")
|
||||
user_message_start = adapter_obj.get("user_start", "\n### Instruction:\n")
|
||||
user_message_end = adapter_obj.get("user_end", "")
|
||||
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
|
||||
assistant_message_end = adapter_obj.get("assistant_end", "")
|
||||
images_added = []
|
||||
|
||||
for message in messages_array:
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_start
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_start
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_start
|
||||
|
||||
# content can be a string or an array of objects
|
||||
curr_content = message['content']
|
||||
if isinstance(curr_content, str):
|
||||
messages_string += curr_content
|
||||
elif isinstance(curr_content, list): #is an array
|
||||
for item in curr_content:
|
||||
if item['type']=="text":
|
||||
messages_string += item['text']
|
||||
elif item['type']=="image_url":
|
||||
if item['image_url'] and item['image_url']['url'] and item['image_url']['url'].startswith("data:image"):
|
||||
images_added.append(item['image_url']['url'].split(",", 1)[1])
|
||||
|
||||
if message['role'] == "system":
|
||||
messages_string += system_message_end
|
||||
elif message['role'] == "user":
|
||||
messages_string += user_message_end
|
||||
elif message['role'] == "assistant":
|
||||
messages_string += assistant_message_end
|
||||
|
||||
messages_string += assistant_message_start
|
||||
genparams["prompt"] = messages_string
|
||||
if len(images_added)>0:
|
||||
genparams["images"] = images_added
|
||||
if len(genparams.get('stop_sequence', []))==0: #only set stop seq if it wont overwrite existing
|
||||
genparams["stop_sequence"] = [user_message_start.strip(),assistant_message_start.strip()]
|
||||
else:
|
||||
genparams["stop_sequence"].append(user_message_start.strip())
|
||||
genparams["stop_sequence"].append(assistant_message_start.strip())
|
||||
genparams["trim_stop"] = True
|
||||
|
||||
elif api_format==5:
|
||||
firstimg = genparams.get('image', "")
|
||||
genparams["images"] = [firstimg]
|
||||
genparams["max_length"] = 32
|
||||
genparams["prompt"] = "### Instruction: In one sentence, write a descriptive caption for this image.\n### Response:"
|
||||
|
||||
#flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
if not washordereq:
|
||||
|
@ -846,7 +859,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
self.wfile.write(f'data: {data}\n\n'.encode())
|
||||
self.wfile.flush()
|
||||
|
||||
async def handle_sse_stream(self, api_format):
|
||||
async def handle_sse_stream(self, genparams, api_format):
|
||||
global friendlymodelname
|
||||
self.send_response(200)
|
||||
self.send_header("cache-control", "no-cache")
|
||||
|
@ -855,8 +868,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
current_token = 0
|
||||
incomplete_token_buffer = bytearray()
|
||||
async_sleep_short = 0.02
|
||||
await asyncio.sleep(0.25) #anti race condition, prevent check from overtaking generate
|
||||
try:
|
||||
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
|
||||
while True:
|
||||
streamDone = handle.has_finished() #exit next loop on done
|
||||
tokenStr = ""
|
||||
|
@ -876,19 +891,37 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
tokenStr += tokenSeg
|
||||
|
||||
if tokenStr!="":
|
||||
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
elif api_format == 3: # non chat completions
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
sseq = genparams.get('stop_sequence', [])
|
||||
trimstop = genparams.get('trim_stop', False)
|
||||
if trimstop and not streamDone and string_contains_sequence_substring(tokenStr,sseq):
|
||||
tokenReserve += tokenStr
|
||||
await asyncio.sleep(async_sleep_short) #if a stop sequence could trigger soon, do not send output
|
||||
else:
|
||||
event_str = json.dumps({"token": tokenStr})
|
||||
await self.send_kai_sse_event(event_str)
|
||||
tokenStr = ""
|
||||
tokenStr = tokenReserve + tokenStr
|
||||
tokenReserve = ""
|
||||
|
||||
#apply trimming if needed
|
||||
if trimstop:
|
||||
for trim_str in sseq:
|
||||
sindex = tokenStr.find(trim_str)
|
||||
if sindex != -1 and trim_str!="":
|
||||
tokenStr = tokenStr[:sindex]
|
||||
|
||||
if tokenStr!="":
|
||||
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
elif api_format == 3: # non chat completions
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
else:
|
||||
event_str = json.dumps({"token": tokenStr})
|
||||
await self.send_kai_sse_event(event_str)
|
||||
tokenStr = ""
|
||||
else:
|
||||
await asyncio.sleep(async_sleep_short)
|
||||
else:
|
||||
await asyncio.sleep(0.02) #this should keep things responsive
|
||||
await asyncio.sleep(async_sleep_short) #this should keep things responsive
|
||||
|
||||
if streamDone:
|
||||
if api_format == 4 or api_format == 3: # if oai chat, send last [DONE] message consistent with openai format
|
||||
|
@ -907,12 +940,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
async def handle_request(self, genparams, api_format, stream_flag):
|
||||
async def handle_request(self, raw_genparams, api_format, stream_flag):
|
||||
tasks = []
|
||||
|
||||
genparams = transform_genparams(raw_genparams, api_format)
|
||||
|
||||
try:
|
||||
if stream_flag:
|
||||
tasks.append(self.handle_sse_stream(api_format))
|
||||
tasks.append(self.handle_sse_stream(genparams, api_format))
|
||||
|
||||
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
||||
tasks.append(generate_task)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue