mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
added logprobs api and logprobs viewer
This commit is contained in:
parent
6731dd64f1
commit
aa26a58085
5 changed files with 229 additions and 29 deletions
22
expose.cpp
22
expose.cpp
|
@ -294,11 +294,29 @@ extern "C"
|
|||
return output;
|
||||
}
|
||||
|
||||
static std::vector<TopPicksData> last_logprob_toppicks;
|
||||
static std::vector<logprob_item> last_logprob_items;
|
||||
last_logprobs_outputs last_logprobs()
|
||||
{
|
||||
last_logprobs_outputs output;
|
||||
std::vector<TopPicksData> toppicks = gpttype_get_top_picks_data(); //copy top picks
|
||||
output.count = 0;
|
||||
last_logprob_items.clear();
|
||||
last_logprob_toppicks.clear();
|
||||
last_logprob_toppicks = gpttype_get_top_picks_data(); //copy top picks
|
||||
for(int i=0;i<last_logprob_toppicks.size();++i)
|
||||
{
|
||||
logprob_item itm;
|
||||
itm.option_count = last_logprob_toppicks[i].tokenid.size();
|
||||
itm.selected_token = last_logprob_toppicks[i].selected_token.c_str();
|
||||
itm.selected_logprob = last_logprob_toppicks[i].selected_logprob;
|
||||
itm.logprobs = last_logprob_toppicks[i].logprobs.data();
|
||||
for(int j=0;j<itm.option_count && j<logprobs_max;++j)
|
||||
{
|
||||
itm.tokens[j] = last_logprob_toppicks[i].tokens[j].c_str();
|
||||
}
|
||||
last_logprob_items.push_back(itm);
|
||||
}
|
||||
output.count = last_logprob_items.size();
|
||||
output.logprob_items = last_logprob_items.data();
|
||||
return output;
|
||||
}
|
||||
|
||||
|
|
16
expose.h
16
expose.h
|
@ -3,6 +3,7 @@
|
|||
|
||||
const int tensor_split_max = 16;
|
||||
const int images_max = 4;
|
||||
const int logprobs_max = 5;
|
||||
|
||||
// match kobold's sampler list and order
|
||||
enum samplers
|
||||
|
@ -111,6 +112,8 @@ struct generation_outputs
|
|||
{
|
||||
int status = -1;
|
||||
int stopreason = stop_reason::INVALID;
|
||||
int prompt_tokens = 0;
|
||||
int completion_tokens = 0;
|
||||
const char * text; //response will now be stored in c++ allocated memory
|
||||
};
|
||||
struct token_count_outputs
|
||||
|
@ -118,12 +121,17 @@ struct token_count_outputs
|
|||
int count = 0;
|
||||
int * ids; //we'll just use shared memory for this one, bit of a hack
|
||||
};
|
||||
|
||||
struct logprob_item {
|
||||
int option_count;
|
||||
const char * selected_token;
|
||||
float selected_logprob;
|
||||
const char * tokens[logprobs_max];
|
||||
float * logprobs = nullptr;
|
||||
};
|
||||
struct last_logprobs_outputs {
|
||||
int count = 0;
|
||||
char ** selected_token;
|
||||
float * selected_logprob;
|
||||
char * tokens[5];
|
||||
float * logprobs[5];
|
||||
logprob_item * logprob_items = nullptr;
|
||||
};
|
||||
struct sd_load_model_inputs
|
||||
{
|
||||
|
|
|
@ -597,13 +597,13 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng
|
|||
int idx = dist(rng);
|
||||
|
||||
newpick.selected_token = FileFormatTokenizeID(candidates->data[idx].id, file_format, true);
|
||||
newpick.selected_logprob = candidates->data[idx].logit;
|
||||
newpick.selected_logprob = logf(candidates->data[idx].p);
|
||||
newpick.selected_probability = candidates->data[idx].p;
|
||||
newpick.selected_tokenid = candidates->data[idx].id;
|
||||
for (size_t i = 0; (i < candidates->size && i<5); ++i)
|
||||
for (size_t i = 0; (i < candidates->size && i<logprobs_max); ++i)
|
||||
{
|
||||
newpick.tokens.push_back(FileFormatTokenizeID(candidates->data[i].id, file_format, true));
|
||||
newpick.logprobs.push_back(candidates->data[i].logit);
|
||||
newpick.logprobs.push_back(logf(candidates->data[i].p));
|
||||
newpick.p.push_back(candidates->data[i].p);
|
||||
newpick.tokenid.push_back(candidates->data[i].id);
|
||||
}
|
||||
|
@ -2467,6 +2467,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("\nWarning: KCPP text generation not initialized!\n");
|
||||
output.text = nullptr;
|
||||
output.status = 0;
|
||||
output.prompt_tokens = output.completion_tokens = 0;
|
||||
output.stopreason = stop_reason::INVALID;
|
||||
generation_finished = true;
|
||||
return output;
|
||||
|
@ -3142,6 +3143,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past);
|
||||
output.text = nullptr;
|
||||
output.status = 0;
|
||||
output.prompt_tokens = output.completion_tokens = 0;
|
||||
output.stopreason = stop_reason::INVALID;
|
||||
generation_finished = true;
|
||||
return output;
|
||||
|
@ -3471,6 +3473,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
|
||||
output.text = nullptr;
|
||||
output.status = 0;
|
||||
output.prompt_tokens = output.completion_tokens = 0;
|
||||
output.stopreason = stop_reason::INVALID;
|
||||
generation_finished = true;
|
||||
return output;
|
||||
|
@ -3482,6 +3485,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled);
|
||||
output.text = nullptr;
|
||||
output.status = 0;
|
||||
output.prompt_tokens = output.completion_tokens = 0;
|
||||
output.stopreason = stop_reason::INVALID;
|
||||
generation_finished = true;
|
||||
return output;
|
||||
|
@ -3534,6 +3538,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
|
|||
printf("\nCtxLimit:%d/%d, Amt:%d/%d, Init:%.2fs, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, realnpredict, kcpp_data->n_predict, time0, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
|
||||
fflush(stdout);
|
||||
output.status = 1;
|
||||
int finaltokcount = (int)current_context_tokens.size()-realnpredict;
|
||||
output.prompt_tokens = (finaltokcount<0?0:finaltokcount);
|
||||
output.completion_tokens = realnpredict;
|
||||
output.stopreason = last_stop_reason;
|
||||
last_eval_time = pt2;
|
||||
last_process_time = pt1;
|
||||
|
|
133
klite.embd
133
klite.embd
|
@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
|
|||
-->
|
||||
|
||||
<script>
|
||||
const LITEVER = 182;
|
||||
const LITEVER = 183;
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
var localflag = true;
|
||||
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
|
||||
|
@ -1267,6 +1267,20 @@ Current version indicated by LITEVER below.
|
|||
padding: min(0.4vw, 5px);
|
||||
}
|
||||
|
||||
.logprobstable
|
||||
{
|
||||
font-size: 11px;
|
||||
width: 100%;
|
||||
border-spacing: 2px;
|
||||
}
|
||||
.logprobstable table, th, td {
|
||||
border: 2px solid #5d5d5d;
|
||||
}
|
||||
.logprobstable>tbody>tr>td
|
||||
{
|
||||
width: 16.4%;
|
||||
}
|
||||
|
||||
.tablelines
|
||||
{
|
||||
border: 1px solid;
|
||||
|
@ -3557,6 +3571,8 @@ Current version indicated by LITEVER below.
|
|||
last_stop_reason = "stop";
|
||||
}
|
||||
|
||||
last_response_obj = JSON.parse(JSON.stringify(data));
|
||||
|
||||
//handle some early stopping criterias
|
||||
if (localsettings.opmode == 3) //stop on selfname found
|
||||
{
|
||||
|
@ -4347,6 +4363,7 @@ Current version indicated by LITEVER below.
|
|||
var koboldcpp_version_obj = {};
|
||||
var koboldcpp_has_vision = false;
|
||||
var last_request_str = "No Requests Available"; //full context of last submitted request
|
||||
var last_response_obj = null;
|
||||
var lastcheckgenkey = ""; //for checking polled-streaming unique id when generating in kcpp
|
||||
var globalabortcontroller = null;
|
||||
var passed_ai_warning_local = false;
|
||||
|
@ -4419,6 +4436,7 @@ Current version indicated by LITEVER below.
|
|||
instruct_has_markdown: true,
|
||||
placeholder_tags: true,
|
||||
render_special_tags: false,
|
||||
request_logprobs: false,
|
||||
persist_session: true,
|
||||
speech_synth: 0, //0 is disabled, 1000 is xtts
|
||||
xtts_voice: "female_calm",
|
||||
|
@ -6395,6 +6413,10 @@ Current version indicated by LITEVER below.
|
|||
{
|
||||
current_wi = load_nai_wi(new_loaded_storyobj);
|
||||
}
|
||||
else if(new_loaded_storyobj.length>=1 && new_loaded_storyobj[0].keys!="" && new_loaded_storyobj[0].value!="" && (new_loaded_storyobj[0].useForCharacterCreation === true || new_loaded_storyobj[0].useForCharacterCreation === false))
|
||||
{
|
||||
current_wi = load_aid_wi(new_loaded_storyobj);
|
||||
}
|
||||
else {
|
||||
msgbox("Could not load selected json file. Does not appear to be a KoboldAI story or compatible format.");
|
||||
}
|
||||
|
@ -7015,6 +7037,34 @@ Current version indicated by LITEVER below.
|
|||
return loadedwi;
|
||||
}
|
||||
|
||||
function load_aid_wi(obj)
|
||||
{
|
||||
console.log("Loading aid wi");
|
||||
let loadedwi = [];
|
||||
for (let i=0;i<obj.length;++i) {
|
||||
var itm = obj[i];
|
||||
var key = "";
|
||||
if(itm.keys && itm.keys!="")
|
||||
{
|
||||
key = itm.keys;
|
||||
}
|
||||
|
||||
let nwi = {
|
||||
"key": key,
|
||||
"keysecondary": "",
|
||||
"keyanti": "",
|
||||
"content": itm.value,
|
||||
"comment": "",
|
||||
"folder": null,
|
||||
"selective": false,
|
||||
"constant": false,
|
||||
"probability":100
|
||||
};
|
||||
loadedwi.push(nwi);
|
||||
}
|
||||
return loadedwi;
|
||||
}
|
||||
|
||||
function get_aetherroom_scenario()
|
||||
{
|
||||
inputBox("Enter aetherroom.club prompt URL, or 4-digit prompt number","Import from aetherroom.club","","https://aetherroom.club/1234", ()=>{
|
||||
|
@ -7862,7 +7912,56 @@ Current version indicated by LITEVER below.
|
|||
|
||||
function show_last_req()
|
||||
{
|
||||
msgbox(last_request_str,"Last Request Sent",false);
|
||||
let lr = "Request:\n" + last_request_str;
|
||||
if(last_response_obj!=null)
|
||||
{
|
||||
lr += "\n\nResponse:\n" + JSON.stringify(last_response_obj);
|
||||
}
|
||||
msgbox(lr,"Last Request Info",false);
|
||||
}
|
||||
function show_last_logprobs()
|
||||
{
|
||||
let lastlogprobsstr = "";
|
||||
let kcpp_has_logprobs = (last_response_obj!=null && last_response_obj.results && last_response_obj.results.length > 0 && last_response_obj.results[0].logprobs!=null);
|
||||
let oai_has_logprobs = (last_response_obj!=null && last_response_obj.choices && last_response_obj.choices.length > 0 && last_response_obj.choices[0].logprobs!=null);
|
||||
if(kcpp_has_logprobs || oai_has_logprobs)
|
||||
{
|
||||
let lpc = (kcpp_has_logprobs?last_response_obj.results[0].logprobs.content:last_response_obj.choices[0].logprobs.content);
|
||||
if(lpc)
|
||||
{
|
||||
lastlogprobsstr += `<table class="logprobstable">`;
|
||||
for(let i=0;i<lpc.length;++i)
|
||||
{
|
||||
lastlogprobsstr += "<tr>";
|
||||
let cn = lpc[i];
|
||||
lastlogprobsstr += `<td style="color:lime">${escapeHtml(cn.token)}<br>(${(Math.exp(cn.logprob)*100).toFixed(2)}%)</td>`;
|
||||
let addspace = false;
|
||||
for(let j=0;j<5;++j)
|
||||
{
|
||||
if(j>=cn.top_logprobs.length)
|
||||
{
|
||||
lastlogprobsstr += `<td></td>`;
|
||||
continue;
|
||||
}
|
||||
if(cn.top_logprobs[j].token==cn.token)
|
||||
{
|
||||
addspace = true;
|
||||
continue;
|
||||
}
|
||||
lastlogprobsstr += `<td>${escapeHtml(cn.top_logprobs[j].token)}<br>(${(Math.exp(cn.top_logprobs[j].logprob)*100).toFixed(2)}%)</td>`
|
||||
}
|
||||
if(addspace)
|
||||
{
|
||||
lastlogprobsstr += `<td></td>`;
|
||||
}
|
||||
lastlogprobsstr += "</tr>";
|
||||
}
|
||||
lastlogprobsstr += "</table>";
|
||||
}
|
||||
} else {
|
||||
lastlogprobsstr = "Not Available";
|
||||
}
|
||||
msgbox(lastlogprobsstr,"Logit Probability Viewer",true);
|
||||
}
|
||||
|
||||
var worker_data_showonly = []; //only for table display, dont mix
|
||||
|
@ -10079,6 +10178,7 @@ Current version indicated by LITEVER below.
|
|||
document.getElementById("trimwhitespace").checked = localsettings.trimwhitespace;
|
||||
document.getElementById("compressnewlines").checked = localsettings.compressnewlines;
|
||||
document.getElementById("render_special_tags").checked = localsettings.render_special_tags;
|
||||
document.getElementById("request_logprobs").checked = localsettings.request_logprobs;
|
||||
document.getElementById("eos_ban_mode").value = localsettings.eos_ban_mode;
|
||||
document.getElementById("persist_session").checked = localsettings.persist_session;
|
||||
document.getElementById("opmode").value = localsettings.opmode;
|
||||
|
@ -10463,6 +10563,7 @@ Current version indicated by LITEVER below.
|
|||
localsettings.trimwhitespace = (document.getElementById("trimwhitespace").checked ? true : false);
|
||||
localsettings.compressnewlines = (document.getElementById("compressnewlines").checked ? true : false);
|
||||
localsettings.render_special_tags = (document.getElementById("render_special_tags").checked ? true : false);
|
||||
localsettings.request_logprobs = (document.getElementById("request_logprobs").checked ? true : false);
|
||||
localsettings.eos_ban_mode = document.getElementById("eos_ban_mode").value;
|
||||
localsettings.persist_session = (document.getElementById("persist_session").checked ? true : false);
|
||||
if(document.getElementById("opmode").value==1)
|
||||
|
@ -11146,6 +11247,7 @@ Current version indicated by LITEVER below.
|
|||
gametext_arr = [];
|
||||
redo_arr = [];
|
||||
last_request_str = "No Requests Available";
|
||||
last_response_obj = null;
|
||||
retry_prev_text = [];
|
||||
retry_preserve_last = false;
|
||||
redo_prev_text = [];
|
||||
|
@ -11304,7 +11406,13 @@ Current version indicated by LITEVER below.
|
|||
{
|
||||
let escapedpat = escapeHtml(regexreplace_data[i].p);
|
||||
let pat = new RegExp(escapedpat, "gm");
|
||||
inputtxt = inputtxt.replace(pat, regexreplace_data[i].r);
|
||||
let rep = regexreplace_data[i].r;
|
||||
rep = rep.replace(/\\\\/g, "[temp_rr_seq]")
|
||||
.replace(/\\n/g, "\n")
|
||||
.replace(/\\t/g, "\t")
|
||||
.replace(/\\r/g, "\r")
|
||||
.replace(/\[temp_rr_seq\]/g, "\\\\");
|
||||
inputtxt = inputtxt.replace(pat, rep);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12821,6 +12929,7 @@ Current version indicated by LITEVER below.
|
|||
submit_payload.params.smoothing_factor = localsettings.smoothing_factor;
|
||||
submit_payload.params.banned_tokens = get_token_bans();
|
||||
submit_payload.params.render_special = localsettings.render_special_tags;
|
||||
submit_payload.params.logprobs = localsettings.request_logprobs;
|
||||
}
|
||||
if(custom_kobold_endpoint != "" && is_using_kcpp_with_dry() && localsettings.dry_multiplier > 0)
|
||||
{
|
||||
|
@ -12889,6 +12998,7 @@ Current version indicated by LITEVER below.
|
|||
streamchunk = ((pstreamamount != null && pstreamamount > 0) ? pstreamamount:8); //8 tokens per stream tick by default
|
||||
}
|
||||
last_request_str = JSON.stringify(submit_payload);
|
||||
last_response_obj = null;
|
||||
if (localsettings.tokenstreammode==2 && is_using_kcpp_with_sse()) {
|
||||
let sub_endpt = apply_proxy_url(custom_kobold_endpoint + kobold_custom_gen_stream_endpoint);
|
||||
kobold_api_stream_sse(sub_endpt, submit_payload);
|
||||
|
@ -13011,6 +13121,7 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
|
||||
last_request_str = JSON.stringify(oai_payload);
|
||||
last_response_obj = null;
|
||||
let oaiheaders = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + custom_oai_key
|
||||
|
@ -13119,6 +13230,7 @@ Current version indicated by LITEVER below.
|
|||
|
||||
|
||||
last_request_str = JSON.stringify(claude_payload);
|
||||
last_response_obj = null;
|
||||
|
||||
let claudeheaders = {
|
||||
'Content-Type': 'application/json',
|
||||
|
@ -13268,6 +13380,7 @@ Current version indicated by LITEVER below.
|
|||
|
||||
let targetep = urlbase + custom_palm_key;
|
||||
last_request_str = JSON.stringify(payload);
|
||||
last_response_obj = null;
|
||||
|
||||
fetch(targetep, {
|
||||
method: 'POST',
|
||||
|
@ -13338,6 +13451,7 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
|
||||
last_request_str = JSON.stringify(cohere_payload);
|
||||
last_response_obj = null;
|
||||
let cohere_headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + custom_cohere_key
|
||||
|
@ -13438,6 +13552,7 @@ Current version indicated by LITEVER below.
|
|||
}
|
||||
|
||||
last_request_str = JSON.stringify(submit_payload);
|
||||
last_response_obj = null;
|
||||
|
||||
fetch(selectedhorde.submit_endpoint, {
|
||||
method: 'POST', // or 'PUT'
|
||||
|
@ -14258,7 +14373,8 @@ Current version indicated by LITEVER below.
|
|||
{
|
||||
shownotify();
|
||||
}
|
||||
let lastreq = `<a href="#" onclick="show_last_req()">Last request</a> served by <a href="#" onclick="get_and_show_workers()">${genworker}</a> using <span class="color_darkgreen">${genmdl}</span>${(genkudos>0?` for ${genkudos} kudos`:``)} in ${getTimeTaken()} seconds.`;
|
||||
let lastresp = ` <a href="#" class="color_blueurl" onclick="show_last_logprobs()">(View Logprobs)</a>`;
|
||||
let lastreq = `<a href="#" onclick="show_last_req()">Last request</a> served by <a href="#" onclick="get_and_show_workers()">${genworker}</a> using <span class="color_darkgreen">${genmdl}</span>${(genkudos>0?` for ${genkudos} kudos`:``)} in ${getTimeTaken()} seconds.${(last_response_obj!=null && (last_response_obj.results && last_response_obj.results.length > 0 && last_response_obj.results[0].logprobs!=null))?lastresp:""}`;
|
||||
document.getElementById("lastreq1").innerHTML = lastreq;
|
||||
document.getElementById("lastreq2").innerHTML = lastreq;
|
||||
document.getElementById("lastreq3").innerHTML = lastreq;
|
||||
|
@ -18473,7 +18589,9 @@ Current version indicated by LITEVER below.
|
|||
<option value="claude-3-opus-20240229">claude-3-opus</option>
|
||||
<option value="claude-3-sonnet-20240229">claude-3-sonnet</option>
|
||||
<option value="claude-3-haiku-20240307">claude-3-haiku</option>
|
||||
<option value="claude-3-5-sonnet-20240620">claude-3.5-sonnet</option>
|
||||
<option value="claude-3-5-sonnet-20240620">claude-3-5-sonnet-20240620</option>
|
||||
<option value="claude-3-5-sonnet-20241022">claude-3-5-sonnet-20241022</option>
|
||||
<option value="claude-3-5-sonnet-latest">claude-3-5-sonnet-latest</option>
|
||||
</select>
|
||||
<input type="checkbox" id="claudeaddversion" onchange="" checked>
|
||||
<div class="box-label" title="Add endpoint version">Add Endpoint Version</div>
|
||||
|
@ -19333,6 +19451,11 @@ Current version indicated by LITEVER below.
|
|||
class="helptext">If enabled, renders special tags like EOS and padding tokens. Not recommended.</span></span></div>
|
||||
<input title="Render Special Tags" type="checkbox" id="render_special_tags" style="margin:0px 0px 0px auto;">
|
||||
</div>
|
||||
<div class="settinglabel">
|
||||
<div class="justifyleft settingsmall">Request Logprobs <span class="helpicon">?<span
|
||||
class="helptext">If enabled, request top 5 alternative token logit probabilities for each generated token. Incurs an overhead.</span></span></div>
|
||||
<input title="Request Logprobs" type="checkbox" id="request_logprobs" style="margin:0px 0px 0px auto;">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
|
|
74
koboldcpp.py
74
koboldcpp.py
|
@ -23,6 +23,7 @@ tensor_split_max = 16
|
|||
images_max = 4
|
||||
bias_min_value = -100.0
|
||||
bias_max_value = 100.0
|
||||
logprobs_max = 5
|
||||
|
||||
# abuse prevention
|
||||
stop_token_max = 512
|
||||
|
@ -102,12 +103,15 @@ class token_count_outputs(ctypes.Structure):
|
|||
("ids", ctypes.POINTER(ctypes.c_int))]
|
||||
|
||||
# returns top 5 logprobs per token
|
||||
class logprob_item(ctypes.Structure):
|
||||
_fields_ = [("option_count", ctypes.c_int),
|
||||
("selected_token", ctypes.c_char_p),
|
||||
("selected_logprob", ctypes.c_float),
|
||||
("tokens", ctypes.c_char_p * logprobs_max),
|
||||
("logprobs", ctypes.POINTER(ctypes.c_float))]
|
||||
class last_logprobs_outputs(ctypes.Structure):
|
||||
_fields_ = [("count", ctypes.c_int),
|
||||
("selected_token", ctypes.POINTER(ctypes.c_char_p)),
|
||||
("selected_logprob", ctypes.POINTER(ctypes.c_float)),
|
||||
("tokens", ctypes.POINTER(5 * ctypes.c_char_p)),
|
||||
("logprobs", ctypes.POINTER(5 * ctypes.c_float))]
|
||||
("logprob_items", ctypes.POINTER(logprob_item))]
|
||||
|
||||
class load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("threads", ctypes.c_int),
|
||||
|
@ -190,6 +194,8 @@ class generation_inputs(ctypes.Structure):
|
|||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
("stopreason", ctypes.c_int),
|
||||
("prompt_tokens", ctypes.c_int),
|
||||
("completion_tokens", ctypes.c_int),
|
||||
("text", ctypes.c_char_p)]
|
||||
|
||||
class sd_load_model_inputs(ctypes.Structure):
|
||||
|
@ -896,7 +902,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
memory = genparams.get('memory', "")
|
||||
images = genparams.get('images', [])
|
||||
max_context_length = genparams.get('max_context_length', maxctx)
|
||||
max_length = genparams.get('max_length', 180)
|
||||
max_length = genparams.get('max_length', 200)
|
||||
temperature = genparams.get('temperature', 0.7)
|
||||
top_k = genparams.get('top_k', 100)
|
||||
top_a = genparams.get('top_a', 0.0)
|
||||
|
@ -1078,7 +1084,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
if pendingabortkey!="" and pendingabortkey==genkey:
|
||||
print(f"\nDeferred Abort for GenKey: {pendingabortkey}")
|
||||
pendingabortkey = ""
|
||||
return {"text":"","status":-1,"stopreason":-1}
|
||||
return {"text":"","status":-1,"stopreason":-1, "prompt_tokens":0, "completion_tokens": 0}
|
||||
else:
|
||||
ret = handle.generate(inputs)
|
||||
outstr = ""
|
||||
|
@ -1089,7 +1095,7 @@ def generate(genparams, is_quiet=False, stream_flag=False):
|
|||
sindex = outstr.find(trim_str)
|
||||
if sindex != -1 and trim_str!="":
|
||||
outstr = outstr[:sindex]
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
|
||||
return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason,"prompt_tokens":ret.prompt_tokens, "completion_tokens": ret.completion_tokens}
|
||||
|
||||
|
||||
def sd_load_model(model_filename,vae_filename,lora_filename):
|
||||
|
@ -1267,13 +1273,14 @@ def transform_genparams(genparams, api_format):
|
|||
if api_format==1:
|
||||
genparams["prompt"] = genparams.get('text', "")
|
||||
genparams["top_k"] = int(genparams.get('top_k', 120))
|
||||
genparams["max_length"] = genparams.get('max', 180)
|
||||
genparams["max_length"] = genparams.get('max', 200)
|
||||
|
||||
elif api_format==2:
|
||||
pass
|
||||
|
||||
elif api_format==3 or api_format==4:
|
||||
genparams["max_length"] = genparams.get('max_tokens', (400 if api_format==4 else 180))
|
||||
default_max_tok = (400 if api_format==4 else 200)
|
||||
genparams["max_length"] = genparams.get('max_tokens', genparams.get('max_completion_tokens', default_max_tok))
|
||||
presence_penalty = genparams.get('presence_penalty', genparams.get('frequency_penalty', 0.0))
|
||||
genparams["presence_penalty"] = presence_penalty
|
||||
# openai allows either a string or a list as a stop sequence
|
||||
|
@ -1460,7 +1467,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag)
|
||||
|
||||
genout = {"text": "", "status": -1, "stopreason": -1}
|
||||
genout = {"text": "", "status": -1, "stopreason": -1, "prompt_tokens":0, "completion_tokens": 0}
|
||||
if stream_flag:
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
@ -1469,8 +1476,45 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
genout = run_blocking()
|
||||
|
||||
recvtxt = genout['text']
|
||||
prompttokens = genout['prompt_tokens']
|
||||
comptokens = genout['completion_tokens']
|
||||
currfinishreason = ("length" if (genout['stopreason'] != 1) else "stop")
|
||||
|
||||
# grab logprobs if not streaming
|
||||
logprobsdict = None
|
||||
if not stream_flag and ("logprobs" in genparams and genparams["logprobs"]):
|
||||
lastlogprobs = handle.last_logprobs()
|
||||
logprobsdict = {}
|
||||
logprobsdict['content'] = []
|
||||
logprobsdict['tokens'] = []
|
||||
logprobsdict['token_logprobs'] = []
|
||||
logprobsdict['top_logprobs'] = []
|
||||
logprobsdict['text_offset'] = []
|
||||
text_offset_counter = 0
|
||||
for i in range(lastlogprobs.count):
|
||||
lp_content_item = {}
|
||||
logprob_item = lastlogprobs.logprob_items[i]
|
||||
toptoken = ctypes.string_at(logprob_item.selected_token).decode("UTF-8","ignore")
|
||||
logprobsdict['tokens'].append(toptoken)
|
||||
lp_content_item['token'] = toptoken
|
||||
logprobsdict['token_logprobs'].append(logprob_item.selected_logprob)
|
||||
lp_content_item['logprob'] = logprob_item.selected_logprob
|
||||
lp_content_item['bytes'] = list(toptoken.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'] = []
|
||||
logprobsdict['text_offset'].append(text_offset_counter)
|
||||
text_offset_counter += len(toptoken)
|
||||
tops = {}
|
||||
for j in range(min(logprob_item.option_count,logprobs_max)):
|
||||
tl_item = {}
|
||||
tl_item['logprob'] = logprob_item.logprobs[j]
|
||||
tokstr = ctypes.string_at(logprob_item.tokens[j]).decode("UTF-8","ignore")
|
||||
tops[tokstr] = logprob_item.logprobs[j]
|
||||
tl_item['token'] = tokstr
|
||||
tl_item['bytes'] = list(tokstr.encode('utf-8'))
|
||||
lp_content_item['top_logprobs'].append(tl_item)
|
||||
logprobsdict['top_logprobs'].append(tops)
|
||||
logprobsdict['content'].append(lp_content_item)
|
||||
|
||||
# flag instance as non-idle for a while
|
||||
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
|
||||
if not washordereq:
|
||||
|
@ -1484,8 +1528,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
res = {"data": {"seqs": [recvtxt]}}
|
||||
elif api_format == 3:
|
||||
res = {"id": "cmpl-A1", "object": "text_completion", "created": int(time.time()), "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
|
||||
"usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
|
||||
"choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
|
||||
elif api_format == 4:
|
||||
using_openai_tools = genparams.get('using_openai_tools', False)
|
||||
tool_calls = []
|
||||
|
@ -1494,12 +1538,12 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if tool_calls and len(tool_calls)>0:
|
||||
recvtxt = None
|
||||
res = {"id": "chatcmpl-A1", "object": "chat.completion", "created": int(time.time()), "model": friendlymodelname,
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 100, "total_tokens": 200},
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason}]}
|
||||
"usage": {"prompt_tokens": prompttokens, "completion_tokens": comptokens, "total_tokens": (prompttokens+comptokens)},
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": recvtxt, "tool_calls": tool_calls}, "finish_reason": currfinishreason, "logprobs":logprobsdict}]}
|
||||
elif api_format == 5:
|
||||
res = {"caption": end_trim_to_sentence(recvtxt)}
|
||||
else:
|
||||
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason}]}
|
||||
res = {"results": [{"text": recvtxt, "finish_reason": currfinishreason, "logprobs":logprobsdict, "prompt_tokens": prompttokens, "completion_tokens": comptokens}]}
|
||||
|
||||
try:
|
||||
return res
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue