clean and refactoring pass before supporting newer models for different arch

This commit is contained in:
Concedo 2023-05-17 11:23:29 +08:00
parent 60ee00428b
commit 90fe9096b4
12 changed files with 81 additions and 367 deletions

View file

@ -2,7 +2,6 @@
#include "otherarch.h"
#include "utils.h"
#include "common-ggml.h"
#include <cassert>
#include <cmath>
@ -17,7 +16,7 @@
// load the model's weights from a file
ModelLoadResult stablelm_model_load(const std::string & fname, stablelm_model & model, gpt_vocab & vocab, FileFormat file_format) {
ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt_vocab & vocab, FileFormat file_format) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
auto fin = std::ifstream(fname, std::ios::binary);
@ -340,8 +339,8 @@ ModelLoadResult stablelm_model_load(const std::string & fname, stablelm_model &
// - embd_inp: the embeddings of the tokens in the context
// - embd_w: the predicted logits for the next token
//
bool stablelm_eval(
const stablelm_model & model,
bool gpt_neox_eval(
const gpt_neox_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
@ -497,7 +496,7 @@ bool stablelm_eval(
}
}
if(file_format==FileFormat::NEOX_3)
if(file_format==FileFormat::NEOX_3||file_format==FileFormat::NEOX_5)
{
// layer input + Attn
cur = ggml_add(ctx0, cur, inpL);
@ -511,7 +510,7 @@ bool stablelm_eval(
// post attention layer norm
// note here we pass inpL instead of cur
{
cur = ggml_norm(ctx0, (file_format==FileFormat::NEOX_3?cur:inpL));
cur = ggml_norm(ctx0, ((file_format==FileFormat::NEOX_3||file_format==FileFormat::NEOX_5)?cur:inpL));
cur = ggml_add(ctx0,
ggml_mul(ctx0,
@ -542,7 +541,7 @@ bool stablelm_eval(
cur);
}
if (file_format == FileFormat::NEOX_3)
if (file_format==FileFormat::NEOX_3||file_format==FileFormat::NEOX_5)
{
// layer input + FF
inpL = ggml_add(ctx0, cur, inpFF);