the dark gods have been sated, and redpajama is integrated... but at what cost?

This commit is contained in:
Concedo 2023-05-08 20:58:00 +08:00
parent b9904c3093
commit 2f2eff6e13
4 changed files with 57 additions and 23 deletions

View file

@ -369,7 +369,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return ModelLoadResult::SUCCESS;
}
else if(file_format==FileFormat::NEOX_1 || file_format==FileFormat::NEOX_2)
else if(file_format==FileFormat::NEOX_1 || file_format==FileFormat::NEOX_2 || file_format==FileFormat::NEOX_3)
{
ModelLoadResult res = stablelm_model_load(params.model, neox_ctx, vocab, file_format);
if(res==ModelLoadResult::FAIL)
@ -383,7 +383,23 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return res;
}
// determine the required inference memory per token:
stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
if(logits.size()>0 && file_format==FileFormat::NEOX_2 && !IsNanCheck(logits[0]))
{
//run the black magic eval to determine if it's redpajama. VERY UGLY HACK!
std::vector<int> test_embd = ::gpt_tokenize(vocab, "1 2 3 4 5 6 7");
stablelm_eval(neox_ctx, params.n_threads, 0, test_embd, logits, mem_per_token, FileFormat::NEOX_3);
int topid = std::max_element(logits.begin(),logits.end())-logits.begin();
std::string predicted = vocab.id_to_token[topid].c_str();
if(predicted.find("8") != std::string::npos)
{
printf("\n---\nRedPajama NeoX Detected! Switching to new format! (use_parallel_residual=False)\n");
ggml_free(neox_ctx.ctx);
return ModelLoadResult::RETRY_LOAD;
}
}
return ModelLoadResult::SUCCESS;
}
else
@ -514,13 +530,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
bool approved_format = (file_format == FileFormat::GGML ||
file_format == FileFormat::GGHF ||
file_format == FileFormat::GGJT ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPTJ_3 ||
file_format == FileFormat::NEOX_1 ||
file_format == FileFormat::NEOX_2);
bool approved_format = !(file_format == FileFormat::BADFORMAT ||
file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::RWKV_1);
bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
// bool blasmode = false;
int original_batch = params.n_batch;
@ -579,7 +593,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
n_vocab = gpt2_ctx_v2.hparams.n_vocab;
}
else if(file_format == FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2)
else if(file_format == FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3)
{
n_vocab = neox_ctx.hparams.n_vocab;
}
@ -614,14 +628,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
for (auto id : embd_inp)
{
printf("'%s', ",llama_token_to_str(llama_ctx_v1, id));
printf("'%s (%d)', ",llama_token_to_str(llama_ctx_v1, id),id);
}
}
else
{
for (auto id : embd_inp)
{
printf("'%s', ",vocab.id_to_token[id].c_str());
printf("'%s (%d)', ",vocab.id_to_token[id].c_str(),id);
}
}
printf("\n");
@ -665,9 +679,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
evalres = gpt2_eval(gpt2_ctx_v2, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
}
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2)
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3)
{
evalres = stablelm_eval(neox_ctx, params.n_threads, n_past, embd, logits, mem_per_token);
evalres = stablelm_eval(neox_ctx, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
}
else if(file_format==FileFormat::GPTJ_1 || file_format==FileFormat::GPTJ_2)
{