Various enhancement and integration pygmalion.cpp

This commit is contained in:
Concedo 2023-04-03 00:04:43 +08:00
parent 3f4967b827
commit 8dd8ab1659
20 changed files with 2362 additions and 526 deletions

View file

@ -24,17 +24,38 @@ extern "C"
{
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
static FileFormat file_format = FAIL;
static FileFormat file_format = FileFormat::BADFORMAT;
bool load_model(const load_model_inputs inputs)
{
std::string model = inputs.model_filename;
file_format = check_file_format(model.c_str());
if(file_format==GPTJ1 || file_format==GPTJ2)
if(file_format==FileFormat::GPTJ1 || file_format==FileFormat::GPTJ2 || file_format==FileFormat::GPTJ3)
{
printf("\n---\nIdentified as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
return gptj_load_model(inputs, file_format);
printf("\n---\nIdentified as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
ModelLoadResult lr = gptj_load_model(inputs, file_format);
if (lr == ModelLoadResult::RETRY_LOAD)
{
file_format = FileFormat::GPTJ2;
printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
lr = gptj_load_model(inputs, file_format);
}
if (lr == ModelLoadResult::RETRY_LOAD)
{
file_format = FileFormat::GPTJ3;
printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
lr = gptj_load_model(inputs, file_format);
}
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
{
return false;
}
else
{
return true;
}
}
else
{
@ -45,7 +66,7 @@ extern "C"
generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
{
if (file_format == GPTJ1 || file_format == GPTJ2)
if (file_format == FileFormat::GPTJ1 || file_format == FileFormat::GPTJ2 || file_format==FileFormat::GPTJ3)
{
return gptj_generate(inputs, output);
}