lora : improve compat with mergekit-extract-lora (#11131)

* (wip) support mergekit-extracted lora

* support mergekit-extract-lora

* use lora->get_scale

* correct comment

* correct norm name & condition

* add some hints
This commit is contained in:
Xuan Son Nguyen 2025-01-08 15:59:53 +01:00 committed by GitHub
parent c07d437bbd
commit 4d2b3d8804
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 74 additions and 12 deletions

View file

@ -2545,6 +2545,21 @@ static struct ggml_tensor * llm_build_inp_embd(
ggml_set_input(lctx.inp_tokens);
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
// apply lora for embedding tokens if needed
for (auto & it : lctx.lora_adapters) {
struct llama_lora_weight * lora = it.first->get_weight(tok_embd);
if (lora == nullptr) {
continue;
}
const float adapter_scale = it.second;
const float scale = lora->get_scale(it.first->alpha, adapter_scale);
struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
ctx, lora->b, // non-transposed lora_b
ggml_get_rows(ctx, lora->a, lctx.inp_tokens)
), scale);
inpL = ggml_add(ctx, inpL, inpL_delta);
}
} else {
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
inpL = lctx.inp_embd;
@ -2617,9 +2632,8 @@ static struct ggml_tensor * llm_build_lora_mm(
if (lora == nullptr) {
continue;
}
const float alpha = it.first->alpha;
const float rank = (float) lora->b->ne[0];
const float scale = alpha ? it.second * alpha / rank : it.second;
const float adapter_scale = it.second;
const float scale = lora->get_scale(it.first->alpha, adapter_scale);
struct ggml_tensor * ab_cur = ggml_mul_mat(
ctx0, lora->b,
ggml_mul_mat(ctx0, lora->a, cur)
@ -3967,6 +3981,7 @@ struct llm_build_context {
// feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);