added version label, improved file type checks

This commit is contained in:
Concedo 2023-04-10 01:03:09 +08:00
parent 1543c700d8
commit 18a154715e
5 changed files with 69 additions and 4 deletions

View file

@ -35,6 +35,12 @@ static std::vector<gpt_vocab::id> current_context_tokens;
static size_t mem_per_token = 0;
static std::vector<float> logits;
inline bool IsNanCheck(float f)
{
const unsigned int u = *(unsigned int*)&f;
return (u&0x7F800000) == 0x7F800000 && (u&0x7FFFFF); // Both NaN and qNan.
}
ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format)
{
ggml_time_init();
@ -93,6 +99,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
// determine the required inference memory per token:
legacy_gptj_eval(model_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
//if the logits are NAN, it means the model is incompatible
if(logits.size()>0 && IsNanCheck(logits[0]))
{
printf("\nBad Logits detected! Retrying GPT-J model loading...");
ggml_v1_free(model_v1.ctx);
return ModelLoadResult::RETRY_LOAD;
}
return ModelLoadResult::SUCCESS;
}
else
@ -110,7 +125,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
// determine the required inference memory per token:
gptj_eval(model_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
gptj_eval(model_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
//if the logits are NAN, it means the model is incompatible
if(logits.size()>0 && IsNanCheck(logits[0]))
{
printf("\nBad Logits detected! Retrying GPT-J model loading...");
ggml_free(model_v2.ctx);
return ModelLoadResult::RETRY_LOAD;
}
return ModelLoadResult::SUCCESS;
}
@ -204,7 +229,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
bool blasmode = false; //(embd_inp.size() >= 32 && ggml_cpu_has_blas());
// bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
// bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas());
bool blasmode = false;
int original_batch = params.n_batch;
int original_threads = params.n_threads;
if (blasmode)
@ -355,7 +382,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
time2 = timer_check();
printf("\nTime Taken - Processing:%.1fs, Generation:%.1fs, Total:%.1fs", time1, time2, (time1 + time2));
float pt1 = (time1*1000.0/(embd_inp_size==0?1:embd_inp_size));
float pt2 = (time2*1000.0/(params.n_predict==0?1:params.n_predict));
printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2));
fflush(stdout);
output.status = 1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());