fixes to stopper tokens, fixed BLAS mode for GPT2 and GPTJ, updated kobold lite

This commit is contained in:
Concedo 2023-04-16 21:54:18 +08:00
parent 6548d3b3fb
commit c757fbee1d
6 changed files with 17 additions and 14 deletions

View file

@ -157,9 +157,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
stop_sequence.clear();
for(int x=0;x<stop_token_max;++x)
{
if(inputs.stop_sequence[x]!="")
std::string stopper = inputs.stop_sequence[x];
if(stopper!="")
{
stop_sequence.push_back(inputs.stop_sequence[x]);
stop_sequence.push_back(stopper);
}
}
params.prompt = inputs.prompt;
@ -211,14 +212,16 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext);
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
// bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
// bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
bool blasmode = false;
bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
// bool blasmode = false;
int original_batch = params.n_batch;
int original_threads = params.n_threads;
if (blasmode)
{
params.n_batch = blasbatchsize; //received reports of 1024 and above crashing on some models
//for gpttype, GPT2 crashes above 256.
int bbs = (blasbatchsize>256?256:blasbatchsize);
params.n_batch = bbs; //received reports of 1024 and above crashing on some models
params.n_threads = 1;
}
@ -350,7 +353,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (concat_output.find(matched) != std::string::npos)
{
remaining_tokens = 0;
printf("\n(Stop sequence triggered)");
printf("\n(Stop sequence triggered: %s)",matched.c_str());
break;
}
}