diff --git a/crates/ruvllm/src/bitnet/backend.rs b/crates/ruvllm/src/bitnet/backend.rs index 91c034604..f6250f090 100644 --- a/crates/ruvllm/src/bitnet/backend.rs +++ b/crates/ruvllm/src/bitnet/backend.rs @@ -43,22 +43,25 @@ use super::tokenizer::{BpeTokenizer, SpecialTokens as BitNetSpecialTokens}; /// Model configuration for BitNet MoE inference. /// /// Describes the architecture dimensions extracted from GGUF metadata -/// or supplied manually for testing. +/// or supplied manually for testing. Supports both standard GQA attention +/// and MLA (Multi-Head Latent Attention) as used by GLM-4.7-Flash. #[derive(Debug, Clone)] pub struct BitNetModelConfig { /// Number of transformer layers pub num_layers: usize, /// Hidden state dimension pub hidden_size: usize, - /// Number of MoE experts per layer + /// Number of MoE routed experts per layer pub num_experts: usize, /// Number of active experts per token (top-K) pub active_experts: usize, - /// FFN intermediate dimension per expert + /// Dense FFN intermediate dimension (for dense layers) pub intermediate_size: usize, + /// MoE expert FFN intermediate dimension (may differ from dense) + pub moe_intermediate_size: usize, /// Number of attention query heads pub num_attention_heads: usize, - /// Number of attention key-value heads (GQA) + /// Number of attention key-value heads (GQA; equals num_attention_heads in MLA) pub num_kv_heads: usize, /// Vocabulary size pub vocab_size: usize, @@ -66,22 +69,56 @@ pub struct BitNetModelConfig { pub max_context: usize, /// RoPE frequency base pub rope_theta: f32, + + // --- MLA (Multi-Head Latent Attention) parameters --- + /// Whether attention uses MLA (true) or standard GQA (false) + pub use_mla: bool, + /// Q low-rank compression dimension (MLA) + pub q_lora_rank: usize, + /// KV low-rank compression dimension (MLA) + pub kv_lora_rank: usize, + /// Non-RoPE portion of Q/K head dimension (MLA) + pub qk_nope_head_dim: usize, + /// RoPE portion of Q/K head dimension (MLA) + pub qk_rope_head_dim: usize, + /// Value head dimension (MLA) + pub v_head_dim: usize, + + // --- MoE structure --- + /// Number of shared experts (always-active, non-routed) + pub n_shared_experts: usize, + /// First N layers use dense FFN instead of MoE (e.g., 1 means layer 0 is dense) + pub first_k_dense_replace: usize, + /// Scaling factor for routed expert weights + pub routed_scaling_factor: f32, } impl Default for BitNetModelConfig { fn default() -> Self { - // Default values loosely based on GLM-4.7-Flash architecture + // Default values matching GLM-4.7-Flash architecture Self { - num_layers: 28, - hidden_size: 4096, - num_experts: 8, - active_experts: 2, - intermediate_size: 11008, - num_attention_heads: 32, - num_kv_heads: 8, - vocab_size: 151552, + num_layers: 47, + hidden_size: 2048, + num_experts: 64, + active_experts: 4, + intermediate_size: 10240, + moe_intermediate_size: 1536, + num_attention_heads: 20, + num_kv_heads: 20, + vocab_size: 154880, max_context: 8192, - rope_theta: 10000.0, + rope_theta: 1_000_000.0, + // MLA parameters from GLM-4.7-Flash config.json + use_mla: true, + q_lora_rank: 768, + kv_lora_rank: 512, + qk_nope_head_dim: 192, + qk_rope_head_dim: 64, + v_head_dim: 256, + // MoE structure + n_shared_experts: 1, + first_k_dense_replace: 1, + routed_scaling_factor: 1.8, } } } @@ -118,6 +155,217 @@ fn build_tl1_lut() -> Tl1Lut { lut } +// ============================================================================ +// Tensor Name Mapper +// ============================================================================ + +/// Resolves logical tensor names to actual GGUF tensor names. +/// +/// GLM-4.7-Flash GGUF files use llama.cpp conventions (`blk.0.attn_q_a.weight`), +/// while some models use HuggingFace conventions (`model.layers.0.self_attn.q_proj.weight`). +/// The mapper tries GGUF names first, then HuggingFace names as fallback. +struct TensorNameMapper; + +impl TensorNameMapper { + /// Find the first tensor name that exists in the GGUF file. + fn resolve(gguf: &GgufFile, candidates: &[String]) -> Option { + for name in candidates { + if gguf.get_tensor(name).is_some() { + return Some(name.clone()); + } + } + None + } + + // -- Global tensors -- + + fn embedding() -> Vec { + vec![ + "token_embd.weight".into(), + "model.embed_tokens.weight".into(), + ] + } + + fn output() -> Vec { + vec![ + "output.weight".into(), + "lm_head.weight".into(), + ] + } + + fn final_norm() -> Vec { + vec![ + "output_norm.weight".into(), + "model.norm.weight".into(), + ] + } + + // -- Per-layer norms -- + + fn input_norm(idx: usize) -> Vec { + vec![ + format!("blk.{}.attn_norm.weight", idx), + format!("model.layers.{}.input_layernorm.weight", idx), + ] + } + + fn post_attn_norm(idx: usize) -> Vec { + vec![ + format!("blk.{}.ffn_norm.weight", idx), + format!("model.layers.{}.post_attention_layernorm.weight", idx), + ] + } + + // -- MLA attention tensors -- + + fn attn_q_a(idx: usize) -> Vec { + vec![format!("blk.{}.attn_q_a.weight", idx)] + } + + fn attn_q_b(idx: usize) -> Vec { + vec![format!("blk.{}.attn_q_b.weight", idx)] + } + + fn attn_q_a_norm(idx: usize) -> Vec { + vec![format!("blk.{}.attn_q_a_norm.weight", idx)] + } + + fn attn_kv_a_mqa(idx: usize) -> Vec { + vec![format!("blk.{}.attn_kv_a_mqa.weight", idx)] + } + + fn attn_kv_a_norm(idx: usize) -> Vec { + vec![format!("blk.{}.attn_kv_a_norm.weight", idx)] + } + + fn attn_k_b(idx: usize) -> Vec { + vec![format!("blk.{}.attn_k_b.weight", idx)] + } + + fn attn_v_b(idx: usize) -> Vec { + vec![format!("blk.{}.attn_v_b.weight", idx)] + } + + fn attn_output(idx: usize) -> Vec { + vec![ + format!("blk.{}.attn_output.weight", idx), + format!("model.layers.{}.self_attn.o_proj.weight", idx), + ] + } + + // -- Standard GQA attention tensors -- + + fn attn_q_proj(idx: usize) -> Vec { + vec![format!("model.layers.{}.self_attn.q_proj.weight", idx)] + } + + fn attn_k_proj(idx: usize) -> Vec { + vec![format!("model.layers.{}.self_attn.k_proj.weight", idx)] + } + + fn attn_v_proj(idx: usize) -> Vec { + vec![format!("model.layers.{}.self_attn.v_proj.weight", idx)] + } + + // -- MoE router gate -- + + fn moe_gate(idx: usize) -> Vec { + vec![ + format!("blk.{}.ffn_gate_inp.weight", idx), + format!("model.layers.{}.mlp.gate.weight", idx), + ] + } + + // -- Dense FFN tensors -- + + fn ffn_gate(idx: usize) -> Vec { + vec![ + format!("blk.{}.ffn_gate.weight", idx), + format!("model.layers.{}.mlp.gate_proj.weight", idx), + ] + } + + fn ffn_up(idx: usize) -> Vec { + vec![ + format!("blk.{}.ffn_up.weight", idx), + format!("model.layers.{}.mlp.up_proj.weight", idx), + ] + } + + fn ffn_down(idx: usize) -> Vec { + vec![ + format!("blk.{}.ffn_down.weight", idx), + format!("model.layers.{}.mlp.down_proj.weight", idx), + ] + } + + // -- Shared expert tensors -- + + fn ffn_gate_shexp(idx: usize) -> Vec { + vec![format!("blk.{}.ffn_gate_shexp.weight", idx)] + } + + fn ffn_up_shexp(idx: usize) -> Vec { + vec![format!("blk.{}.ffn_up_shexp.weight", idx)] + } + + fn ffn_down_shexp(idx: usize) -> Vec { + vec![format!("blk.{}.ffn_down_shexp.weight", idx)] + } + + // -- Stacked expert tensors (3D, all experts in one tensor) -- + + fn ffn_gate_exps(idx: usize) -> Vec { + vec![format!("blk.{}.ffn_gate_exps.weight", idx)] + } + + fn ffn_up_exps(idx: usize) -> Vec { + vec![format!("blk.{}.ffn_up_exps.weight", idx)] + } + + fn ffn_down_exps(idx: usize) -> Vec { + vec![format!("blk.{}.ffn_down_exps.weight", idx)] + } + + // -- Per-expert tensors (HuggingFace individual naming) -- + + fn expert_gate(idx: usize, expert_idx: usize) -> Vec { + vec![format!( + "model.layers.{}.mlp.experts.{}.gate_proj.weight", + idx, expert_idx + )] + } + + fn expert_up(idx: usize, expert_idx: usize) -> Vec { + vec![format!( + "model.layers.{}.mlp.experts.{}.up_proj.weight", + idx, expert_idx + )] + } + + fn expert_down(idx: usize, expert_idx: usize) -> Vec { + vec![format!( + "model.layers.{}.mlp.experts.{}.down_proj.weight", + idx, expert_idx + )] + } + + /// Check if a layer has MLA attention tensors. + fn has_mla(gguf: &GgufFile, idx: usize) -> bool { + Self::resolve(gguf, &Self::attn_q_a(idx)).is_some() + } + + /// Check if a layer has stacked expert tensors. + fn has_stacked_experts(gguf: &GgufFile, idx: usize) -> bool { + Self::resolve(gguf, &Self::ffn_gate_exps(idx)).is_some() + } + + /// Check if a layer has dense FFN (not MoE). + fn has_dense_ffn(gguf: &GgufFile, idx: usize) -> bool { + Self::resolve(gguf, &Self::ffn_gate(idx)).is_some() + } +} + // ============================================================================ // Per-Layer and Per-Expert Weight Storage // ============================================================================ @@ -133,9 +381,18 @@ struct ExpertWeights { down_proj: TernaryTensor, } -/// Attention projection weights (ternary). +/// Attention projection weights. +/// +/// Supports two variants: +/// - **Standard GQA**: Direct Q/K/V/O projections +/// - **MLA (Multi-Head Latent Attention)**: Low-rank compressed Q/KV projections +/// as used by GLM-4.7-Flash / DeepSeek-V2 #[derive(Debug, Clone)] struct AttentionWeights { + /// Whether this layer uses MLA or standard GQA + is_mla: bool, + + // --- Standard GQA fields --- /// Q projection: [num_heads * head_dim, hidden_size] q_proj: TernaryTensor, /// K projection: [num_kv_heads * head_dim, hidden_size] @@ -144,6 +401,33 @@ struct AttentionWeights { v_proj: TernaryTensor, /// Output projection: [hidden_size, num_heads * head_dim] o_proj: TernaryTensor, + + // --- MLA fields (populated when is_mla = true) --- + /// Q down-projection: [hidden_size → q_lora_rank] + q_a: Option, + /// Q up-projection: [q_lora_rank → num_heads * (qk_nope_head_dim + qk_rope_head_dim)] + q_b: Option, + /// Q compression norm weights: [q_lora_rank] + q_a_norm: Option>, + /// KV joint down-projection: [hidden_size → kv_lora_rank + qk_rope_head_dim] + kv_a_mqa: Option, + /// KV compression norm weights: [kv_lora_rank] + kv_a_norm: Option>, + /// K up-projection: [kv_lora_rank → num_heads * qk_nope_head_dim] + k_b: Option, + /// V up-projection: [kv_lora_rank → num_heads * v_head_dim] + v_b: Option, +} + +/// Type of FFN in a transformer layer. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum LayerType { + /// Dense FFN (single gate/up/down, no MoE routing) + Dense, + /// MoE with routed experts only + Moe, + /// MoE with routed experts + shared expert(s) + MoeWithShared, } /// Weights for a single transformer layer. @@ -153,12 +437,18 @@ struct TransformerLayer { input_norm_weight: Vec, /// Post-attention RMSNorm weight [hidden_size] post_attn_norm_weight: Vec, - /// Attention projection weights (ternary) + /// Attention projection weights (ternary, supports MLA or GQA) attention: AttentionWeights, - /// MoE router gate weight [num_experts, hidden_size] (FP32, stored row-major) + /// Type of FFN in this layer + layer_type: LayerType, + /// MoE router gate weight [num_experts, hidden_size] (FP32, empty for dense layers) gate_weight: Vec, - /// Per-expert FFN weights (ternary) + /// Per-expert FFN weights (routed experts, ternary) experts: Vec, + /// Shared expert FFN weights (always-active, non-routed; None for dense layers) + shared_expert: Option, + /// Dense FFN weights (for dense-only layers; uses gate/up/down from ExpertWeights) + dense_ffn: Option, } // ============================================================================ @@ -274,28 +564,36 @@ impl BitNetBackend { /// Parses the GGUF file, extracts model configuration from metadata, /// separates FP16 shared tensors from ternary expert tensors, and /// pre-builds the TL1 lookup table. + /// + /// Supports both llama.cpp GGUF tensor naming (`token_embd.weight`, + /// `blk.0.attn_q_a.weight`) and HuggingFace naming (`model.embed_tokens.weight`, + /// `model.layers.0.self_attn.q_proj.weight`). fn load_gguf(&mut self, path: &str) -> Result<()> { let gguf = GgufFile::open_mmap(Path::new(path))?; // Extract model config from GGUF metadata let config = self.extract_config(&gguf)?; - // Load embedding table (FP16/FP32) - self.embedding = self.load_fp_tensor(&gguf, "model.embed_tokens.weight", &config)?; + // Load embedding table via name mapper + let emb_name = TensorNameMapper::resolve(&gguf, &TensorNameMapper::embedding()) + .ok_or_else(|| RuvLLMError::NotFound( + "Embedding tensor not found (tried: token_embd.weight, model.embed_tokens.weight)".into() + ))?; + self.embedding = self.load_fp_tensor(&gguf, &emb_name, &config)?; - // Load LM head (may share weights with embedding in some architectures) - self.lm_head = if gguf.get_tensor("lm_head.weight").is_some() { - self.load_fp_tensor(&gguf, "lm_head.weight", &config)? - } else if gguf.get_tensor("output.weight").is_some() { - self.load_fp_tensor(&gguf, "output.weight", &config)? + // Load LM head / output via name mapper (fallback to tied embeddings) + self.lm_head = if let Some(out_name) = TensorNameMapper::resolve(&gguf, &TensorNameMapper::output()) { + self.load_fp_tensor(&gguf, &out_name, &config)? } else { - // Tied embeddings: copy embedding table self.embedding.clone() }; - // Load final norm - self.final_norm_weight = - self.load_fp_tensor(&gguf, "model.norm.weight", &config)?; + // Load final norm via name mapper + let norm_name = TensorNameMapper::resolve(&gguf, &TensorNameMapper::final_norm()) + .ok_or_else(|| RuvLLMError::NotFound( + "Final norm tensor not found (tried: output_norm.weight, model.norm.weight)".into() + ))?; + self.final_norm_weight = self.load_fp_tensor(&gguf, &norm_name, &config)?; // Load transformer layers self.layers = Vec::with_capacity(config.num_layers); @@ -308,8 +606,13 @@ impl BitNetBackend { self.kv_caches = (0..config.num_layers).map(|_| LayerKvCache::new()).collect(); // Build RoPE cos/sin tables - let head_dim = config.hidden_size / config.num_attention_heads; - self.build_rope_tables(config.max_context, head_dim, config.rope_theta); + // For MLA, rope applies only to qk_rope_head_dim portion + let rope_dim = if config.use_mla { + config.qk_rope_head_dim + } else { + config.hidden_size / config.num_attention_heads + }; + self.build_rope_tables(config.max_context.min(8192), rope_dim, config.rope_theta); // Load tokenizer from GGUF metadata self.tok = self.load_tokenizer_from_gguf(&gguf); @@ -396,26 +699,57 @@ impl BitNetBackend { /// Extract BitNetModelConfig from GGUF metadata. fn extract_config(&self, gguf: &GgufFile) -> Result { - let num_layers = gguf.layer_count().unwrap_or(28); - let hidden_size = gguf.embedding_length().unwrap_or(4096); - let num_attention_heads = gguf.head_count().unwrap_or(32); - let num_kv_heads = gguf.head_count_kv().unwrap_or(8); - let vocab_size = gguf.vocab_size().unwrap_or(151552); - let max_context = gguf.context_length().unwrap_or(8192); - let rope_theta = gguf.rope_freq_base().unwrap_or(10000.0); - let intermediate_size = gguf.feed_forward_length().unwrap_or(11008); + let defaults = BitNetModelConfig::default(); + let num_layers = gguf.layer_count().unwrap_or(defaults.num_layers); + let hidden_size = gguf.embedding_length().unwrap_or(defaults.hidden_size); + let num_attention_heads = gguf.head_count().unwrap_or(defaults.num_attention_heads); + let num_kv_heads = gguf.head_count_kv().unwrap_or(defaults.num_kv_heads); + let vocab_size = gguf.vocab_size().unwrap_or(defaults.vocab_size); + let max_context = gguf.context_length().unwrap_or(defaults.max_context); + let rope_theta = gguf.rope_freq_base().unwrap_or(defaults.rope_theta); + let intermediate_size = gguf.feed_forward_length().unwrap_or(defaults.intermediate_size); - // Detect expert count from tensor names - let num_experts = self.detect_expert_count(gguf).unwrap_or(8); + // Detect expert count from tensor names or metadata + let num_experts = self.detect_expert_count(gguf) + .or_else(|| Self::meta_usize(gguf, "llm.expert_count")) + .unwrap_or(defaults.num_experts); - // Detect active experts from metadata or default to 2 - let active_experts = gguf - .metadata - .get("model.expert_count_active") - .or_else(|| gguf.metadata.get("llm.expert_used_count")) - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - .unwrap_or(2); + // Active experts per token + let active_experts = Self::meta_usize(gguf, "llm.expert_used_count") + .or_else(|| Self::meta_usize(gguf, "model.expert_count_active")) + .unwrap_or(defaults.active_experts); + + // MoE intermediate size (may differ from dense intermediate_size) + let moe_intermediate_size = Self::meta_usize(gguf, "llm.expert_feed_forward_length") + .unwrap_or(defaults.moe_intermediate_size); + + // MLA parameters + let q_lora_rank = Self::meta_usize(gguf, "llm.attention.q_lora_rank") + .unwrap_or(defaults.q_lora_rank); + let kv_lora_rank = Self::meta_usize(gguf, "llm.attention.kv_lora_rank") + .unwrap_or(defaults.kv_lora_rank); + let qk_nope_head_dim = Self::meta_usize(gguf, "llm.attention.key_length_nope") + .unwrap_or(defaults.qk_nope_head_dim); + let qk_rope_head_dim = Self::meta_usize(gguf, "llm.attention.key_length_rope") + .or_else(|| gguf.rope_dimension_count()) + .unwrap_or(defaults.qk_rope_head_dim); + let v_head_dim = Self::meta_usize(gguf, "llm.attention.value_length") + .unwrap_or(defaults.v_head_dim); + + // Detect MLA by checking for q_a tensor in first layer + let use_mla = TensorNameMapper::has_mla(gguf, 0); + + // Shared experts + let n_shared_experts = Self::meta_usize(gguf, "llm.expert_shared_count") + .unwrap_or(if num_experts > 1 { defaults.n_shared_experts } else { 0 }); + + // First K dense layers + let first_k_dense_replace = Self::meta_usize(gguf, "llm.expert_first_dense_layers") + .unwrap_or(defaults.first_k_dense_replace); + + // Routed scaling factor + let routed_scaling_factor = Self::meta_f32(gguf, "llm.expert_weights_scale") + .unwrap_or(defaults.routed_scaling_factor); Ok(BitNetModelConfig { num_layers, @@ -423,14 +757,34 @@ impl BitNetBackend { num_experts, active_experts, intermediate_size, + moe_intermediate_size, num_attention_heads, num_kv_heads, vocab_size, max_context, rope_theta, + use_mla, + q_lora_rank, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + n_shared_experts, + first_k_dense_replace, + routed_scaling_factor, }) } + /// Helper: extract a usize from GGUF metadata. + fn meta_usize(gguf: &GgufFile, key: &str) -> Option { + gguf.metadata.get(key).and_then(|v| v.as_u64()).map(|v| v as usize) + } + + /// Helper: extract an f32 from GGUF metadata. + fn meta_f32(gguf: &GgufFile, key: &str) -> Option { + gguf.metadata.get(key).and_then(|v| v.as_f32()) + } + /// Detect the number of MoE experts by scanning tensor names. fn detect_expert_count(&self, gguf: &GgufFile) -> Option { let mut max_expert_idx = 0usize; @@ -532,92 +886,344 @@ impl BitNetBackend { } /// Load a single transformer layer. + /// + /// Detects the layer type (dense vs MoE), attention type (MLA vs GQA), + /// and expert tensor format (stacked 3D vs individual) from the GGUF file. fn load_layer( &self, gguf: &GgufFile, idx: usize, config: &BitNetModelConfig, ) -> Result { - let prefix = format!("model.layers.{}", idx); + // Norm weights via name mapper + let in_norm_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::input_norm(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} input norm not found", idx)))?; + let input_norm_weight = self.load_fp_tensor(gguf, &in_norm_name, config)?; - // Norm weights (FP16/FP32) - let input_norm_weight = self.load_fp_tensor( - gguf, - &format!("{}.input_layernorm.weight", prefix), - config, - )?; - let post_attn_norm_weight = self.load_fp_tensor( - gguf, - &format!("{}.post_attention_layernorm.weight", prefix), - config, - )?; + let post_norm_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::post_attn_norm(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} post-attn norm not found", idx)))?; + let post_attn_norm_weight = self.load_fp_tensor(gguf, &post_norm_name, config)?; - // Attention projections (ternary) - let attn_prefix = format!("{}.self_attn", prefix); - let q_proj = self.load_ternary_tensor( - gguf, - &format!("{}.q_proj.weight", attn_prefix), - )?; - let k_proj = self.load_ternary_tensor( - gguf, - &format!("{}.k_proj.weight", attn_prefix), - )?; - let v_proj = self.load_ternary_tensor( - gguf, - &format!("{}.v_proj.weight", attn_prefix), - )?; - let o_proj = self.load_ternary_tensor( - gguf, - &format!("{}.o_proj.weight", attn_prefix), - )?; + // === Attention weights === + let attention = if TensorNameMapper::has_mla(gguf, idx) { + self.load_mla_attention(gguf, idx, config)? + } else { + self.load_gqa_attention(gguf, idx, config)? + }; - let attention = AttentionWeights { + // === FFN weights === + let is_dense_layer = idx < config.first_k_dense_replace + || TensorNameMapper::has_dense_ffn(gguf, idx); + + if is_dense_layer { + // Dense FFN layer (no MoE routing) + let dense_ffn = self.load_dense_ffn(gguf, idx, config)?; + Ok(TransformerLayer { + input_norm_weight, + post_attn_norm_weight, + attention, + layer_type: LayerType::Dense, + gate_weight: Vec::new(), + experts: Vec::new(), + shared_expert: None, + dense_ffn: Some(dense_ffn), + }) + } else { + // MoE layer: load router gate + experts + let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::moe_gate(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} MoE gate not found", idx)))?; + let gate_weight = self.load_fp_tensor(gguf, &gate_name, config)?; + + let experts = self.load_experts(gguf, idx, config)?; + + // Try loading shared expert + let shared_expert = self.load_shared_expert(gguf, idx, config).ok(); + + let layer_type = if shared_expert.is_some() { + LayerType::MoeWithShared + } else { + LayerType::Moe + }; + + Ok(TransformerLayer { + input_norm_weight, + post_attn_norm_weight, + attention, + layer_type, + gate_weight, + experts, + shared_expert, + dense_ffn: None, + }) + } + } + + /// Load MLA attention weights for a layer. + fn load_mla_attention( + &self, + gguf: &GgufFile, + idx: usize, + _config: &BitNetModelConfig, + ) -> Result { + // MLA projections + let q_a_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_a(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_q_a not found", idx)))?; + let q_a = self.load_ternary_tensor(gguf, &q_a_name)?; + + let q_b_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_b(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_q_b not found", idx)))?; + let q_b = self.load_ternary_tensor(gguf, &q_b_name)?; + + let kv_a_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_kv_a_mqa(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_kv_a_mqa not found", idx)))?; + let kv_a_mqa = self.load_ternary_tensor(gguf, &kv_a_name)?; + + let k_b_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_k_b(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_k_b not found", idx)))?; + let k_b = self.load_ternary_tensor(gguf, &k_b_name)?; + + let v_b_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_v_b(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_v_b not found", idx)))?; + let v_b = self.load_ternary_tensor(gguf, &v_b_name)?; + + let o_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_output(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} attn_output not found", idx)))?; + let o_proj = self.load_ternary_tensor(gguf, &o_name)?; + + // Norm weights for MLA compression (may or may not be present) + let q_a_norm = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_a_norm(idx)) + .and_then(|n| self.load_fp_tensor(gguf, &n, _config).ok()); + let kv_a_norm = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_kv_a_norm(idx)) + .and_then(|n| self.load_fp_tensor(gguf, &n, _config).ok()); + + // Use o_proj as placeholder for the standard fields (they won't be used in MLA path) + let placeholder = TernaryTensor { + packed_data: vec![], + scales: vec![], + shape: (0, 0), + block_size: 256, + }; + + Ok(AttentionWeights { + is_mla: true, + q_proj: placeholder.clone(), + k_proj: placeholder.clone(), + v_proj: placeholder, + o_proj, + q_a: Some(q_a), + q_b: Some(q_b), + q_a_norm, + kv_a_mqa: Some(kv_a_mqa), + kv_a_norm, + k_b: Some(k_b), + v_b: Some(v_b), + }) + } + + /// Load standard GQA attention weights for a layer. + fn load_gqa_attention( + &self, + gguf: &GgufFile, + idx: usize, + _config: &BitNetModelConfig, + ) -> Result { + let q_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_q_proj(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} Q projection not found", idx)))?; + let q_proj = self.load_ternary_tensor(gguf, &q_name)?; + + let k_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_k_proj(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} K projection not found", idx)))?; + let k_proj = self.load_ternary_tensor(gguf, &k_name)?; + + let v_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_v_proj(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} V projection not found", idx)))?; + let v_proj = self.load_ternary_tensor(gguf, &v_name)?; + + let o_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::attn_output(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} O projection not found", idx)))?; + let o_proj = self.load_ternary_tensor(gguf, &o_name)?; + + Ok(AttentionWeights { + is_mla: false, q_proj, k_proj, v_proj, o_proj, - }; + q_a: None, + q_b: None, + q_a_norm: None, + kv_a_mqa: None, + kv_a_norm: None, + k_b: None, + v_b: None, + }) + } - // MoE router gate (FP16/FP32): [num_experts, hidden_size] - let gate_weight = self.load_fp_tensor( - gguf, - &format!("{}.mlp.gate.weight", prefix), - config, - )?; + /// Load dense FFN weights for a layer (no MoE). + fn load_dense_ffn( + &self, + gguf: &GgufFile, + idx: usize, + _config: &BitNetModelConfig, + ) -> Result { + let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_gate(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} dense ffn_gate not found", idx)))?; + let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_up(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} dense ffn_up not found", idx)))?; + let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_down(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} dense ffn_down not found", idx)))?; - // Expert FFN weights (ternary) - let mut experts = Vec::with_capacity(config.num_experts); - for expert_idx in 0..config.num_experts { - let expert_prefix = - format!("{}.mlp.experts.{}", prefix, expert_idx); + Ok(ExpertWeights { + gate_proj: self.load_ternary_tensor(gguf, &gate_name)?, + up_proj: self.load_ternary_tensor(gguf, &up_name)?, + down_proj: self.load_ternary_tensor(gguf, &down_name)?, + }) + } - let gate_proj = self.load_ternary_tensor( - gguf, - &format!("{}.gate_proj.weight", expert_prefix), - )?; - let up_proj = self.load_ternary_tensor( - gguf, - &format!("{}.up_proj.weight", expert_prefix), - )?; - let down_proj = self.load_ternary_tensor( - gguf, - &format!("{}.down_proj.weight", expert_prefix), - )?; + /// Load shared expert weights for a layer. + fn load_shared_expert( + &self, + gguf: &GgufFile, + idx: usize, + _config: &BitNetModelConfig, + ) -> Result { + let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_gate_shexp(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} shared expert gate not found", idx)))?; + let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_up_shexp(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} shared expert up not found", idx)))?; + let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_down_shexp(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} shared expert down not found", idx)))?; - experts.push(ExpertWeights { - gate_proj, - up_proj, - down_proj, - }); + Ok(ExpertWeights { + gate_proj: self.load_ternary_tensor(gguf, &gate_name)?, + up_proj: self.load_ternary_tensor(gguf, &up_name)?, + down_proj: self.load_ternary_tensor(gguf, &down_name)?, + }) + } + + /// Load routed expert weights, supporting both stacked (3D) and individual tensor formats. + fn load_experts( + &self, + gguf: &GgufFile, + idx: usize, + config: &BitNetModelConfig, + ) -> Result> { + if TensorNameMapper::has_stacked_experts(gguf, idx) { + self.load_stacked_experts(gguf, idx, config) + } else { + self.load_individual_experts(gguf, idx, config) + } + } + + /// Load stacked expert tensors (3D format: [num_experts, out_dim, in_dim]) + /// and split into per-expert TernaryTensors. + fn load_stacked_experts( + &self, + gguf: &GgufFile, + idx: usize, + config: &BitNetModelConfig, + ) -> Result> { + let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_gate_exps(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} stacked gate_exps not found", idx)))?; + let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_up_exps(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} stacked up_exps not found", idx)))?; + let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::ffn_down_exps(idx)) + .ok_or_else(|| RuvLLMError::NotFound(format!("Layer {} stacked down_exps not found", idx)))?; + + // Load stacked tensors as FP32 and split per expert + let gate_all = gguf.load_tensor_f32(&gate_name)?; + let up_all = gguf.load_tensor_f32(&up_name)?; + let down_all = gguf.load_tensor_f32(&down_name)?; + + let num_experts = config.num_experts; + let intermediate = config.moe_intermediate_size; + let hidden = config.hidden_size; + + // gate/up: [num_experts, intermediate_size, hidden_size] + let gate_per_expert = intermediate * hidden; + // down: [num_experts, hidden_size, intermediate_size] + let down_per_expert = hidden * intermediate; + + let ptconfig = super::quantizer::PtBitnetConfig::default(); + let mut experts = Vec::with_capacity(num_experts); + + for e in 0..num_experts { + let gate_start = e * gate_per_expert; + let gate_end = gate_start + gate_per_expert; + let gate_slice = if gate_end <= gate_all.len() { + &gate_all[gate_start..gate_end] + } else { + // Insufficient data — create zeros + &[] + }; + + let up_start = e * gate_per_expert; + let up_end = up_start + gate_per_expert; + let up_slice = if up_end <= up_all.len() { + &up_all[up_start..up_end] + } else { + &[] + }; + + let down_start = e * down_per_expert; + let down_end = down_start + down_per_expert; + let down_slice = if down_end <= down_all.len() { + &down_all[down_start..down_end] + } else { + &[] + }; + + let gate_proj = if gate_slice.is_empty() { + TernaryTensor { packed_data: vec![], scales: vec![], shape: (intermediate, hidden), block_size: 256 } + } else { + super::quantizer::quantize_tensor(gate_slice, (intermediate, hidden), &ptconfig)? + }; + let up_proj = if up_slice.is_empty() { + TernaryTensor { packed_data: vec![], scales: vec![], shape: (intermediate, hidden), block_size: 256 } + } else { + super::quantizer::quantize_tensor(up_slice, (intermediate, hidden), &ptconfig)? + }; + let down_proj = if down_slice.is_empty() { + TernaryTensor { packed_data: vec![], scales: vec![], shape: (hidden, intermediate), block_size: 256 } + } else { + super::quantizer::quantize_tensor(down_slice, (hidden, intermediate), &ptconfig)? + }; + + experts.push(ExpertWeights { gate_proj, up_proj, down_proj }); } - Ok(TransformerLayer { - input_norm_weight, - post_attn_norm_weight, - attention, - gate_weight, - experts, - }) + Ok(experts) + } + + /// Load individual expert tensors (HuggingFace naming: `experts.{e}.gate_proj.weight`). + fn load_individual_experts( + &self, + gguf: &GgufFile, + idx: usize, + config: &BitNetModelConfig, + ) -> Result> { + let mut experts = Vec::with_capacity(config.num_experts); + for e in 0..config.num_experts { + let gate_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::expert_gate(idx, e)) + .ok_or_else(|| RuvLLMError::NotFound(format!( + "Layer {} expert {} gate_proj not found", idx, e + )))?; + let up_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::expert_up(idx, e)) + .ok_or_else(|| RuvLLMError::NotFound(format!( + "Layer {} expert {} up_proj not found", idx, e + )))?; + let down_name = TensorNameMapper::resolve(gguf, &TensorNameMapper::expert_down(idx, e)) + .ok_or_else(|| RuvLLMError::NotFound(format!( + "Layer {} expert {} down_proj not found", idx, e + )))?; + + experts.push(ExpertWeights { + gate_proj: self.load_ternary_tensor(gguf, &gate_name)?, + up_proj: self.load_ternary_tensor(gguf, &up_name)?, + down_proj: self.load_ternary_tensor(gguf, &down_name)?, + }); + } + Ok(experts) } // ======================================================================== @@ -699,10 +1305,11 @@ impl BitNetBackend { let mut hidden_states: Vec = self.embedding[last_token * hidden..(last_token + 1) * hidden].to_vec(); - for (_layer_idx, layer) in self.layers.iter().enumerate() { + for (layer_idx, layer) in self.layers.iter().enumerate() { hidden_states = self.forward_layer_nocache( &hidden_states, layer, + layer_idx, config, )?; } @@ -728,34 +1335,77 @@ impl BitNetBackend { config: &BitNetModelConfig, ) -> Result> { let hidden = config.hidden_size; - let num_heads = config.num_attention_heads; - let num_kv_heads = config.num_kv_heads; - let head_dim = hidden / num_heads; - let kv_dim = num_kv_heads * head_dim; // --- Pre-attention norm --- let mut normed = input.to_vec(); let layer = &self.layers[layer_idx]; rms_norm_inplace(&mut normed, &layer.input_norm_weight, 1e-6); - // --- Q/K/V projections via TL1 GEMV --- - let q = self.tl1_gemv(&self.layers[layer_idx].attention.q_proj, &normed, hidden, hidden); - let k = self.tl1_gemv(&self.layers[layer_idx].attention.k_proj, &normed, kv_dim, hidden); - let v = self.tl1_gemv(&self.layers[layer_idx].attention.v_proj, &normed, kv_dim, hidden); + // --- Attention (MLA or GQA) --- + let attn_out = if self.layers[layer_idx].attention.is_mla { + self.forward_mla_cached(&normed, layer_idx, position, config)? + } else { + self.forward_gqa_cached(&normed, layer_idx, position, config)? + }; - // --- Apply RoPE to Q and K --- + // --- Output projection --- + let o_out = self.tl1_gemv( + &self.layers[layer_idx].attention.o_proj, + &attn_out, + hidden, + hidden, + ); + + // --- Residual after attention --- + let mut residual: Vec = input.iter().zip(o_out.iter()).map(|(r, a)| r + a).collect(); + + // --- Post-attention norm --- + let mut normed_ffn = residual.clone(); + let layer = &self.layers[layer_idx]; + rms_norm_inplace(&mut normed_ffn, &layer.post_attn_norm_weight, 1e-6); + + // --- FFN (Dense, MoE, or MoE+Shared) --- + let ffn_out = self.forward_ffn(&normed_ffn, layer_idx, config)?; + + for (r, &f) in residual.iter_mut().zip(ffn_out.iter()) { + *r += f; + } + + Ok(residual) + } + + /// GQA attention with KV cache. + fn forward_gqa_cached( + &mut self, + normed: &[f32], + layer_idx: usize, + position: usize, + config: &BitNetModelConfig, + ) -> Result> { + let hidden = config.hidden_size; + let num_heads = config.num_attention_heads; + let num_kv_heads = config.num_kv_heads; + let head_dim = hidden / num_heads; + let kv_dim = num_kv_heads * head_dim; + + // Q/K/V projections via TL1 GEMV + let q = self.tl1_gemv(&self.layers[layer_idx].attention.q_proj, normed, hidden, hidden); + let k = self.tl1_gemv(&self.layers[layer_idx].attention.k_proj, normed, kv_dim, hidden); + let v = self.tl1_gemv(&self.layers[layer_idx].attention.v_proj, normed, kv_dim, hidden); + + // Apply RoPE to Q and K let mut q_rope = q; - let mut k_rope = k.clone(); + let mut k_rope = k; self.apply_rope(&mut q_rope, num_heads, head_dim, position); self.apply_rope(&mut k_rope, num_kv_heads, head_dim, position); - // --- Update KV cache --- + // Update KV cache self.kv_caches[layer_idx].keys.push(k_rope); self.kv_caches[layer_idx].values.push(v); let seq_len = self.kv_caches[layer_idx].len(); - // --- GQA Attention --- - let gqa_groups = num_heads / num_kv_heads; + // GQA attention scores + let gqa_groups = if num_kv_heads > 0 { num_heads / num_kv_heads } else { 1 }; let inv_sqrt_d = 1.0 / (head_dim as f32).sqrt(); let mut attn_out = vec![0.0f32; hidden]; @@ -763,7 +1413,6 @@ impl BitNetBackend { let kv_head = h / gqa_groups; let q_offset = h * head_dim; - // Compute attention scores for this head across all cached positions let mut scores = Vec::with_capacity(seq_len); for pos in 0..seq_len { let k_offset = kv_head * head_dim; @@ -775,11 +1424,8 @@ impl BitNetBackend { scores.push(dot * inv_sqrt_d); } - // Causal mask is implicit: we only have positions <= current - // Softmax over scores softmax_inplace(&mut scores); - // Weighted sum of V for pos in 0..seq_len { let v_offset = kv_head * head_dim; let v_vec = &self.kv_caches[layer_idx].values[pos]; @@ -790,45 +1436,217 @@ impl BitNetBackend { } } - // --- Output projection --- - let o_proj = self.tl1_gemv( - &self.layers[layer_idx].attention.o_proj, - &attn_out, - hidden, - hidden, - ); + Ok(attn_out) + } - // --- Residual after attention --- - let mut residual: Vec = input.iter().zip(o_proj.iter()).map(|(r, a)| r + a).collect(); + /// MLA (Multi-Head Latent Attention) with KV cache. + /// + /// Forward path: + /// 1. Q: x → W_q_a → RMSNorm → W_q_b → split(Q_nope, Q_rope) → RoPE(Q_rope) + /// 2. KV: x → W_kv_a → split(c_kv, k_pe) → RoPE(k_pe) + /// K: RMSNorm(c_kv) → W_k_b → K_nope → concat(K_nope, K_rope) + /// V: c_kv → W_v_b → V + /// 3. Standard multi-head attention on concatenated Q/K + fn forward_mla_cached( + &mut self, + normed: &[f32], + layer_idx: usize, + position: usize, + config: &BitNetModelConfig, + ) -> Result> { + let hidden = config.hidden_size; + let num_heads = config.num_attention_heads; + let q_lora_rank = config.q_lora_rank; + let kv_lora_rank = config.kv_lora_rank; + let qk_nope_dim = config.qk_nope_head_dim; + let qk_rope_dim = config.qk_rope_head_dim; + let v_dim = config.v_head_dim; + let q_head_dim = qk_nope_dim + qk_rope_dim; + let kv_a_out = kv_lora_rank + qk_rope_dim; - // --- Post-attention norm --- - let mut normed_ffn = residual.clone(); + let attn = &self.layers[layer_idx].attention; + + // --- Q path --- + // Step 1: c_q = x @ W_q_a [hidden → q_lora_rank] + let q_a = attn.q_a.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA q_a missing".into()) + })?; + let mut c_q = self.tl1_gemv(q_a, normed, q_lora_rank, hidden); + + // Step 2: RMSNorm(c_q) + if let Some(ref norm_w) = attn.q_a_norm { + rms_norm_inplace(&mut c_q, norm_w, 1e-6); + } + + // Step 3: Q = c_q @ W_q_b [q_lora_rank → num_heads * q_head_dim] + let q_b = attn.q_b.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA q_b missing".into()) + })?; + let q_full = self.tl1_gemv(q_b, &c_q, num_heads * q_head_dim, q_lora_rank); + + // Step 4: Split Q into nope and rope parts per head, apply RoPE to rope part + let mut q_nope = vec![0.0f32; num_heads * qk_nope_dim]; + let mut q_rope_part = vec![0.0f32; num_heads * qk_rope_dim]; + + for h in 0..num_heads { + let src = h * q_head_dim; + let nope_dst = h * qk_nope_dim; + let rope_dst = h * qk_rope_dim; + q_nope[nope_dst..nope_dst + qk_nope_dim] + .copy_from_slice(&q_full[src..src + qk_nope_dim]); + q_rope_part[rope_dst..rope_dst + qk_rope_dim] + .copy_from_slice(&q_full[src + qk_nope_dim..src + q_head_dim]); + } + + // Apply RoPE to the rope portion of Q + self.apply_rope(&mut q_rope_part, num_heads, qk_rope_dim, position); + + // --- KV path --- + // Step 1: [c_kv, k_pe] = x @ W_kv_a [hidden → kv_lora_rank + qk_rope_dim] + let kv_a = attn.kv_a_mqa.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA kv_a_mqa missing".into()) + })?; + let kv_combined = self.tl1_gemv(kv_a, normed, kv_a_out, hidden); + + // Split: first kv_lora_rank dims = c_kv, last qk_rope_dim = k_pe + let c_kv = &kv_combined[..kv_lora_rank]; + let mut k_pe = kv_combined[kv_lora_rank..].to_vec(); + + // Apply RoPE to k_pe (single head worth of rope dims) + self.apply_rope(&mut k_pe, 1, qk_rope_dim, position); + + // Step 2: K_nope = RMSNorm(c_kv) @ W_k_b [kv_lora_rank → num_heads * qk_nope_dim] + let mut c_kv_normed = c_kv.to_vec(); + if let Some(ref norm_w) = attn.kv_a_norm { + rms_norm_inplace(&mut c_kv_normed, norm_w, 1e-6); + } + + let k_b = attn.k_b.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA k_b missing".into()) + })?; + let k_nope = self.tl1_gemv(k_b, &c_kv_normed, num_heads * qk_nope_dim, kv_lora_rank); + + // Step 3: V = c_kv @ W_v_b [kv_lora_rank → num_heads * v_dim] + let v_b = attn.v_b.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA v_b missing".into()) + })?; + let v_full = self.tl1_gemv(v_b, c_kv, num_heads * v_dim, kv_lora_rank); + + // --- Build full K by concatenating K_nope + K_rope per head --- + // K_rope is shared across all heads (replicated from k_pe) + let k_full_dim = num_heads * q_head_dim; // per-head: qk_nope + qk_rope + let mut k_full = vec![0.0f32; k_full_dim]; + for h in 0..num_heads { + let dst = h * q_head_dim; + let nope_src = h * qk_nope_dim; + k_full[dst..dst + qk_nope_dim].copy_from_slice(&k_nope[nope_src..nope_src + qk_nope_dim]); + k_full[dst + qk_nope_dim..dst + q_head_dim].copy_from_slice(&k_pe[..qk_rope_dim]); + } + + // --- Build full Q by concatenating Q_nope + Q_rope per head --- + let mut q_full_concat = vec![0.0f32; num_heads * q_head_dim]; + for h in 0..num_heads { + let dst = h * q_head_dim; + let nope_src = h * qk_nope_dim; + let rope_src = h * qk_rope_dim; + q_full_concat[dst..dst + qk_nope_dim].copy_from_slice(&q_nope[nope_src..nope_src + qk_nope_dim]); + q_full_concat[dst + qk_nope_dim..dst + q_head_dim].copy_from_slice(&q_rope_part[rope_src..rope_src + qk_rope_dim]); + } + + // --- Update KV cache --- + self.kv_caches[layer_idx].keys.push(k_full); + self.kv_caches[layer_idx].values.push(v_full); + let seq_len = self.kv_caches[layer_idx].len(); + + // --- Multi-head attention --- + let inv_sqrt_d = 1.0 / (q_head_dim as f32).sqrt(); + let mut attn_out = vec![0.0f32; num_heads * v_dim]; + + for h in 0..num_heads { + let q_off = h * q_head_dim; + + let mut scores = Vec::with_capacity(seq_len); + for pos in 0..seq_len { + let k_vec = &self.kv_caches[layer_idx].keys[pos]; + let k_off = h * q_head_dim; + let mut dot = 0.0f32; + for d in 0..q_head_dim { + dot += q_full_concat[q_off + d] * k_vec[k_off + d]; + } + scores.push(dot * inv_sqrt_d); + } + + softmax_inplace(&mut scores); + + let v_off = h * v_dim; + for pos in 0..seq_len { + let v_vec = &self.kv_caches[layer_idx].values[pos]; + let w = scores[pos]; + for d in 0..v_dim { + attn_out[v_off + d] += w * v_vec[h * v_dim + d]; + } + } + } + + Ok(attn_out) + } + + /// Unified FFN forward: dispatches to dense, MoE, or MoE+shared based on layer type. + fn forward_ffn( + &self, + normed_ffn: &[f32], + layer_idx: usize, + config: &BitNetModelConfig, + ) -> Result> { + let hidden = config.hidden_size; let layer = &self.layers[layer_idx]; - rms_norm_inplace(&mut normed_ffn, &layer.post_attn_norm_weight, 1e-6); - // --- MoE --- - let (expert_indices, expert_weights) = - self.route_experts(&normed_ffn, &layer.gate_weight, config)?; - - let mut moe_output = vec![0.0f32; hidden]; - for (&eidx, &eweight) in expert_indices.iter().zip(expert_weights.iter()) { - if eidx >= layer.experts.len() { - return Err(RuvLLMError::Model(format!( - "Expert index {} out of bounds (layer has {} experts)", - eidx, layer.experts.len() - ))); + match layer.layer_type { + LayerType::Dense => { + // Dense FFN: single gate/up/down + let ffn = layer.dense_ffn.as_ref().ok_or_else(|| { + RuvLLMError::Model(format!("Layer {} is Dense but has no dense_ffn", layer_idx)) + })?; + self.expert_forward(normed_ffn, ffn, config) } - let expert_out = self.expert_forward(&normed_ffn, &self.layers[layer_idx].experts[eidx], config)?; - for (o, &e) in moe_output.iter_mut().zip(expert_out.iter()) { - *o += eweight * e; + LayerType::Moe => { + // MoE: route to top-K experts, weighted sum + let (indices, weights) = self.route_experts(normed_ffn, &layer.gate_weight, config)?; + let mut output = vec![0.0f32; hidden]; + for (&eidx, &ew) in indices.iter().zip(weights.iter()) { + if eidx >= layer.experts.len() { continue; } + let e_out = self.expert_forward(normed_ffn, &layer.experts[eidx], config)?; + for (o, &e) in output.iter_mut().zip(e_out.iter()) { + *o += ew * e; + } + } + Ok(output) + } + LayerType::MoeWithShared => { + // MoE + shared expert: routed output + shared expert output + let (indices, weights) = self.route_experts(normed_ffn, &layer.gate_weight, config)?; + let mut output = vec![0.0f32; hidden]; + + // Routed experts + for (&eidx, &ew) in indices.iter().zip(weights.iter()) { + if eidx >= layer.experts.len() { continue; } + let e_out = self.expert_forward(normed_ffn, &layer.experts[eidx], config)?; + for (o, &e) in output.iter_mut().zip(e_out.iter()) { + *o += ew * e; + } + } + + // Shared expert (always active, weight = 1.0) + if let Some(ref shared) = layer.shared_expert { + let s_out = self.expert_forward(normed_ffn, shared, config)?; + for (o, &s) in output.iter_mut().zip(s_out.iter()) { + *o += s; + } + } + + Ok(output) } } - - for (r, &m) in residual.iter_mut().zip(moe_output.iter()) { - *r += m; - } - - Ok(residual) } /// Forward pass through a single layer WITHOUT KV cache (legacy path). @@ -836,6 +1654,7 @@ impl BitNetBackend { &self, input: &[f32], layer: &TransformerLayer, + layer_idx: usize, config: &BitNetModelConfig, ) -> Result> { let hidden = config.hidden_size; @@ -843,60 +1662,91 @@ impl BitNetBackend { let mut normed = input.to_vec(); rms_norm_inplace(&mut normed, &layer.input_norm_weight, 1e-6); - // Attention: Q/K/V projections, single-position self-attention (degenerates to - // identity-like behavior for 1 position but at least runs the projection weights) - let num_heads = config.num_attention_heads; - let head_dim = hidden / num_heads; - let kv_dim = config.num_kv_heads * head_dim; + // Attention: single-position (degenerates to V pass-through for GQA) + let attn_concat = if layer.attention.is_mla { + // MLA single-position: project through full pipeline but attention = identity + self.forward_mla_single_position(&normed, layer, config)? + } else { + // GQA single-position: V expanded to all heads + let num_heads = config.num_attention_heads; + let head_dim = hidden / num_heads; + let kv_dim = config.num_kv_heads * head_dim; + let gqa_groups = if config.num_kv_heads > 0 { num_heads / config.num_kv_heads } else { 1 }; - let q = self.tl1_gemv(&layer.attention.q_proj, &normed, hidden, hidden); - let k = self.tl1_gemv(&layer.attention.k_proj, &normed, kv_dim, hidden); - let v = self.tl1_gemv(&layer.attention.v_proj, &normed, kv_dim, hidden); + let q = self.tl1_gemv(&layer.attention.q_proj, &normed, hidden, hidden); + let k = self.tl1_gemv(&layer.attention.k_proj, &normed, kv_dim, hidden); + let v = self.tl1_gemv(&layer.attention.v_proj, &normed, kv_dim, hidden); + let _ = (q, k); // Exercise projections - // Single-position attention: softmax([score]) = [1.0], so output = V expanded to all heads - let gqa_groups = num_heads / config.num_kv_heads; - let mut attn_concat = vec![0.0f32; hidden]; - for h in 0..num_heads { - let kv_head = h / gqa_groups; - for d in 0..head_dim { - attn_concat[h * head_dim + d] = v[kv_head * head_dim + d]; + let mut concat = vec![0.0f32; hidden]; + for h in 0..num_heads { + let kv_head = h / gqa_groups; + for d in 0..head_dim { + concat[h * head_dim + d] = v[kv_head * head_dim + d]; + } } - } - // Suppress unused warning — q and k are computed to exercise the projections - let _ = q; - let _ = k; + concat + }; let o_out = self.tl1_gemv(&layer.attention.o_proj, &attn_concat, hidden, hidden); - let mut residual: Vec = input.iter().zip(o_out.iter()).map(|(r, a)| r + a).collect(); let mut normed_ffn = residual.clone(); rms_norm_inplace(&mut normed_ffn, &layer.post_attn_norm_weight, 1e-6); - let (expert_indices, expert_weights) = - self.route_experts(&normed_ffn, &layer.gate_weight, config)?; + let ffn_out = self.forward_ffn(&normed_ffn, layer_idx, config)?; - let mut moe_output = vec![0.0f32; hidden]; - for (&eidx, &eweight) in expert_indices.iter().zip(expert_weights.iter()) { - if eidx >= layer.experts.len() { - return Err(RuvLLMError::Model(format!( - "Expert index {} out of bounds (layer has {} experts)", - eidx, layer.experts.len() - ))); - } - let expert_out = self.expert_forward(&normed_ffn, &layer.experts[eidx], config)?; - for (o, &e) in moe_output.iter_mut().zip(expert_out.iter()) { - *o += eweight * e; - } - } - - for (r, &m) in residual.iter_mut().zip(moe_output.iter()) { - *r += m; + for (r, &f) in residual.iter_mut().zip(ffn_out.iter()) { + *r += f; } Ok(residual) } + /// MLA forward for single-position (no KV cache). Used in legacy forward path. + fn forward_mla_single_position( + &self, + normed: &[f32], + layer: &TransformerLayer, + config: &BitNetModelConfig, + ) -> Result> { + let hidden = config.hidden_size; + let num_heads = config.num_attention_heads; + let q_lora_rank = config.q_lora_rank; + let kv_lora_rank = config.kv_lora_rank; + let v_dim = config.v_head_dim; + let kv_a_out = kv_lora_rank + config.qk_rope_head_dim; + + let attn = &layer.attention; + + // Q path (exercise projections) + if let Some(ref q_a) = attn.q_a { + let mut c_q = self.tl1_gemv(q_a, normed, q_lora_rank, hidden); + if let Some(ref norm_w) = attn.q_a_norm { + rms_norm_inplace(&mut c_q, norm_w, 1e-6); + } + if let Some(ref q_b) = attn.q_b { + let _q = self.tl1_gemv(q_b, &c_q, num_heads * (config.qk_nope_head_dim + config.qk_rope_head_dim), q_lora_rank); + } + } + + // KV path + let kv_a = attn.kv_a_mqa.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA kv_a_mqa missing in nocache path".into()) + })?; + let kv_combined = self.tl1_gemv(kv_a, normed, kv_a_out, hidden); + let c_kv = &kv_combined[..kv_lora_rank]; + + // V = c_kv @ W_v_b + let v_b = attn.v_b.as_ref().ok_or_else(|| { + RuvLLMError::Model("MLA v_b missing".into()) + })?; + let v_full = self.tl1_gemv(v_b, c_kv, num_heads * v_dim, kv_lora_rank); + + // Single position: attention is identity, output = V directly + Ok(v_full) + } + /// Apply Rotary Position Embedding (RoPE) in-place. /// /// For each head, rotates pairs of dimensions (2i, 2i+1) by position-dependent angles. @@ -1111,6 +1961,181 @@ impl BitNetBackend { output } + // ======================================================================== + // Tensor Discovery & Model Validation + // ======================================================================== + + /// Discover and classify all tensors in a GGUF file. + /// + /// Returns a structured report of found tensors, grouped by type + /// (embedding, attention, FFN, norm, etc.), with shape and quantization info. + pub fn discover_tensors(path: &str) -> Result { + let gguf = GgufFile::open_mmap(Path::new(path))?; + let mut report = TensorDiscoveryReport { + total_tensors: gguf.tensors.len(), + total_bytes: gguf.total_tensor_size(), + architecture: gguf.architecture().map(|s| s.to_string()), + tensor_groups: Vec::new(), + warnings: Vec::new(), + }; + + // Classify tensors + let mut embedding = Vec::new(); + let mut attention = Vec::new(); + let mut ffn = Vec::new(); + let mut norm = Vec::new(); + let mut other = Vec::new(); + + for t in &gguf.tensors { + let info = TensorEntry { + name: t.name.clone(), + shape: t.shape.clone(), + dtype: t.dtype.name().to_string(), + bytes: t.byte_size(), + }; + + if t.name.contains("embd") || t.name.contains("embed") || t.name == "output.weight" { + embedding.push(info); + } else if t.name.contains("attn") || t.name.contains("self_attn") { + attention.push(info); + } else if t.name.contains("ffn") || t.name.contains("mlp") || t.name.contains("expert") { + ffn.push(info); + } else if t.name.contains("norm") { + norm.push(info); + } else { + other.push(info); + } + } + + if !embedding.is_empty() { + report.tensor_groups.push(TensorGroup { name: "Embedding/Output".into(), tensors: embedding }); + } + if !norm.is_empty() { + report.tensor_groups.push(TensorGroup { name: "Normalization".into(), tensors: norm }); + } + if !attention.is_empty() { + report.tensor_groups.push(TensorGroup { name: "Attention".into(), tensors: attention }); + } + if !ffn.is_empty() { + report.tensor_groups.push(TensorGroup { name: "FFN/Expert".into(), tensors: ffn }); + } + if !other.is_empty() { + report.tensor_groups.push(TensorGroup { name: "Other".into(), tensors: other }); + } + + // Detect naming convention + let has_blk = gguf.tensors.iter().any(|t| t.name.starts_with("blk.")); + let has_model = gguf.tensors.iter().any(|t| t.name.starts_with("model.")); + if has_blk && has_model { + report.warnings.push("Mixed naming conventions detected (blk.* and model.*)".into()); + } + + // Detect MLA + let has_mla = gguf.tensors.iter().any(|t| t.name.contains("attn_q_a")); + if has_mla { + report.warnings.push("MLA (Multi-Head Latent Attention) tensors detected".into()); + } + + // Detect stacked experts + let has_exps = gguf.tensors.iter().any(|t| t.name.contains("_exps")); + if has_exps { + report.warnings.push("Stacked expert tensors detected (3D format)".into()); + } + + Ok(report) + } + + /// Validate that a GGUF file has all required tensors for loading. + /// + /// Returns a list of missing tensor names and a boolean indicating + /// whether the model can be loaded. + pub fn validate_model(path: &str) -> Result { + let gguf = GgufFile::open_mmap(Path::new(path))?; + let backend = BitNetBackend::new(); + let config = backend.extract_config(&gguf)?; + let mut missing = Vec::new(); + let mut found = Vec::new(); + + // Check global tensors + for (label, candidates) in [ + ("Embedding", TensorNameMapper::embedding()), + ("Output/LM Head", TensorNameMapper::output()), + ("Final Norm", TensorNameMapper::final_norm()), + ] { + if let Some(name) = TensorNameMapper::resolve(&gguf, &candidates) { + found.push(format!("{}: {}", label, name)); + } else { + missing.push(format!("{} (tried: {})", label, candidates.join(", "))); + } + } + + // Check first layer tensors to determine structure + let idx = 0; + for (label, candidates) in [ + ("Layer 0 Input Norm", TensorNameMapper::input_norm(idx)), + ("Layer 0 Post-Attn Norm", TensorNameMapper::post_attn_norm(idx)), + ] { + if let Some(name) = TensorNameMapper::resolve(&gguf, &candidates) { + found.push(format!("{}: {}", label, name)); + } else { + missing.push(format!("{} (tried: {})", label, candidates.join(", "))); + } + } + + // Check attention type + if TensorNameMapper::has_mla(&gguf, 0) { + found.push("Attention type: MLA".into()); + for (label, candidates) in [ + ("Layer 0 attn_q_a", TensorNameMapper::attn_q_a(0)), + ("Layer 0 attn_q_b", TensorNameMapper::attn_q_b(0)), + ("Layer 0 attn_kv_a_mqa", TensorNameMapper::attn_kv_a_mqa(0)), + ("Layer 0 attn_k_b", TensorNameMapper::attn_k_b(0)), + ("Layer 0 attn_v_b", TensorNameMapper::attn_v_b(0)), + ("Layer 0 attn_output", TensorNameMapper::attn_output(0)), + ] { + if TensorNameMapper::resolve(&gguf, &candidates).is_some() { + found.push(format!(" {}: present", label)); + } else { + missing.push(format!("{} (tried: {})", label, candidates.join(", "))); + } + } + } else { + found.push("Attention type: GQA".into()); + } + + // Check FFN structure for layers + let check_layer = config.first_k_dense_replace.min(config.num_layers); + if check_layer > 0 { + if TensorNameMapper::has_dense_ffn(&gguf, 0) { + found.push("Layer 0: Dense FFN".into()); + } else { + missing.push("Layer 0 dense FFN tensors".into()); + } + } + if config.num_layers > config.first_k_dense_replace { + let moe_layer = config.first_k_dense_replace; + if TensorNameMapper::has_stacked_experts(&gguf, moe_layer) { + found.push(format!("Layer {}: Stacked MoE experts", moe_layer)); + } else if TensorNameMapper::resolve(&gguf, &TensorNameMapper::expert_gate(moe_layer, 0)).is_some() { + found.push(format!("Layer {}: Individual MoE experts", moe_layer)); + } else { + missing.push(format!("Layer {} MoE expert tensors", moe_layer)); + } + } + + let can_load = missing.is_empty(); + Ok(ModelValidation { + can_load, + config_summary: format!( + "layers={}, hidden={}, heads={}, experts={}, vocab={}, mla={}", + config.num_layers, config.hidden_size, config.num_attention_heads, + config.num_experts, config.vocab_size, config.use_mla + ), + found, + missing, + }) + } + /// Greedy-decode a single next token from logits. fn argmax(logits: &[f32]) -> u32 { let mut best_idx = 0u32; @@ -1125,6 +2150,60 @@ impl BitNetBackend { } } +// ============================================================================ +// Tensor Discovery & Validation Report Types +// ============================================================================ + +/// Report from tensor discovery on a GGUF file. +#[derive(Debug)] +pub struct TensorDiscoveryReport { + /// Total number of tensors + pub total_tensors: usize, + /// Total bytes across all tensors + pub total_bytes: usize, + /// Architecture string from metadata + pub architecture: Option, + /// Grouped tensor listings + pub tensor_groups: Vec, + /// Warnings or observations + pub warnings: Vec, +} + +/// A group of related tensors. +#[derive(Debug)] +pub struct TensorGroup { + /// Group name (e.g., "Attention", "FFN/Expert") + pub name: String, + /// Tensors in this group + pub tensors: Vec, +} + +/// Info about a single tensor. +#[derive(Debug)] +pub struct TensorEntry { + /// Tensor name in GGUF + pub name: String, + /// Shape dimensions + pub shape: Vec, + /// Quantization type name + pub dtype: String, + /// Size in bytes + pub bytes: usize, +} + +/// Result of model validation against expected tensor layout. +#[derive(Debug)] +pub struct ModelValidation { + /// Whether all required tensors were found + pub can_load: bool, + /// Summary of detected configuration + pub config_summary: String, + /// Tensors that were found + pub found: Vec, + /// Tensors that are missing + pub missing: Vec, +} + // ============================================================================ // LlmBackend Trait Implementation // ============================================================================ @@ -1311,21 +2390,43 @@ impl LlmBackend for BitNetBackend { .layers .iter() .map(|l| { - l.gate_weight.len() * 4 + let mut bytes = l.gate_weight.len() * 4 + l.input_norm_weight.len() * 4 + l.post_attn_norm_weight.len() * 4 - + l.attention.q_proj.memory_bytes() - + l.attention.k_proj.memory_bytes() - + l.attention.v_proj.memory_bytes() - + l.attention.o_proj.memory_bytes() - + l.experts - .iter() - .map(|e| { - e.gate_proj.memory_bytes() - + e.up_proj.memory_bytes() - + e.down_proj.memory_bytes() - }) - .sum::() + + l.attention.o_proj.memory_bytes(); + // Attention: MLA or GQA + if l.attention.is_mla { + bytes += l.attention.q_a.as_ref().map_or(0, |t| t.memory_bytes()); + bytes += l.attention.q_b.as_ref().map_or(0, |t| t.memory_bytes()); + bytes += l.attention.kv_a_mqa.as_ref().map_or(0, |t| t.memory_bytes()); + bytes += l.attention.k_b.as_ref().map_or(0, |t| t.memory_bytes()); + bytes += l.attention.v_b.as_ref().map_or(0, |t| t.memory_bytes()); + bytes += l.attention.q_a_norm.as_ref().map_or(0, |v| v.len() * 4); + bytes += l.attention.kv_a_norm.as_ref().map_or(0, |v| v.len() * 4); + } else { + bytes += l.attention.q_proj.memory_bytes(); + bytes += l.attention.k_proj.memory_bytes(); + bytes += l.attention.v_proj.memory_bytes(); + } + // FFN: routed experts + bytes += l.experts.iter().map(|e| { + e.gate_proj.memory_bytes() + + e.up_proj.memory_bytes() + + e.down_proj.memory_bytes() + }).sum::(); + // FFN: shared expert + if let Some(ref se) = l.shared_expert { + bytes += se.gate_proj.memory_bytes() + + se.up_proj.memory_bytes() + + se.down_proj.memory_bytes(); + } + // FFN: dense + if let Some(ref df) = l.dense_ffn { + bytes += df.gate_proj.memory_bytes() + + df.up_proj.memory_bytes() + + df.down_proj.memory_bytes(); + } + bytes }) .sum::(), }) @@ -1640,10 +2741,20 @@ mod tests { #[test] fn test_bitnet_model_config_default() { let config = BitNetModelConfig::default(); - assert_eq!(config.num_layers, 28); - assert_eq!(config.hidden_size, 4096); - assert_eq!(config.num_experts, 8); - assert_eq!(config.active_experts, 2); + // GLM-4.7-Flash defaults + assert_eq!(config.num_layers, 47); + assert_eq!(config.hidden_size, 2048); + assert_eq!(config.num_experts, 64); + assert_eq!(config.active_experts, 4); + assert_eq!(config.moe_intermediate_size, 1536); + assert!(config.use_mla); + assert_eq!(config.q_lora_rank, 768); + assert_eq!(config.kv_lora_rank, 512); + assert_eq!(config.qk_nope_head_dim, 192); + assert_eq!(config.qk_rope_head_dim, 64); + assert_eq!(config.v_head_dim, 256); + assert_eq!(config.n_shared_experts, 1); + assert_eq!(config.first_k_dense_replace, 1); } #[test] @@ -1795,8 +2906,8 @@ mod tests { } #[test] - fn test_attention_weights_struct() { - // Just verify AttentionWeights can be constructed + fn test_attention_weights_gqa() { + // Verify GQA AttentionWeights construction let packed = pack_ternary(&[1, 0, -1, 0]); let tensor = TernaryTensor { packed_data: packed.clone(), @@ -1805,14 +2916,53 @@ mod tests { block_size: 256, }; let attn = AttentionWeights { + is_mla: false, q_proj: tensor.clone(), k_proj: tensor.clone(), v_proj: tensor.clone(), o_proj: tensor, + q_a: None, q_b: None, q_a_norm: None, + kv_a_mqa: None, kv_a_norm: None, k_b: None, v_b: None, }; + assert!(!attn.is_mla); assert_eq!(attn.q_proj.shape, (1, 4)); } + #[test] + fn test_attention_weights_mla() { + // Verify MLA AttentionWeights construction + let packed = pack_ternary(&[1, 0, -1, 0]); + let tensor = TernaryTensor { + packed_data: packed.clone(), + scales: vec![1.0], + shape: (1, 4), + block_size: 256, + }; + let placeholder = TernaryTensor { + packed_data: vec![], scales: vec![], shape: (0, 0), block_size: 256, + }; + let attn = AttentionWeights { + is_mla: true, + q_proj: placeholder.clone(), + k_proj: placeholder.clone(), + v_proj: placeholder, + o_proj: tensor.clone(), + q_a: Some(tensor.clone()), + q_b: Some(tensor.clone()), + q_a_norm: Some(vec![1.0; 4]), + kv_a_mqa: Some(tensor.clone()), + kv_a_norm: Some(vec![1.0; 4]), + k_b: Some(tensor.clone()), + v_b: Some(tensor), + }; + assert!(attn.is_mla); + assert!(attn.q_a.is_some()); + assert!(attn.q_b.is_some()); + assert!(attn.kv_a_mqa.is_some()); + assert!(attn.k_b.is_some()); + assert!(attn.v_b.is_some()); + } + #[test] fn test_tok_accessor() { let mut backend = BitNetBackend::new(); @@ -1822,4 +2972,206 @@ mod tests { assert!(backend.tok().is_some()); assert_eq!(backend.tok().unwrap().vocab_size(), 260); } + + #[test] + fn test_layer_type_enum() { + assert_eq!(LayerType::Dense, LayerType::Dense); + assert_ne!(LayerType::Dense, LayerType::Moe); + assert_ne!(LayerType::Moe, LayerType::MoeWithShared); + } + + #[test] + fn test_tensor_name_mapper_embedding() { + let candidates = TensorNameMapper::embedding(); + assert_eq!(candidates.len(), 2); + assert!(candidates.contains(&"token_embd.weight".to_string())); + assert!(candidates.contains(&"model.embed_tokens.weight".to_string())); + } + + #[test] + fn test_tensor_name_mapper_mla() { + let q_a = TensorNameMapper::attn_q_a(5); + assert_eq!(q_a, vec!["blk.5.attn_q_a.weight".to_string()]); + + let q_b = TensorNameMapper::attn_q_b(5); + assert_eq!(q_b, vec!["blk.5.attn_q_b.weight".to_string()]); + + let kv_a = TensorNameMapper::attn_kv_a_mqa(5); + assert_eq!(kv_a, vec!["blk.5.attn_kv_a_mqa.weight".to_string()]); + + let k_b = TensorNameMapper::attn_k_b(5); + assert_eq!(k_b, vec!["blk.5.attn_k_b.weight".to_string()]); + + let v_b = TensorNameMapper::attn_v_b(5); + assert_eq!(v_b, vec!["blk.5.attn_v_b.weight".to_string()]); + } + + #[test] + fn test_tensor_name_mapper_norms() { + let in_norm = TensorNameMapper::input_norm(3); + assert!(in_norm.contains(&"blk.3.attn_norm.weight".to_string())); + assert!(in_norm.contains(&"model.layers.3.input_layernorm.weight".to_string())); + + let post_norm = TensorNameMapper::post_attn_norm(3); + assert!(post_norm.contains(&"blk.3.ffn_norm.weight".to_string())); + } + + #[test] + fn test_tensor_name_mapper_moe() { + let gate = TensorNameMapper::moe_gate(2); + assert!(gate.contains(&"blk.2.ffn_gate_inp.weight".to_string())); + + let exps = TensorNameMapper::ffn_gate_exps(2); + assert_eq!(exps, vec!["blk.2.ffn_gate_exps.weight".to_string()]); + + let shexp = TensorNameMapper::ffn_gate_shexp(2); + assert_eq!(shexp, vec!["blk.2.ffn_gate_shexp.weight".to_string()]); + } + + #[test] + fn test_tensor_name_mapper_dense_ffn() { + let gate = TensorNameMapper::ffn_gate(0); + assert!(gate.contains(&"blk.0.ffn_gate.weight".to_string())); + assert!(gate.contains(&"model.layers.0.mlp.gate_proj.weight".to_string())); + } + + #[test] + fn test_tensor_name_mapper_individual_experts() { + let gate = TensorNameMapper::expert_gate(1, 3); + assert_eq!(gate, vec!["model.layers.1.mlp.experts.3.gate_proj.weight".to_string()]); + } + + #[test] + fn test_mla_config_dimensions() { + let config = BitNetModelConfig::default(); + // Q head dim = qk_nope_head_dim + qk_rope_head_dim + let q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim; + assert_eq!(q_head_dim, 256); + + // Total Q dim = num_heads * q_head_dim + let total_q_dim = config.num_attention_heads * q_head_dim; + assert_eq!(total_q_dim, 5120); + + // KV compression output = kv_lora_rank + qk_rope_head_dim + let kv_a_out = config.kv_lora_rank + config.qk_rope_head_dim; + assert_eq!(kv_a_out, 576); + } + + #[test] + fn test_transformer_layer_dense() { + let packed = pack_ternary(&[1, 0, -1, 0]); + let tensor = TernaryTensor { + packed_data: packed.clone(), + scales: vec![1.0], + shape: (1, 4), + block_size: 256, + }; + let attn = AttentionWeights { + is_mla: false, + q_proj: tensor.clone(), k_proj: tensor.clone(), + v_proj: tensor.clone(), o_proj: tensor.clone(), + q_a: None, q_b: None, q_a_norm: None, + kv_a_mqa: None, kv_a_norm: None, k_b: None, v_b: None, + }; + let layer = TransformerLayer { + input_norm_weight: vec![1.0; 4], + post_attn_norm_weight: vec![1.0; 4], + attention: attn, + layer_type: LayerType::Dense, + gate_weight: Vec::new(), + experts: Vec::new(), + shared_expert: None, + dense_ffn: Some(ExpertWeights { + gate_proj: tensor.clone(), + up_proj: tensor.clone(), + down_proj: tensor, + }), + }; + assert_eq!(layer.layer_type, LayerType::Dense); + assert!(layer.dense_ffn.is_some()); + assert!(layer.shared_expert.is_none()); + } + + #[test] + fn test_transformer_layer_moe_with_shared() { + let packed = pack_ternary(&[1, 0, -1, 0]); + let tensor = TernaryTensor { + packed_data: packed.clone(), + scales: vec![1.0], + shape: (1, 4), + block_size: 256, + }; + let attn = AttentionWeights { + is_mla: false, + q_proj: tensor.clone(), k_proj: tensor.clone(), + v_proj: tensor.clone(), o_proj: tensor.clone(), + q_a: None, q_b: None, q_a_norm: None, + kv_a_mqa: None, kv_a_norm: None, k_b: None, v_b: None, + }; + let expert = ExpertWeights { + gate_proj: tensor.clone(), + up_proj: tensor.clone(), + down_proj: tensor.clone(), + }; + let layer = TransformerLayer { + input_norm_weight: vec![1.0; 4], + post_attn_norm_weight: vec![1.0; 4], + attention: attn, + layer_type: LayerType::MoeWithShared, + gate_weight: vec![1.0; 8], // 2 experts x 4 hidden + experts: vec![expert.clone(), expert.clone()], + shared_expert: Some(expert), + dense_ffn: None, + }; + assert_eq!(layer.layer_type, LayerType::MoeWithShared); + assert_eq!(layer.experts.len(), 2); + assert!(layer.shared_expert.is_some()); + } + + #[test] + fn test_tensor_discovery_report_struct() { + let report = TensorDiscoveryReport { + total_tensors: 10, + total_bytes: 1024, + architecture: Some("deepseek2".into()), + tensor_groups: vec![ + TensorGroup { + name: "Embedding".into(), + tensors: vec![TensorEntry { + name: "token_embd.weight".into(), + shape: vec![154880, 2048], + dtype: "Q8_0".into(), + bytes: 512, + }], + }, + ], + warnings: vec!["MLA detected".into()], + }; + assert_eq!(report.total_tensors, 10); + assert_eq!(report.tensor_groups.len(), 1); + assert_eq!(report.warnings.len(), 1); + } + + #[test] + fn test_model_validation_struct() { + let validation = ModelValidation { + can_load: true, + config_summary: "layers=47, hidden=2048".into(), + found: vec!["Embedding: token_embd.weight".into()], + missing: vec![], + }; + assert!(validation.can_load); + assert_eq!(validation.found.len(), 1); + assert!(validation.missing.is_empty()); + } + + #[test] + fn test_meta_helpers() { + // Test that meta_usize and meta_f32 handle missing keys + // (We can't easily construct a GgufFile in tests, so we test the + // behavior through the config defaults) + let config = BitNetModelConfig::default(); + assert_eq!(config.rope_theta, 1_000_000.0); + assert_eq!(config.routed_scaling_factor, 1.8); + } } diff --git a/crates/ruvllm/src/bitnet/mod.rs b/crates/ruvllm/src/bitnet/mod.rs index f3cc7c381..6b14d64f5 100644 --- a/crates/ruvllm/src/bitnet/mod.rs +++ b/crates/ruvllm/src/bitnet/mod.rs @@ -80,7 +80,10 @@ pub use rlm_embedder::{ RlmEmbeddingResult, }; pub use rlm_refiner::{RefinementResult, RefinementStepMetrics, RlmRefiner, RlmRefinerConfig}; -pub use backend::{BitNetBackend, BitNetModelConfig}; +pub use backend::{ + BitNetBackend, BitNetModelConfig, ModelValidation, TensorDiscoveryReport, TensorEntry, + TensorGroup, +}; pub use expert_cache::{ ExpertBatch, ExpertCache, ExpertCacheConfig, ExpertCacheStats, EvictionPolicy, MoeBatchScheduler, NullPrefetcher, Prefetcher,