handle memory separately for kcpp

This commit is contained in:
Concedo 2023-11-07 17:15:14 +08:00
parent f277ed0e8c
commit fb3bcac368
4 changed files with 105 additions and 22 deletions

View file

@ -1388,6 +1388,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
stop_sequence.push_back(stopper);
}
}
std::string addedmemory = inputs.memory;
params.prompt = inputs.prompt;
params.seed = inputs.seed;
params.n_predict = inputs.max_length;
@ -1442,7 +1443,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt
std::vector<int> embd_inp;
std::vector<int> embd_inp_mem; //for storing added memory
TokenizeString(params.prompt, embd_inp, file_format);
if(addedmemory!="")
{
TokenizeString(addedmemory, embd_inp_mem, file_format);
}
//truncate to front of the prompt if its too long
int32_t nctx = params.n_ctx;
@ -1461,6 +1467,46 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
//added special memory, overwrite if needed
if(addedmemory!="")
{
//remove bos token from prompt, it'll be taken from memory
std::vector<int> bos;
TokenizeString("", bos, file_format);
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) {
embd_inp.erase(embd_inp.begin());
}
//shorten memory if needed
if (embd_inp_mem.size() + params.n_predict + 4 > nctx)
{
int offset = embd_inp_mem.size() - nctx + params.n_predict + 4;
embd_inp_mem = std::vector<int>(embd_inp_mem.begin() + offset, embd_inp_mem.end());
//replace bos into front if exists
if(bos.size()>0 && embd_inp_mem.size()>0)
{
embd_inp_mem[0] = bos[0];
}
}
//shorten main prompt by trimming the front if needed
int addmemtokens = embd_inp_mem.size();
int totalsize = (addmemtokens + embd_inp.size() + params.n_predict);
if(totalsize > nctx)
{
int excess = totalsize - nctx;
if (embd_inp.size() >= excess) {
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + excess);
} else {
embd_inp.clear();
}
}
//stick memory to front of prompt
embd_inp.insert(embd_inp.begin(), embd_inp_mem.begin(), embd_inp_mem.end());
}
//determine how much npast we have to rewind from the current state
std::vector<gpt_vocab::id> embd;