new gpt2 format supported

This commit is contained in:
Concedo 2023-04-08 17:35:36 +08:00
parent 1369b46bb7
commit d8e37bfe75
12 changed files with 962 additions and 51 deletions

View file

@ -20,7 +20,7 @@
ModelLoadResult legacy_gptj_model_load(const std::string & fname, gptj_model_v1 & model, gpt_vocab & vocab, FileFormat file_format) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
bool super_old_format = (file_format==FileFormat::GPTJ1);
bool super_old_format = (file_format==FileFormat::GPTJ_1);
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
@ -372,7 +372,7 @@ bool legacy_gptj_eval(
size_t & mem_per_token,
FileFormat file_format) {
bool super_old_format = (file_format==FileFormat::GPTJ1);
bool super_old_format = (file_format==FileFormat::GPTJ_1);
const int N = embd_inp.size();
const auto & hparams = model.hparams;