mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 17:44:38 +00:00
added MPT support
This commit is contained in:
parent
9839259b63
commit
6f82e17b7a
10 changed files with 983 additions and 48 deletions
|
@ -26,6 +26,7 @@
|
|||
#include "rwkv_v3.cpp"
|
||||
#include "neox_v2.cpp"
|
||||
#include "neox_v3.cpp"
|
||||
#include "mpt_v3.cpp"
|
||||
|
||||
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
|
@ -44,6 +45,8 @@ static gpt2_model gpt2_ctx_v3;
|
|||
static gpt_neox_v2_model neox_ctx_v2;
|
||||
static gpt_neox_model neox_ctx_v3;
|
||||
|
||||
static mpt_model mpt_ctx_v3;
|
||||
|
||||
static rwkv_v2_context * rwkv_ctx_v2;
|
||||
static rwkv_context * rwkv_ctx_v3;
|
||||
static llama_v2_context_params llama_ctx_params_v2;
|
||||
|
@ -298,7 +301,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
params.n_ctx = inputs.max_context_length;
|
||||
|
||||
neox_ctx_v2.hparams.n_ctx = gptj_ctx_v1.hparams.n_ctx = gptj_ctx_v2.hparams.n_ctx = gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx
|
||||
= neox_ctx_v3.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = params.n_ctx;
|
||||
= neox_ctx_v3.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = mpt_ctx_v3.hparams.n_ctx = params.n_ctx;
|
||||
|
||||
printf("System Info: %s\n", llama_print_system_info());
|
||||
SetQuantsUnshuffled(false);
|
||||
|
@ -682,6 +685,19 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
|
||||
}
|
||||
else if(file_format==FileFormat::MPT_1)
|
||||
{
|
||||
bool res = mpt_model_load(params.model, mpt_ctx_v3, vocab);
|
||||
if(res==false)
|
||||
{
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
|
||||
// determine the required inference memory per token:
|
||||
mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\nUnknown Model, cannot load.\n");
|
||||
|
@ -869,6 +885,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
n_vocab = neox_ctx_v3.hparams.n_vocab;
|
||||
}
|
||||
else if( file_format==FileFormat::MPT_1)
|
||||
{
|
||||
n_vocab = mpt_ctx_v3.hparams.n_vocab;
|
||||
}
|
||||
else if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2)
|
||||
{
|
||||
n_vocab = vocab.id_to_token.size(); //handled seperately
|
||||
|
@ -1006,6 +1026,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
evalres = gptj_eval(gptj_ctx_v3, params.n_threads, n_past, embd, logits, mem_per_token);
|
||||
}
|
||||
else if(file_format==FileFormat::MPT_1)
|
||||
{
|
||||
evalres = mpt_eval(mpt_ctx_v3, params.n_threads, n_past, embd, logits, false, mem_per_token);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\nCannot find eval function\n");
|
||||
|
@ -1098,7 +1122,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
file_format == FileFormat::NEOX_4 ||
|
||||
file_format == FileFormat::NEOX_5 ||
|
||||
file_format == FileFormat::NEOX_6 ||
|
||||
file_format == FileFormat::NEOX_7)
|
||||
file_format == FileFormat::NEOX_7 ||
|
||||
file_format == FileFormat::MPT_1)
|
||||
{
|
||||
eosID = 0;
|
||||
int topid = std::min_element(logits.begin(),logits.end())-logits.begin();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue