Removed junk, fixed some bugs and support dynamic number of sharded files

Merge remote-tracking branch 'origin/master' into concedo

# Conflicts:
#	README.md
This commit is contained in:
Concedo 2023-03-19 11:13:00 +08:00
commit f952b7c613
14 changed files with 40 additions and 312 deletions

View file

@ -86,7 +86,7 @@ struct llama_model {
};
// load the model's weights from a file
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, int n_parts_overwrite=-1) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
@ -132,6 +132,10 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab
n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
if(n_parts_overwrite>0)
{
n_parts = n_parts_overwrite;
}
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
@ -793,6 +797,11 @@ int main(int argc, char ** argv) {
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
}
if (params.seed < 0) {
params.seed = time(NULL);