added logprobs api and logprobs viewer

This commit is contained in:
Concedo 2024-11-01 00:22:15 +08:00
parent 6731dd64f1
commit aa26a58085
5 changed files with 229 additions and 29 deletions

View file

@ -294,11 +294,29 @@ extern "C"
return output; return output;
} }
static std::vector<TopPicksData> last_logprob_toppicks;
static std::vector<logprob_item> last_logprob_items;
last_logprobs_outputs last_logprobs() last_logprobs_outputs last_logprobs()
{ {
last_logprobs_outputs output; last_logprobs_outputs output;
std::vector<TopPicksData> toppicks = gpttype_get_top_picks_data(); //copy top picks last_logprob_items.clear();
output.count = 0; last_logprob_toppicks.clear();
last_logprob_toppicks = gpttype_get_top_picks_data(); //copy top picks
for(int i=0;i<last_logprob_toppicks.size();++i)
{
logprob_item itm;
itm.option_count = last_logprob_toppicks[i].tokenid.size();
itm.selected_token = last_logprob_toppicks[i].selected_token.c_str();
itm.selected_logprob = last_logprob_toppicks[i].selected_logprob;
itm.logprobs = last_logprob_toppicks[i].logprobs.data();
for(int j=0;j<itm.option_count && j<logprobs_max;++j)
{
itm.tokens[j] = last_logprob_toppicks[i].tokens[j].c_str();
}
last_logprob_items.push_back(itm);
}
output.count = last_logprob_items.size();
output.logprob_items = last_logprob_items.data();
return output; return output;
} }

View file

@ -3,6 +3,7 @@
const int tensor_split_max = 16; const int tensor_split_max = 16;
const int images_max = 4; const int images_max = 4;
const int logprobs_max = 5;
// match kobold's sampler list and order // match kobold's sampler list and order
enum samplers enum samplers
@ -111,6 +112,8 @@ struct generation_outputs
{ {
int status = -1; int status = -1;
int stopreason = stop_reason::INVALID; int stopreason = stop_reason::INVALID;
int prompt_tokens = 0;
int completion_tokens = 0;
const char * text; //response will now be stored in c++ allocated memory const char * text; //response will now be stored in c++ allocated memory
}; };
struct token_count_outputs struct token_count_outputs
@ -118,12 +121,17 @@ struct token_count_outputs
int count = 0; int count = 0;
int * ids; //we'll just use shared memory for this one, bit of a hack int * ids; //we'll just use shared memory for this one, bit of a hack
}; };
struct logprob_item {
int option_count;
const char * selected_token;
float selected_logprob;
const char * tokens[logprobs_max];
float * logprobs = nullptr;
};
struct last_logprobs_outputs { struct last_logprobs_outputs {
int count = 0; int count = 0;
char ** selected_token; logprob_item * logprob_items = nullptr;
float * selected_logprob;
char * tokens[5];
float * logprobs[5];
}; };
struct sd_load_model_inputs struct sd_load_model_inputs
{ {

View file

@ -597,13 +597,13 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng
int idx = dist(rng); int idx = dist(rng);
newpick.selected_token = FileFormatTokenizeID(candidates->data[idx].id, file_format, true); newpick.selected_token = FileFormatTokenizeID(candidates->data[idx].id, file_format, true);
newpick.selected_logprob = candidates->data[idx].logit; newpick.selected_logprob = logf(candidates->data[idx].p);
newpick.selected_probability = candidates->data[idx].p; newpick.selected_probability = candidates->data[idx].p;
newpick.selected_tokenid = candidates->data[idx].id; newpick.selected_tokenid = candidates->data[idx].id;
for (size_t i = 0; (i < candidates->size && i<5); ++i) for (size_t i = 0; (i < candidates->size && i<logprobs_max); ++i)
{ {
newpick.tokens.push_back(FileFormatTokenizeID(candidates->data[i].id, file_format, true)); newpick.tokens.push_back(FileFormatTokenizeID(candidates->data[i].id, file_format, true));
newpick.logprobs.push_back(candidates->data[i].logit); newpick.logprobs.push_back(logf(candidates->data[i].p));
newpick.p.push_back(candidates->data[i].p); newpick.p.push_back(candidates->data[i].p);
newpick.tokenid.push_back(candidates->data[i].id); newpick.tokenid.push_back(candidates->data[i].id);
} }
@ -2467,6 +2467,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nWarning: KCPP text generation not initialized!\n"); printf("\nWarning: KCPP text generation not initialized!\n");
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.prompt_tokens = output.completion_tokens = 0;
output.stopreason = stop_reason::INVALID; output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
@ -3142,6 +3143,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past); fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past);
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.prompt_tokens = output.completion_tokens = 0;
output.stopreason = stop_reason::INVALID; output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
@ -3471,6 +3473,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past); fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.prompt_tokens = output.completion_tokens = 0;
output.stopreason = stop_reason::INVALID; output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
@ -3482,6 +3485,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled); fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled);
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.prompt_tokens = output.completion_tokens = 0;
output.stopreason = stop_reason::INVALID; output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
@ -3534,6 +3538,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second); printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
int finaltokcount = (int)current_context_tokens.size()-realnpredict;
output.prompt_tokens = (finaltokcount<0?0:finaltokcount);
output.completion_tokens = realnpredict;
output.stopreason = last_stop_reason; output.stopreason = last_stop_reason;
last_eval_time = pt2; last_eval_time = pt2;
last_process_time = pt1; last_process_time = pt1;

View file

@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
--> -->
<script> <script>
const LITEVER = 182; const LITEVER = 183;
const urlParams = new URLSearchParams(window.location.search); const urlParams = new URLSearchParams(window.location.search);
var localflag = true; var localflag = true;
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_"; const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
@ -1267,6 +1267,20 @@ Current version indicated by LITEVER below.
padding: min(0.4vw, 5px); padding: min(0.4vw, 5px);
} }
.logprobstable
{
font-size: 11px;
width: 100%;
border-spacing: 2px;
}
.logprobstable table, th, td {
border: 2px solid #5d5d5d;
}
.logprobstable>tbody>tr>td
{
width: 16.4%;
}
.tablelines .tablelines
{ {
border: 1px solid; border: 1px solid;
@ -3557,6 +3571,8 @@ Current version indicated by LITEVER below.
last_stop_reason = "stop"; last_stop_reason = "stop";
} }
last_response_obj = JSON.parse(JSON.stringify(data));
//handle some early stopping criterias //handle some early stopping criterias
if (localsettings.opmode == 3) //stop on selfname found if (localsettings.opmode == 3) //stop on selfname found
{ {
@ -4347,6 +4363,7 @@ Current version indicated by LITEVER below.
var koboldcpp_version_obj = {}; var koboldcpp_version_obj = {};
var koboldcpp_has_vision = false; var koboldcpp_has_vision = false;
var last_request_str = "No Requests Available"; //full context of last submitted request var last_request_str = "No Requests Available"; //full context of last submitted request
var last_response_obj = null;
var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp
var globalabortcontroller = null; var globalabortcontroller = null;
var passed_ai_warning_local = false; var passed_ai_warning_local = false;
@ -4419,6 +4436,7 @@ Current version indicated by LITEVER below.
instruct_has_markdown: true, instruct_has_markdown: true,
placeholder_tags: true, placeholder_tags: true,
render_special_tags: false, render_special_tags: false,
request_logprobs: false,
persist_session: true, persist_session: true,
speech_synth: 0, //0 is disabled, 1000 is xtts speech_synth: 0, //0 is disabled, 1000 is xtts
xtts_voice: "female_calm", xtts_voice: "female_calm",
@ -6395,6 +6413,10 @@ Current version indicated by LITEVER below.
{ {
current_wi = load_nai_wi(new_loaded_storyobj); current_wi = load_nai_wi(new_loaded_storyobj);
} }
else if(new_loaded_storyobj.length>=1 && new_loaded_storyobj[0].keys!="" && new_loaded_storyobj[0].value!="" && (new_loaded_storyobj[0].useForCharacterCreation === true || new_loaded_storyobj[0].useForCharacterCreation === false))
{
current_wi = load_aid_wi(new_loaded_storyobj);
}
else { else {
msgbox("Could not load selected json file. Does not appear to be a KoboldAI story or compatible format."); msgbox("Could not load selected json file. Does not appear to be a KoboldAI story or compatible format.");
} }
@ -7015,6 +7037,34 @@ Current version indicated by LITEVER below.
return loadedwi; return loadedwi;
} }
function load_aid_wi(obj)
{
console.log("Loading aid wi");
let loadedwi = [];
for (let i=0;i<obj.length;++i) {
var itm = obj[i];
var key = "";
if(itm.keys && itm.keys!="")
{
key = itm.keys;
}
let nwi = {
"key": key,
"keysecondary": "",
"keyanti": "",
"content": itm.value,
"comment": "",
"folder": null,
"selective": false,
"constant": false,
"probability":100
};
loadedwi.push(nwi);
}
return loadedwi;
}
function get_aetherroom_scenario() function get_aetherroom_scenario()
{ {
inputBox("Enter aetherroom.club prompt URL, or 4-digit prompt number","Import from aetherroom.club","","https://aetherroom.club/1234", ()=>{ inputBox("Enter aetherroom.club prompt URL, or 4-digit prompt number","Import from aetherroom.club","","https://aetherroom.club/1234", ()=>{
@ -7862,7 +7912,56 @@ Current version indicated by LITEVER below.
function show_last_req() function show_last_req()
{ {
msgbox(last_request_str,"Last Request Sent",false); let lr = "Request:\n" + last_request_str;
if(last_response_obj!=null)
{
lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj);
}
msgbox(lr,"Last Request Info",false);
}
function show_last_logprobs()
{
let lastlogprobsstr = "";
let kcpp_has_logprobs = (last_response_obj!=null && last_response_obj.results && last_response_obj.results.length > 0 && last_response_obj.results[0].logprobs!=null);
let oai_has_logprobs = (last_response_obj!=null && last_response_obj.choices && last_response_obj.choices.length > 0 && last_response_obj.choices[0].logprobs!=null);
if(kcpp_has_logprobs || oai_has_logprobs)
{
let lpc = (kcpp_has_logprobs?last_response_obj.results[0].logprobs.content:last_response_obj.choices[0].logprobs.content);
if(lpc)
{
lastlogprobsstr += `<table class="logprobstable">`;
for(let i=0;i<lpc.length;++i)
{
lastlogprobsstr += "<tr>";
let cn = lpc[i];
lastlogprobsstr += `<td style="color:lime">${escapeHtml(cn.token)}<br>(${(Math.exp(cn.logprob)*100).toFixed(2)}%)</td>`;
let addspace = false;
for(let j=0;j<5;++j)
{
if(j>=cn.top_logprobs.length)
{
lastlogprobsstr += `<td></td>`;
continue;
}
if(cn.top_logprobs[j].token==cn.token)
{
addspace = true;
continue;
}
lastlogprobsstr += `<td>${escapeHtml(cn.top_logprobs[j].token)}<br>(${(Math.exp(cn.top_logprobs[j].logprob)*100).toFixed(2)}%)</td>`
}
if(addspace)
{
lastlogprobsstr += `<td></td>`;
}
lastlogprobsstr += "</tr>";
}
lastlogprobsstr += "</table>";
}
} else {
lastlogprobsstr = "Not Available";
}
msgbox(lastlogprobsstr,"Logit Probability Viewer",true);
} }
var worker_data_showonly = []; //only for table display, dont mix var worker_data_showonly = []; //only for table display, dont mix
@ -10079,6 +10178,7 @@ Current version indicated by LITEVER below.
document.getElementById("trimwhitespace").checked = localsettings.trimwhitespace; document.getElementById("trimwhitespace").checked = localsettings.trimwhitespace;
document.getElementById("compressnewlines").checked = localsettings.compressnewlines; document.getElementById("compressnewlines").checked = localsettings.compressnewlines;
document.getElementById("render_special_tags").checked = localsettings.render_special_tags; document.getElementById("render_special_tags").checked = localsettings.render_special_tags;
document.getElementById("request_logprobs").checked = localsettings.request_logprobs;
document.getElementById("eos_ban_mode").value = localsettings.eos_ban_mode; document.getElementById("eos_ban_mode").value = localsettings.eos_ban_mode;
document.getElementById("persist_session").checked = localsettings.persist_session; document.getElementById("persist_session").checked = localsettings.persist_session;
document.getElementById("opmode").value = localsettings.opmode; document.getElementById("opmode").value = localsettings.opmode;
@ -10463,6 +10563,7 @@ Current version indicated by LITEVER below.
localsettings.trimwhitespace = (document.getElementById("trimwhitespace").checked ? true : false); localsettings.trimwhitespace = (document.getElementById("trimwhitespace").checked ? true : false);
localsettings.compressnewlines = (document.getElementById("compressnewlines").checked ? true : false); localsettings.compressnewlines = (document.getElementById("compressnewlines").checked ? true : false);
localsettings.render_special_tags = (document.getElementById("render_special_tags").checked ? true : false); localsettings.render_special_tags = (document.getElementById("render_special_tags").checked ? true : false);
localsettings.request_logprobs = (document.getElementById("request_logprobs").checked ? true : false);
localsettings.eos_ban_mode = document.getElementById("eos_ban_mode").value; localsettings.eos_ban_mode = document.getElementById("eos_ban_mode").value;
localsettings.persist_session = (document.getElementById("persist_session").checked ? true : false); localsettings.persist_session = (document.getElementById("persist_session").checked ? true : false);
if(document.getElementById("opmode").value==1) if(document.getElementById("opmode").value==1)
@ -11146,6 +11247,7 @@ Current version indicated by LITEVER below.
gametext_arr = []; gametext_arr = [];
redo_arr = []; redo_arr = [];
last_request_str = "No Requests Available"; last_request_str = "No Requests Available";
last_response_obj = null;
retry_prev_text = []; retry_prev_text = [];
retry_preserve_last = false; retry_preserve_last = false;
redo_prev_text = []; redo_prev_text = [];
@ -11304,7 +11406,13 @@ Current version indicated by LITEVER below.
{ {
let escapedpat = escapeHtml(regexreplace_data[i].p); let escapedpat = escapeHtml(regexreplace_data[i].p);
let pat = new RegExp(escapedpat, "gm"); let pat = new RegExp(escapedpat, "gm");
inputtxt = inputtxt.replace(pat, regexreplace_data[i].r); let rep = regexreplace_data[i].r;
rep = rep.replace(/\\\\/g, "[temp_rr_seq]")
.replace(/\\n/g, "\n")
.replace(/\\t/g, "\t")
.replace(/\\r/g, "\r")
.replace(/\[temp_rr_seq\]/g, "\\\\");
inputtxt = inputtxt.replace(pat, rep);
} }
} }
} }
@ -12821,6 +12929,7 @@ Current version indicated by LITEVER below.
submit_payload.params.smoothing_factor = localsettings.smoothing_factor; submit_payload.params.smoothing_factor = localsettings.smoothing_factor;
submit_payload.params.banned_tokens = get_token_bans(); submit_payload.params.banned_tokens = get_token_bans();
submit_payload.params.render_special = localsettings.render_special_tags; submit_payload.params.render_special = localsettings.render_special_tags;
submit_payload.params.logprobs = localsettings.request_logprobs;
} }
if(custom_kobold_endpoint != "" && is_using_kcpp_with_dry() && localsettings.dry_multiplier > 0) if(custom_kobold_endpoint != "" && is_using_kcpp_with_dry() && localsettings.dry_multiplier > 0)
{ {
@ -12889,6 +12998,7 @@ Current version indicated by LITEVER below.
streamchunk = ((pstreamamount != null && pstreamamount > 0) ? pstreamamount:8); //8 tokens per stream tick by default streamchunk = ((pstreamamount != null && pstreamamount > 0) ? pstreamamount:8); //8 tokens per stream tick by default
} }
last_request_str = JSON.stringify(submit_payload); last_request_str = JSON.stringify(submit_payload);
last_response_obj = null;
if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) { if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) {
let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint); let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint);
kobold_api_stream_sse(sub_endpt, submit_payload); kobold_api_stream_sse(sub_endpt, submit_payload);
@ -13011,6 +13121,7 @@ Current version indicated by LITEVER below.
} }
last_request_str = JSON.stringify(oai_payload); last_request_str = JSON.stringify(oai_payload);
last_response_obj = null;
let oaiheaders = { let oaiheaders = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': 'Bearer ' + custom_oai_key 'Authorization': 'Bearer ' + custom_oai_key
@ -13119,6 +13230,7 @@ Current version indicated by LITEVER below.
last_request_str = JSON.stringify(claude_payload); last_request_str = JSON.stringify(claude_payload);
last_response_obj = null;
let claudeheaders = { let claudeheaders = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -13268,6 +13380,7 @@ Current version indicated by LITEVER below.
let targetep = urlbase + custom_palm_key; let targetep = urlbase + custom_palm_key;
last_request_str = JSON.stringify(payload); last_request_str = JSON.stringify(payload);
last_response_obj = null;
fetch(targetep, { fetch(targetep, {
method: 'POST', method: 'POST',
@ -13338,6 +13451,7 @@ Current version indicated by LITEVER below.
} }
last_request_str = JSON.stringify(cohere_payload); last_request_str = JSON.stringify(cohere_payload);
last_response_obj = null;
let cohere_headers = { let cohere_headers = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Authorization': 'Bearer ' + custom_cohere_key 'Authorization': 'Bearer ' + custom_cohere_key
@ -13438,6 +13552,7 @@ Current version indicated by LITEVER below.
} }
last_request_str = JSON.stringify(submit_payload); last_request_str = JSON.stringify(submit_payload);
last_response_obj = null;
fetch(selectedhorde.submit_endpoint, { fetch(selectedhorde.submit_endpoint, {
method: 'POST', // or 'PUT' method: 'POST', // or 'PUT'
@ -14258,7 +14373,8 @@ Current version indicated by LITEVER below.
{ {
shownotify(); shownotify();
} }
let lastreq = `<a href="#" onclick="show_last_req()">Last request</a> served by <a href="#" onclick="get_and_show_workers()">${genworker}</a> using <span class="color_darkgreen">${genmdl}</span>${(genkudos>0?` for ${genkudos} kudos`:``)} in ${getTimeTaken()} seconds.`; let lastresp = ` <a href="#" class="color_blueurl" onclick="show_last_logprobs()">(View Logprobs)</a>`;
let lastreq = `<a href="#" onclick="show_last_req()">Last request</a> served by <a href="#" onclick="get_and_show_workers()">${genworker}</a> using <span class="color_darkgreen">${genmdl}</span>${(genkudos>0?` for ${genkudos} kudos`:``)} in ${getTimeTaken()} seconds.${(last_response_obj!=null && (last_response_obj.results && last_response_obj.results.length > 0 && last_response_obj.results[0].logprobs!=null))?lastresp:""}`;
document.getElementById("lastreq1").innerHTML = lastreq; document.getElementById("lastreq1").innerHTML = lastreq;
document.getElementById("lastreq2").innerHTML = lastreq; document.getElementById("lastreq2").innerHTML = lastreq;
document.getElementById("lastreq3").innerHTML = lastreq; document.getElementById("lastreq3").innerHTML = lastreq;
@ -18473,7 +18589,9 @@ Current version indicated by LITEVER below.
<option value="claude-3-opus-20240229">claude-3-opus</option> <option value="claude-3-opus-20240229">claude-3-opus</option>
<option value="claude-3-sonnet-20240229">claude-3-sonnet</option> <option value="claude-3-sonnet-20240229">claude-3-sonnet</option>
<option value="claude-3-haiku-20240307">claude-3-haiku</option> <option value="claude-3-haiku-20240307">claude-3-haiku</option>
<option value="claude-3-5-sonnet-20240620">claude-3.5-sonnet</option> <option value="claude-3-5-sonnet-20240620">claude-3-5-sonnet-20240620</option>
<option value="claude-3-5-sonnet-20241022">claude-3-5-sonnet-20241022</option>
<option value="claude-3-5-sonnet-latest">claude-3-5-sonnet-latest</option>
</select> </select>
<input type="checkbox" id="claudeaddversion" onchange="" checked> <input type="checkbox" id="claudeaddversion" onchange="" checked>
<div class="box-label" title="Add endpoint version">Add Endpoint Version</div> <div class="box-label" title="Add endpoint version">Add Endpoint Version</div>
@ -19333,6 +19451,11 @@ Current version indicated by LITEVER below.
class="helptext">If enabled, renders special tags like EOS and padding tokens. Not recommended.</span></span></div> class="helptext">If enabled, renders special tags like EOS and padding tokens. Not recommended.</span></span></div>
<input title="Render Special Tags" type="checkbox" id="render_special_tags" style="margin:0px 0px 0px auto;"> <input title="Render Special Tags" type="checkbox" id="render_special_tags" style="margin:0px 0px 0px auto;">
</div> </div>
<div class="settinglabel">
<div class="justifyleft settingsmall">Request Logprobs <span class="helpicon">?<span
class="helptext">If enabled, request top 5 alternative token logit probabilities for each generated token. Incurs an overhead.</span></span></div>
<input title="Request Logprobs" type="checkbox" id="request_logprobs" style="margin:0px 0px 0px auto;">
</div>
</div> </div>

View file

@ -23,6 +23,7 @@ tensor_split_max = 16
images_max = 4 images_max = 4
bias_min_value = -100.0 bias_min_value = -100.0
bias_max_value = 100.0 bias_max_value = 100.0
logprobs_max = 5
# abuse prevention # abuse prevention
stop_token_max = 512 stop_token_max = 512
@ -102,12 +103,15 @@ class token_count_outputs(ctypes.Structure):
("ids", ctypes.POINTER(ctypes.c_int))] ("ids", ctypes.POINTER(ctypes.c_int))]
# returns top 5 logprobs per token # returns top 5 logprobs per token
class logprob_item(ctypes.Structure):
_fields_ = [("option_count", ctypes.c_int),
("selected_token", ctypes.c_char_p),
("selected_logprob", ctypes.c_float),
("tokens", ctypes.c_char_p * logprobs_max),
("logprobs", ctypes.POINTER(ctypes.c_float))]
class last_logprobs_outputs(ctypes.Structure): class last_logprobs_outputs(ctypes.Structure):
_fields_ = [("count", ctypes.c_int), _fields_ = [("count", ctypes.c_int),
("selected_token", ctypes.POINTER(ctypes.c_char_p)), ("logprob_items", ctypes.POINTER(logprob_item))]
("selected_logprob", ctypes.POINTER(ctypes.c_float)),
("tokens", ctypes.POINTER(5 * ctypes.c_char_p)),
("logprobs", ctypes.POINTER(5 * ctypes.c_float))]
class load_model_inputs(ctypes.Structure): class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int), _fields_ = [("threads", ctypes.c_int),
@ -190,6 +194,8 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure): class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int), _fields_ = [("status", ctypes.c_int),
("stopreason", ctypes.c_int), ("stopreason", ctypes.c_int),
("prompt_tokens", ctypes.c_int),
("completion_tokens", ctypes.c_int),
("text", ctypes.c_char_p)] ("text", ctypes.c_char_p)]
class sd_load_model_inputs(ctypes.Structure): class sd_load_model_inputs(ctypes.Structure):
@ -896,7 +902,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
memory = genparams.get('memory', "") memory = genparams.get('memory', "")
images = genparams.get('images', []) images = genparams.get('images', [])
max_context_length = genparams.get('max_context_length', maxctx) max_context_length = genparams.get('max_context_length', maxctx)
max_length = genparams.get('max_length', 180) max_length = genparams.get('max_length', 200)
temperature = genparams.get('temperature', 0.7) temperature = genparams.get('temperature', 0.7)
top_k = genparams.get('top_k', 100) top_k = genparams.get('top_k', 100)
top_a = genparams.get('top_a', 0.0) top_a = genparams.get('top_a', 0.0)
@ -1078,7 +1084,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
if pendingabortkey!="" and pendingabortkey==genkey: if pendingabortkey!="" and pendingabortkey==genkey:
print(f"\nDeferred Abort for GenKey: {pendingabortkey}") print(f"\nDeferred Abort for GenKey: {pendingabortkey}")
pendingabortkey = "" pendingabortkey = ""
return {"text":"","status":-1,"stopreason":-1} return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0}
else: else:
ret = handle.generate(inputs) ret = handle.generate(inputs)
outstr = "" outstr = ""
@ -1089,7 +1095,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
sindex = outstr.find(trim_str) sindex = outstr.find(trim_str)
if sindex != -1 and trim_str!="": if sindex != -1 and trim_str!="":
outstr = outstr[:sindex] outstr = outstr[:sindex]
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason} return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
def sd_load_model(model_filename,vae_filename,lora_filename): def sd_load_model(model_filename,vae_filename,lora_filename):
@ -1267,13 +1273,14 @@ def transform_genparams(genparams, api_format):
if api_format==1: if api_format==1:
genparams["prompt"] = genparams.get('text', "") genparams["prompt"] = genparams.get('text', "")
genparams["top_k"] = int(genparams.get('top_k', 120)) genparams["top_k"] = int(genparams.get('top_k', 120))
genparams["max_length"] = genparams.get('max', 180) genparams["max_length"] = genparams.get('max', 200)
elif api_format==2: elif api_format==2:
pass pass
elif api_format==3 or api_format==4: elif api_format==3 or api_format==4:
genparams["max_length"] = genparams.get('max_tokens', (400 if api_format==4 else 180)) default_max_tok = (400 if api_format==4 else 200)
genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0)) presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
genparams["presence_penalty"] = presence_penalty genparams["presence_penalty"] = presence_penalty
# openai allows either a string or a list as a stop sequence # openai allows either a string or a list as a stop sequence
@ -1460,7 +1467,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag) return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag)
genout = {"text": "", "status": -1, "stopreason": -1} genout = {"text": "", "status": -1, "stopreason": -1, "prompt_tokens":0, "completion_tokens": 0}
if stream_flag: if stream_flag:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
@ -1469,8 +1476,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
genout = run_blocking() genout = run_blocking()
recvtxt = genout['text'] recvtxt = genout['text']
prompttokens = genout['prompt_tokens']
comptokens = genout['completion_tokens']
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop") currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
# grab logprobs if not streaming
logprobsdict = None
if not stream_flag and ("logprobs" in genparams and genparams["logprobs"]):
lastlogprobs = handle.last_logprobs()
logprobsdict = {}
logprobsdict['content'] = []
logprobsdict['tokens'] = []
logprobsdict['token_logprobs'] = []
logprobsdict['top_logprobs'] = []
logprobsdict['text_offset'] = []
text_offset_counter = 0
for i in range(lastlogprobs.count):
lp_content_item = {}
logprob_item = lastlogprobs.logprob_items[i]
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
logprobsdict['tokens'].append(toptoken)
lp_content_item['token'] = toptoken
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
lp_content_item['logprob'] = logprob_item.selected_logprob
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
lp_content_item['top_logprobs'] = []
logprobsdict['text_offset'].append(text_offset_counter)
text_offset_counter += len(toptoken)
tops = {}
for j in range(min(logprob_item.option_count,logprobs_max)):
tl_item = {}
tl_item['logprob'] = logprob_item.logprobs[j]
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
tops[tokstr] = logprob_item.logprobs[j]
tl_item['token'] = tokstr
tl_item['bytes'] = list(tokstr.encode('utf-8'))
lp_content_item['top_logprobs'].append(tl_item)
logprobsdict['top_logprobs'].append(tops)
logprobsdict['content'].append(lp_content_item)
# flag instance as non-idle for a while # flag instance as non-idle for a while
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_') washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
if not washordereq: if not washordereq:
@ -1484,8 +1528,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
res = {"data": {"seqs": [recvtxt]}} res = {"data": {"seqs": [recvtxt]}}
elif api_format == 3: elif api_format == 3:
res = {"id": "cmpl-A1", "object": "text_completion", "created": int(time.time()), "model": friendlymodelname, res = {"id": "cmpl-A1", "object": "text_completion", "created": int(time.time()), "model": friendlymodelname,
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200}, "usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]} "choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
elif api_format == 4: elif api_format == 4:
using_openai_tools = genparams.get('using_openai_tools', False) using_openai_tools = genparams.get('using_openai_tools', False)
tool_calls = [] tool_calls = []
@ -1494,12 +1538,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if tool_calls and len(tool_calls)>0: if tool_calls and len(tool_calls)>0:
recvtxt = None recvtxt = None
res = {"id": "chatcmpl-A1", "object": "chat.completion", "created": int(time.time()), "model": friendlymodelname, res = {"id": "chatcmpl-A1", "object": "chat.completion", "created": int(time.time()), "model": friendlymodelname,
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200}, "usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]} "choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
elif api_format == 5: elif api_format == 5:
res = {"caption": end_trim_to_sentence(recvtxt)} res = {"caption": end_trim_to_sentence(recvtxt)}
else: else:
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]} res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]}
try: try:
return res return res