mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
smart buffered stop sequence workaround for SSE streaming mode.
This commit is contained in:
parent
7f54f5580b
commit
a7b79ed2d7
1 changed files with 132 additions and 97 deletions
95
koboldcpp.py
95
koboldcpp.py
|
@ -606,6 +606,15 @@ def bring_terminal_to_foreground():
|
||||||
ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9)
|
ctypes.windll.user32.ShowWindow(ctypes.windll.kernel32.GetConsoleWindow(), 9)
|
||||||
ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow())
|
ctypes.windll.user32.SetForegroundWindow(ctypes.windll.kernel32.GetConsoleWindow())
|
||||||
|
|
||||||
|
def string_contains_sequence_substring(inputstr,sequences):
|
||||||
|
if inputstr.strip()=="":
|
||||||
|
return False
|
||||||
|
for s in sequences:
|
||||||
|
if s.strip()=="":
|
||||||
|
continue
|
||||||
|
if s.strip() in inputstr.strip() or inputstr.strip() in s.strip():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
#################################################################
|
#################################################################
|
||||||
### A hacky simple HTTP server simulating a kobold api by Concedo
|
### A hacky simple HTTP server simulating a kobold api by Concedo
|
||||||
|
@ -646,31 +655,7 @@ start_time = time.time()
|
||||||
last_req_time = time.time()
|
last_req_time = time.time()
|
||||||
last_non_horde_req_time = time.time()
|
last_non_horde_req_time = time.time()
|
||||||
|
|
||||||
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
def transform_genparams(genparams, api_format):
|
||||||
sys_version = ""
|
|
||||||
server_version = "ConcedoLlamaForKoboldServer"
|
|
||||||
|
|
||||||
def __init__(self, addr, port, embedded_kailite, embedded_kcpp_docs):
|
|
||||||
self.addr = addr
|
|
||||||
self.port = port
|
|
||||||
self.embedded_kailite = embedded_kailite
|
|
||||||
self.embedded_kcpp_docs = embedded_kcpp_docs
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def log_message(self, format, *args):
|
|
||||||
global showdebug
|
|
||||||
if showdebug:
|
|
||||||
super().log_message(format, *args)
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def generate_text(self, genparams, api_format, stream_flag):
|
|
||||||
from datetime import datetime
|
|
||||||
global friendlymodelname, chatcompl_adapter
|
|
||||||
is_quiet = args.quiet
|
|
||||||
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
|
||||||
|
|
||||||
#alias all nonstandard alternative names for rep pen.
|
#alias all nonstandard alternative names for rep pen.
|
||||||
rp1 = genparams.get('repeat_penalty', 1.0)
|
rp1 = genparams.get('repeat_penalty', 1.0)
|
||||||
rp2 = genparams.get('repetition_penalty', 1.0)
|
rp2 = genparams.get('repetition_penalty', 1.0)
|
||||||
|
@ -755,6 +740,34 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
genparams["max_length"] = 32
|
genparams["max_length"] = 32
|
||||||
genparams["prompt"] = "### Instruction: In one sentence, write a descriptive caption for this image.\n### Response:"
|
genparams["prompt"] = "### Instruction: In one sentence, write a descriptive caption for this image.\n### Response:"
|
||||||
|
|
||||||
|
return genparams
|
||||||
|
|
||||||
|
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
sys_version = ""
|
||||||
|
server_version = "ConcedoLlamaForKoboldServer"
|
||||||
|
|
||||||
|
def __init__(self, addr, port, embedded_kailite, embedded_kcpp_docs):
|
||||||
|
self.addr = addr
|
||||||
|
self.port = port
|
||||||
|
self.embedded_kailite = embedded_kailite
|
||||||
|
self.embedded_kcpp_docs = embedded_kcpp_docs
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def log_message(self, format, *args):
|
||||||
|
global showdebug
|
||||||
|
if showdebug:
|
||||||
|
super().log_message(format, *args)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def generate_text(self, genparams, api_format, stream_flag):
|
||||||
|
from datetime import datetime
|
||||||
|
global friendlymodelname, chatcompl_adapter
|
||||||
|
is_quiet = args.quiet
|
||||||
|
|
||||||
|
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||||
|
|
||||||
#flag instance as non-idle for a while
|
#flag instance as non-idle for a while
|
||||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||||
if not washordereq:
|
if not washordereq:
|
||||||
|
@ -846,7 +859,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
self.wfile.write(f'data: {data}\n\n'.encode())
|
self.wfile.write(f'data: {data}\n\n'.encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
async def handle_sse_stream(self, api_format):
|
async def handle_sse_stream(self, genparams, api_format):
|
||||||
global friendlymodelname
|
global friendlymodelname
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header("cache-control", "no-cache")
|
self.send_header("cache-control", "no-cache")
|
||||||
|
@ -855,8 +868,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
|
|
||||||
current_token = 0
|
current_token = 0
|
||||||
incomplete_token_buffer = bytearray()
|
incomplete_token_buffer = bytearray()
|
||||||
|
async_sleep_short = 0.02
|
||||||
await asyncio.sleep(0.25) #anti race condition, prevent check from overtaking generate
|
await asyncio.sleep(0.25) #anti race condition, prevent check from overtaking generate
|
||||||
try:
|
try:
|
||||||
|
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
|
||||||
while True:
|
while True:
|
||||||
streamDone = handle.has_finished() #exit next loop on done
|
streamDone = handle.has_finished() #exit next loop on done
|
||||||
tokenStr = ""
|
tokenStr = ""
|
||||||
|
@ -875,6 +890,23 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
incomplete_token_buffer.clear()
|
incomplete_token_buffer.clear()
|
||||||
tokenStr += tokenSeg
|
tokenStr += tokenSeg
|
||||||
|
|
||||||
|
if tokenStr!="":
|
||||||
|
sseq = genparams.get('stop_sequence', [])
|
||||||
|
trimstop = genparams.get('trim_stop', False)
|
||||||
|
if trimstop and not streamDone and string_contains_sequence_substring(tokenStr,sseq):
|
||||||
|
tokenReserve += tokenStr
|
||||||
|
await asyncio.sleep(async_sleep_short) #if a stop sequence could trigger soon, do not send output
|
||||||
|
else:
|
||||||
|
tokenStr = tokenReserve + tokenStr
|
||||||
|
tokenReserve = ""
|
||||||
|
|
||||||
|
#apply trimming if needed
|
||||||
|
if trimstop:
|
||||||
|
for trim_str in sseq:
|
||||||
|
sindex = tokenStr.find(trim_str)
|
||||||
|
if sindex != -1 and trim_str!="":
|
||||||
|
tokenStr = tokenStr[:sindex]
|
||||||
|
|
||||||
if tokenStr!="":
|
if tokenStr!="":
|
||||||
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
||||||
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
||||||
|
@ -886,9 +918,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
event_str = json.dumps({"token": tokenStr})
|
event_str = json.dumps({"token": tokenStr})
|
||||||
await self.send_kai_sse_event(event_str)
|
await self.send_kai_sse_event(event_str)
|
||||||
tokenStr = ""
|
tokenStr = ""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await asyncio.sleep(0.02) #this should keep things responsive
|
await asyncio.sleep(async_sleep_short)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(async_sleep_short) #this should keep things responsive
|
||||||
|
|
||||||
if streamDone:
|
if streamDone:
|
||||||
if api_format == 4 or api_format == 3: # if oai chat, send last [DONE] message consistent with openai format
|
if api_format == 4 or api_format == 3: # if oai chat, send last [DONE] message consistent with openai format
|
||||||
|
@ -907,12 +940,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
|
||||||
async def handle_request(self, genparams, api_format, stream_flag):
|
async def handle_request(self, raw_genparams, api_format, stream_flag):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
|
genparams = transform_genparams(raw_genparams, api_format)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if stream_flag:
|
if stream_flag:
|
||||||
tasks.append(self.handle_sse_stream(api_format))
|
tasks.append(self.handle_sse_stream(genparams, api_format))
|
||||||
|
|
||||||
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag))
|
||||||
tasks.append(generate_task)
|
tasks.append(generate_task)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue