Added backwards compatibility to an earlier version of NeoX.

This commit is contained in:
Concedo 2023-04-25 20:34:18 +08:00
parent bff998f871
commit 5eec5d6ed9
6 changed files with 49 additions and 21 deletions

View file

@ -218,13 +218,18 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return ModelLoadResult::SUCCESS;
}
else if(file_format==FileFormat::NEOX_1)
else if(file_format==FileFormat::NEOX_1 || file_format==FileFormat::NEOX_2)
{
bool res = stablelm_model_load(params.model, neox_ctx, vocab);
if(!res)
ModelLoadResult res = stablelm_model_load(params.model, neox_ctx, vocab, file_format);
if(res==ModelLoadResult::FAIL)
{
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return ModelLoadResult::FAIL;
return res;
}
else if(res==ModelLoadResult::RETRY_LOAD)
{
printf("\nIncorrect Tensor Size Detected! Retrying GPT-NeoX model loading...");
return res;
}
// determine the required inference memory per token:
stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
@ -245,8 +250,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
// determine the required inference memory per token:
gptj_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
gptj_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
//if the logits are NAN, it means the model is incompatible
if(logits.size()>0 && IsNanCheck(logits[0]))