partially working, but the blas matmul is broken

This commit is contained in:
Concedo 2023-05-13 11:35:38 +08:00
parent b335f73a60
commit 05cf5f7d6e
8 changed files with 53 additions and 21 deletions

View file

@ -225,8 +225,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("System Info: %s\n", llama_print_system_info());
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
//newer format has bit unshuffling
SetQuantsUnshuffled(file_format== FileFormat::GGJT_2);
llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = inputs.max_context_length;
llama_ctx_params.n_parts = -1;
@ -243,7 +246,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str());
return ModelLoadResult::FAIL;
}
if (file_format < FileFormat::GGJT)
if (file_format < FileFormat::GGJT_2)
{
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
}
@ -484,7 +487,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt
std::vector<int> embd_inp;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
params.prompt.insert(0, 1, ' ');
if (file_format == FileFormat::GGML)
@ -543,7 +546,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
{
//for non llama, limit to 256
int bbs = blasbatchsize;
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT)
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT && file_format != FileFormat::GGJT_2)
{
bbs = (blasbatchsize > 256 ? 256 : blasbatchsize);
}
@ -573,7 +576,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
double time1 = 0, time2 = 0;
int32_t n_vocab = 0;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
n_vocab = llama_n_vocab(llama_ctx_v1);
}
@ -624,7 +627,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
if(debugmode)
{
printf("\n[Debug: Dump Input Tokens]\n");
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
for (auto id : embd_inp)
{
@ -661,7 +664,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
bool evalres = false;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
evalres = (llama_eval(llama_ctx_v1, embd.data(), embdsize, n_past, params.n_threads)==0);
}
@ -722,7 +725,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\n");
}
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
auto logits = llama_get_logits(llama_ctx_v1);
@ -772,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// decrement remaining sampling budget
--remaining_tokens;
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
concat_output += llama_token_to_str(llama_ctx_v1, id);
if(unbanTokens && id==llama_token_eos())