mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 00:54:41 +00:00
fossilize ggml library ver 3, to support ggjtv3
This commit is contained in:
parent
1804238e3f
commit
db14de5c32
18 changed files with 44315 additions and 1591 deletions
|
@ -774,17 +774,29 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
|
||||
//this is used for the mem_per_token eval, openblas needs more RAM
|
||||
bool use_scratch = ggml_cpu_has_gpublas();
|
||||
bool v3_use_scratch = ggml_v3_cpu_has_gpublas();
|
||||
|
||||
int cu_parseinfo_maindevice = inputs.cublas_info<=0?0:inputs.cublas_info;
|
||||
|
||||
printf("System Info: %s\n", llama_print_system_info());
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if(ggml_cpu_has_gpublas() && cu_parseinfo_maindevice>0)
|
||||
if(file_format==FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)
|
||||
{
|
||||
printf("CUBLAS: Set main device to %d\n",cu_parseinfo_maindevice);
|
||||
ggml_cuda_set_main_device(cu_parseinfo_maindevice);
|
||||
if(ggml_cpu_has_gpublas() && cu_parseinfo_maindevice>0)
|
||||
{
|
||||
printf("CUBLAS: Set main device to %d\n",cu_parseinfo_maindevice);
|
||||
ggml_cuda_set_main_device(cu_parseinfo_maindevice);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ggml_v3_cpu_has_gpublas() && cu_parseinfo_maindevice>0)
|
||||
{
|
||||
printf("CUBLAS v3: Set main device to %d\n",cu_parseinfo_maindevice);
|
||||
ggml_v3_cuda_set_main_device(cu_parseinfo_maindevice);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
SetQuantsUnshuffled(false);
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
|
@ -1187,7 +1199,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
n_vocab = gpt2_ctx_v3.hparams.n_vocab;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
gpt2_eval(gpt2_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
|
||||
gpt2_eval(gpt2_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, v3_use_scratch);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
|
@ -1262,19 +1274,19 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
n_vocab = gptj_ctx_v3.hparams.n_vocab;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
|
||||
gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, v3_use_scratch);
|
||||
|
||||
//if the logits are NAN or duplicated, it means the model is incompatible
|
||||
std::vector<float> oldlogits(logits);
|
||||
|
||||
//this is another hack because they change the library - we run the eval through the model
|
||||
//twice and compare logits. if they give the same logits for different inputs, model is broken
|
||||
gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token, use_scratch);
|
||||
gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token, v3_use_scratch);
|
||||
|
||||
if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits)))
|
||||
{
|
||||
printf("\nBad Logits detected! Retrying GPT-J model loading...");
|
||||
ggml_free(gptj_ctx_v3.ctx);
|
||||
ggml_v3_free(gptj_ctx_v3.ctx);
|
||||
return ModelLoadResult::RETRY_LOAD;
|
||||
}
|
||||
|
||||
|
@ -1338,7 +1350,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
n_vocab = neox_ctx_v3.hparams.n_vocab;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
gpt_neox_eval(neox_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, use_scratch);
|
||||
gpt_neox_eval(neox_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, v3_use_scratch);
|
||||
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
|
@ -1399,7 +1411,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
n_vocab = mpt_ctx_v3.hparams.n_vocab;
|
||||
|
||||
// determine the required inference memory per token:
|
||||
mpt_eval(mpt_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token, use_scratch);
|
||||
mpt_eval(mpt_ctx_v3, kcpp_params->n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token, v3_use_scratch);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
|
@ -1709,7 +1721,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
|
||||
bool startedsampling = false;
|
||||
bool use_scratch = true; //for normal inference always use scratch
|
||||
bool v3_use_scratch = true; //for normal inference always use scratch
|
||||
|
||||
timer_start();
|
||||
double time1 = 0, time2 = 0;
|
||||
|
@ -1849,7 +1861,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format==FileFormat::GPT2_4)
|
||||
{
|
||||
evalres = gpt2_eval(gpt2_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, use_scratch);
|
||||
evalres = gpt2_eval(gpt2_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, v3_use_scratch);
|
||||
}
|
||||
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
|
||||
{
|
||||
|
@ -1857,7 +1869,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7)
|
||||
{
|
||||
evalres = gpt_neox_eval(neox_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, use_scratch);
|
||||
evalres = gpt_neox_eval(neox_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, v3_use_scratch);
|
||||
}
|
||||
else if(file_format==FileFormat::GPTJ_1 || file_format==FileFormat::GPTJ_2)
|
||||
{
|
||||
|
@ -1869,11 +1881,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format==FileFormat::GPTJ_5)
|
||||
{
|
||||
evalres = gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, use_scratch);
|
||||
evalres = gptj_eval(gptj_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, mem_per_token, v3_use_scratch);
|
||||
}
|
||||
else if(file_format==FileFormat::MPT_1)
|
||||
{
|
||||
evalres = mpt_eval(mpt_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, false, mem_per_token, use_scratch);
|
||||
evalres = mpt_eval(mpt_ctx_v3, kcpp_params->n_threads, n_past, embd, logits, false, mem_per_token, v3_use_scratch);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue