mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
logprobs feature completed
This commit is contained in:
parent
f7406dfdb1
commit
6a27003a06
3 changed files with 396 additions and 43 deletions
93
koboldcpp.py
93
koboldcpp.py
|
@ -1258,6 +1258,41 @@ def extract_json_from_string(input_string):
|
|||
pass
|
||||
return []
|
||||
|
||||
def parse_last_logprobs(lastlogprobs):
|
||||
if not lastlogprobs:
|
||||
return None
|
||||
logprobsdict = {}
|
||||
logprobsdict['content'] = []
|
||||
logprobsdict['tokens'] = []
|
||||
logprobsdict['token_logprobs'] = []
|
||||
logprobsdict['top_logprobs'] = []
|
||||
logprobsdict['text_offset'] = []
|
||||
text_offset_counter = 0
|
||||
for i in range(lastlogprobs.count):
|
||||
lp_content_item = {}
|
||||
logprob_item = lastlogprobs.logprob_items[i]
|
||||
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
|
||||
logprobsdict['tokens'].append(toptoken)
|
||||
lp_content_item['token'] = toptoken
|
||||
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
|
||||
lp_content_item['logprob'] = logprob_item.selected_logprob
|
||||
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'] = []
|
||||
logprobsdict['text_offset'].append(text_offset_counter)
|
||||
text_offset_counter += len(toptoken)
|
||||
tops = {}
|
||||
for j in range(min(logprob_item.option_count,logprobs_max)):
|
||||
tl_item = {}
|
||||
tl_item['logprob'] = logprob_item.logprobs[j]
|
||||
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
|
||||
tops[tokstr] = logprob_item.logprobs[j]
|
||||
tl_item['token'] = tokstr
|
||||
tl_item['bytes'] = list(tokstr.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'].append(tl_item)
|
||||
logprobsdict['top_logprobs'].append(tops)
|
||||
logprobsdict['content'].append(lp_content_item)
|
||||
return logprobsdict
|
||||
|
||||
def transform_genparams(genparams, api_format):
|
||||
global chatcompl_adapter
|
||||
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
|
||||
|
@ -1484,36 +1519,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
logprobsdict = None
|
||||
if not stream_flag and ("logprobs" in genparams and genparams["logprobs"]):
|
||||
lastlogprobs = handle.last_logprobs()
|
||||
logprobsdict = {}
|
||||
logprobsdict['content'] = []
|
||||
logprobsdict['tokens'] = []
|
||||
logprobsdict['token_logprobs'] = []
|
||||
logprobsdict['top_logprobs'] = []
|
||||
logprobsdict['text_offset'] = []
|
||||
text_offset_counter = 0
|
||||
for i in range(lastlogprobs.count):
|
||||
lp_content_item = {}
|
||||
logprob_item = lastlogprobs.logprob_items[i]
|
||||
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
|
||||
logprobsdict['tokens'].append(toptoken)
|
||||
lp_content_item['token'] = toptoken
|
||||
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
|
||||
lp_content_item['logprob'] = logprob_item.selected_logprob
|
||||
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'] = []
|
||||
logprobsdict['text_offset'].append(text_offset_counter)
|
||||
text_offset_counter += len(toptoken)
|
||||
tops = {}
|
||||
for j in range(min(logprob_item.option_count,logprobs_max)):
|
||||
tl_item = {}
|
||||
tl_item['logprob'] = logprob_item.logprobs[j]
|
||||
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
|
||||
tops[tokstr] = logprob_item.logprobs[j]
|
||||
tl_item['token'] = tokstr
|
||||
tl_item['bytes'] = list(tokstr.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'].append(tl_item)
|
||||
logprobsdict['top_logprobs'].append(tops)
|
||||
logprobsdict['content'].append(lp_content_item)
|
||||
logprobsdict = parse_last_logprobs(lastlogprobs)
|
||||
|
||||
# flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
|
@ -1860,6 +1866,15 @@ Enter Prompt:<br>
|
|||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||
|
||||
elif self.path.endswith('/api/extra/last_logprobs'):
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
logprobsdict = None
|
||||
if requestsinqueue==0 and totalgens>0 and currentusergenkey=="":
|
||||
lastlogprobs = handle.last_logprobs()
|
||||
logprobsdict = parse_last_logprobs(lastlogprobs)
|
||||
response_body = (json.dumps({"logprobs":logprobsdict}).encode())
|
||||
|
||||
elif self.path.endswith('/v1/models'):
|
||||
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":int(time.time()),"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
|
||||
|
||||
|
@ -2004,6 +2019,24 @@ Enter Prompt:<br>
|
|||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||
|
||||
elif self.path.endswith('/api/extra/last_logprobs'):
|
||||
if not self.secure_endpoint():
|
||||
return
|
||||
logprobsdict = None
|
||||
multiuserkey = ""
|
||||
try:
|
||||
tempbody = json.loads(body)
|
||||
if isinstance(tempbody, dict):
|
||||
multiuserkey = tempbody.get('genkey', "")
|
||||
except Exception as e:
|
||||
multiuserkey = ""
|
||||
|
||||
if totalgens>0:
|
||||
if (multiuserkey=="" and multiuserkey==currentusergenkey and requestsinqueue==0) or (multiuserkey!="" and multiuserkey==currentusergenkey): #avoid leaking prompts in multiuser
|
||||
lastlogprobs = handle.last_logprobs()
|
||||
logprobsdict = parse_last_logprobs(lastlogprobs)
|
||||
response_body = (json.dumps({"logprobs":logprobsdict}).encode())
|
||||
|
||||
if response_body is not None:
|
||||
self.send_response(response_code)
|
||||
self.send_header('content-length', str(len(response_body)))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue