new gpt2 format supported

This commit is contained in:
Concedo 2023-04-08 17:35:36 +08:00
parent 1369b46bb7
commit d8e37bfe75
12 changed files with 962 additions and 51 deletions

View file

@ -17,7 +17,7 @@
// load the model's weights from a file
ModelLoadResult gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, FileFormat file_format) {
ModelLoadResult legacy_gpt2_model_load(const std::string & fname, gpt2_v1_model & model, gpt_vocab & vocab, FileFormat file_format) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
@ -267,9 +267,19 @@ ModelLoadResult gpt2_model_load(const std::string & fname, gpt2_model & model, g
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return ModelLoadResult::FAIL;
//test for transposition and retry older loader
if(tensor->ne[0]==ne[1] && tensor->ne[1]==ne[0] && should_transpose_layer(name))
{
printf("\nFound a transposed tensor. This could be an older or newer model. Retrying load...");
ggml_v1_free(ctx);
return ModelLoadResult::RETRY_LOAD;
}
else
{
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
return ModelLoadResult::FAIL;
}
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_v1_fp16_t);
@ -302,8 +312,8 @@ ModelLoadResult gpt2_model_load(const std::string & fname, gpt2_model & model, g
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
bool gpt2_eval(
const gpt2_model & model,
bool legacy_gpt2_eval(
const gpt2_v1_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
@ -641,13 +651,13 @@ bool gpt2_eval(
// int64_t t_load_us = 0;
// gpt_vocab vocab;
// gpt2_model model;
// gpt2_v1_model model;
// // load the model
// {
// const int64_t t_start_us = ggml_v1_time_us();
// if (!gpt2_model_load(params.model, model, vocab, FileFormat::GPT2)) {
// if (!legacy_gpt2_model_load(params.model, model, vocab, FileFormat::GPT2_1)) {
// fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
// return 1;
// }
@ -676,14 +686,14 @@ bool gpt2_eval(
// // determine the required inference memory per token:
// size_t mem_per_token = 0;
// gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, FileFormat::GPT2);
// legacy_gpt2_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, FileFormat::GPT2_1);
// for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
// // predict
// if (embd.size() > 0) {
// const int64_t t_start_us = ggml_v1_time_us();
// if (!gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, FileFormat::GPT2)) {
// if (!legacy_gpt2_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, FileFormat::GPT2_1)) {
// printf("Failed to predict\n");
// return 1;
// }