mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
neox is updated
This commit is contained in:
parent
90fe9096b4
commit
00da2a5f4e
5 changed files with 111 additions and 74 deletions
|
@ -393,18 +393,22 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
SetQuantsUnshuffled(file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5);
|
||||
|
||||
// determine the required inference memory per token:
|
||||
gpt_neox_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
||||
gpt_neox_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
|
||||
if(logits.size()>0 && (file_format==FileFormat::NEOX_2 || file_format==FileFormat::NEOX_4) && !IsNanCheck(logits[0]))
|
||||
if(logits.size()>0 && file_format==FileFormat::NEOX_2 && !IsNanCheck(logits[0]))
|
||||
{
|
||||
//run the black magic eval to determine if it's redpajama. VERY UGLY HACK!
|
||||
std::vector<int> test_embd = ::gpt_tokenize(vocab, "1 2 3 4 5 6 7");
|
||||
gpt_neox_eval(neox_ctx, params.n_threads, 0, test_embd, logits, mem_per_token, (file_format==FileFormat::NEOX_2?FileFormat::NEOX_3:FileFormat::NEOX_5));
|
||||
std::vector<int> test_embd = ::gpt_tokenize(vocab, "1 2 3 4 5 6 7");
|
||||
auto orig_par_res = neox_ctx.hparams.par_res;
|
||||
neox_ctx.hparams.par_res = 0; //test with residual false
|
||||
gpt_neox_eval(neox_ctx, params.n_threads, 0, test_embd, logits, mem_per_token);
|
||||
neox_ctx.hparams.par_res = orig_par_res;
|
||||
int topid = std::max_element(logits.begin(),logits.end())-logits.begin();
|
||||
std::string predicted = vocab.id_to_token[topid].c_str();
|
||||
if(predicted.find("8") != std::string::npos)
|
||||
auto findresult = predicted.find("8");
|
||||
if(findresult != std::string::npos && findresult<2)
|
||||
{
|
||||
printf("\n---\nRedPajama NeoX Detected! Switching to new format! (use_parallel_residual=False)\n");
|
||||
printf("\n---\nOld RedPajama NeoX Detected! Switching to new format! (use_parallel_residual=False)\n");
|
||||
ggml_free(neox_ctx.ctx);
|
||||
return ModelLoadResult::RETRY_LOAD;
|
||||
}
|
||||
|
@ -694,7 +698,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
|
||||
{
|
||||
evalres = gpt_neox_eval(neox_ctx, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||
evalres = gpt_neox_eval(neox_ctx, params.n_threads, n_past, embd, logits, mem_per_token);
|
||||
}
|
||||
else if(file_format==FileFormat::GPTJ_1 || file_format==FileFormat::GPTJ_2)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue