logprobs feature completed

This commit is contained in:
Concedo 2024-11-01 15:24:07 +08:00
parent f7406dfdb1
commit 6a27003a06
3 changed files with 396 additions and 43 deletions

View file

@ -257,6 +257,11 @@
"minimum": 0,
"type": "number"
},
"logprobs": {
"default": false,
"description": "If true, return up to 5 top logprobs for generated tokens. Incurs performance overhead.",
"type": "boolean"
},
},
"required": [
"prompt"
@ -808,6 +813,215 @@
]
}
},
"/api/extra/last_logprobs": {
"post": {
"description": "Obtains the token logprobs of the most recent request. A unique genkey previously submitted is required in multiuser mode.",
"requestBody": {
"content": {
"application/json": {
"example": {
"genkey": "KCPP2342"
},
"schema": {
"properties": {
"genkey": {
"type": "string",
"description": "A unique key used to identify the previous generation."
}
},
"type": "object"
}
}
},
"required": false
},
"responses": {
"200": {
"content": {
"application/json": {
"example": {
"logprobs": {
"content": [
{
"token": "Hello",
"logprob": -0.31725305,
"bytes": [72, 101, 108, 108, 111],
"top_logprobs": [
{
"token": "Hello",
"logprob": -0.31725305,
"bytes": [72, 101, 108, 108, 111]
},
{
"token": "Hi",
"logprob": -1.3190403,
"bytes": [72, 105]
}
]
},
{
"token": "!",
"logprob": -0.02380986,
"bytes": [
33
],
"top_logprobs": [
{
"token": "!",
"logprob": -0.02380986,
"bytes": [33]
},
{
"token": " there",
"logprob": -3.787621,
"bytes": [32, 116, 104, 101, 114, 101]
}
]
},
{
"token": " How",
"logprob": -0.000054669687,
"bytes": [32, 72, 111, 119],
"top_logprobs": [
{
"token": " How",
"logprob": -0.000054669687,
"bytes": [32, 72, 111, 119]
},
{
"token": "<|end|>",
"logprob": -10.953937,
"bytes": null
}
]
},
{
"token": " can",
"logprob": -0.015801601,
"bytes": [32, 99, 97, 110],
"top_logprobs": [
{
"token": " can",
"logprob": -0.015801601,
"bytes": [32, 99, 97, 110]
},
{
"token": " may",
"logprob": -4.161023,
"bytes": [32, 109, 97, 121]
}
]
},
{
"token": " I",
"logprob": -3.7697225e-6,
"bytes": [
32,
73
],
"top_logprobs": [
{
"token": " I",
"logprob": -3.7697225e-6,
"bytes": [32, 73]
},
{
"token": " assist",
"logprob": -13.596657,
"bytes": [32, 97, 115, 115, 105, 115, 116]
}
]
},
{
"token": " assist",
"logprob": -0.04571125,
"bytes": [32, 97, 115, 115, 105, 115, 116],
"top_logprobs": [
{
"token": " assist",
"logprob": -0.04571125,
"bytes": [32, 97, 115, 115, 105, 115, 116]
},
{
"token": " help",
"logprob": -3.1089056,
"bytes": [32, 104, 101, 108, 112]
}
]
},
{
"token": " you",
"logprob": -5.4385737e-6,
"bytes": [32, 121, 111, 117],
"top_logprobs": [
{
"token": " you",
"logprob": -5.4385737e-6,
"bytes": [32, 121, 111, 117]
},
{
"token": " today",
"logprob": -12.807695,
"bytes": [32, 116, 111, 100, 97, 121]
}
]
},
{
"token": " today",
"logprob": -0.0040071653,
"bytes": [32, 116, 111, 100, 97, 121],
"top_logprobs": [
{
"token": " today",
"logprob": -0.0040071653,
"bytes": [32, 116, 111, 100, 97, 121]
},
{
"token": "?",
"logprob": -5.5247097,
"bytes": [63]
}
]
},
{
"token": "?",
"logprob": -0.0008108172,
"bytes": [63],
"top_logprobs": [
{
"token": "?",
"logprob": -0.0008108172,
"bytes": [63]
},
{
"token": "?\n",
"logprob": -7.184561,
"bytes": [63, 10]
}
]
}
]
}
},
"schema": {
"properties": {
"logprobs": {
"type": "object",
"description": "A logprobs object in the same format as OpenAI API."
}
}
}
}
},
"description": "Successful request"
}
},
"summary": "Obtains the token logprobs of the most recent request.",
"tags": [
"api/extra"
]
}
},
"/api/extra/tokencount": {
"post": {
"description": "Counts the number of tokens in a string.",

View file

@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
-->
<script>
const LITEVER = 183;
const LITEVER = 184;
const urlParams = new URLSearchParams(window.location.search);
var localflag = true;
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
@ -3105,6 +3105,15 @@ Current version indicated by LITEVER below.
.replace(/&#039;/g, "\'");
}
function unescapeRegexNewlines(input)
{
return input.replace(/\\\\/g, "[temp_rr_seq]")
.replace(/\\n/g, "\n")
.replace(/\\t/g, "\t")
.replace(/\\r/g, "\r")
.replace(/\[temp_rr_seq\]/g, "\\\\");
}
function isNumeric(n)
{
return !isNaN(parseFloat(n)) && isFinite(n);
@ -3685,9 +3694,11 @@ Current version indicated by LITEVER below.
let dch = data.choices[0];
if (dch.text) {
synchro_polled_response = dch.text;
last_response_obj = JSON.parse(JSON.stringify(data));
}
else if (dch.message) {
synchro_polled_response = dch.message.content;
last_response_obj = JSON.parse(JSON.stringify(data));
if(localsettings.opmode==1 && gametext_arr.length>0 && synchro_polled_response!="")
{
@ -3919,14 +3930,55 @@ Current version indicated by LITEVER below.
}
},
close() { //end of stream
let finish_actions = function()
{
synchro_polled_response = synchro_pending_stream;
synchro_pending_stream = "";
poll_pending_response();
};
//handle gen failures
if(resp.status==503)
{
finish_actions();
msgbox("Error while submitting prompt: Server appears to be busy.");
}
else
{
//if we wanted logprobs, try fetching them manually
if(localsettings.request_logprobs && last_response_obj==null)
{
fetch(custom_kobold_endpoint + koboldcpp_logprobs_endpoint, {
method: 'POST',
headers: get_kobold_header(),
body: JSON.stringify({
"genkey": lastcheckgenkey
}),
})
.then((response) => response.json())
.then((data) => {
//makes sure a delayed response doesnt arrive late and mess up
if (data && data.logprobs != null && last_response_obj==null) {
//fake a last response obj
let fakedresponse = {
"artificial_response": true,
"results":[{"logprobs":data.logprobs}]
};
last_response_obj = fakedresponse;
}
finish_actions();
})
.catch((error) => {
console.error('Error:', error);
finish_actions();
});
}
else
{
finish_actions();
}
}
},
abort(error) {
console.error('Error:', error);
@ -4227,6 +4279,7 @@ Current version indicated by LITEVER below.
const koboldcpp_version_endpoint = "/api/extra/version";
const koboldcpp_abort_endpoint = "/api/extra/abort";
const koboldcpp_check_endpoint = "/api/extra/generate/check";
const koboldcpp_logprobs_endpoint = "/api/extra/last_logprobs";
const koboldcpp_truemaxctxlen_endpoint = "/api/extra/true_max_context_length";
const koboldcpp_preloadstory_endpoint = "/api/extra/preloadstory";
const koboldcpp_transcribe_endpoint = "/api/extra/transcribe";
@ -7927,7 +7980,48 @@ Current version indicated by LITEVER below.
if(kcpp_has_logprobs || oai_has_logprobs)
{
let lpc = (kcpp_has_logprobs?last_response_obj.results[0].logprobs.content:last_response_obj.choices[0].logprobs.content);
if(lpc)
if(!lpc)
{
if(oai_has_logprobs)
{
//try legacy logprobs api
let seltokarr = last_response_obj.choices[0].logprobs.tokens;
let sellogarr = last_response_obj.choices[0].logprobs.token_logprobs;
let topdict = last_response_obj.choices[0].logprobs.top_logprobs;
if(seltokarr && sellogarr && topdict)
{
lastlogprobsstr += `<table class="logprobstable">`;
for(let i=0;i<seltokarr.length;++i)
{
lastlogprobsstr += "<tr>";
lastlogprobsstr += `<td style="color:lime">${escapeHtml(seltokarr[i])}<br>(${(Math.exp(sellogarr[i])*100).toFixed(2)}%)</td>`;
let addspace = false;
let dictkeys = Object.keys(topdict[i]);
for(let j=0;j<5;++j)
{
if(j>=dictkeys.length)
{
lastlogprobsstr += `<td></td>`;
continue;
}
if(dictkeys[j]==seltokarr[i])
{
addspace = true;
continue;
}
lastlogprobsstr += `<td>${escapeHtml(dictkeys[j])}<br>(${(Math.exp(topdict[i][dictkeys[j]])*100).toFixed(2)}%)</td>`
}
if(addspace)
{
lastlogprobsstr += `<td></td>`;
}
lastlogprobsstr += "</tr>";
}
lastlogprobsstr += "</table>";
}
}
}
else
{
lastlogprobsstr += `<table class="logprobstable">`;
for(let i=0;i<lpc.length;++i)
@ -11407,11 +11501,7 @@ Current version indicated by LITEVER below.
let escapedpat = escapeHtml(regexreplace_data[i].p);
let pat = new RegExp(escapedpat, "gm");
let rep = regexreplace_data[i].r;
rep = rep.replace(/\\\\/g, "[temp_rr_seq]")
.replace(/\\n/g, "\n")
.replace(/\\t/g, "\t")
.replace(/\\r/g, "\r")
.replace(/\[temp_rr_seq\]/g, "\\\\");
rep = unescapeRegexNewlines(rep);
inputtxt = inputtxt.replace(pat, rep);
}
}
@ -12170,7 +12260,9 @@ Current version indicated by LITEVER below.
if(regexreplace_data[i].b && !(regexreplace_data[i].d) && regexreplace_data[i].p!="")
{
let pat = new RegExp(regexreplace_data[i].p, "gm");
newgen = newgen.replace(pat, regexreplace_data[i].r);
let rep = regexreplace_data[i].r;
rep = unescapeRegexNewlines(rep);
newgen = newgen.replace(pat, rep);
}
}
}
@ -13030,6 +13122,16 @@ Current version indicated by LITEVER below.
"temperature": submit_payload.params.temperature,
"top_p": submit_payload.params.top_p,
}
if(localsettings.request_logprobs && !targetep.toLowerCase().includes("api.mistral.ai"))
{
if(document.getElementById("useoaichatcompl").checked)
{
oai_payload.logprobs = true;
oai_payload.top_logprobs = 5;
}else{
oai_payload.logprobs = 5;
}
}
if(!targetep.toLowerCase().includes("api.mistral.ai"))
{
//mistral api does not support presence pen
@ -14150,7 +14252,9 @@ Current version indicated by LITEVER below.
if(regexreplace_data[i].p!="" && !(regexreplace_data[i].d))
{
let pat = new RegExp(regexreplace_data[i].p, "gm");
gentxt = gentxt.replace(pat, regexreplace_data[i].r);
let rep = regexreplace_data[i].r;
rep = unescapeRegexNewlines(rep);
gentxt = gentxt.replace(pat, rep);
}
}
}
@ -14373,8 +14477,10 @@ Current version indicated by LITEVER below.
{
shownotify();
}
let kcpp_has_logprobs = (last_response_obj!=null && last_response_obj.results && last_response_obj.results.length > 0 && last_response_obj.results[0].logprobs!=null);
let oai_has_logprobs = (last_response_obj!=null && last_response_obj.choices && last_response_obj.choices.length > 0 && last_response_obj.choices[0].logprobs!=null);
let lastresp = ` <a href="#" class="color_blueurl" onclick="show_last_logprobs()">(View Logprobs)</a>`;
let lastreq = `<a href="#" onclick="show_last_req()">Last request</a> served by <a href="#" onclick="get_and_show_workers()">${genworker}</a> using <span class="color_darkgreen">${genmdl}</span>${(genkudos>0?` for ${genkudos} kudos`:``)} in ${getTimeTaken()} seconds.${(last_response_obj!=null && (last_response_obj.results && last_response_obj.results.length > 0 && last_response_obj.results[0].logprobs!=null))?lastresp:""}`;
let lastreq = `<a href="#" onclick="show_last_req()">Last request</a> served by <a href="#" onclick="get_and_show_workers()">${genworker}</a> using <span class="color_darkgreen">${genmdl}</span>${(genkudos>0?` for ${genkudos} kudos`:``)} in ${getTimeTaken()} seconds.${(last_response_obj!=null && (kcpp_has_logprobs || oai_has_logprobs)?lastresp:"")}`;
document.getElementById("lastreq1").innerHTML = lastreq;
document.getElementById("lastreq2").innerHTML = lastreq;
document.getElementById("lastreq3").innerHTML = lastreq;

View file

@ -1258,6 +1258,41 @@ def extract_json_from_string(input_string):
pass
return []
def parse_last_logprobs(lastlogprobs):
if not lastlogprobs:
return None
logprobsdict = {}
logprobsdict['content'] = []
logprobsdict['tokens'] = []
logprobsdict['token_logprobs'] = []
logprobsdict['top_logprobs'] = []
logprobsdict['text_offset'] = []
text_offset_counter = 0
for i in range(lastlogprobs.count):
lp_content_item = {}
logprob_item = lastlogprobs.logprob_items[i]
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
logprobsdict['tokens'].append(toptoken)
lp_content_item['token'] = toptoken
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
lp_content_item['logprob'] = logprob_item.selected_logprob
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
lp_content_item['top_logprobs'] = []
logprobsdict['text_offset'].append(text_offset_counter)
text_offset_counter += len(toptoken)
tops = {}
for j in range(min(logprob_item.option_count,logprobs_max)):
tl_item = {}
tl_item['logprob'] = logprob_item.logprobs[j]
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
tops[tokstr] = logprob_item.logprobs[j]
tl_item['token'] = tokstr
tl_item['bytes'] = list(tokstr.encode('utf-8'))
lp_content_item['top_logprobs'].append(tl_item)
logprobsdict['top_logprobs'].append(tops)
logprobsdict['content'].append(lp_content_item)
return logprobsdict
def transform_genparams(genparams, api_format):
global chatcompl_adapter
#api format 1=basic,2=kai,3=oai,4=oai-chat,5=interrogate
@ -1484,36 +1519,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
logprobsdict = None
if not stream_flag and ("logprobs" in genparams and genparams["logprobs"]):
lastlogprobs = handle.last_logprobs()
logprobsdict = {}
logprobsdict['content'] = []
logprobsdict['tokens'] = []
logprobsdict['token_logprobs'] = []
logprobsdict['top_logprobs'] = []
logprobsdict['text_offset'] = []
text_offset_counter = 0
for i in range(lastlogprobs.count):
lp_content_item = {}
logprob_item = lastlogprobs.logprob_items[i]
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
logprobsdict['tokens'].append(toptoken)
lp_content_item['token'] = toptoken
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
lp_content_item['logprob'] = logprob_item.selected_logprob
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
lp_content_item['top_logprobs'] = []
logprobsdict['text_offset'].append(text_offset_counter)
text_offset_counter += len(toptoken)
tops = {}
for j in range(min(logprob_item.option_count,logprobs_max)):
tl_item = {}
tl_item['logprob'] = logprob_item.logprobs[j]
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
tops[tokstr] = logprob_item.logprobs[j]
tl_item['token'] = tokstr
tl_item['bytes'] = list(tokstr.encode('utf-8'))
lp_content_item['top_logprobs'].append(tl_item)
logprobsdict['top_logprobs'].append(tops)
logprobsdict['content'].append(lp_content_item)
logprobsdict = parse_last_logprobs(lastlogprobs)
# flag instance as non-idle for a while
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
@ -1860,6 +1866,15 @@ Enter Prompt:<br>
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
elif self.path.endswith('/api/extra/last_logprobs'):
if not self.secure_endpoint():
return
logprobsdict = None
if requestsinqueue==0 and totalgens>0 and currentusergenkey=="":
lastlogprobs = handle.last_logprobs()
logprobsdict = parse_last_logprobs(lastlogprobs)
response_body = (json.dumps({"logprobs":logprobsdict}).encode())
elif self.path.endswith('/v1/models'):
response_body = (json.dumps({"object":"list","data":[{"id":friendlymodelname,"object":"model","created":int(time.time()),"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode())
@ -2004,6 +2019,24 @@ Enter Prompt:<br>
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode())
elif self.path.endswith('/api/extra/last_logprobs'):
if not self.secure_endpoint():
return
logprobsdict = None
multiuserkey = ""
try:
tempbody = json.loads(body)
if isinstance(tempbody, dict):
multiuserkey = tempbody.get('genkey', "")
except Exception as e:
multiuserkey = ""
if totalgens>0:
if (multiuserkey=="" and multiuserkey==currentusergenkey and requestsinqueue==0) or (multiuserkey!="" and multiuserkey==currentusergenkey): #avoid leaking prompts in multiuser
lastlogprobs = handle.last_logprobs()
logprobsdict = parse_last_logprobs(lastlogprobs)
response_body = (json.dumps({"logprobs":logprobsdict}).encode())
if response_body is not None:
self.send_response(response_code)
self.send_header('content-length', str(len(response_body)))