neox is updated

This commit is contained in:
Concedo 2023-05-17 14:56:54 +08:00
parent 90fe9096b4
commit 00da2a5f4e
5 changed files with 111 additions and 74 deletions

View file

@ -393,18 +393,22 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
SetQuantsUnshuffled(file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5);
// determine the required inference memory per token:
gpt_neox_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
gpt_neox_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
if(logits.size()>0 && (file_format==FileFormat::NEOX_2 || file_format==FileFormat::NEOX_4) && !IsNanCheck(logits[0]))
if(logits.size()>0 && file_format==FileFormat::NEOX_2 && !IsNanCheck(logits[0]))
{
//run the black magic eval to determine if it's redpajama. VERY UGLY HACK!
std::vector<int> test_embd = ::gpt_tokenize(vocab, "1 2 3 4 5 6 7");
gpt_neox_eval(neox_ctx, params.n_threads, 0, test_embd, logits, mem_per_token, (file_format==FileFormat::NEOX_2?FileFormat::NEOX_3:FileFormat::NEOX_5));
std::vector<int> test_embd = ::gpt_tokenize(vocab, "1 2 3 4 5 6 7");
auto orig_par_res = neox_ctx.hparams.par_res;
neox_ctx.hparams.par_res = 0; //test with residual false
gpt_neox_eval(neox_ctx, params.n_threads, 0, test_embd, logits, mem_per_token);
neox_ctx.hparams.par_res = orig_par_res;
int topid = std::max_element(logits.begin(),logits.end())-logits.begin();
std::string predicted = vocab.id_to_token[topid].c_str();
if(predicted.find("8") != std::string::npos)
auto findresult = predicted.find("8");
if(findresult != std::string::npos && findresult<2)
{
printf("\n---\nRedPajama NeoX Detected! Switching to new format! (use_parallel_residual=False)\n");
printf("\n---\nOld RedPajama NeoX Detected! Switching to new format! (use_parallel_residual=False)\n");
ggml_free(neox_ctx.ctx);
return ModelLoadResult::RETRY_LOAD;
}
@ -694,7 +698,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
{
evalres = gpt_neox_eval(neox_ctx, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
evalres = gpt_neox_eval(neox_ctx, params.n_threads, n_past, embd, logits, mem_per_token);
}
else if(file_format==FileFormat::GPTJ_1 || file_format==FileFormat::GPTJ_2)
{