added lora support

This commit is contained in:
Concedo 2023-04-22 12:29:38 +08:00
parent c454f8b848
commit 6e908c1792
4 changed files with 45 additions and 19 deletions

View file

@ -76,7 +76,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = inputs.max_context_length;
llama_ctx_params.n_parts = -1;//inputs.n_parts_overwrite;
llama_ctx_params.n_parts = -1;
llama_ctx_params.seed = -1;
llama_ctx_params.f16_kv = inputs.f16_kv;
llama_ctx_params.logits_all = false;
@ -95,6 +95,21 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
}
if (lora_filename != "")
{
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
int err = llama_apply_lora_from_file(llama_ctx_v1,
lora_filename.c_str(),
NULL,
n_threads);
if (err != 0)
{
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return ModelLoadResult::FAIL;
}
}
//determine mem per token
const std::vector<int> tmp = {0, 1, 2, 3};
llama_eval(llama_ctx_v1, tmp.data(), tmp.size(), 0, params.n_threads);