clean and refactoring pass before supporting newer models for different arch

This commit is contained in:
Concedo 2023-05-17 11:23:29 +08:00
parent 60ee00428b
commit 90fe9096b4
12 changed files with 81 additions and 367 deletions

View file

@ -14,18 +14,19 @@
#include <regex>
// default hparams (StableLM 3B)
struct stablelm_hparams {
struct gpt_neox_hparams {
int32_t n_vocab = 50257;
int32_t n_ctx = 4096;
int32_t n_embd = 4096;
int32_t n_head = 32;
int32_t n_layer = 16;
int32_t n_rot = 32; // 0.25 * (n_embd / n_head)
int32_t par_res = 1; // 1 = true, 0 = false
int32_t ftype = 1;
};
// quantize a model
bool stablelm_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
bool gpt_neox_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) {
gpt_vocab vocab;
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
@ -54,7 +55,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
fout.write((char *) &magic, sizeof(magic));
}
stablelm_hparams hparams;
gpt_neox_hparams hparams;
// load hparams
{
@ -64,14 +65,22 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
finp.read((char *) &hparams.n_head, sizeof(hparams.n_head));
finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
finp.read((char *) &hparams.par_res, sizeof(hparams.par_res));
finp.read((char *) &hparams.ftype, sizeof(hparams.ftype));
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: ftype = %d\n", __func__, hparams.ftype);
const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR;
const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype;
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
printf("%s: n_head = %d\n", __func__, hparams.n_head);
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
printf("%s: par_res = %d\n", __func__, hparams.par_res);
printf("%s: ftype (src) = %d\n", __func__, hparams.ftype);
printf("%s: qntvr (src) = %d\n", __func__, qntvr_src);
printf("%s: ftype (dst) = %d\n", __func__, ftype_dst);
printf("%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION);
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
@ -79,7 +88,8 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
fout.write((char *) &hparams.n_head, sizeof(hparams.n_head));
fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot));
fout.write((char *) &ftype, sizeof(hparams.ftype));
fout.write((char *) &hparams.par_res, sizeof(hparams.par_res));
fout.write((char *) &ftype_dst, sizeof(ftype_dst));
}
// load vocab
@ -118,7 +128,7 @@ bool stablelm_model_quantize(const std::string & fname_inp, const std::string &
}
// usage:
// ./stablelm2-quantize models/stablelm2-117M/ggml-model.bin models/stablelm2-117M/ggml-model-quant.bin type
// ./gpt-neox-quantize models/stalellm2-117M/ggml-model.bin models/stablelm2-117M/ggml-model-quant.bin type
//
int main(int argc, char ** argv) {
ggml_time_init();
@ -148,7 +158,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();
if (!stablelm_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
if (!gpt_neox_model_quantize(fname_inp, fname_out, ggml_ftype(ftype))) {
fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str());
return 1;
}