added logprobs api and logprobs viewer

This commit is contained in:
Concedo 2024-11-01 00:22:15 +08:00
parent 6731dd64f1
commit aa26a58085
5 changed files with 229 additions and 29 deletions

View file

@ -23,6 +23,7 @@ tensor_split_max = 16
images_max = 4
bias_min_value = -100.0
bias_max_value = 100.0
logprobs_max = 5
# abuse prevention
stop_token_max = 512
@ -102,12 +103,15 @@ class token_count_outputs(ctypes.Structure):
("ids", ctypes.POINTER(ctypes.c_int))]
# returns top 5 logprobs per token
class logprob_item(ctypes.Structure):
_fields_ = [("option_count", ctypes.c_int),
("selected_token", ctypes.c_char_p),
("selected_logprob", ctypes.c_float),
("tokens", ctypes.c_char_p * logprobs_max),
("logprobs", ctypes.POINTER(ctypes.c_float))]
class last_logprobs_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int),
("selected_token", ctypes.POINTER(ctypes.c_char_p)),
("selected_logprob", ctypes.POINTER(ctypes.c_float)),
("tokens", ctypes.POINTER(5 * ctypes.c_char_p)),
("logprobs", ctypes.POINTER(5 * ctypes.c_float))]
("logprob_items", ctypes.POINTER(logprob_item))]
class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int),
@ -190,6 +194,8 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("stopreason", ctypes.c_int),
("prompt_tokens", ctypes.c_int),
("completion_tokens", ctypes.c_int),
("text", ctypes.c_char_p)]
class sd_load_model_inputs(ctypes.Structure):
@ -896,7 +902,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
memory = genparams.get('memory', "")
images = genparams.get('images', [])
max_context_length = genparams.get('max_context_length', maxctx)
max_length = genparams.get('max_length', 180)
max_length = genparams.get('max_length', 200)
temperature = genparams.get('temperature', 0.7)
top_k = genparams.get('top_k', 100)
top_a = genparams.get('top_a', 0.0)
@ -1078,7 +1084,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
if pendingabortkey!="" and pendingabortkey==genkey:
print(f"\nDeferred Abort for GenKey: {pendingabortkey}")
pendingabortkey = ""
return {"text":"","status":-1,"stopreason":-1}
return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0}
else:
ret = handle.generate(inputs)
outstr = ""
@ -1089,7 +1095,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
sindex = outstr.find(trim_str)
if sindex != -1 and trim_str!="":
outstr = outstr[:sindex]
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
def sd_load_model(model_filename,vae_filename,lora_filename):
@ -1267,13 +1273,14 @@ def transform_genparams(genparams, api_format):
if api_format==1:
genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 180)
genparams["max_length"] = genparams.get('max', 200)
elif api_format==2:
pass
elif api_format==3 or api_format==4:
genparams["max_length"] = genparams.get('max_tokens', (400 if api_format==4 else 180))
default_max_tok = (400 if api_format==4 else 200)
genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
genparams["presence_penalty"] = presence_penalty
# openai allows either a string or a list as a stop sequence
@ -1460,7 +1467,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag)
genout = {"text": "", "status": -1, "stopreason": -1}
genout = {"text": "", "status": -1, "stopreason": -1, "prompt_tokens":0, "completion_tokens": 0}
if stream_flag:
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor()
@ -1469,8 +1476,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
genout = run_blocking()
recvtxt = genout['text']
prompttokens = genout['prompt_tokens']
comptokens = genout['completion_tokens']
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
# grab logprobs if not streaming
logprobsdict = None
if not stream_flag and ("logprobs" in genparams and genparams["logprobs"]):
lastlogprobs = handle.last_logprobs()
logprobsdict = {}
logprobsdict['content'] = []
logprobsdict['tokens'] = []
logprobsdict['token_logprobs'] = []
logprobsdict['top_logprobs'] = []
logprobsdict['text_offset'] = []
text_offset_counter = 0
for i in range(lastlogprobs.count):
lp_content_item = {}
logprob_item = lastlogprobs.logprob_items[i]
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
logprobsdict['tokens'].append(toptoken)
lp_content_item['token'] = toptoken
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
lp_content_item['logprob'] = logprob_item.selected_logprob
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
lp_content_item['top_logprobs'] = []
logprobsdict['text_offset'].append(text_offset_counter)
text_offset_counter += len(toptoken)
tops = {}
for j in range(min(logprob_item.option_count,logprobs_max)):
tl_item = {}
tl_item['logprob'] = logprob_item.logprobs[j]
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
tops[tokstr] = logprob_item.logprobs[j]
tl_item['token'] = tokstr
tl_item['bytes'] = list(tokstr.encode('utf-8'))
lp_content_item['top_logprobs'].append(tl_item)
logprobsdict['top_logprobs'].append(tops)
logprobsdict['content'].append(lp_content_item)
# flag instance as non-idle for a while
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
if not washordereq:
@ -1484,8 +1528,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
res = {"data": {"seqs": [recvtxt]}}
elif api_format == 3:
res = {"id": "cmpl-A1", "object": "text_completion", "created": int(time.time()), "model": friendlymodelname,
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
"usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
elif api_format == 4:
using_openai_tools = genparams.get('using_openai_tools', False)
tool_calls = []
@ -1494,12 +1538,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if tool_calls and len(tool_calls)>0:
recvtxt = None
res = {"id": "chatcmpl-A1", "object": "chat.completion", "created": int(time.time()), "model": friendlymodelname,
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]}
"usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
elif api_format == 5:
res = {"caption": end_trim_to_sentence(recvtxt)}
else:
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]}
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]}
try:
return res