mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 09:04:36 +00:00
partially working, but the blas matmul is broken
This commit is contained in:
parent
b335f73a60
commit
05cf5f7d6e
8 changed files with 53 additions and 21 deletions
|
@ -225,8 +225,11 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
printf("System Info: %s\n", llama_print_system_info());
|
||||
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
//newer format has bit unshuffling
|
||||
SetQuantsUnshuffled(file_format== FileFormat::GGJT_2);
|
||||
|
||||
llama_ctx_params = llama_context_default_params();
|
||||
llama_ctx_params.n_ctx = inputs.max_context_length;
|
||||
llama_ctx_params.n_parts = -1;
|
||||
|
@ -243,7 +246,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
if (file_format < FileFormat::GGJT)
|
||||
if (file_format < FileFormat::GGJT_2)
|
||||
{
|
||||
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
|
||||
}
|
||||
|
@ -484,7 +487,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
// tokenize the prompt
|
||||
std::vector<int> embd_inp;
|
||||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
params.prompt.insert(0, 1, ' ');
|
||||
if (file_format == FileFormat::GGML)
|
||||
|
@ -543,7 +546,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
{
|
||||
//for non llama, limit to 256
|
||||
int bbs = blasbatchsize;
|
||||
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT)
|
||||
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT && file_format != FileFormat::GGJT_2)
|
||||
{
|
||||
bbs = (blasbatchsize > 256 ? 256 : blasbatchsize);
|
||||
}
|
||||
|
@ -573,7 +576,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
double time1 = 0, time2 = 0;
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
n_vocab = llama_n_vocab(llama_ctx_v1);
|
||||
}
|
||||
|
@ -624,7 +627,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
if(debugmode)
|
||||
{
|
||||
printf("\n[Debug: Dump Input Tokens]\n");
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
for (auto id : embd_inp)
|
||||
{
|
||||
|
@ -661,7 +664,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
|
||||
bool evalres = false;
|
||||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
evalres = (llama_eval(llama_ctx_v1, embd.data(), embdsize, n_past, params.n_threads)==0);
|
||||
}
|
||||
|
@ -722,7 +725,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
printf("\n");
|
||||
}
|
||||
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
auto logits = llama_get_logits(llama_ctx_v1);
|
||||
|
||||
|
@ -772,7 +775,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
// decrement remaining sampling budget
|
||||
--remaining_tokens;
|
||||
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
|
||||
{
|
||||
concat_output += llama_token_to_str(llama_ctx_v1, id);
|
||||
if(unbanTokens && id==llama_token_eos())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue