context shift feature done

This commit is contained in:
Concedo 2023-10-29 18:21:39 +08:00
parent 338d6c265d
commit 7924592a83
4 changed files with 41 additions and 18 deletions

View file

@ -78,6 +78,7 @@ static int n_threads = 4;
static int n_blasthreads = 4;
static int n_batch = 8;
static bool useSmartContext = false;
static bool useContextShift = false;
static int blasbatchsize = 512;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::string modelname;
@ -647,7 +648,7 @@ void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_t
current_context_tokens[i - diff] = current_context_tokens[i];
}
printf("\n[Smart Context Pro: Erased %d tokens at position %d]", diff, trimstart+1);
printf("\n[Context Shifting: Erased %d tokens at position %d]", diff, trimstart+1);
current_context_tokens.resize(current_context_tokens.size() - diff - 1);
}
@ -665,6 +666,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_batch = params.n_batch = inputs.batch_size;
modelname = params.model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
useContextShift = inputs.use_contextshift;
debugmode = inputs.debugmode;
blasbatchsize = inputs.blasbatchsize;
if(blasbatchsize<=0)
@ -1464,13 +1466,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
else
{
bool triggersc = useSmartContext;
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
if(useContextShift && (file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON))
{
if(useSmartContext)
{
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
triggersc = false;
}
PurgeMissingTokens(llama_ctx_v4, current_context_tokens, embd_inp, inputs.max_length, nctx);
triggersc = false;
}
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false);
if(file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
@ -1717,7 +1716,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if (!evalres)
{
fprintf(stderr, "Failed to predict\n");
fprintf(stderr, "\nFailed to predict! Check your context buffer sizes!\n");
snprintf(output.text, sizeof(output.text), "%s", "");
output.status = 0;
generation_finished = true;