Merge branch 'upstream' into concedo_experimental

# Conflicts:
#	CMakePresets.json
#	README.md
#	common/CMakeLists.txt
#	ggml/src/ggml-cann/ggml-cann.cpp
#	ggml/src/ggml-opencl/CMakeLists.txt
#	ggml/src/ggml-opencl/ggml-opencl.cpp
#	ggml/src/ggml-sycl/ggml-sycl.cpp
#	scripts/sync-ggml.last
#	tests/test-backend-ops.cpp
#	tools/run/CMakeLists.txt
This commit is contained in:
Concedo 2025-07-13 23:39:41 +08:00
commit 8cebec5128
41 changed files with 28682 additions and 366 deletions

View file

@ -45,17 +45,21 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_190M: return "190M";
case LLM_TYPE_220M: return "220M";
case LLM_TYPE_250M: return "250M";
case LLM_TYPE_256M: return "256M";
case LLM_TYPE_270M: return "270M";
case LLM_TYPE_335M: return "335M";
case LLM_TYPE_350M: return "350M";
case LLM_TYPE_410M: return "410M";
case LLM_TYPE_450M: return "450M";
case LLM_TYPE_475M: return "475M";
case LLM_TYPE_700M: return "700M";
case LLM_TYPE_770M: return "770M";
case LLM_TYPE_780M: return "780M";
case LLM_TYPE_0_3B: return "0.3B";
case LLM_TYPE_0_5B: return "0.5B";
case LLM_TYPE_0_6B: return "0.6B";
case LLM_TYPE_1B: return "1B";
case LLM_TYPE_1_2B: return "1.2B";
case LLM_TYPE_1_3B: return "1.3B";
case LLM_TYPE_1_4B: return "1.4B";
case LLM_TYPE_1_5B: return "1.5B";
@ -586,6 +590,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case 22: type = LLM_TYPE_1B; break;
case 26: type = LLM_TYPE_3B; break;
case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B
case 30: type = LLM_TYPE_256M; break; // smoldocling 256M
// granite uses a vocab with len 49152
case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break;
case 36: type = LLM_TYPE_8B; break; // granite
@ -1509,6 +1514,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
// Granite uses rope_finetuned as a switch for rope, so default to true
bool rope_finetuned = true;
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
hparams.rope_finetuned = rope_finetuned;
switch (hparams.n_layer) {
case 32: type = LLM_TYPE_3B; break;
case 40: type = LLM_TYPE_3B; break;
@ -1516,6 +1526,40 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
// For Granite MoE Shared
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
} break;
case LLM_ARCH_GRANITE_HYBRID:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false);
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false);
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false);
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Granite uses rope_finetuned as a switch for rope, so default to true
bool rope_finetuned = true;
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
hparams.rope_finetuned = rope_finetuned;
// A layer is recurrent IFF the n_head_kv value is set to 0
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0;
}
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
// TODO: Add llm type label (not sure this is useful)
default: type = LLM_TYPE_UNKNOWN;
}
// For Granite MoE Shared
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false);
} break;
@ -1627,6 +1671,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_LFM2:
{
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}
switch (hparams.n_embd) {
case 1024: type = LLM_TYPE_350M; break;
case 1536: type = LLM_TYPE_700M; break;
case 2048: type = LLM_TYPE_1_2B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@ -3458,6 +3516,99 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
}
} break;
case LLM_ARCH_GRANITE_HYBRID:
{
// mamba2 Mixer SSM params
// NOTE: int64_t for tensor dimensions
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_ssm_head = hparams.ssm_dt_rank;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
// only an expansion factor of 2 is supported for now
GGML_ASSERT(2 * n_embd == d_inner);
// embeddings
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
{
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed, duplicated to allow offloading
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
if (hparams.is_recurrent(i)) {
// ssm layers
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
// out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
} else {
// attention layers (with optional bias)
const int64_t n_head_i = hparams.n_head(i);
const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
}
// feed forward (w/ optional biases)
if (n_expert > 0) {
// MoE FFN
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
// For Granite MoE Shared
if (hparams.n_ff_shexp > 0) {
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
}
} else {
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
}
}
} break;
case LLM_ARCH_XVERSE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -4868,6 +5019,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_LFM2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
// ffn is same for transformer and conv layers
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
// for operator_norm
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
if (!hparams.is_recurrent(i)) {
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
} else {
layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0);
layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0);
layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0);
}
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -5121,7 +5305,8 @@ void llama_model::print_info() const {
if (arch == LLM_ARCH_MAMBA ||
arch == LLM_ARCH_MAMBA2 ||
arch == LLM_ARCH_JAMBA ||
arch == LLM_ARCH_FALCON_H1) {
arch == LLM_ARCH_FALCON_H1 ||
arch == LLM_ARCH_GRANITE_HYBRID) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@ -5176,7 +5361,8 @@ void llama_model::print_info() const {
if (arch == LLM_ARCH_MINICPM ||
arch == LLM_ARCH_GRANITE ||
arch == LLM_ARCH_GRANITE_MOE) {
arch == LLM_ARCH_GRANITE_MOE ||
arch == LLM_ARCH_GRANITE_HYBRID) {
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
@ -13895,13 +14081,11 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
}
};
struct llm_build_granite : public llm_graph_context {
llm_build_granite(
const llama_model & model,
const llm_graph_params & params,
ggml_cgraph * gf,
const bool use_rope = true)
ggml_cgraph * gf)
: llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -13916,14 +14100,12 @@ struct llm_build_granite : public llm_graph_context {
// inp_pos - built only if rope enabled
ggml_tensor * inp_pos = nullptr;
if (use_rope) {
if (hparams.rope_finetuned) {
inp_pos = build_inp_pos();
}
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
@ -13936,128 +14118,17 @@ struct llm_build_granite : public llm_graph_context {
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (use_rope) {
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
cur = build_attention_layer(
gf, cur, inp_pos, inp_attn,
model, n_embd_head, il);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// For Granite architectures - scale residual
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
// For Granite MoE Shared
if (hparams.n_ff_shexp > 0) {
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
} else {
cur = moe_out;
}
}
// For Granite architectures - scale residual
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// ffn
cur = build_layer_ffn(cur, inpSA, model, il);
// input for next layer
inpL = cur;
@ -14082,6 +14153,370 @@ struct llm_build_granite : public llm_graph_context {
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * build_attention_layer(
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv_unified * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
const bool use_rope = hparams.rope_finetuned;
if (use_rope) {
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
return cur;
}
ggml_tensor * build_layer_ffn(
ggml_tensor * cur,
ggml_tensor * inpSA,
const llama_model & model,
const int il) {
// For Granite architectures - scale residual
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
// For Granite MoE Shared
if (hparams.n_ff_shexp > 0) {
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
} else {
cur = moe_out;
}
}
// For Granite architectures - scale residual
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
return cur;
}
};
struct llm_build_granite_hybrid : public llm_graph_context_mamba {
llm_build_granite_hybrid(
const llama_model & model,
const llm_graph_params & params,
ggml_cgraph * gf) :
llm_graph_context_mamba(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
auto * inp = build_inp_mem_hybrid();
ggml_tensor * inp_out_ids = build_inp_out_ids();
// Positional embeddings populated if rope enabled
ggml_tensor * inp_pos = nullptr;
if (hparams.rope_finetuned) {
inp_pos = build_inp_pos();
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
if (hparams.is_recurrent(il)) {
// ssm layer //
cur = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il);
} else {
// attention layer //
cur = build_attention_layer(
gf, cur, inp_pos, inp->get_attn(), model,
n_embd_head, il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// ffn
cur = build_layer_ffn(cur, inpSA, model, il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
// For Granite architectures - scale logits
if (hparams.f_logit_scale) {
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * build_attention_layer(
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv_unified * inp_attn,
const llama_model & model,
const int64_t n_embd_head,
const int il) {
// compute Q and K and (optionally) RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
const bool use_rope = hparams.rope_finetuned;
if (use_rope) {
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
return cur;
}
ggml_tensor * build_layer_ffn(
ggml_tensor * cur,
ggml_tensor * inpSA,
const llama_model & model,
const int il) {
// For Granite architectures - scale residual
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network (non-MoE)
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
// MoE branch
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
false, 0.0,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
il);
cb(moe_out, "ffn_moe_out", il);
// For Granite MoE Shared
if (hparams.n_ff_shexp > 0) {
ggml_tensor * ffn_shexp = build_ffn(cur,
model.layers[il].ffn_up_shexp, NULL, NULL,
model.layers[il].ffn_gate_shexp, NULL, NULL,
model.layers[il].ffn_down_shexp, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_shexp, "ffn_shexp", il);
cur = ggml_add(ctx0, moe_out, ffn_shexp);
cb(cur, "ffn_out", il);
} else {
cur = moe_out;
}
}
// For Granite architectures - scale residual
if (hparams.f_residual_scale) {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
return cur;
}
};
// ref: https://github.com/facebookresearch/chameleon
@ -15574,6 +16009,163 @@ struct llm_build_smollm3 : public llm_graph_context {
}
};
struct llm_build_lfm2 : public llm_graph_context {
const llama_model & model;
llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
ggml_tensor * cur = build_inp_embd(model.tok_embd);
cb(cur, "model.embed_tokens", -1);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_hybrid = build_inp_mem_hybrid();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
auto * prev_cur = cur;
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "model.layers.{}.operator_norm", il);
cur = hparams.is_recurrent(il) ?
build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) :
build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ;
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
prev_cur = ggml_get_rows(ctx0, prev_cur, inp_out_ids);
}
cur = ggml_add(ctx0, prev_cur, cur);
cur = ggml_add(ctx0, cur, build_feed_forward(cur, il));
}
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "model.embedding_norm", -1);
res->t_embd = cur;
// lm_head is tied with embeddings
cur = build_lora_mm(model.tok_embd, cur);
cb(cur, "lm_head", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * build_feed_forward(ggml_tensor * cur,
int il) const {
cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "model.layers.{}.ffn_norm", il);
GGML_ASSERT(!model.layers[il].ffn_up_b);
GGML_ASSERT(!model.layers[il].ffn_gate_b);
GGML_ASSERT(!model.layers[il].ffn_down_b);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "model.layers.{}.feed_forward.w2", il);
return cur;
}
ggml_tensor * build_attn_block(ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * inp_pos,
llm_graph_input_attn_kv_unified * inp_attn,
int il) const {
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
auto const n_embd_head = hparams.n_embd_head_v;
auto const n_head_kv = hparams.n_head_kv(il);
auto * q = build_lora_mm(model.layers[il].wq, cur);
cb(q, "model.layers.{}.self_attn.q_proj", il);
auto * k = build_lora_mm(model.layers[il].wk, cur);
cb(k, "model.layers.{}.self_attn.k_proj", il);
auto * v = build_lora_mm(model.layers[il].wv, cur);
cb(v, "model.layers.{}.self_attn.v_proj", il);
q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
// qk norm
q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(q, "model.layers.{}.self_attn.q_layernorm", il);
k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(k, "model.layers.{}.self_attn.k_layernorm", il);
// RoPE
q = ggml_rope_ext(
ctx0, q, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
k = ggml_rope_ext(
ctx0, k, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL,
q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
cb(cur, "model.layers.{}.self_attn.out_proj", il);
return cur;
}
ggml_tensor * build_shortconv_block(ggml_cgraph * gf,
ggml_tensor * cur,
llm_graph_input_rs * inp_recr,
int il) {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
cb(bcx, "model.layers.{}.conv.in_proj", il);
constexpr auto n_chunks = 3;
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
auto const chunk_size = bcx->ne[0] / n_chunks;
auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx));
auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx));
auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx));
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
// read conv state directly, with build_rs generation is slower
ggml_tensor * conv_state = mctx_cur->get_r_l(il);
const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs);
bx = ggml_concat(ctx0, conv, bx, 0);
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
// write conv state
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state));
auto * conv_kernel = model.layers[il].shortconv.conv;
GGML_ASSERT(hparams.n_shortconv_l_cache > 0);
// construct ssm_conv op
ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
cb(conv_out, "model.layers.{}.conv.conv", il);
auto * y = ggml_mul(ctx0, c, conv_out);
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
cb(y, "model.layers.{}.conv.out_proj", il);
return y;
}
};
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;
@ -15932,6 +16524,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_granite>(*this, params, gf);
} break;
case LLM_ARCH_GRANITE_HYBRID:
{
llm = std::make_unique<llm_build_granite_hybrid>(*this, params, gf);
} break;
case LLM_ARCH_CHAMELEON:
{
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
@ -15972,6 +16568,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
} break;
case LLM_ARCH_LFM2:
{
llm = std::make_unique<llm_build_lfm2>(*this, params, gf);
} break;
default:
GGML_ABORT("fatal error");
}
@ -16121,6 +16721,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GLM4:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_BAILINGMOE:
case LLM_ARCH_NEO_BERT:
@ -16164,6 +16765,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_MINICPM3:
case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
case LLM_ARCH_LFM2:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL: