added version label, improved file type checks

This commit is contained in:
Concedo 2023-04-10 01:03:09 +08:00
parent 1543c700d8
commit 18a154715e
5 changed files with 69 additions and 4 deletions

View file

@ -86,10 +86,32 @@ void print_tok_vec(std::vector<float> &embd)
if(vocabsiz==50400) //know GPT-J vocab size
{
fileformat = FileFormat::GPTJ_1;
uint32_t temp;
fin.read((char *)&temp, sizeof(temp)); //ctx
fin.read((char *)&temp, sizeof(temp)); //n_embd
fin.read((char *)&temp, sizeof(temp)); //n_head
fin.read((char *)&temp, sizeof(temp)); //n_layer
fin.read((char *)&temp, sizeof(temp)); //n_rot
fin.read((char *)&temp, sizeof(temp)); //f16
if(temp!=0 && temp!=1)
{
fileformat = FileFormat::GPTJ_3; //quantized format cannot be legacy type
}
}
if(vocabsiz==50257)
{
fileformat = FileFormat::GPT2_1;
uint32_t temp;
fin.read((char *)&temp, sizeof(temp)); //ctx
fin.read((char *)&temp, sizeof(temp)); //n_embd
fin.read((char *)&temp, sizeof(temp)); //n_head
fin.read((char *)&temp, sizeof(temp)); //n_layer
fin.read((char *)&temp, sizeof(temp)); //f16
if(temp!=0 && temp!=1)
{
fileformat = FileFormat::GPT2_2; //quantized format cannot be legacy type
}
}
}
else if(magic == 0x67676d66) //v2 format ggmf