mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
new gpt2 format supported
This commit is contained in:
parent
1369b46bb7
commit
d8e37bfe75
12 changed files with 962 additions and 51 deletions
|
@ -16,13 +16,15 @@
|
|||
#include "otherarch/gptj_v1.cpp"
|
||||
#include "otherarch/gptj_v2.cpp"
|
||||
#include "otherarch/gpt2_v1.cpp"
|
||||
#include "otherarch/gpt2_v2.cpp"
|
||||
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
static FileFormat file_format = FileFormat::BADFORMAT;
|
||||
static gpt_vocab vocab;
|
||||
static gptj_model_v1 model_v1;
|
||||
static gptj_model model_v2;
|
||||
static gpt2_model model_gpt2;
|
||||
static gpt2_v1_model model_gpt2_v1;
|
||||
static gpt2_model model_gpt2_v2;
|
||||
static gpt_params params;
|
||||
static int n_past = 0;
|
||||
static int n_threads = 4;
|
||||
|
@ -42,19 +44,41 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
n_batch = params.n_batch = inputs.batch_size;
|
||||
modelname = params.model = inputs.model_filename;
|
||||
|
||||
if (file_format == FileFormat::GPT2)
|
||||
if (file_format == FileFormat::GPT2_1)
|
||||
{
|
||||
ModelLoadResult res = gpt2_model_load(params.model, model_gpt2, vocab, file_format);
|
||||
ModelLoadResult res = legacy_gpt2_model_load(params.model, model_gpt2_v1, vocab, file_format);
|
||||
if(res==ModelLoadResult::FAIL)
|
||||
{
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return res;
|
||||
}
|
||||
else if(res==ModelLoadResult::RETRY_LOAD)
|
||||
{
|
||||
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
|
||||
return res;
|
||||
}
|
||||
// determine the required inference memory per token:
|
||||
gpt2_eval(model_gpt2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
||||
legacy_gpt2_eval(model_gpt2_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else if (file_format == FileFormat::GPTJ1 || file_format == FileFormat::GPTJ2)
|
||||
else if (file_format == FileFormat::GPT2_2)
|
||||
{
|
||||
ModelLoadResult res = gpt2_model_load(params.model, model_gpt2_v2, vocab, file_format);
|
||||
if(res==ModelLoadResult::FAIL)
|
||||
{
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return res;
|
||||
}
|
||||
else if(res==ModelLoadResult::RETRY_LOAD)
|
||||
{
|
||||
printf("\nTensor Transposition Detected! Retrying GPT-2 model loading...");
|
||||
return res;
|
||||
}
|
||||
// determine the required inference memory per token:
|
||||
gpt2_eval(model_gpt2_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2)
|
||||
{
|
||||
ModelLoadResult res = legacy_gptj_model_load(params.model, model_v1, vocab, file_format);
|
||||
if(res==ModelLoadResult::FAIL)
|
||||
|
@ -125,17 +149,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
//truncate to front of the prompt if its too long
|
||||
int32_t nctx = 512;
|
||||
if(file_format == FileFormat::GPTJ1||file_format == FileFormat::GPTJ2)
|
||||
if(file_format == FileFormat::GPTJ_1||file_format == FileFormat::GPTJ_2)
|
||||
{
|
||||
nctx = model_v1.hparams.n_ctx;
|
||||
}
|
||||
else if(file_format==FileFormat::GPTJ3)
|
||||
else if(file_format==FileFormat::GPTJ_3)
|
||||
{
|
||||
nctx = model_v2.hparams.n_ctx;
|
||||
}
|
||||
else if(file_format==FileFormat::GPT2)
|
||||
else if(file_format==FileFormat::GPT2_1)
|
||||
{
|
||||
nctx = model_gpt2.hparams.n_ctx;
|
||||
nctx = model_gpt2_v1.hparams.n_ctx;
|
||||
}
|
||||
else if(file_format==FileFormat::GPT2_2)
|
||||
{
|
||||
nctx = model_gpt2_v2.hparams.n_ctx;
|
||||
}
|
||||
|
||||
if (embd_inp.size() + params.n_predict > nctx)
|
||||
|
@ -198,17 +226,21 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
double time1 = 0, time2 = 0;
|
||||
unsigned int embd_inp_size = embd_inp.size();
|
||||
int32_t n_vocab = 0;
|
||||
if(file_format == FileFormat::GPTJ1||file_format == FileFormat::GPTJ2)
|
||||
if(file_format == FileFormat::GPTJ_1||file_format == FileFormat::GPTJ_2)
|
||||
{
|
||||
n_vocab = model_v1.hparams.n_vocab;
|
||||
}
|
||||
else if(file_format == FileFormat::GPTJ3)
|
||||
else if(file_format == FileFormat::GPTJ_3)
|
||||
{
|
||||
n_vocab = model_v2.hparams.n_vocab;
|
||||
}
|
||||
else if(file_format == FileFormat::GPT2)
|
||||
else if(file_format == FileFormat::GPT2_1)
|
||||
{
|
||||
n_vocab = model_gpt2.hparams.n_vocab;
|
||||
n_vocab = model_gpt2_v1.hparams.n_vocab;
|
||||
}
|
||||
else if(file_format == FileFormat::GPT2_2)
|
||||
{
|
||||
n_vocab = model_gpt2_v2.hparams.n_vocab;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -236,11 +268,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
bool evalres = false;
|
||||
|
||||
//print_tok_vec(logits);
|
||||
if(file_format==FileFormat::GPT2)
|
||||
if(file_format==FileFormat::GPT2_1)
|
||||
{
|
||||
evalres = gpt2_eval(model_gpt2, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||
evalres = legacy_gpt2_eval(model_gpt2_v1, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||
}
|
||||
else if(file_format==FileFormat::GPTJ1 || file_format==FileFormat::GPTJ2)
|
||||
else if(file_format==FileFormat::GPT2_2)
|
||||
{
|
||||
evalres = gpt2_eval(model_gpt2_v2, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||
}
|
||||
else if(file_format==FileFormat::GPTJ_1 || file_format==FileFormat::GPTJ_2)
|
||||
{
|
||||
evalres = legacy_gptj_eval(model_v1, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue