expose stop reason in generation

This commit is contained in:
Concedo 2024-04-27 01:12:12 +08:00
parent 327682fb97
commit 4ec8a9c57b
4 changed files with 184 additions and 85 deletions

View file

@ -97,6 +97,7 @@ struct generation_inputs
struct generation_outputs struct generation_outputs
{ {
int status = -1; int status = -1;
int stopreason = stop_reason::INVALID;
const char * text; //response will now be stored in c++ allocated memory const char * text; //response will now be stored in c++ allocated memory
}; };
struct token_count_outputs struct token_count_outputs

View file

@ -1584,6 +1584,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nWarning: KCPP text generation not initialized!\n"); printf("\nWarning: KCPP text generation not initialized!\n");
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
} }
@ -2125,6 +2126,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past); fprintf(stderr, "\nFailed to predict at %d! Check your context buffer sizes!\n",n_past);
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
} }
@ -2334,6 +2336,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past); fprintf(stderr, "\nFailed to eval llava image at %d!\n",n_past);
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
} }
@ -2344,6 +2347,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled); fprintf(stderr, "\nLLAVA image tokens mismatch at %d! (%d vs %d tokens)\n",n_past,llavatokenscounted,llavatokensevaled);
output.text = nullptr; output.text = nullptr;
output.status = 0; output.status = 0;
output.stopreason = stop_reason::INVALID;
generation_finished = true; generation_finished = true;
return output; return output;
} }
@ -2381,6 +2385,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
printf("\nCtxLimit: %d/%d, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second); printf("\nCtxLimit: %d/%d, Process:%.2fs (%.1fms/T = %.2fT/s), Generate:%.2fs (%.1fms/T = %.2fT/s), Total:%.2fs (%.2fT/s)",(int)current_context_tokens.size(),(int)nctx, time1, pt1, ts1, time2, pt2, ts2, (time1 + time2), tokens_per_second);
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
output.stopreason = last_stop_reason;
generation_finished = true; generation_finished = true;
last_eval_time = pt2; last_eval_time = pt2;
last_process_time = pt1; last_process_time = pt1;

View file

@ -7,7 +7,7 @@ Just copy this single static HTML file anywhere and open it in a browser, or fro
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
If you are submitting a pull request for Lite, PLEASE use the above repo, not the KoboldCpp one. If you are submitting a pull request for Lite, PLEASE use the above repo, not the KoboldCpp one.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 134 Current version: 135
-Concedo -Concedo
--> -->
@ -1402,6 +1402,14 @@ Current version: 134
margin: 2px; margin: 2px;
font-weight: bolder; font-weight: bolder;
} }
.wiarrowbtn
{
font-size: 12px;
height: 18px;
padding: 2px;
margin: 0px 1px 0px 1px;
font-weight: bolder;
}
.wiinputkeycol .wiinputkeycol
{ {
min-width: 70px; min-width: 70px;
@ -3065,6 +3073,9 @@ Current version: 134
if (custom_kobold_endpoint != "" && data && data.results != null && data.results.length > 0) { if (custom_kobold_endpoint != "" && data && data.results != null && data.results.length > 0) {
synchro_streaming_response += data.results[0].text; synchro_streaming_response += data.results[0].text;
synchro_streaming_tokens_left -= tokens_per_tick; synchro_streaming_tokens_left -= tokens_per_tick;
if (data.results[0].finish_reason == "stop") {
last_stop_reason = "stop";
}
//handle some early stopping criterias //handle some early stopping criterias
if (localsettings.opmode == 3) //stop on selfname found if (localsettings.opmode == 3) //stop on selfname found
@ -3213,6 +3224,10 @@ Current version: 134
for (let event of chunk) { for (let event of chunk) {
if (event.event === 'message') { if (event.event === 'message') {
synchro_pending_stream += event.data.token; synchro_pending_stream += event.data.token;
if(event.data.finish_reason=="stop")
{
last_stop_reason = "stop";
}
} }
} }
} }
@ -3614,6 +3629,7 @@ Current version: 134
var custom_claude_model = ""; var custom_claude_model = "";
var uses_cors_proxy = false; //we start off attempting a direct connection. switch to proxy if that fails var uses_cors_proxy = false; //we start off attempting a direct connection. switch to proxy if that fails
var synchro_polled_response = null; var synchro_polled_response = null;
var last_stop_reason = ""; //update stop reason if known
var synchro_pending_stream = ""; //used for token pseduo streaming for kobold api only var synchro_pending_stream = ""; //used for token pseduo streaming for kobold api only
var waiting_for_autosummary = false; var waiting_for_autosummary = false;
var italics_regex = new RegExp(/\*(\S[^*]+\S)\*/g); //the fallback regex var italics_regex = new RegExp(/\*(\S[^*]+\S)\*/g); //the fallback regex
@ -3633,8 +3649,8 @@ Current version: 134
var welcome = ""; var welcome = "";
var personal_notes = ""; var personal_notes = "";
var logitbiasdict = {}; var logitbiasdict = {};
var regexreplace_pattern = []; var regexreplace_data = [];
var regexreplace_replacement = []; const num_regex_rows = 4;
var localsettings = { var localsettings = {
my_api_key: "0000000000", //put here so it can be saved and loaded in persistent mode my_api_key: "0000000000", //put here so it can be saved and loaded in persistent mode
@ -5084,8 +5100,7 @@ Current version: 134
new_save_storyobj.wiinsertlocation = wi_insertlocation; new_save_storyobj.wiinsertlocation = wi_insertlocation;
new_save_storyobj.personal_notes = personal_notes; new_save_storyobj.personal_notes = personal_notes;
new_save_storyobj.logitbiasdict = JSON.parse(JSON.stringify(logitbiasdict)); new_save_storyobj.logitbiasdict = JSON.parse(JSON.stringify(logitbiasdict));
new_save_storyobj.regexreplace_pattern = JSON.parse(JSON.stringify(regexreplace_pattern)); new_save_storyobj.regexreplace_data = JSON.parse(JSON.stringify(regexreplace_data));
new_save_storyobj.regexreplace_replacement = JSON.parse(JSON.stringify(regexreplace_replacement));
if (export_settings) { if (export_settings) {
new_save_storyobj.savedsettings = JSON.parse(JSON.stringify(localsettings)); new_save_storyobj.savedsettings = JSON.parse(JSON.stringify(localsettings));
@ -5257,8 +5272,7 @@ Current version: 134
let old_current_wi = current_wi; let old_current_wi = current_wi;
let old_extrastopseq = extrastopseq; let old_extrastopseq = extrastopseq;
let old_notes = personal_notes; let old_notes = personal_notes;
let old_regexreplace_pattern = regexreplace_pattern; let old_regexreplace_data = regexreplace_data;
let old_regexreplace_replacement = regexreplace_replacement;
//determine if oldui file or newui file format //determine if oldui file or newui file format
restart_new_game(false); restart_new_game(false);
@ -5329,11 +5343,18 @@ Current version: 134
if (storyobj.personal_notes) { if (storyobj.personal_notes) {
personal_notes = storyobj.personal_notes; personal_notes = storyobj.personal_notes;
} }
if (storyobj.regexreplace_pattern) { //todo: remove temporary backwards compatibility for regex
regexreplace_pattern = storyobj.regexreplace_pattern; if (storyobj.regexreplace_pattern && storyobj.regexreplace_replacement) {
let pat = storyobj.regexreplace_pattern;
let rep = storyobj.regexreplace_replacement;
let ll = Math.min(pat.length,rep.length)
for(let i=0;i<ll;++i)
{
regexreplace_data.push({"p":pat[i],"r":rep[i],"b":false});
} }
if (storyobj.regexreplace_replacement) { }
regexreplace_replacement = storyobj.regexreplace_replacement; if (storyobj.regexreplace_data) {
regexreplace_data = storyobj.regexreplace_data;
} }
} else { } else {
//v2 load //v2 load
@ -5394,8 +5415,7 @@ Current version: 134
if(!loadstopseq) if(!loadstopseq)
{ {
extrastopseq = old_extrastopseq; extrastopseq = old_extrastopseq;
regexreplace_pattern = old_regexreplace_pattern; regexreplace_data = old_regexreplace_data;
regexreplace_replacement = old_regexreplace_replacement;
} }
if (storyobj.savedsettings && storyobj.savedsettings != "") if (storyobj.savedsettings && storyobj.savedsettings != "")
@ -8901,14 +8921,15 @@ Current version: 134
extrastopseq = document.getElementById("extrastopseq").value; extrastopseq = document.getElementById("extrastopseq").value;
newlineaftermemory = (document.getElementById("newlineaftermemory").checked?true:false); newlineaftermemory = (document.getElementById("newlineaftermemory").checked?true:false);
logitbiasdict = pendinglogitbias; logitbiasdict = pendinglogitbias;
regexreplace_pattern = []; regexreplace_data = [];
regexreplace_replacement = []; for(let i=0;i<num_regex_rows;++i)
for(let i=0;i<20;++i)
{ {
let v1 = ""; let v1 = "";
let v2 = ""; let v2 = "";
let bothways = false;
let box1 = document.getElementById("regexreplace_pattern"+i); let box1 = document.getElementById("regexreplace_pattern"+i);
let box2 = document.getElementById("regexreplace_replacement"+i); let box2 = document.getElementById("regexreplace_replacement"+i);
let bw = document.getElementById("regexreplace_bothways"+i).checked;
if(!box1 || !box2) if(!box1 || !box2)
{ {
break; break;
@ -8924,8 +8945,7 @@ Current version: 134
if(v1) if(v1)
{ {
regexreplace_pattern.push(v1); regexreplace_data.push({"p":v1,"r":v2,"b":bw});
regexreplace_replacement.push(v2);
} }
} }
@ -9081,6 +9101,7 @@ Current version: 134
pending_response_id = ""; pending_response_id = "";
poll_in_progress = false; poll_in_progress = false;
synchro_polled_response = null; synchro_polled_response = null;
last_stop_reason = "";
synchro_pending_stream = ""; synchro_pending_stream = "";
waiting_for_autosummary = false; waiting_for_autosummary = false;
horde_poll_nearly_completed = false; horde_poll_nearly_completed = false;
@ -9097,6 +9118,7 @@ Current version: 134
nextgeneratedimagemilestone = generateimagesinterval; nextgeneratedimagemilestone = generateimagesinterval;
pending_response_id = ""; pending_response_id = "";
synchro_polled_response = null; synchro_polled_response = null;
last_stop_reason = "";
synchro_pending_stream = ""; synchro_pending_stream = "";
waiting_for_autosummary = false; waiting_for_autosummary = false;
last_reply_was_empty = false; last_reply_was_empty = false;
@ -9127,8 +9149,7 @@ Current version: 134
wi_searchdepth = 0; wi_searchdepth = 0;
wi_insertlocation = 0; wi_insertlocation = 0;
current_anotetemplate = "[Author's note: <|>]"; current_anotetemplate = "[Author's note: <|>]";
regexreplace_pattern = []; regexreplace_data = [];
regexreplace_replacement = [];
} }
render_gametext(save); //necessary to trigger an autosave to wipe out current story in case they exit browser after newgame. render_gametext(save); //necessary to trigger an autosave to wipe out current story in case they exit browser after newgame.
} }
@ -9718,6 +9739,20 @@ Current version: 134
function submit_generation() { function submit_generation() {
let newgen = document.getElementById("input_text").value; let newgen = document.getElementById("input_text").value;
//apply regex transforms
if(regexreplace_data && regexreplace_data.length>0)
{
for(let i=0;i<regexreplace_data.length;++i)
{
if(regexreplace_data[i].b && regexreplace_data[i].p!="")
{
let pat = new RegExp(regexreplace_data[i].p, "gm");
newgen = newgen.replace(pat, regexreplace_data[i].r);
}
}
}
const user_input_empty = (newgen.trim()==""); const user_input_empty = (newgen.trim()=="");
let doNotGenerate = false; let doNotGenerate = false;
pending_context_postinjection = ""; pending_context_postinjection = "";
@ -10349,6 +10384,7 @@ Current version: 134
poll_ticks_passed = 0; poll_ticks_passed = 0;
poll_in_progress = false; poll_in_progress = false;
synchro_polled_response = null; synchro_polled_response = null;
last_stop_reason = "";
synchro_pending_stream = ""; synchro_pending_stream = "";
//if this is set, we don't use horde, use the custom endpoint instead //if this is set, we don't use horde, use the custom endpoint instead
@ -11420,7 +11456,11 @@ Current version: 134
gentxt = trim_extra_stop_seqs(gentxt,true); gentxt = trim_extra_stop_seqs(gentxt,true);
//always trim incomplete sentences for adventure and chat (if not multiline) //always trim incomplete sentences for adventure and chat (if not multiline)
if (localsettings.opmode == 2 || (localsettings.opmode == 3 && !localsettings.allow_continue_chat) || localsettings.trimsentences == true) { //do not trim if instruct mode AND stop token reached
let donottrim = ((localsettings.opmode == 4||localsettings.opmode == 3) && last_stop_reason=="stop");
if (!donottrim && (localsettings.opmode == 2
|| (localsettings.opmode == 3 && !localsettings.allow_continue_chat)
|| localsettings.trimsentences == true)) {
gentxt = end_trim_to_sentence(gentxt,true); gentxt = end_trim_to_sentence(gentxt,true);
} }
@ -11428,14 +11468,14 @@ Current version: 134
gentxt = trim_extra_stop_seqs(gentxt,false); gentxt = trim_extra_stop_seqs(gentxt,false);
//apply regex transform //apply regex transform
if(regexreplace_pattern && regexreplace_pattern.length>0) if(regexreplace_data && regexreplace_data.length>0)
{ {
for(let i=0;i<regexreplace_pattern.length;++i) for(let i=0;i<regexreplace_data.length;++i)
{ {
if(regexreplace_pattern[i]!="") if(regexreplace_data[i].p!="")
{ {
let pat = new RegExp(regexreplace_pattern[i], "gm"); let pat = new RegExp(regexreplace_data[i].p, "gm");
gentxt = gentxt.replace(pat, regexreplace_replacement[i]); gentxt = gentxt.replace(pat, regexreplace_data[i].r);
} }
} }
} }
@ -12001,6 +12041,7 @@ Current version: 134
} }
} }
synchro_polled_response = null; synchro_polled_response = null;
last_stop_reason = "";
synchro_pending_stream = ""; synchro_pending_stream = "";
show_abort_button(false); show_abort_button(false);
render_gametext(); render_gametext();
@ -13139,26 +13180,56 @@ Current version: 134
backup_wi(); backup_wi();
update_wi(); update_wi();
for(let i=0;i<20;++i) //setup regex replacers
populate_regex_replacers();
document.getElementById("btnlogitbias").disabled = !is_using_custom_ep();
}
function populate_regex_replacers()
{
let regextablehtml = `
<tr>
<th>Pattern <span class="helpicon">?<span class="helptext">The regex pattern to match against any incoming text. Leave blank to disable.</span></span></th>
<th>Replacement <span class="helpicon">?<span class="helptext">The string to replace matches with. Capture groups are allowed (e.g. $1). To remove all matches, leave this blank.</span></span></th>
<th>Both Ways <span class="helpicon">?<span class="helptext">If enabled, regex applies for both inputs and outputs, otherwise output only.</span></span></th>
</tr>`;
let regextable = document.getElementById("regex_replace_table");
for(let i=0;i<num_regex_rows;++i)
{
regextablehtml += `
<tr>
<td><input class="settinglabel miniinput" type="text" placeholder="(Inactive)" value="" id="regexreplace_pattern${i}"></td>
<td><input class="settinglabel miniinput" type="text" placeholder="(Remove)" value="" id="regexreplace_replacement${i}"></td>
<td><input type="checkbox" id="regexreplace_bothways${i}" style="margin:0px 0 0;"></td>
</tr>
`;
}
regextable.innerHTML = regextablehtml;
for(let i=0;i<num_regex_rows;++i)
{ {
let a1 = document.getElementById("regexreplace_pattern"+i); let a1 = document.getElementById("regexreplace_pattern"+i);
let a2 = document.getElementById("regexreplace_replacement"+i); let a2 = document.getElementById("regexreplace_replacement"+i);
if(a1 && a2) let a3 = document.getElementById("regexreplace_bothways"+i);
if(a1 && a2 && a3)
{ {
if(i<regexreplace_pattern.length) if(i<regexreplace_data.length)
{ {
a1.value = regexreplace_pattern[i]; a1.value = regexreplace_data[i].p;
a2.value = regexreplace_replacement[i]; a2.value = regexreplace_data[i].r;
a3.checked = (regexreplace_data[i].b?true:false);
} }
else else
{ {
a1.value = a2.value = ""; a1.value = a2.value = "";
a3.checked = false;
} }
} }
} }
document.getElementById("btnlogitbias").disabled = !is_using_custom_ep();
} }
function toggle_wi_sk(idx) { function toggle_wi_sk(idx) {
@ -13197,6 +13268,28 @@ Current version: 134
update_wi(); update_wi();
} }
function up_wi(idx) {
save_wi();
var ce = current_wi[idx];
if (idx > 0 && idx < current_wi.length) {
const temp = current_wi[idx - 1];
current_wi[idx - 1] = current_wi[idx];
current_wi[idx] = temp;
}
update_wi();
}
function down_wi(idx) {
save_wi();
var ce = current_wi[idx];
if (idx >= 0 && idx+1 < current_wi.length) {
const temp = current_wi[idx + 1];
current_wi[idx + 1] = current_wi[idx];
current_wi[idx] = temp;
}
update_wi();
}
function add_wi() { function add_wi() {
save_wi(); save_wi();
var ne = { var ne = {
@ -13263,8 +13356,10 @@ Current version: 134
let probarr = [100,90,75,50,25,10,5,1]; let probarr = [100,90,75,50,25,10,5,1];
selectionhtml += `<tr class='`+ (ishidden?"hidden":"") +`' id="wirow` + i + `"><td class="col-8" style="font-size: 10px;">` + selectionhtml += `<tr class='`+ (ishidden?"hidden":"") +`' id="wirow` + i + `"><td class="col-8" style="font-size: 10px;">`
`<button type="button" class="btn btn-danger widelbtn" id="widel` + i + `" onclick="return del_wi(` + i + `)">X</button></td>` + +`<button type="button" class="btn btn-danger widelbtn" id="widel` + i + `" onclick="return del_wi(` + i + `)">X</button></td>`
+`<td><button type="button" class="btn btn-primary wiarrowbtn" id="wiup` + i + `" onclick="return up_wi(` + i + `)">▲</button>`
+`<button type="button" class="btn btn-primary wiarrowbtn" id="widown` + i + `" onclick="return down_wi(` + i + `)">▼</button></td>` +
`<td class="col-6 wiinputkeycol"> `<td class="col-6 wiinputkeycol">
<input class="form-control wiinputkey" id="wikey`+ i + `" placeholder="Key(s)" value="` + winame + `"> <input class="form-control wiinputkey" id="wikey`+ i + `" placeholder="Key(s)" value="` + winame + `">
<input class="form-control wiinputkey `+ (curr.selective ? `` : `hidden`) + `" id="wikeysec` + i + `" placeholder="Sec. Key(s)" value="` + wisec + `">` + `</td> <input class="form-control wiinputkey `+ (curr.selective ? `` : `hidden`) + `" id="wikeysec` + i + `" placeholder="Sec. Key(s)" value="` + wisec + `">` + `</td>
@ -15256,25 +15351,10 @@ Current version: 134
<div><button type="button" class="btn btn-primary" style="width:134px;padding:6px 6px;" id="btnlogitbias" onclick="set_logit_bias()">Edit Logit Biases</button></div> <div><button type="button" class="btn btn-primary" style="width:134px;padding:6px 6px;" id="btnlogitbias" onclick="set_logit_bias()">Edit Logit Biases</button></div>
<div class="settinglabel"> <div class="settinglabel">
<div class="justifyleft"><br>Custom Regex Replace <span class="helpicon">?<span <div class="justifyleft"><br>Custom Regex Replace <span class="helpicon">?<span
class="helptext">Allows transforming incoming text with up to 3 regex patterns, modifying all matches. Replacements will be applied in sequence.</span></span></div> class="helptext">Allows transforming incoming text with regex patterns, modifying all matches. Replacements will be applied in sequence.</span></span></div>
</div> </div>
<table class="settinglabel text-center" style="border-spacing: 3px 2px; border-collapse: separate;"> <table id="regex_replace_table" class="settinglabel text-center" style="border-spacing: 3px 2px; border-collapse: separate;">
<tr>
<th>Pattern <span class="helpicon">?<span class="helptext">The regex pattern to match against any incoming text. Leave blank to disable.</span></span></th>
<th>Replacement <span class="helpicon">?<span class="helptext">The string to replace matches with. Capture groups are allowed (e.g. $1). To remove all matches, leave this blank.</span></span></th>
</tr>
<tr>
<td><input class="settinglabel miniinput" type="text" placeholder="(Inactive)" value="" id="regexreplace_pattern0"></td>
<td><input class="settinglabel miniinput" type="text" placeholder="(Remove)" value="" id="regexreplace_replacement0"></td>
</tr>
<tr>
<td><input class="settinglabel miniinput" type="text" placeholder="(Inactive)" value="" id="regexreplace_pattern1"></td>
<td><input class="settinglabel miniinput" type="text" placeholder="(Remove)" value="" id="regexreplace_replacement1"></td>
</tr>
<tr>
<td><input class="settinglabel miniinput" type="text" placeholder="(Inactive)" value="" id="regexreplace_pattern2"></td>
<td><input class="settinglabel miniinput" type="text" placeholder="(Remove)" value="" id="regexreplace_replacement2"></td>
</tr>
</table> </table>
</div> </div>

View file

@ -95,6 +95,7 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure): class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int), _fields_ = [("status", ctypes.c_int),
("stopreason", ctypes.c_int),
("text", ctypes.c_char_p)] ("text", ctypes.c_char_p)]
class sd_load_model_inputs(ctypes.Structure): class sd_load_model_inputs(ctypes.Structure):
@ -493,7 +494,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
if pendingabortkey!="" and pendingabortkey==genkey: if pendingabortkey!="" and pendingabortkey==genkey:
print(f"\nDeferred Abort for GenKey: {pendingabortkey}") print(f"\nDeferred Abort for GenKey: {pendingabortkey}")
pendingabortkey = "" pendingabortkey = ""
return "" return {"text":"","status":-1,"stopreason":-1}
else: else:
ret = handle.generate(inputs) ret = handle.generate(inputs)
outstr = "" outstr = ""
@ -504,7 +505,7 @@ def generate(prompt, memory="", images=[], max_length=32, max_context_length=512
sindex = outstr.find(trim_str) sindex = outstr.find(trim_str)
if sindex != -1 and trim_str!="": if sindex != -1 and trim_str!="":
outstr = outstr[:sindex] outstr = outstr[:sindex]
return outstr return {"text":outstr,"status":ret.status,"stopreason":ret.stopreason}
def sd_load_model(model_filename): def sd_load_model(model_filename):
@ -656,6 +657,7 @@ nocertify = False
start_time = time.time() start_time = time.time()
last_req_time = time.time() last_req_time = time.time()
last_non_horde_req_time = time.time() last_non_horde_req_time = time.time()
currfinishreason = "null"
def transform_genparams(genparams, api_format): def transform_genparams(genparams, api_format):
#alias all nonstandard alternative names for rep pen. #alias all nonstandard alternative names for rep pen.
@ -765,8 +767,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
async def generate_text(self, genparams, api_format, stream_flag): async def generate_text(self, genparams, api_format, stream_flag):
from datetime import datetime from datetime import datetime
global friendlymodelname, chatcompl_adapter global friendlymodelname, chatcompl_adapter, currfinishreason
is_quiet = args.quiet is_quiet = args.quiet
currfinishreason = "null"
def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
@ -812,13 +815,16 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
render_special=genparams.get('render_special', False), render_special=genparams.get('render_special', False),
) )
recvtxt = "" genout = {"text":"","status":-1,"stopreason":-1}
if stream_flag: if stream_flag:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
recvtxt = await loop.run_in_executor(executor, run_blocking) genout = await loop.run_in_executor(executor, run_blocking)
else: else:
recvtxt = run_blocking() genout = run_blocking()
recvtxt = genout['text']
currfinishreason = ("length" if (genout['stopreason']!=1) else "stop")
#flag instance as non-idle for a while #flag instance as non-idle for a while
washordereq = genparams.get('genkey', '').startswith('HORDEREQ_') washordereq = genparams.get('genkey', '').startswith('HORDEREQ_')
@ -834,15 +840,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
elif api_format==3: elif api_format==3:
res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname, res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": friendlymodelname,
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200}, "usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
"choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]} "choices": [{"text": recvtxt, "index": 0, "finish_reason": currfinishreason}]}
elif api_format==4: elif api_format==4:
res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname, res = {"id": "chatcmpl-1", "object": "chat.completion", "created": 1, "model": friendlymodelname,
"usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200}, "usage": {"prompt_tokens": 100,"completion_tokens": 100,"total_tokens": 200},
"choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": "length"}]} "choices": [{"index": 0, "message":{"role": "assistant", "content": recvtxt,}, "finish_reason": currfinishreason}]}
elif api_format==5: elif api_format==5:
res = {"caption": end_trim_to_sentence(recvtxt)} res = {"caption": end_trim_to_sentence(recvtxt)}
else: else:
res = {"results": [{"text": recvtxt}]} res = {"results": [{"text": recvtxt, "finish_reason":currfinishreason}]}
try: try:
return res return res
@ -863,7 +869,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.wfile.flush() self.wfile.flush()
async def handle_sse_stream(self, genparams, api_format): async def handle_sse_stream(self, genparams, api_format):
global friendlymodelname global friendlymodelname, currfinishreason
self.send_response(200) self.send_response(200)
self.send_header("cache-control", "no-cache") self.send_header("cache-control", "no-cache")
self.send_header("connection", "keep-alive") self.send_header("connection", "keep-alive")
@ -877,6 +883,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
while True: while True:
streamDone = handle.has_finished() #exit next loop on done streamDone = handle.has_finished() #exit next loop on done
if streamDone:
sr = handle.get_last_stop_reason()
currfinishreason = ("length" if (sr!=1) else "stop")
tokenStr = "" tokenStr = ""
streamcount = handle.get_stream_count() streamcount = handle.get_stream_count()
while current_token < streamcount: while current_token < streamcount:
@ -893,13 +902,14 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
incomplete_token_buffer.clear() incomplete_token_buffer.clear()
tokenStr += tokenSeg tokenStr += tokenSeg
if tokenStr!="": if tokenStr!="" or streamDone:
sseq = genparams.get('stop_sequence', []) sseq = genparams.get('stop_sequence', [])
trimstop = genparams.get('trim_stop', False) trimstop = genparams.get('trim_stop', False)
if trimstop and not streamDone and string_contains_sequence_substring(tokenStr,sseq): if trimstop and not streamDone and string_contains_sequence_substring(tokenStr,sseq):
tokenReserve += tokenStr tokenReserve += tokenStr
await asyncio.sleep(async_sleep_short) #if a stop sequence could trigger soon, do not send output await asyncio.sleep(async_sleep_short) #if a stop sequence could trigger soon, do not send output
else: else:
if tokenStr!="":
tokenStr = tokenReserve + tokenStr tokenStr = tokenReserve + tokenStr
tokenReserve = "" tokenReserve = ""
@ -910,15 +920,15 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if sindex != -1 and trim_str!="": if sindex != -1 and trim_str!="":
tokenStr = tokenStr[:sindex] tokenStr = tokenStr[:sindex]
if tokenStr!="": if tokenStr!="" or streamDone:
if api_format == 4: # if oai chat, set format to expected openai streaming response if api_format == 4: # if oai chat, set format to expected openai streaming response
event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","delta":{'role':'assistant','content':tokenStr}}]}) event_str = json.dumps({"id":"koboldcpp","object":"chat.completion.chunk","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":currfinishreason,"delta":{'role':'assistant','content':tokenStr}}]})
await self.send_oai_sse_event(event_str) await self.send_oai_sse_event(event_str)
elif api_format == 3: # non chat completions elif api_format == 3: # non chat completions
event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":"length","text":tokenStr}]}) event_str = json.dumps({"id":"koboldcpp","object":"text_completion","created":1,"model":friendlymodelname,"choices":[{"index":0,"finish_reason":currfinishreason,"text":tokenStr}]})
await self.send_oai_sse_event(event_str) await self.send_oai_sse_event(event_str)
else: else:
event_str = json.dumps({"token": tokenStr}) event_str = json.dumps({"token": tokenStr, "finish_reason":currfinishreason})
await self.send_kai_sse_event(event_str) await self.send_kai_sse_event(event_str)
tokenStr = "" tokenStr = ""
else: else:
@ -3159,7 +3169,8 @@ def main(launch_args,start_server=True):
benchprompt = "11111111" benchprompt = "11111111"
for i in range(0,10): #generate massive prompt for i in range(0,10): #generate massive prompt
benchprompt += benchprompt benchprompt += benchprompt
result = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True) genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
result = genout['text']
result = (result[:5] if len(result)>5 else "") result = (result[:5] if len(result)>5 else "")
resultok = (result=="11111") resultok = (result=="11111")
t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001 t_pp = float(handle.get_last_process_time())*float(benchmaxctx-benchlen)*0.001
@ -3212,7 +3223,9 @@ def run_in_queue(launch_args, input_queue, output_queue):
data = input_queue.get() data = input_queue.get()
if data['command'] == 'generate': if data['command'] == 'generate':
(args, kwargs) = data['data'] (args, kwargs) = data['data']
output_queue.put({'command': 'generated text', 'data': generate(*args, **kwargs)}) genout = generate(*args, **kwargs)
result = genout['text']
output_queue.put({'command': 'generated text', 'data': result})
time.sleep(0.2) time.sleep(0.2)
def start_in_seperate_process(launch_args): def start_in_seperate_process(launch_args):