mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
model : support LiquidAI LFM2 hybrid family (#14620)
**Important** LFM2 was [merged ](https://github.com/huggingface/transformers/pull/39340)into transformers, but has not yet been released. To convert into gguf, install transformers from source ```shell pip install "transformers @ git+https://github.com/huggingface/transformers.git@main" ```
This commit is contained in:
parent
756aa1020a
commit
f5e96b368f
14 changed files with 373 additions and 3 deletions
|
@ -43,15 +43,18 @@ const char * llm_type_name(llm_type type) {
|
|||
case LLM_TYPE_256M: return "256M";
|
||||
case LLM_TYPE_270M: return "270M";
|
||||
case LLM_TYPE_335M: return "335M";
|
||||
case LLM_TYPE_350M: return "350M";
|
||||
case LLM_TYPE_410M: return "410M";
|
||||
case LLM_TYPE_450M: return "450M";
|
||||
case LLM_TYPE_475M: return "475M";
|
||||
case LLM_TYPE_700M: return "700M";
|
||||
case LLM_TYPE_770M: return "770M";
|
||||
case LLM_TYPE_780M: return "780M";
|
||||
case LLM_TYPE_0_3B: return "0.3B";
|
||||
case LLM_TYPE_0_5B: return "0.5B";
|
||||
case LLM_TYPE_0_6B: return "0.6B";
|
||||
case LLM_TYPE_1B: return "1B";
|
||||
case LLM_TYPE_1_2B: return "1.2B";
|
||||
case LLM_TYPE_1_3B: return "1.3B";
|
||||
case LLM_TYPE_1_4B: return "1.4B";
|
||||
case LLM_TYPE_1_5B: return "1.5B";
|
||||
|
@ -1663,6 +1666,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
|
||||
}
|
||||
switch (hparams.n_embd) {
|
||||
case 1024: type = LLM_TYPE_350M; break;
|
||||
case 1536: type = LLM_TYPE_700M; break;
|
||||
case 2048: type = LLM_TYPE_1_2B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
@ -4906,6 +4923,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
// ffn is same for transformer and conv layers
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
|
||||
// for operator_norm
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (!hparams.is_recurrent(i)) {
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0);
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
} else {
|
||||
layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0);
|
||||
layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0);
|
||||
layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
@ -15859,6 +15909,163 @@ struct llm_build_smollm3 : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_lfm2 : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
|
||||
|
||||
ggml_tensor * cur = build_inp_embd(model.tok_embd);
|
||||
cb(cur, "model.embed_tokens", -1);
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_hybrid = build_inp_mem_hybrid();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
auto * prev_cur = cur;
|
||||
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "model.layers.{}.operator_norm", il);
|
||||
|
||||
cur = hparams.is_recurrent(il) ?
|
||||
build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) :
|
||||
build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ;
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
prev_cur = ggml_get_rows(ctx0, prev_cur, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, prev_cur, cur);
|
||||
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
|
||||
}
|
||||
|
||||
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "model.embedding_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head is tied with embeddings
|
||||
cur = build_lora_mm(model.tok_embd, cur);
|
||||
cb(cur, "lm_head", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * build_feed_forward(ggml_tensor * cur,
|
||||
int il) const {
|
||||
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "model.layers.{}.ffn_norm", il);
|
||||
|
||||
GGML_ASSERT(!model.layers[il].ffn_up_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_gate_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_down_b);
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "model.layers.{}.feed_forward.w2", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_attn_block(ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv_unified * inp_attn,
|
||||
int il) const {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
auto const n_embd_head = hparams.n_embd_head_v;
|
||||
auto const n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
auto * q = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(q, "model.layers.{}.self_attn.q_proj", il);
|
||||
auto * k = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(k, "model.layers.{}.self_attn.k_proj", il);
|
||||
auto * v = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(v, "model.layers.{}.self_attn.v_proj", il);
|
||||
|
||||
q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// qk norm
|
||||
q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(q, "model.layers.{}.self_attn.q_layernorm", il);
|
||||
k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(k, "model.layers.{}.self_attn.k_layernorm", il);
|
||||
|
||||
// RoPE
|
||||
q = ggml_rope_ext(
|
||||
ctx0, q, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
k = ggml_rope_ext(
|
||||
ctx0, k, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,
|
||||
q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
||||
cb(cur, "model.layers.{}.self_attn.out_proj", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_shortconv_block(ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
llm_graph_input_rs * inp_recr,
|
||||
int il) {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||
|
||||
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
|
||||
cb(bcx, "model.layers.{}.conv.in_proj", il);
|
||||
|
||||
constexpr auto n_chunks = 3;
|
||||
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
|
||||
auto const chunk_size = bcx->ne[0] / n_chunks;
|
||||
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
|
||||
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
|
||||
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
|
||||
|
||||
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
|
||||
|
||||
// read conv state directly, with build_rs generation is slower
|
||||
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
|
||||
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
|
||||
|
||||
// write conv state
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
|
||||
|
||||
auto * conv_kernel = model.layers[il].shortconv.conv;
|
||||
GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
|
||||
|
||||
// construct ssm_conv op
|
||||
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
|
||||
cb(conv_out, "model.layers.{}.conv.conv", il);
|
||||
|
||||
auto * y = ggml_mul(ctx0, c, conv_out);
|
||||
|
||||
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
|
||||
cb(y, "model.layers.{}.conv.out_proj", il);
|
||||
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
|
@ -16261,6 +16468,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
|||
{
|
||||
llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_lfm2>(*this, params, gf);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -16454,6 +16665,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_MINICPM3:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
case LLM_ARCH_LFM2:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue