mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2026-05-09 19:46:11 +00:00
ggml : add NVFP4 quantization type support (#19769)
* WIP: add NVFP4 quantization support * tests * improve NVFP4 dot product implementation performance and fix bad super call * typo * Use nvfp4 kvalues * vulkan : fix NVFP4 shader compilation by including kvalues_mxfp4 lookup table * vulcal and perf fixes * wip * Fix metal * fix vulcan * Rename threshold & fix wrong scale * Fix MOE * Shelf backend implementations (CUDA, Metal, Vulkan, arch-specific SIMD) Remove NVFP4 support from GPU backends and architecture-specific optimized dot products. These should be added in separate PRs so backend specialists can review them independently. Reverted files: - ggml-cuda: common.cuh, convert.cu, mmq.cu/cuh, mmvq.cu, vecdotq.cuh, quantize.cu/cuh, mma.cuh, ggml-cuda.cu, fattn-tile.cuh - ggml-metal: ggml-metal.metal, ggml-metal-device.cpp, ggml-metal-impl.h, ggml-metal-ops.cpp - ggml-vulkan: ggml-vulkan.cpp, all vulkan-shaders/* - ggml-cpu arch: arm/quants.c, x86/quants.c, powerpc/quants.c, s390/quants.c Core NVFP4 support (type definition, CPU fallback dot product, quantization, dequantization, conversion) is retained. * Fix arch-fallback.h: add NVFP4 generic fallback for all platforms After shelving backend-specific SIMD implementations, the generic CPU dot product needs to be aliased on ARM, x86, PowerPC, and s390 platforms that previously relied on arch-specific versions. * quantize: add NVFP4 as a quantization type option * Fix ggml_fp32_to_ue4m3: handle subnormal values Previously, values with ue4m3_exp <= 0 were clamped to 0, causing all small scales to underflow. This made NVFP4 quantization via llama-quantize produce garbage (PPL = 5.8M) since typical transformer weights have amax/6.0 in the range 0.001-0.01, which falls in the UE4M3 subnormal range. Now subnormals are properly encoded as man * 2^-9 (exp=0, man=1..7), matching the decode path in ggml_ue4m3_to_fp32. Result: NVFP4 requantization now produces PPL = 15.25 (vs F16 = 14.33), comparable to Q4_1 (PPL = 15.81) at slightly lower BPW (4.70 vs 5.15). * Restore ARM NEON NVFP4 dot product implementation Restores the optimized ggml_vec_dot_nvfp4_q8_0 for ARM NEON using vqtbl1q_s8 lookup and ggml_vdotq_s32 dot products. tg128 performance: 4.37 t/s (generic) -> 13.66 t/s (NEON) = 3.1x speedup * Optimize ARM NEON NVFP4 dot product: LUT + vpaddq + vfmaq - Add ue4m3_scale_lut[128] to ggml-common.h replacing branch-heavy ggml_ue4m3_to_fp32() in the hot loop - Use vpaddq_s32 for pairwise int32 reduction instead of vaddvq_s32 - Accumulate with vfmaq_f32 into float32x4_t vector accumulators tg128: 8.1 -> 31.0 t/s (3.8x speedup, 77% of Q4_1 speed) * ARM NEON NVFP4: rearrange q8 to match nibble layout Alternative approach: rearrange q8 data to match the NVFP4 lo/hi nibble layout instead of rearranging the looked-up NVFP4 values. Eliminates vcombine_s8(vget_low, vget_low) shuffles. Performance is equivalent (~18.5 t/s) - the bottleneck is the 2x block overhead from QK=16 vs QK=32, not the shuffle instructions. * CPU only backend 64 super-block layout * cleanup * Remove unused LUT * int * exclude NVFP4 from unsupported ops in metal build * remove quantization for now * store scales as native UE4M3, preserve original model bits when possible * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * correct comment * format * reduce duplication and cleanup * Address comments * move detection to prepare_tensors * Use math instead of const * Move * fix comment * Shelf quantize tests * Rebase and move check * cleanup * lint * Update gguf-py/gguf/scripts/gguf_convert_endian.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Use fallback quant config * Simplify Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * organize * Refactor * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * add quantize_nvfp4 (required for test_quants.py) * add quantize_nvfp4 (required for test_quants.py) * add quantize_nvfp4 (required for test_quants.py) * fix return type --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
3ca19b0e9f
commit
5eae9cb1d9
31 changed files with 710 additions and 51 deletions
|
|
@ -1166,7 +1166,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
ggml_tensor * probs_in,
|
||||
ggml_tensor * gate_up_exps) const {
|
||||
ggml_tensor * gate_up_exps,
|
||||
ggml_tensor * up_exps_s,
|
||||
ggml_tensor * gate_exps_s,
|
||||
ggml_tensor * down_exps_s) const {
|
||||
return build_moe_ffn(
|
||||
cur,
|
||||
gate_inp, /* gate_inp_b */ nullptr,
|
||||
|
|
@ -1182,7 +1185,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
gating_op,
|
||||
il,
|
||||
probs_in,
|
||||
gate_up_exps
|
||||
gate_up_exps,
|
||||
/* gate_up_exps_b */ nullptr,
|
||||
up_exps_s,
|
||||
gate_exps_s,
|
||||
down_exps_s
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -1206,7 +1213,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
int il,
|
||||
ggml_tensor * probs_in,
|
||||
ggml_tensor * gate_up_exps,
|
||||
ggml_tensor * gate_up_exps_b) const {
|
||||
ggml_tensor * gate_up_exps_b,
|
||||
ggml_tensor * up_exps_s,
|
||||
ggml_tensor * gate_exps_s,
|
||||
ggml_tensor * down_exps_s) const {
|
||||
const int64_t n_embd = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
||||
|
|
@ -1358,6 +1368,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
cb(gate_up, "ffn_moe_gate_up_biased", il);
|
||||
}
|
||||
|
||||
// apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused)
|
||||
if (up_exps_s) {
|
||||
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
|
||||
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
||||
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
||||
gate_up = ggml_mul(ctx0, gate_up, s);
|
||||
cb(gate_up, "ffn_moe_gate_up_scaled", il);
|
||||
}
|
||||
|
||||
const int64_t n_ff = gate_up->ne[0] / 2;
|
||||
cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
|
||||
cb(cur, "ffn_moe_gate", il);
|
||||
|
|
@ -1373,6 +1392,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
cb(up, "ffn_moe_up_biased", il);
|
||||
}
|
||||
|
||||
// apply per-expert scale2 to up
|
||||
if (up_exps_s) {
|
||||
ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1);
|
||||
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
||||
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
||||
up = ggml_mul(ctx0, up, s);
|
||||
cb(up, "ffn_moe_up_scaled", il);
|
||||
}
|
||||
|
||||
if (gate_exps) {
|
||||
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(cur, "ffn_moe_gate", il);
|
||||
|
|
@ -1384,6 +1412,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
||||
cb(cur, "ffn_moe_gate_biased", il);
|
||||
}
|
||||
|
||||
// apply per-expert scale2 to gate
|
||||
if (gate_exps_s) {
|
||||
ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1);
|
||||
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
||||
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
||||
cur = ggml_mul(ctx0, cur, s);
|
||||
cb(cur, "ffn_moe_gate_scaled", il);
|
||||
}
|
||||
}
|
||||
|
||||
const bool has_gate = gate_exps || gate_up_exps;
|
||||
|
|
@ -1463,6 +1500,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
cb(experts, "ffn_moe_down_biased", il);
|
||||
}
|
||||
|
||||
// apply per-expert scale2 to down
|
||||
if (down_exps_s) {
|
||||
ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1);
|
||||
s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1);
|
||||
s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens]
|
||||
experts = ggml_mul(ctx0, experts, s);
|
||||
cb(experts, "ffn_moe_down_scaled", il);
|
||||
}
|
||||
|
||||
if (!weight_before_ffn) {
|
||||
experts = ggml_mul(ctx0, experts, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
|
|
|
|||
|
|
@ -814,7 +814,10 @@ struct llm_graph_context {
|
|||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
ggml_tensor * probs_in = nullptr,
|
||||
ggml_tensor * gate_up_exps = nullptr) const;
|
||||
ggml_tensor * gate_up_exps = nullptr,
|
||||
ggml_tensor * up_exps_s = nullptr,
|
||||
ggml_tensor * gate_exps_s = nullptr,
|
||||
ggml_tensor * down_exps_s = nullptr) const;
|
||||
|
||||
ggml_tensor * build_moe_ffn(
|
||||
ggml_tensor * cur,
|
||||
|
|
@ -836,7 +839,10 @@ struct llm_graph_context {
|
|||
int il,
|
||||
ggml_tensor * probs_in = nullptr,
|
||||
ggml_tensor * gate_up_exps = nullptr,
|
||||
ggml_tensor * gate_up_exps_b = nullptr) const;
|
||||
ggml_tensor * gate_up_exps_b = nullptr,
|
||||
ggml_tensor * up_exps_s = nullptr,
|
||||
ggml_tensor * gate_exps_s = nullptr,
|
||||
ggml_tensor * down_exps_s = nullptr) const;
|
||||
|
||||
//
|
||||
// inputs
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
|||
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
|
||||
case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
|
||||
|
|
@ -724,6 +725,7 @@ llama_model_loader::llama_model_loader(
|
|||
case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break;
|
||||
case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break;
|
||||
case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break;
|
||||
case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break;
|
||||
default:
|
||||
{
|
||||
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
|
||||
|
|
|
|||
|
|
@ -5010,23 +5010,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_T5:
|
||||
|
|
@ -7443,6 +7443,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
||||
// generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2)
|
||||
// this avoids having to add scale loading to every architecture
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// attention weight scales (per-tensor, shape {1})
|
||||
if (!layer.wq_s && layer.wq) {
|
||||
layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wk_s && layer.wk) {
|
||||
layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wv_s && layer.wv) {
|
||||
layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.wo_s && layer.wo) {
|
||||
layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
// dense FFN weight scales (per-tensor, shape {1})
|
||||
if (!layer.ffn_gate_s && layer.ffn_gate) {
|
||||
layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_down_s && layer.ffn_down) {
|
||||
layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_up_s && layer.ffn_up) {
|
||||
layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
|
||||
// MoE expert weight scales (per-expert, shape {n_expert})
|
||||
if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) {
|
||||
layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_down_exps_s && layer.ffn_down_exps) {
|
||||
layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
if (!layer.ffn_up_exps_s && layer.ffn_up_exps) {
|
||||
layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ml.done_getting_tensors();
|
||||
|
|
|
|||
|
|
@ -295,6 +295,11 @@ struct llama_layer {
|
|||
struct ggml_tensor * ffn_up_exps_b = nullptr;
|
||||
struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
|
||||
|
||||
// ff MoE per-expert scales (NVFP4 per-tensor scale2)
|
||||
struct ggml_tensor * ffn_gate_exps_s = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps_s = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps_s = nullptr;
|
||||
|
||||
// ff MoE latent proj
|
||||
struct ggml_tensor * ffn_latent_down = nullptr;
|
||||
struct ggml_tensor * ffn_latent_up = nullptr;
|
||||
|
|
@ -392,13 +397,13 @@ struct llama_layer {
|
|||
struct ggml_tensor * rope_freqs = nullptr;
|
||||
|
||||
// bitnet scale
|
||||
struct ggml_tensor * wq_scale = nullptr;
|
||||
struct ggml_tensor * wk_scale = nullptr;
|
||||
struct ggml_tensor * wv_scale = nullptr;
|
||||
struct ggml_tensor * wo_scale = nullptr;
|
||||
struct ggml_tensor * ffn_gate_scale = nullptr;
|
||||
struct ggml_tensor * ffn_up_scale = nullptr;
|
||||
struct ggml_tensor * ffn_down_scale = nullptr;
|
||||
struct ggml_tensor * wq_s = nullptr;
|
||||
struct ggml_tensor * wk_s = nullptr;
|
||||
struct ggml_tensor * wv_s = nullptr;
|
||||
struct ggml_tensor * wo_s = nullptr;
|
||||
struct ggml_tensor * ffn_gate_s = nullptr;
|
||||
struct ggml_tensor * ffn_up_s = nullptr;
|
||||
struct ggml_tensor * ffn_down_s = nullptr;
|
||||
|
||||
// altup & laurel
|
||||
struct ggml_tensor * per_layer_inp_gate = nullptr;
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
|
|||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].wq_scale) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
|
||||
if (model.layers[il].wq_s) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_s);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
|
|
@ -41,8 +41,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
|
|||
|
||||
// B1.K
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].wk_scale) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
|
||||
if (model.layers[il].wk_s) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_s);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
|
|
@ -52,8 +52,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
|
|||
|
||||
// B1.V
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].wv_scale) {
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
|
||||
if (model.layers[il].wv_s) {
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_s);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
|
|
@ -91,8 +91,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
|
|||
cb(cur, "attn_sub_norm", il);
|
||||
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
if (model.layers[il].wo_scale) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
if (model.layers[il].bo) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
||||
|
|
@ -115,8 +115,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
|
|||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale,
|
||||
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
|
||||
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s,
|
||||
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
|
||||
NULL, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
|
|
@ -128,8 +128,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
|
|||
cb(cur, "ffn_sub_norm", il);
|
||||
|
||||
cur = build_lora_mm(model.layers[il].ffn_down, cur);
|
||||
if (model.layers[il].ffn_down_scale) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
|
||||
if (model.layers[il].ffn_down_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_s);
|
||||
}
|
||||
cb(cur, "ffn_down", il);
|
||||
|
||||
|
|
|
|||
|
|
@ -44,18 +44,27 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
|
|||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].wq_s) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_s);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].wk_s) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_s);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].wv_s) {
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_s);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
|
|
@ -91,6 +100,9 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
|
|||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
|
|
@ -109,9 +121,9 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
|
|||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
|
@ -132,7 +144,11 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra
|
|||
LLM_FFN_SILU, true,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
il,
|
||||
nullptr, nullptr,
|
||||
model.layers[il].ffn_up_exps_s,
|
||||
model.layers[il].ffn_gate_exps_s,
|
||||
model.layers[il].ffn_down_exps_s);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
|
|
|||
|
|
@ -31,12 +31,21 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
|
|||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].wq_s) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_s);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].wk_s) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_s);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].wv_s) {
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_s);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
|
@ -68,6 +77,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
|
|||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
|
@ -83,9 +95,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
|
|||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s,
|
||||
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
|
||||
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
|
|
|||
|
|
@ -31,12 +31,21 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
|
|||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].wq_s) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_s);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].wk_s) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_s);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].wv_s) {
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_s);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
|
@ -68,6 +77,9 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
|
|||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
if (model.layers[il].wo_s) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
|
||||
}
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
|
@ -93,7 +105,11 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
|
|||
LLM_FFN_SILU, true,
|
||||
hparams.expert_weights_scale,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
il,
|
||||
nullptr, nullptr,
|
||||
model.layers[il].ffn_up_exps_s,
|
||||
model.layers[il].ffn_gate_exps_s,
|
||||
model.layers[il].ffn_down_exps_s);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
cur = moe_out;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue