mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
expose stop reason in generation
This commit is contained in:
parent
327682fb97
commit
4ec8a9c57b
4 changed files with 184 additions and 85 deletions
65
koboldcpp.py
65
koboldcpp.py
|
@ -95,6 +95,7 @@ class generation_inputs(ctypes.Structure):
|
|||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("stopreason", ctypes.c_int),
|
||||
("text", ctypes.c_char_p)]
|
||||
|
||||
class sd_load_model_inputs(ctypes.Structure):
|
||||
|
@ -493,7 +494,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
|
|||
if pendingabortkey!="" and pendingabortkey==genkey:
|
||||
print(f"\nDeferred Abort for GenKey: {pendingabortkey}")
|
||||
pendingabortkey = ""
|
||||
return ""
|
||||
return {"text":"","status":-1,"stopreason":-1}
|
||||
else:
|
||||
ret = handle.generate(inputs)
|
||||
outstr = ""
|
||||
|
@ -504,7 +505,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
|
|||
sindex = outstr.find(trim_str)
|
||||
if sindex != -1 and trim_str!="":
|
||||
outstr = outstr[:sindex]
|
||||
return outstr
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
|
||||
|
||||
|
||||
def sd_load_model(model_filename):
|
||||
|
@ -656,6 +657,7 @@ nocertify = False
|
|||
start_time = time.time()
|
||||
last_req_time = time.time()
|
||||
last_non_horde_req_time = time.time()
|
||||
currfinishreason = "null"
|
||||
|
||||
def transform_genparams(genparams, api_format):
|
||||
#alias all nonstandard alternative names for rep pen.
|
||||
|
@ -765,8 +767,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
async def generate_text(self, genparams, api_format, stream_flag):
|
||||
from datetime import datetime
|
||||
global friendlymodelname, chatcompl_adapter
|
||||
global friendlymodelname, chatcompl_adapter, currfinishreason
|
||||
is_quiet = args.quiet
|
||||
currfinishreason = "null"
|
||||
|
||||
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
|
||||
|
||||
|
@ -812,13 +815,16 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
render_special=genparams.get('render_special', False),
|
||||
)
|
||||
|
||||
recvtxt = ""
|
||||
genout = {"text":"","status":-1,"stopreason":-1}
|
||||
if stream_flag:
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor()
|
||||
recvtxt = await loop.run_in_executor(executor, run_blocking)
|
||||
genout = await loop.run_in_executor(executor, run_blocking)
|
||||
else:
|
||||
recvtxt = run_blocking()
|
||||
genout = run_blocking()
|
||||
|
||||
recvtxt = genout['text']
|
||||
currfinishreason = ("length" if (genout['stopreason']!=1) else "stop")
|
||||
|
||||
#flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
|
@ -834,15 +840,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
elif api_format==3:
|
||||
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]}
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
|
||||
elif api_format==4:
|
||||
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
|
||||
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": "length"}]}
|
||||
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": currfinishreason}]}
|
||||
elif api_format==5:
|
||||
res = {"caption": end_trim_to_sentence(recvtxt)}
|
||||
else:
|
||||
res = {"results": [{"text": recvtxt}]}
|
||||
res = {"results": [{"text": recvtxt, "finish_reason":currfinishreason}]}
|
||||
|
||||
try:
|
||||
return res
|
||||
|
@ -863,7 +869,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
self.wfile.flush()
|
||||
|
||||
async def handle_sse_stream(self, genparams, api_format):
|
||||
global friendlymodelname
|
||||
global friendlymodelname, currfinishreason
|
||||
self.send_response(200)
|
||||
self.send_header("cache-control", "no-cache")
|
||||
self.send_header("connection", "keep-alive")
|
||||
|
@ -877,6 +883,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
|
||||
while True:
|
||||
streamDone = handle.has_finished() #exit next loop on done
|
||||
if streamDone:
|
||||
sr = handle.get_last_stop_reason()
|
||||
currfinishreason = ("length" if (sr!=1) else "stop")
|
||||
tokenStr = ""
|
||||
streamcount = handle.get_stream_count()
|
||||
while current_token < streamcount:
|
||||
|
@ -893,32 +902,33 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
incomplete_token_buffer.clear()
|
||||
tokenStr += tokenSeg
|
||||
|
||||
if tokenStr!="":
|
||||
if tokenStr!="" or streamDone:
|
||||
sseq = genparams.get('stop_sequence', [])
|
||||
trimstop = genparams.get('trim_stop', False)
|
||||
if trimstop and not streamDone and string_contains_sequence_substring(tokenStr,sseq):
|
||||
tokenReserve += tokenStr
|
||||
await asyncio.sleep(async_sleep_short) #if a stop sequence could trigger soon, do not send output
|
||||
else:
|
||||
tokenStr = tokenReserve + tokenStr
|
||||
tokenReserve = ""
|
||||
|
||||
#apply trimming if needed
|
||||
if trimstop:
|
||||
for trim_str in sseq:
|
||||
sindex = tokenStr.find(trim_str)
|
||||
if sindex != -1 and trim_str!="":
|
||||
tokenStr = tokenStr[:sindex]
|
||||
|
||||
if tokenStr!="":
|
||||
tokenStr = tokenReserve + tokenStr
|
||||
tokenReserve = ""
|
||||
|
||||
#apply trimming if needed
|
||||
if trimstop:
|
||||
for trim_str in sseq:
|
||||
sindex = tokenStr.find(trim_str)
|
||||
if sindex != -1 and trim_str!="":
|
||||
tokenStr = tokenStr[:sindex]
|
||||
|
||||
if tokenStr!="" or streamDone:
|
||||
if api_format == 4: # if oai chat, set format to expected openai streaming response
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]})
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":currfinishreason,"delta":{'role':'assistant','content':tokenStr}}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
elif api_format == 3: # non chat completions
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]})
|
||||
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":currfinishreason,"text":tokenStr}]})
|
||||
await self.send_oai_sse_event(event_str)
|
||||
else:
|
||||
event_str = json.dumps({"token": tokenStr})
|
||||
event_str = json.dumps({"token": tokenStr, "finish_reason":currfinishreason})
|
||||
await self.send_kai_sse_event(event_str)
|
||||
tokenStr = ""
|
||||
else:
|
||||
|
@ -3159,7 +3169,8 @@ def main(launch_args,start_server=True):
|
|||
benchprompt = "11111111"
|
||||
for i in range(0,10): #generate massive prompt
|
||||
benchprompt += benchprompt
|
||||
result = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
|
||||
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
|
||||
result = genout['text']
|
||||
result = (result[:5] if len(result)>5 else "")
|
||||
resultok = (result=="11111")
|
||||
t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001
|
||||
|
@ -3212,7 +3223,9 @@ def run_in_queue(launch_args, input_queue, output_queue):
|
|||
data = input_queue.get()
|
||||
if data['command'] == 'generate':
|
||||
(args, kwargs) = data['data']
|
||||
output_queue.put({'command': 'generated text', 'data': generate(*args, **kwargs)})
|
||||
genout = generate(*args, **kwargs)
|
||||
result = genout['text']
|
||||
output_queue.put({'command': 'generated text', 'data': result})
|
||||
time.sleep(0.2)
|
||||
|
||||
def start_in_seperate_process(launch_args):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue