fixed a few OOM errors with larger contexts - I cannot figure out why they happen, so I am forced to increase the buffer size.

This commit is contained in:
Concedo 2023-04-11 00:14:57 +08:00
parent f53238f570
commit 69b85f5b61
5 changed files with 25 additions and 27 deletions

View file

@ -49,6 +49,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_threads = params.n_threads = inputs.threads;
n_batch = params.n_batch = inputs.batch_size;
modelname = params.model = inputs.model_filename;
params.memory_f16 = inputs.f16_kv;
params.n_ctx = inputs.max_context_length;
model_v1.hparams.n_ctx = model_v2.hparams.n_ctx = model_gpt2_v1.hparams.n_ctx = model_gpt2_v2.hparams.n_ctx = params.n_ctx;
if (file_format == FileFormat::GPT2_1)
{
@ -153,6 +156,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
params.temp = inputs.temperature;
params.repeat_last_n = inputs.rep_pen_range;
params.repeat_penalty = inputs.rep_pen;
params.n_ctx = inputs.max_context_length;
params.n_batch = n_batch;
params.n_threads = n_threads;
@ -173,23 +177,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
//truncate to front of the prompt if its too long
int32_t nctx = 512;
if(file_format == FileFormat::GPTJ_1||file_format == FileFormat::GPTJ_2)
{
nctx = model_v1.hparams.n_ctx;
}
else if(file_format==FileFormat::GPTJ_3)
{
nctx = model_v2.hparams.n_ctx;
}
else if(file_format==FileFormat::GPT2_1)
{
nctx = model_gpt2_v1.hparams.n_ctx;
}
else if(file_format==FileFormat::GPT2_2)
{
nctx = model_gpt2_v2.hparams.n_ctx;
}
int32_t nctx = params.n_ctx;
if (embd_inp.size() + params.n_predict > nctx)
{