mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
added logprobs api and logprobs viewer
This commit is contained in:
parent
6731dd64f1
commit
aa26a58085
5 changed files with 229 additions and 29 deletions
74
koboldcpp.py
74
koboldcpp.py
|
@ -23,6 +23,7 @@ tensor_split_max = 16
|
|||
images_max = 4
|
||||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
logprobs_max = 5
|
||||
|
||||
# abuse prevention
|
||||
stop_token_max = 512
|
||||
|
@ -102,12 +103,15 @@ class token_count_outputs(ctypes.Structure):
|
|||
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||
|
||||
# returns top 5 logprobs per token
|
||||
class logprob_item(ctypes.Structure):
|
||||
_fields_ = [("option_count", ctypes.c_int),
|
||||
("selected_token", ctypes.c_char_p),
|
||||
("selected_logprob", ctypes.c_float),
|
||||
("tokens", ctypes.c_char_p * logprobs_max),
|
||||
("logprobs", ctypes.POINTER(ctypes.c_float))]
|
||||
class last_logprobs_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
("selected_token", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("selected_logprob", ctypes.POINTER(ctypes.c_float)),
|
||||
("tokens", ctypes.POINTER(5 * ctypes.c_char_p)),
|
||||
("logprobs", ctypes.POINTER(5 * ctypes.c_float))]
|
||||
("logprob_items", ctypes.POINTER(logprob_item))]
|
||||
|
||||
class load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("threads", ctypes.c_int),
|
||||
|
@ -190,6 +194,8 @@ class generation_inputs(ctypes.Structure):
|
|||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("stopreason", ctypes.c_int),
|
||||
("prompt_tokens", ctypes.c_int),
|
||||
("completion_tokens", ctypes.c_int),
|
||||
("text", ctypes.c_char_p)]
|
||||
|
||||
class sd_load_model_inputs(ctypes.Structure):
|
||||
|
@ -896,7 +902,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
memory = genparams.get('memory', "")
|
||||
images = genparams.get('images', [])
|
||||
max_context_length = genparams.get('max_context_length', maxctx)
|
||||
max_length = genparams.get('max_length', 180)
|
||||
max_length = genparams.get('max_length', 200)
|
||||
temperature = genparams.get('temperature', 0.7)
|
||||
top_k = genparams.get('top_k', 100)
|
||||
top_a = genparams.get('top_a', 0.0)
|
||||
|
@ -1078,7 +1084,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
if pendingabortkey!="" and pendingabortkey==genkey:
|
||||
print(f"\nDeferred Abort for GenKey: {pendingabortkey}")
|
||||
pendingabortkey = ""
|
||||
return {"text":"","status":-1,"stopreason":-1}
|
||||
return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0}
|
||||
else:
|
||||
ret = handle.generate(inputs)
|
||||
outstr = ""
|
||||
|
@ -1089,7 +1095,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
sindex = outstr.find(trim_str)
|
||||
if sindex != -1 and trim_str!="":
|
||||
outstr = outstr[:sindex]
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
|
||||
|
||||
|
||||
def sd_load_model(model_filename,vae_filename,lora_filename):
|
||||
|
@ -1267,13 +1273,14 @@ def transform_genparams(genparams, api_format):
|
|||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 180)
|
||||
genparams["max_length"] = genparams.get('max', 200)
|
||||
|
||||
elif api_format==2:
|
||||
pass
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
genparams["max_length"] = genparams.get('max_tokens', (400 if api_format==4 else 180))
|
||||
default_max_tok = (400 if api_format==4 else 200)
|
||||
genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
genparams["presence_penalty"] = presence_penalty
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
|
@ -1460,7 +1467,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag)
|
||||
|
||||
genout = {"text": "", "status": -1, "stopreason": -1}
|
||||
genout = {"text": "", "status": -1, "stopreason": -1, "prompt_tokens":0, "completion_tokens": 0}
|
||||
if stream_flag:
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
@ -1469,8 +1476,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
genout = run_blocking()
|
||||
|
||||
recvtxt = genout['text']
|
||||
prompttokens = genout['prompt_tokens']
|
||||
comptokens = genout['completion_tokens']
|
||||
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
|
||||
|
||||
# grab logprobs if not streaming
|
||||
logprobsdict = None
|
||||
if not stream_flag and ("logprobs" in genparams and genparams["logprobs"]):
|
||||
lastlogprobs = handle.last_logprobs()
|
||||
logprobsdict = {}
|
||||
logprobsdict['content'] = []
|
||||
logprobsdict['tokens'] = []
|
||||
logprobsdict['token_logprobs'] = []
|
||||
logprobsdict['top_logprobs'] = []
|
||||
logprobsdict['text_offset'] = []
|
||||
text_offset_counter = 0
|
||||
for i in range(lastlogprobs.count):
|
||||
lp_content_item = {}
|
||||
logprob_item = lastlogprobs.logprob_items[i]
|
||||
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
|
||||
logprobsdict['tokens'].append(toptoken)
|
||||
lp_content_item['token'] = toptoken
|
||||
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
|
||||
lp_content_item['logprob'] = logprob_item.selected_logprob
|
||||
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'] = []
|
||||
logprobsdict['text_offset'].append(text_offset_counter)
|
||||
text_offset_counter += len(toptoken)
|
||||
tops = {}
|
||||
for j in range(min(logprob_item.option_count,logprobs_max)):
|
||||
tl_item = {}
|
||||
tl_item['logprob'] = logprob_item.logprobs[j]
|
||||
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
|
||||
tops[tokstr] = logprob_item.logprobs[j]
|
||||
tl_item['token'] = tokstr
|
||||
tl_item['bytes'] = list(tokstr.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'].append(tl_item)
|
||||
logprobsdict['top_logprobs'].append(tops)
|
||||
logprobsdict['content'].append(lp_content_item)
|
||||
|
||||
# flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
if not washordereq:
|
||||
|
@ -1484,8 +1528,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
res = {"data": {"seqs": [recvtxt]}}
|
||||
elif api_format == 3:
|
||||
res = {"id": "cmpl-A1", "object": "text_completion", "created": int(time.time()), "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
|
||||
"usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
|
||||
elif api_format == 4:
|
||||
using_openai_tools = genparams.get('using_openai_tools', False)
|
||||
tool_calls = []
|
||||
|
@ -1494,12 +1538,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if tool_calls and len(tool_calls)>0:
|
||||
recvtxt = None
|
||||
res = {"id": "chatcmpl-A1", "object": "chat.completion", "created": int(time.time()), "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]}
|
||||
"usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
|
||||
elif api_format == 5:
|
||||
res = {"caption": end_trim_to_sentence(recvtxt)}
|
||||
else:
|
||||
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]}
|
||||
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]}
|
||||
|
||||
try:
|
||||
return res
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue