wip pythia integration

This commit is contained in:
Concedo 2023-04-22 01:08:23 +08:00
parent 68898046c2
commit ef13443047
4 changed files with 34 additions and 3 deletions

View file

@ -201,6 +201,18 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return ModelLoadResult::SUCCESS;
}
else if(file_format==FileFormat::NEOX_1)
{
bool res = stablelm_model_load(params.model, neox_ctx, vocab);
if(!res)
{
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return ModelLoadResult::FAIL;
}
// determine the required inference memory per token:
stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
return ModelLoadResult::SUCCESS;
}
else
{
ModelLoadResult loadresult = gptj_model_load(params.model, gptj_ctx_v2, vocab);