From 2bb04a64eda3b7e6c39eaa41f3e41da37c093e86 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 3 Feb 2026 12:40:18 +0000 Subject: [PATCH] feat: Implement Phase 0 PT-BitNet quantizer module Add bitnet/ module with absmean ternary quantizer, TernaryTensor type, BITNET_T158 dequantization, and comprehensive test suite (~1600 lines). Components: - quantizer.rs: PtBitnetConfig, absmean_ternary(), quantize_tensor() - ternary_tensor.rs: TernaryTensor, pack/unpack 2-bit ternary encoding - dequantize.rs: dequantize_bitnet_t158(), block dequant, error metrics - tests.rs: Packing roundtrips, quantization correctness, edge cases - gguf/quantization.rs: BitnetT158 = 30 enum variant, block_size, bytes Implements AD-1 (weight representation), AD-5 (GGUF extension), AD-18 (PT-BitNet quantization) from ADR-017. https://claude.ai/code/session_011nTcGcn49b8YKJRVoh4TaK --- crates/ruvllm/src/bitnet/dequantize.rs | 269 +++++ crates/ruvllm/src/bitnet/mod.rs | 58 + crates/ruvllm/src/bitnet/quantizer.rs | 338 ++++++ crates/ruvllm/src/bitnet/ternary_tensor.rs | 276 +++++ crates/ruvllm/src/bitnet/tests.rs | 662 ++++++++++++ crates/ruvllm/src/gguf/quantization.rs | 16 + crates/ruvllm/src/lib.rs | 1 + .../bitnet-quantizer-module-design.md | 999 ++++++++++++++++++ 8 files changed, 2619 insertions(+) create mode 100644 crates/ruvllm/src/bitnet/dequantize.rs create mode 100644 crates/ruvllm/src/bitnet/mod.rs create mode 100644 crates/ruvllm/src/bitnet/quantizer.rs create mode 100644 crates/ruvllm/src/bitnet/ternary_tensor.rs create mode 100644 crates/ruvllm/src/bitnet/tests.rs create mode 100644 docs/architecture/bitnet-quantizer-module-design.md diff --git a/crates/ruvllm/src/bitnet/dequantize.rs b/crates/ruvllm/src/bitnet/dequantize.rs new file mode 100644 index 000000000..153fa4b6b --- /dev/null +++ b/crates/ruvllm/src/bitnet/dequantize.rs @@ -0,0 +1,269 @@ +//! BitNet Ternary Dequantization +//! +//! Converts packed 2-bit ternary weights back to FP32 for validation and testing. + +use super::ternary_tensor::unpack_ternary; + +/// Dequantize BITNET_T158 packed ternary data to FP32. +/// +/// This function unpacks 2-bit ternary values and applies per-block scale factors +/// to reconstruct approximate FP32 weights. Used for validation and testing, not +/// for production inference (which should use native ternary kernels). +/// +/// # Data Layout +/// +/// The input data is organized as: +/// ```text +/// [packed_block_0][scale_0][packed_block_1][scale_1]... +/// ``` +/// +/// Where each block contains: +/// - 64 bytes of packed 2-bit ternary data (256 values) +/// - 2 bytes of FP16 scale factor +/// +/// Total: 66 bytes per 256-element block +/// +/// # Arguments +/// +/// * `packed` - Raw GGUF tensor data with interleaved ternary and scales +/// * `scales` - Per-block FP32 scale factors +/// * `num_elements` - Total number of output elements +/// +/// # Returns +/// +/// Vector of FP32 weights approximating the original quantized tensor +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::dequantize_bitnet_t158; +/// +/// // Load from GGUF +/// let packed_data = gguf_tensor.data; // Raw bytes +/// let scales = vec![0.542, 0.381, ...]; // One per block +/// let num_elements = 512; +/// +/// let fp32_weights = dequantize_bitnet_t158(&packed_data, &scales, num_elements); +/// ``` +pub fn dequantize_bitnet_t158(packed: &[u8], scales: &[f32], num_elements: usize) -> Vec { + // Unpack ternary values + let ternary = unpack_ternary(packed, num_elements); + + // Apply per-block scales + let block_size = 256; // Standard BitNet block size + let mut output = Vec::with_capacity(num_elements); + + for (block_idx, chunk) in ternary.chunks(block_size).enumerate() { + let scale = scales.get(block_idx).copied().unwrap_or(1.0); + + for &ternary_val in chunk { + let fp32_val = (ternary_val as f32) * scale; + output.push(fp32_val); + } + } + + output +} + +/// Dequantize a single BITNET_T158 block. +/// +/// Helper function for block-wise dequantization in streaming scenarios. +/// +/// # Arguments +/// +/// * `packed_block` - 64 bytes of packed 2-bit ternary data +/// * `scale` - FP32 scale factor for this block +/// * `output` - Output buffer (must have capacity for 256 FP32 values) +/// +/// # Panics +/// +/// Panics if output buffer is smaller than 256 elements. +pub fn dequantize_bitnet_block(packed_block: &[u8], scale: f32, output: &mut [f32]) { + assert!( + output.len() >= 256, + "Output buffer must hold at least 256 elements" + ); + assert_eq!( + packed_block.len(), + 64, + "Packed block must be exactly 64 bytes" + ); + + let ternary = unpack_ternary(packed_block, 256); + + for (i, &ternary_val) in ternary.iter().enumerate() { + output[i] = (ternary_val as f32) * scale; + } +} + +/// Compute dequantization error metrics. +/// +/// Compares dequantized weights against original FP32 weights to measure +/// quantization quality. +/// +/// # Arguments +/// +/// * `original` - Original FP32 weights +/// * `dequantized` - Dequantized weights from ternary +/// +/// # Returns +/// +/// Tuple of (mean_absolute_error, mean_squared_error, max_error) +pub fn compute_dequant_error(original: &[f32], dequantized: &[f32]) -> (f32, f32, f32) { + assert_eq!( + original.len(), + dequantized.len(), + "Arrays must have same length" + ); + + let mut sum_abs_error = 0.0; + let mut sum_sq_error = 0.0; + let mut max_error = 0.0; + + for (orig, dequant) in original.iter().zip(dequantized.iter()) { + let error = (orig - dequant).abs(); + sum_abs_error += error; + sum_sq_error += error * error; + max_error = max_error.max(error); + } + + let n = original.len() as f32; + let mae = sum_abs_error / n; + let mse = sum_sq_error / n; + + (mae, mse, max_error) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bitnet::{absmean_ternary, pack_ternary}; + + #[test] + fn test_dequantize_bitnet_t158_simple() { + // Create simple ternary data + let ternary = vec![-1i8, 0, 1, -1, 1, 0, 0, 1]; + let packed = pack_ternary(&ternary); + let scales = vec![0.5f32]; + + let result = dequantize_bitnet_t158(&packed, &scales, 8); + + assert_eq!(result.len(), 8); + + // Check values: ternary * scale + assert_eq!(result[0], -0.5); // -1 * 0.5 + assert_eq!(result[1], 0.0); // 0 * 0.5 + assert_eq!(result[2], 0.5); // 1 * 0.5 + assert_eq!(result[3], -0.5); // -1 * 0.5 + } + + #[test] + fn test_dequantize_bitnet_block() { + // Create a full 256-element block + let ternary = vec![1i8; 256]; + let packed = pack_ternary(&ternary); + let scale = 2.0; + + let mut output = vec![0.0f32; 256]; + dequantize_bitnet_block(&packed, scale, &mut output); + + // All values should be 1 * 2.0 = 2.0 + assert!(output.iter().all(|&v| (v - 2.0).abs() < 1e-6)); + } + + #[test] + fn test_dequantize_multiple_blocks() { + // Two blocks with different scales + let ternary1 = vec![1i8; 256]; + let ternary2 = vec![-1i8; 256]; + + let mut all_ternary = ternary1.clone(); + all_ternary.extend_from_slice(&ternary2); + + let packed = pack_ternary(&all_ternary); + let scales = vec![1.0, 2.0]; + + let result = dequantize_bitnet_t158(&packed, &scales, 512); + + // First 256 should be 1.0 * 1.0 = 1.0 + assert!(result[..256].iter().all(|&v| (v - 1.0).abs() < 1e-6)); + + // Next 256 should be -1.0 * 2.0 = -2.0 + assert!(result[256..512] + .iter() + .all(|&v| (v - (-2.0)).abs() < 1e-6)); + } + + #[test] + fn test_roundtrip_quantize_dequantize() { + // Original weights + let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4, 0.2, -0.6]; + + // Quantize + let (ternary, scale) = absmean_ternary(&original); + let packed = pack_ternary(&ternary); + + // Dequantize + let dequantized = dequantize_bitnet_t158(&packed, &[scale], original.len()); + + // Check that we got 8 values back + assert_eq!(dequantized.len(), 8); + + // Values should be approximate (quantization loses precision) + // But should be close for values near the scale + for (orig, dequant) in original.iter().zip(dequantized.iter()) { + let error = (orig - dequant).abs(); + // Error should be bounded by the quantization step (~scale) + assert!(error < scale * 2.0); + } + } + + #[test] + fn test_compute_dequant_error() { + let original = vec![1.0, 2.0, 3.0, 4.0]; + let dequantized = vec![1.1, 1.9, 3.2, 3.8]; + + let (mae, mse, max_error) = compute_dequant_error(&original, &dequantized); + + // MAE should be (0.1 + 0.1 + 0.2 + 0.2) / 4 = 0.15 + assert!((mae - 0.15).abs() < 1e-6); + + // MSE should be (0.01 + 0.01 + 0.04 + 0.04) / 4 = 0.025 + assert!((mse - 0.025).abs() < 1e-6); + + // Max error should be 0.2 + assert!((max_error - 0.2).abs() < 1e-6); + } + + #[test] + #[should_panic(expected = "Output buffer must hold at least 256 elements")] + fn test_dequantize_block_small_buffer() { + let packed = vec![0u8; 64]; + let mut output = vec![0.0f32; 128]; // Too small + dequantize_bitnet_block(&packed, 1.0, &mut output); + } + + #[test] + #[should_panic(expected = "Packed block must be exactly 64 bytes")] + fn test_dequantize_block_wrong_size() { + let packed = vec![0u8; 32]; // Wrong size + let mut output = vec![0.0f32; 256]; + dequantize_bitnet_block(&packed, 1.0, &mut output); + } + + #[test] + fn test_dequantize_with_missing_scales() { + // More elements than scales (should use default 1.0) + let ternary = vec![1i8; 512]; + let packed = pack_ternary(&ternary); + let scales = vec![2.0]; // Only one scale for two blocks + + let result = dequantize_bitnet_t158(&packed, &scales, 512); + + // First 256 use scale 2.0 + assert!(result[..256].iter().all(|&v| (v - 2.0).abs() < 1e-6)); + + // Next 256 use default 1.0 + assert!(result[256..512].iter().all(|&v| (v - 1.0).abs() < 1e-6)); + } +} diff --git a/crates/ruvllm/src/bitnet/mod.rs b/crates/ruvllm/src/bitnet/mod.rs new file mode 100644 index 000000000..98e4a328b --- /dev/null +++ b/crates/ruvllm/src/bitnet/mod.rs @@ -0,0 +1,58 @@ +//! BitNet b1.58 Ternary Quantization for RuvLLM +//! +//! This module implements Microsoft Research's BitNet b1.58 ternary weight quantization +//! for the Craftsman Ultra 30b 1bit model. It provides post-training quantization (PTQ) +//! of FP16 weights to ternary {-1, 0, +1} using absmean quantization. +//! +//! ## Overview +//! +//! BitNet b1.58 enables multiplication-free inference by quantizing weights to three values: +//! -1, 0, +1. This reduces memory footprint to ~2 bits per weight and eliminates floating-point +//! multiplication in matrix operations. +//! +//! ## Key Components +//! +//! - [`TernaryTensor`]: Container for ternary weights with 2-bit packing +//! - [`quantize_tensor`]: Convert FP32 weights to ternary using absmean algorithm +//! - [`dequantize_bitnet_t158`]: Convert packed ternary back to FP32 for validation +//! - [`PtBitnetConfig`]: Configuration for post-training quantization +//! +//! ## Example +//! +//! ```rust,ignore +//! use ruvllm::bitnet::{quantize_tensor, PtBitnetConfig}; +//! +//! // Configure quantization +//! let config = PtBitnetConfig { +//! block_size: 256, +//! optimize_scales: true, +//! ..Default::default() +//! }; +//! +//! // Quantize a weight tensor +//! let fp32_weights = vec![0.5, -0.3, 0.0, 0.8, /* ... */]; +//! let ternary = quantize_tensor(&fp32_weights, (128, 256), &config)?; +//! +//! println!("Sparsity: {:.2}%", ternary.sparsity() * 100.0); +//! println!("Memory: {} bytes", ternary.memory_bytes()); +//! ``` +//! +//! ## Architecture Details +//! +//! From ADR-017 (AD-1, AD-5, AD-18): +//! +//! - **Absmean quantization**: `W_ternary = RoundClip(W / (mean(|W|) + ε), -1, 1)` +//! - **2-bit packing**: 00=-1, 01=0, 10=+1 (4 values per byte) +//! - **Block size**: 256 elements per scale factor +//! - **Storage**: 66 bytes per block (64 bytes ternary + 2 bytes FP16 scale) +//! - **Compression**: 2.06 bits/weight (30B model → ~7.7 GB) + +pub mod dequantize; +pub mod quantizer; +pub mod ternary_tensor; + +pub use dequantize::dequantize_bitnet_t158; +pub use quantizer::{ + absmean_ternary, quantize_tensor, LayerMask, Precision, PtBitnetConfig, TernaryFormat, +}; +pub use ternary_tensor::{pack_ternary, unpack_ternary, TernaryTensor}; diff --git a/crates/ruvllm/src/bitnet/quantizer.rs b/crates/ruvllm/src/bitnet/quantizer.rs new file mode 100644 index 000000000..86e48bbf6 --- /dev/null +++ b/crates/ruvllm/src/bitnet/quantizer.rs @@ -0,0 +1,338 @@ +//! PT-BitNet Post-Training Quantization +//! +//! Core absmean ternary quantization algorithm for converting FP32 weights +//! to BitNet b1.58 ternary format. + +use crate::error::{Result, RuvLLMError}; +use super::ternary_tensor::{pack_ternary, TernaryTensor}; + +/// Configuration for PT-BitNet post-training quantization. +/// +/// Controls the quantization process behavior, including block size, +/// calibration, and layer selection. +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::PtBitnetConfig; +/// +/// let config = PtBitnetConfig { +/// calibration_samples: 1000, +/// block_size: 256, +/// optimize_scales: true, +/// layers_to_quantize: LayerMask::ExpertsOnly, +/// export_format: TernaryFormat::BitnetT158, +/// ..Default::default() +/// }; +/// ``` +#[derive(Debug, Clone)] +pub struct PtBitnetConfig { + /// Number of calibration samples for scale optimization + pub calibration_samples: usize, + /// Elements per quantization block + pub block_size: usize, + /// Enable scale factor optimization via calibration + pub optimize_scales: bool, + /// Which layers to quantize + pub layers_to_quantize: LayerMask, + /// Export format for GGUF serialization + pub export_format: TernaryFormat, + /// Precision for router and shared layers + pub router_precision: Precision, + /// Use memory-mapped I/O for weight loading + pub use_mmap: bool, + /// Use Metal GPU for calibration (Mac Studio only) + pub use_metal_calibration: bool, + /// Maximum memory budget in GB + pub max_memory_gb: usize, +} + +impl Default for PtBitnetConfig { + fn default() -> Self { + Self { + calibration_samples: 1000, + block_size: 256, + optimize_scales: true, + layers_to_quantize: LayerMask::ExpertsOnly, + export_format: TernaryFormat::BitnetT158, + router_precision: Precision::FP16, + use_mmap: true, + use_metal_calibration: cfg!(all(target_os = "macos", feature = "metal-compute")), + max_memory_gb: 64, + } + } +} + +/// Layer selection mask for quantization. +/// +/// Determines which model layers to convert to ternary. Per ADR-017 (AD-2), +/// the MoE router, embeddings, and LM head must remain in higher precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LayerMask { + /// Only MoE expert FFN layers (recommended for Phase 1) + ExpertsOnly, + /// All linear layers except router/embeddings/head + All, + /// Custom layer selection by name pattern + Custom(Vec), +} + +/// Ternary tensor export format. +/// +/// Determines the GGUF quantization type used for serialization. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TernaryFormat { + /// BitNet b1.58 native format (type 30) + BitnetT158, + /// IQ1_S compatible format (type 19) + IQ1S, +} + +/// Precision for non-quantized layers. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Precision { + /// 16-bit floating point + FP16, + /// Brain floating point 16 + BF16, + /// 32-bit floating point + FP32, +} + +/// Core absmean ternary quantization algorithm. +/// +/// Implements the BitNet b1.58 quantization formula: +/// ```text +/// gamma = mean(|block|) + epsilon +/// normalized = block / gamma +/// ternary = round(clamp(normalized, -1, 1)) +/// ``` +/// +/// # Arguments +/// +/// * `block` - FP32 weight block (typically 256 elements) +/// +/// # Returns +/// +/// Tuple of (ternary values, scale factor): +/// - `Vec`: Ternary weights in {-1, 0, +1} +/// - `f32`: Absmean scale factor (gamma) +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::absmean_ternary; +/// +/// let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4]; +/// let (ternary, scale) = absmean_ternary(&weights); +/// +/// println!("Scale: {}", scale); +/// println!("Ternary: {:?}", ternary); // e.g., [1, -1, 1, 0, 0, 1] +/// ``` +pub fn absmean_ternary(block: &[f32]) -> (Vec, f32) { + // Compute absmean scale: gamma = mean(|W|) + let sum_abs: f32 = block.iter().map(|&w| w.abs()).sum(); + let gamma = (sum_abs / block.len() as f32) + 1e-8; + + // Normalize and quantize to {-1, 0, +1} + let ternary: Vec = block + .iter() + .map(|&w| { + let normalized = w / gamma; + let clamped = normalized.clamp(-1.0, 1.0); + clamped.round() as i8 + }) + .collect(); + + (ternary, gamma) +} + +/// Quantize a full FP32 tensor to ternary representation. +/// +/// Processes the input tensor in blocks of `config.block_size`, applying +/// absmean quantization to each block independently. +/// +/// # Arguments +/// +/// * `weights` - FP32 weight tensor (flattened) +/// * `shape` - Tensor shape (rows, cols) +/// * `config` - Quantization configuration +/// +/// # Returns +/// +/// `TernaryTensor` with packed 2-bit data and per-block scales +/// +/// # Errors +/// +/// Returns an error if the weight dimensions are invalid. +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::{quantize_tensor, PtBitnetConfig}; +/// +/// let weights = vec![0.5; 512]; // 512 FP32 weights +/// let shape = (2, 256); +/// let config = PtBitnetConfig::default(); +/// +/// let ternary = quantize_tensor(&weights, shape, &config)?; +/// println!("Compressed to {} bytes", ternary.memory_bytes()); +/// ``` +pub fn quantize_tensor( + weights: &[f32], + shape: (usize, usize), + config: &PtBitnetConfig, +) -> Result { + let (rows, cols) = shape; + let total_elements = rows * cols; + + if weights.len() != total_elements { + return Err(RuvLLMError::Model(format!( + "Weight size mismatch: expected {} elements for shape {:?}, got {}", + total_elements, + shape, + weights.len() + ))); + } + + let block_size = config.block_size; + let num_blocks = (total_elements + block_size - 1) / block_size; + + let mut all_ternary = Vec::with_capacity(total_elements); + let mut scales = Vec::with_capacity(num_blocks); + + // Process each block + for block_idx in 0..num_blocks { + let start = block_idx * block_size; + let end = (start + block_size).min(total_elements); + let block = &weights[start..end]; + + let (ternary, scale) = absmean_ternary(block); + all_ternary.extend_from_slice(&ternary); + scales.push(scale); + } + + // Pack ternary values into 2-bit representation + let packed_data = pack_ternary(&all_ternary); + + Ok(TernaryTensor { + packed_data, + scales, + shape, + block_size, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_absmean_ternary_simple() { + // Simple block with known values + let block = vec![0.5, -0.5, 0.0, 1.0, -1.0, 0.25]; + let (ternary, scale) = absmean_ternary(&block); + + // All values should be in {-1, 0, +1} + assert!(ternary.iter().all(|&v| v >= -1 && v <= 1)); + + // Scale should be positive + assert!(scale > 0.0); + + // Check specific values + // gamma ≈ (0.5 + 0.5 + 0.0 + 1.0 + 1.0 + 0.25) / 6 ≈ 0.542 + // 0.5 / 0.542 ≈ 0.92 → round(0.92) = 1 + // -0.5 / 0.542 ≈ -0.92 → round(-0.92) = -1 + // 0.0 / 0.542 = 0 → round(0) = 0 + assert_eq!(ternary[0], 1); + assert_eq!(ternary[1], -1); + assert_eq!(ternary[2], 0); + } + + #[test] + fn test_absmean_ternary_all_zeros() { + let block = vec![0.0; 256]; + let (ternary, scale) = absmean_ternary(&block); + + // All should quantize to 0 + assert!(ternary.iter().all(|&v| v == 0)); + + // Scale should be epsilon (1e-8) + assert!(scale < 1e-7 && scale > 0.0); + } + + #[test] + fn test_absmean_ternary_large_values() { + let block = vec![10.0, -10.0, 5.0, -5.0]; + let (ternary, _scale) = absmean_ternary(&block); + + // All should saturate to ±1 + assert!(ternary[0] == 1 || ternary[0] == -1); + assert!(ternary[1] == 1 || ternary[1] == -1); + } + + #[test] + fn test_quantize_tensor_simple() { + let weights = vec![0.5; 512]; // 512 identical weights + let shape = (2, 256); + let config = PtBitnetConfig::default(); + + let ternary = quantize_tensor(&weights, shape, &config).unwrap(); + + assert_eq!(ternary.shape, shape); + assert_eq!(ternary.block_size, 256); + assert_eq!(ternary.num_blocks(), 2); // 512 / 256 = 2 blocks + assert_eq!(ternary.scales.len(), 2); + + // 512 elements packed in 2 bits each = 128 bytes + assert_eq!(ternary.packed_data.len(), 128); + } + + #[test] + fn test_quantize_tensor_size_mismatch() { + let weights = vec![0.5; 100]; // Wrong size + let shape = (2, 256); // Expects 512 + let config = PtBitnetConfig::default(); + + let result = quantize_tensor(&weights, shape, &config); + assert!(result.is_err()); + } + + #[test] + fn test_quantize_tensor_memory_savings() { + // Quantize a 1MB FP32 tensor (256K elements) + let weights = vec![0.5; 256 * 1024]; + let shape = (512, 512); + let config = PtBitnetConfig::default(); + + let ternary = quantize_tensor(&weights, shape, &config).unwrap(); + + let original_bytes = weights.len() * 4; // FP32 + let compressed_bytes = ternary.memory_bytes(); + + // Should be ~16x compression (32 bits → 2 bits + scale overhead) + let compression_ratio = original_bytes as f32 / compressed_bytes as f32; + assert!(compression_ratio > 10.0); // At least 10x compression + assert!(compression_ratio < 20.0); // Less than 20x (due to scales) + } + + #[test] + fn test_config_default() { + let config = PtBitnetConfig::default(); + assert_eq!(config.block_size, 256); + assert_eq!(config.calibration_samples, 1000); + assert!(config.optimize_scales); + assert_eq!(config.layers_to_quantize, LayerMask::ExpertsOnly); + } + + #[test] + fn test_layer_mask_variants() { + let experts = LayerMask::ExpertsOnly; + let all = LayerMask::All; + let custom = LayerMask::Custom(vec!["layer.0".to_string()]); + + assert_ne!(experts, all); + assert_ne!(all, custom); + assert_ne!(experts, custom); + } +} diff --git a/crates/ruvllm/src/bitnet/ternary_tensor.rs b/crates/ruvllm/src/bitnet/ternary_tensor.rs new file mode 100644 index 000000000..f4d39dc3c --- /dev/null +++ b/crates/ruvllm/src/bitnet/ternary_tensor.rs @@ -0,0 +1,276 @@ +//! Ternary Tensor Data Structure +//! +//! This module provides the `TernaryTensor` container for BitNet b1.58 ternary weights, +//! along with efficient 2-bit packing/unpacking functions. + +/// Ternary tensor with 2-bit packed representation. +/// +/// Stores ternary weights {-1, 0, +1} in a compact 2-bit format: +/// - 00 = -1 +/// - 01 = 0 +/// - 10 = +1 +/// - 11 = reserved (unused) +/// +/// Each block of `block_size` elements shares a single FP32 scale factor +/// derived from the absmean quantization process. +/// +/// # Memory Layout +/// +/// For a tensor with shape (m, n) and block_size B: +/// - `packed_data`: ceil(m*n / 4) bytes (4 ternary values per byte) +/// - `scales`: ceil(m*n / B) * 4 bytes (one FP32 scale per block) +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::TernaryTensor; +/// +/// let tensor = TernaryTensor { +/// packed_data: vec![0b10010100], // [+1, 0, +1, 0] +/// scales: vec![0.5], +/// shape: (2, 2), +/// block_size: 256, +/// }; +/// +/// println!("Sparsity: {:.2}%", tensor.sparsity() * 100.0); +/// println!("Memory: {} bytes", tensor.memory_bytes()); +/// ``` +#[derive(Debug, Clone)] +pub struct TernaryTensor { + /// Packed 2-bit ternary data (4 values per byte) + pub packed_data: Vec, + /// Per-block scale factors (FP32) + pub scales: Vec, + /// Tensor shape (rows, cols) + pub shape: (usize, usize), + /// Elements per quantization block + pub block_size: usize, +} + +impl TernaryTensor { + /// Calculate the fraction of zero weights (sparsity). + /// + /// Zero weights enable feature filtering and reduce computation + /// in ternary matrix multiplication. + /// + /// # Returns + /// + /// Fraction of weights that are exactly 0, in range [0.0, 1.0] + pub fn sparsity(&self) -> f32 { + let total_elements = self.shape.0 * self.shape.1; + let unpacked = unpack_ternary(&self.packed_data, total_elements); + + let zero_count = unpacked.iter().filter(|&&x| x == 0).count(); + zero_count as f32 / total_elements as f32 + } + + /// Calculate total memory footprint in bytes. + /// + /// Includes both packed ternary data and per-block scales. + /// + /// # Returns + /// + /// Total bytes: packed_data.len() + scales.len() * 4 + pub fn memory_bytes(&self) -> usize { + self.packed_data.len() + self.scales.len() * 4 + } + + /// Get the number of quantization blocks. + pub fn num_blocks(&self) -> usize { + let total_elements = self.shape.0 * self.shape.1; + (total_elements + self.block_size - 1) / self.block_size + } +} + +/// Pack ternary values {-1, 0, +1} into 2-bit representation. +/// +/// Encoding: +/// - -1 → 00 +/// - 0 → 01 +/// - +1 → 10 +/// - (unused) → 11 +/// +/// Four values are packed into each byte in LSB-first order: +/// ```text +/// byte = [v3:v2:v1:v0] +/// ``` +/// +/// # Arguments +/// +/// * `values` - Slice of i8 values, must be in {-1, 0, +1} +/// +/// # Returns +/// +/// Vector of bytes, length = ceil(values.len() / 4) +/// +/// # Panics +/// +/// Panics if any value is not in {-1, 0, +1} +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::pack_ternary; +/// +/// let values = vec![-1, 0, 1, -1]; +/// let packed = pack_ternary(&values); +/// assert_eq!(packed.len(), 1); // 4 values in 1 byte +/// ``` +pub fn pack_ternary(values: &[i8]) -> Vec { + let num_bytes = (values.len() + 3) / 4; + let mut packed = vec![0u8; num_bytes]; + + for (i, &val) in values.iter().enumerate() { + let byte_idx = i / 4; + let bit_offset = (i % 4) * 2; + + let encoded: u8 = match val { + -1 => 0b00, + 0 => 0b01, + 1 => 0b10, + _ => panic!("Invalid ternary value: {} (must be -1, 0, or +1)", val), + }; + + packed[byte_idx] |= encoded << bit_offset; + } + + packed +} + +/// Unpack 2-bit ternary values to i8. +/// +/// Decoding: +/// - 00 → -1 +/// - 01 → 0 +/// - 10 → +1 +/// - 11 → 0 (reserved, treated as zero) +/// +/// # Arguments +/// +/// * `packed` - Packed 2-bit data +/// * `n` - Number of elements to unpack +/// +/// # Returns +/// +/// Vector of i8 values in {-1, 0, +1} +/// +/// # Example +/// +/// ```rust,ignore +/// use ruvllm::bitnet::{pack_ternary, unpack_ternary}; +/// +/// let original = vec![-1, 0, 1, -1]; +/// let packed = pack_ternary(&original); +/// let unpacked = unpack_ternary(&packed, 4); +/// assert_eq!(original, unpacked); +/// ``` +pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec { + let mut values = Vec::with_capacity(n); + + for i in 0..n { + let byte_idx = i / 4; + let bit_offset = (i % 4) * 2; + + if byte_idx >= packed.len() { + break; + } + + let encoded = (packed[byte_idx] >> bit_offset) & 0b11; + + let val = match encoded { + 0b00 => -1, + 0b01 => 0, + 0b10 => 1, + 0b11 => 0, // Reserved, treat as zero + _ => unreachable!(), + }; + + values.push(val); + } + + values +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pack_unpack_ternary() { + let values = vec![-1, 0, 1, -1, 1, 0, 0, 1]; + let packed = pack_ternary(&values); + let unpacked = unpack_ternary(&packed, values.len()); + assert_eq!(values, unpacked); + } + + #[test] + fn test_pack_ternary_single_byte() { + // 4 values fit in 1 byte + let values = vec![-1, 0, 1, -1]; + let packed = pack_ternary(&values); + assert_eq!(packed.len(), 1); + + // Manually verify encoding + // -1=00, 0=01, 1=10, -1=00 + // byte = [00:10:01:00] = 0b00_10_01_00 = 0x08 + assert_eq!(packed[0], 0b00_10_01_00); + } + + #[test] + fn test_pack_ternary_partial_byte() { + // 5 values need 2 bytes + let values = vec![-1, 0, 1, -1, 1]; + let packed = pack_ternary(&values); + assert_eq!(packed.len(), 2); + } + + #[test] + #[should_panic(expected = "Invalid ternary value")] + fn test_pack_invalid_value() { + let values = vec![-1, 0, 2]; // 2 is invalid + pack_ternary(&values); + } + + #[test] + fn test_ternary_tensor_sparsity() { + let values = vec![0, 1, 0, -1, 0, 0, 1, 0]; // 5 zeros out of 8 + let packed = pack_ternary(&values); + + let tensor = TernaryTensor { + packed_data: packed, + scales: vec![1.0], + shape: (2, 4), + block_size: 256, + }; + + let sparsity = tensor.sparsity(); + assert!((sparsity - 0.625).abs() < 0.001); // 5/8 = 0.625 + } + + #[test] + fn test_ternary_tensor_memory() { + let packed = vec![0u8; 64]; // 64 bytes of packed data + let scales = vec![0.5f32; 16]; // 16 scales * 4 bytes = 64 bytes + + let tensor = TernaryTensor { + packed_data: packed, + scales, + shape: (128, 256), + block_size: 256, + }; + + assert_eq!(tensor.memory_bytes(), 64 + 64); // 128 bytes total + } + + #[test] + fn test_ternary_tensor_num_blocks() { + let tensor = TernaryTensor { + packed_data: vec![], + scales: vec![], + shape: (256, 256), // 65536 elements + block_size: 256, // 256 elements per block + }; + + assert_eq!(tensor.num_blocks(), 256); // 65536 / 256 = 256 blocks + } +} diff --git a/crates/ruvllm/src/bitnet/tests.rs b/crates/ruvllm/src/bitnet/tests.rs new file mode 100644 index 000000000..99e2264a9 --- /dev/null +++ b/crates/ruvllm/src/bitnet/tests.rs @@ -0,0 +1,662 @@ +//! Comprehensive tests for PT-BitNet Phase 0 ternary quantization +//! +//! Test coverage based on ADR-017 (AD-1, AD-18): +//! - Ternary packing/unpacking roundtrips +//! - Absmean quantization correctness +//! - Dequantization accuracy +//! - Full tensor quantization +//! - Edge cases and error conditions + +use super::{ + dequantize_bitnet_t158, pack_ternary, quantize_tensor, unpack_ternary, PtBitnetConfig, + TernaryTensor, +}; + +// ============================================================================ +// Test Constants +// ============================================================================ + +const EPSILON: f32 = 1e-6; +const BLOCK_SIZE: usize = 256; + +// ============================================================================ +// 1. Ternary Packing Roundtrip Tests +// ============================================================================ + +#[test] +fn test_pack_unpack_simple_roundtrip() { + // Simple 4-element ternary array + let ternary = vec![1i8, 0, -1, 1]; + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 4); + + assert_eq!(ternary, unpacked, "Packing roundtrip failed for [1, 0, -1, 1]"); +} + +#[test] +fn test_pack_all_zeros() { + let ternary = vec![0i8; 256]; + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 256); + + assert_eq!(ternary, unpacked); + assert!(unpacked.iter().all(|&x| x == 0), "All zeros should remain all zeros"); +} + +#[test] +fn test_pack_all_ones() { + let ternary = vec![1i8; 256]; + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 256); + + assert_eq!(ternary, unpacked); + assert!(unpacked.iter().all(|&x| x == 1), "All +1 should remain all +1"); +} + +#[test] +fn test_pack_all_neg_ones() { + let ternary = vec![-1i8; 256]; + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 256); + + assert_eq!(ternary, unpacked); + assert!(unpacked.iter().all(|&x| x == -1), "All -1 should remain all -1"); +} + +#[test] +fn test_pack_one_block_256_elements() { + // One full block (256 elements) with alternating pattern + let mut ternary = Vec::with_capacity(256); + for i in 0..256 { + ternary.push(match i % 3 { + 0 => 1, + 1 => 0, + 2 => -1, + _ => unreachable!(), + }); + } + + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 256); + + assert_eq!(ternary, unpacked, "256-element block roundtrip failed"); + + // Verify storage size: 256 elements * 2 bits = 64 bytes + assert_eq!(packed.len(), 64, "Packed size should be 64 bytes for 256 elements"); +} + +#[test] +fn test_pack_non_aligned_size() { + // 100 elements (not divisible by 128, the typical packing boundary) + let mut ternary = Vec::with_capacity(100); + for i in 0..100 { + ternary.push(if i % 2 == 0 { 1 } else { -1 }); + } + + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 100); + + assert_eq!( + ternary.len(), + unpacked.len(), + "Unpacked length should match original" + ); + assert_eq!(ternary, unpacked, "Non-aligned size roundtrip failed"); +} + +#[test] +fn test_pack_large_tensor() { + // Multiple blocks (1024 elements = 4 blocks) + let ternary: Vec = (0..1024) + .map(|i| match i % 5 { + 0 | 1 => 1, + 2 | 3 => -1, + 4 => 0, + _ => unreachable!(), + }) + .collect(); + + let packed = pack_ternary(&ternary); + let unpacked = unpack_ternary(&packed, 1024); + + assert_eq!(ternary, unpacked, "Large tensor roundtrip failed"); +} + +// ============================================================================ +// 2. Absmean Quantization Correctness Tests +// ============================================================================ + +#[test] +fn test_quantize_uniform_random() { + // Uniform random weights in [-1, 1] should produce all ternary values + let weights = vec![0.5, -0.3, 0.1, -0.7, 0.9, -0.1, 0.0, 0.4]; + let ternary = quantize_absmean(&weights); + + // All outputs must be in {-1, 0, +1} + for &t in &ternary { + assert!( + t == -1 || t == 0 || t == 1, + "Quantized value {} not in ternary set", + t + ); + } +} + +#[test] +fn test_quantize_all_zeros() { + let weights = vec![0.0; 256]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // All ternary values should be zero + assert!( + ternary.iter().all(|&x| x == 0), + "All-zero input should produce all-zero ternary" + ); + + // Scale should be near epsilon (avoiding division by zero) + assert!( + scale < 1e-5, + "Scale for all-zero weights should be near epsilon, got {}", + scale + ); +} + +#[test] +fn test_quantize_large_positive() { + // Large positive weights should quantize to all +1 + let weights = vec![10.0; 256]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // All should be +1 + assert!( + ternary.iter().all(|&x| x == 1), + "Large positive weights should quantize to +1" + ); + + // Scale should be approximately 10.0 (mean absolute value) + assert!( + (scale - 10.0).abs() < 0.1, + "Scale should be ~10.0, got {}", + scale + ); +} + +#[test] +fn test_quantize_large_negative() { + // Large negative weights should quantize to all -1 + let weights = vec![-10.0; 256]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // All should be -1 + assert!( + ternary.iter().all(|&x| x == -1), + "Large negative weights should quantize to -1" + ); + + // Scale should be approximately 10.0 (mean absolute value) + assert!( + (scale - 10.0).abs() < 0.1, + "Scale should be ~10.0, got {}", + scale + ); +} + +#[test] +fn test_quantize_known_example() { + // From ADR: W_ternary = RoundClip(W / (mean(|W|) + epsilon), -1, 1) + // Example: weights = [0.5, -0.3, 0.1, -0.7] + // gamma = mean(|W|) = (0.5 + 0.3 + 0.1 + 0.7) / 4 = 0.4 + // normalized = [1.25, -0.75, 0.25, -1.75] + // ternary = [1, -1, 0, -1] (after clamp and round) + + let weights = vec![0.5, -0.3, 0.1, -0.7]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // Verify scale is approximately 0.4 + assert!( + (scale - 0.4).abs() < 0.01, + "Expected scale ~0.4, got {}", + scale + ); + + // Verify ternary values + // 1.25 -> 1, -0.75 -> -1, 0.25 -> 0, -1.75 -> -1 + assert_eq!(ternary[0], 1, "0.5/0.4 = 1.25 should round to 1"); + assert_eq!(ternary[1], -1, "-0.3/0.4 = -0.75 should round to -1"); + assert_eq!(ternary[2], 0, "0.1/0.4 = 0.25 should round to 0"); + assert_eq!(ternary[3], -1, "-0.7/0.4 = -1.75 should clamp to -1"); +} + +#[test] +fn test_quantize_scale_calculation() { + // Verify scale = mean(|weights|) + let weights = vec![1.0, -2.0, 3.0, -4.0]; + let (_, scale) = quantize_absmean_with_scale(&weights); + + let expected_scale = (1.0 + 2.0 + 3.0 + 4.0) / 4.0; // = 2.5 + assert!( + (scale - expected_scale).abs() < EPSILON, + "Scale should be mean of absolute values: expected {}, got {}", + expected_scale, + scale + ); +} + +// ============================================================================ +// 3. Dequantization Correctness Tests +// ============================================================================ + +#[test] +fn test_dequantize_simple() { + let ternary = vec![1i8, 0, -1]; + let scale = 2.0; + + let dequantized = dequantize_ternary(&ternary, scale); + + assert_eq!(dequantized.len(), 3); + assert!((dequantized[0] - 2.0).abs() < EPSILON, "1 * 2.0 = 2.0"); + assert!((dequantized[1] - 0.0).abs() < EPSILON, "0 * 2.0 = 0.0"); + assert!((dequantized[2] - (-2.0)).abs() < EPSILON, "-1 * 2.0 = -2.0"); +} + +#[test] +fn test_dequantize_packed_data() { + // Pack known ternary data, then dequantize + let ternary = vec![1i8, 0, -1, 1]; + let packed = pack_ternary(&ternary); + let scale = 3.5; + + let unpacked = unpack_ternary(&packed, 4); + let dequantized = dequantize_ternary(&unpacked, scale); + + assert_eq!(dequantized.len(), 4); + assert!((dequantized[0] - 3.5).abs() < EPSILON); + assert!((dequantized[1] - 0.0).abs() < EPSILON); + assert!((dequantized[2] - (-3.5)).abs() < EPSILON); + assert!((dequantized[3] - 3.5).abs() < EPSILON); +} + +#[test] +fn test_quantize_dequantize_roundtrip_mse() { + // Quantize -> Dequantize should have bounded MSE + let weights = vec![0.5, -0.3, 0.1, -0.7, 0.9, -0.1, 0.4, -0.5]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + let dequantized = dequantize_ternary(&ternary, scale); + + // Compute MSE + let mse: f32 = weights + .iter() + .zip(dequantized.iter()) + .map(|(&w, &d)| (w - d).powi(2)) + .sum::() + / weights.len() as f32; + + // MSE should be reasonable (ternary quantization is lossy) + // For absmean, expect MSE < 0.5 for normalized weights + assert!( + mse < 0.5, + "MSE too high: {} (weights may not reconstruct well)", + mse + ); +} + +#[test] +fn test_dequantize_full_block() { + // Dequantize a full 256-element block + let ternary: Vec = (0..256).map(|i| if i % 2 == 0 { 1 } else { -1 }).collect(); + let scale = 1.5; + + let dequantized = dequantize_ternary(&ternary, scale); + + assert_eq!(dequantized.len(), 256); + for (i, &val) in dequantized.iter().enumerate() { + let expected = if i % 2 == 0 { 1.5 } else { -1.5 }; + assert!( + (val - expected).abs() < EPSILON, + "Element {} incorrect: expected {}, got {}", + i, + expected, + val + ); + } +} + +// ============================================================================ +// 4. Full Tensor Quantization Tests +// ============================================================================ + +#[test] +fn test_tensor_quantize_256x256() { + // 256x256 random tensor (65536 elements) + let mut weights = Vec::with_capacity(65536); + for i in 0..65536 { + let val = ((i as f32) * 0.001).sin(); // Pseudo-random in [-1, 1] + weights.push(val); + } + + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + // Verify shape preserved + assert_eq!( + tensor.num_elements(), + 65536, + "Tensor should preserve element count" + ); + + // Verify sparsity is in valid range + let sparsity = tensor.sparsity(); + assert!( + sparsity >= 0.0 && sparsity <= 1.0, + "Sparsity {} out of range [0, 1]", + sparsity + ); + + // For uniform random, expect ~1/3 zeros (rough heuristic) + assert!( + sparsity > 0.15 && sparsity < 0.5, + "Sparsity {} seems unrealistic for uniform random input", + sparsity + ); +} + +#[test] +fn test_tensor_memory_bytes() { + let weights = vec![0.5; 256]; + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + // Expected memory: + // - Packed data: 256 elements * 2 bits / 8 = 64 bytes + // - Scales: 1 block * 4 bytes (f32) = 4 bytes + // Total: 68 bytes + let expected_bytes = 64 + 4; + + assert_eq!( + tensor.memory_bytes(), + expected_bytes, + "Memory calculation incorrect" + ); +} + +#[test] +fn test_tensor_sparsity_calculation() { + // Known sparsity: 50% zeros + let weights: Vec = (0..256) + .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 }) + .collect(); + + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + let sparsity = tensor.sparsity(); + + // Should be close to 0.5 (half zeros) + assert!( + (sparsity - 0.5).abs() < 0.1, + "Expected sparsity ~0.5, got {}", + sparsity + ); +} + +#[test] +fn test_tensor_block_alignment() { + // 512 elements = 2 blocks of 256 + let weights = vec![1.0; 512]; + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + // Should have 2 scale factors (one per block) + assert_eq!( + tensor.num_blocks(), + 2, + "Expected 2 blocks for 512 elements" + ); +} + +#[test] +fn test_tensor_non_aligned_padding() { + // 300 elements (256 + 44) should create 2 blocks with padding + let weights = vec![0.5; 300]; + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + // Should pad to 2 full blocks (512 elements) + let num_blocks = (300 + BLOCK_SIZE - 1) / BLOCK_SIZE; + assert_eq!( + tensor.num_blocks(), + num_blocks, + "Non-aligned tensor should pad to full blocks" + ); + + // Original element count should be preserved + assert_eq!(tensor.num_elements(), 300); +} + +// ============================================================================ +// 5. TernaryTensor Properties Tests +// ============================================================================ + +#[test] +fn test_ternary_tensor_properties() { + let weights: Vec = (0..512).map(|i| (i as f32) * 0.01).collect(); + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + // Memory bytes should match calculation + let num_blocks = (512 + BLOCK_SIZE - 1) / BLOCK_SIZE; + let packed_bytes = num_blocks * BLOCK_SIZE * 2 / 8; // 2 bits per element + let scale_bytes = num_blocks * 4; // f32 scales + let expected = packed_bytes + scale_bytes; + + assert_eq!(tensor.memory_bytes(), expected); + + // Sparsity should be in valid range + assert!(tensor.sparsity() >= 0.0 && tensor.sparsity() <= 1.0); +} + +#[test] +fn test_ternary_tensor_uniform_random_sparsity() { + // Uniform random should have ~1/3 sparsity + let mut weights = Vec::with_capacity(2048); + for i in 0..2048 { + weights.push(((i as f32) * 1.234).sin()); + } + + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + let sparsity = tensor.sparsity(); + + // Rough heuristic: 20-45% zeros for uniform random + assert!( + sparsity > 0.2 && sparsity < 0.45, + "Uniform random sparsity {} outside expected range [0.2, 0.45]", + sparsity + ); +} + +// ============================================================================ +// 6. Config Validation Tests +// ============================================================================ + +#[test] +fn test_config_default_values() { + let config = PtBitnetConfig::default(); + + assert_eq!(config.block_size, 256, "Default block size should be 256"); + assert!( + config.calibration_samples > 0, + "Calibration samples must be > 0" + ); +} + +#[test] +#[should_panic(expected = "block_size must be > 0")] +fn test_config_invalid_block_size() { + let _config = PtBitnetConfig { + block_size: 0, + ..Default::default() + }; +} + +#[test] +#[should_panic(expected = "calibration_samples must be > 0")] +fn test_config_invalid_calibration_samples() { + let _config = PtBitnetConfig { + calibration_samples: 0, + ..Default::default() + }; +} + +// ============================================================================ +// 7. Edge Case Tests +// ============================================================================ + +#[test] +fn test_empty_input() { + let weights: Vec = vec![]; + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + assert_eq!(tensor.num_elements(), 0); + assert_eq!(tensor.num_blocks(), 0); + assert_eq!(tensor.sparsity(), 0.0); +} + +#[test] +fn test_single_element() { + let weights = vec![0.5]; + let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE); + + assert_eq!(tensor.num_elements(), 1); + // Should create 1 block (padded) + assert_eq!(tensor.num_blocks(), 1); +} + +#[test] +fn test_very_large_values() { + let weights = vec![f32::MAX, f32::MAX, f32::MAX, f32::MAX]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // Should all quantize to +1 + assert!(ternary.iter().all(|&x| x == 1), "f32::MAX should quantize to +1"); + + // Scale should be approximately f32::MAX + assert!(scale > 1e30, "Scale should be very large"); + + // Dequantization should not produce NaN + let dequantized = dequantize_ternary(&ternary, scale); + assert!( + dequantized.iter().all(|&x| !x.is_nan()), + "Dequantization should not produce NaN" + ); +} + +#[test] +fn test_subnormal_floats() { + // Very small positive values (subnormal range) + let weights = vec![1e-40, -1e-40, 1e-39, -1e-39]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // Should quantize reasonably (may be all zeros or small values) + assert!(ternary.iter().all(|&x| x >= -1 && x <= 1)); + + // Scale should be tiny but not zero + assert!(scale > 0.0, "Scale should be > 0 even for subnormal inputs"); +} + +#[test] +fn test_nan_handling() { + // NaN should not crash, but behavior is implementation-defined + let weights = vec![f32::NAN, 1.0, -1.0, 0.0]; + let result = std::panic::catch_unwind(|| { + quantize_absmean_with_scale(&weights) + }); + + // Should either panic or handle gracefully + // At minimum, should not produce infinite loop or segfault + if let Ok((ternary, scale)) = result { + // If it succeeds, output should not contain NaN + assert!( + !scale.is_nan() || scale == 0.0, + "Scale should not be NaN unless handled explicitly" + ); + assert!( + ternary.iter().all(|&x| x >= -1 && x <= 1), + "Ternary values must be in valid range" + ); + } +} + +#[test] +fn test_infinity_handling() { + let weights = vec![f32::INFINITY, f32::NEG_INFINITY, 1.0, -1.0]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // Infinities should quantize to ±1 + assert_eq!(ternary[0], 1, "INFINITY should quantize to +1"); + assert_eq!(ternary[1], -1, "NEG_INFINITY should quantize to -1"); + + // Scale should be finite (or handled gracefully) + // Implementation may cap scale to avoid overflow + assert!( + scale.is_finite() || scale > 1e30, + "Scale should be finite or very large" + ); +} + +#[test] +fn test_mixed_magnitudes() { + // Mix of very large and very small values + let weights = vec![1000.0, 0.001, -1000.0, -0.001, 0.0]; + let (ternary, scale) = quantize_absmean_with_scale(&weights); + + // Should produce valid ternary values + assert!(ternary.iter().all(|&x| x >= -1 && x <= 1)); + + // Scale should be dominated by large values + assert!(scale > 100.0, "Scale should reflect large values"); + + // Small values should quantize to 0 + assert_eq!( + ternary[1], 0, + "0.001 compared to scale ~500 should be 0" + ); + assert_eq!(ternary[3], 0, "-0.001 should be 0"); +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Helper to quantize weights using absmean method +/// Returns both ternary values and scale factor +fn quantize_absmean_with_scale(weights: &[f32]) -> (Vec, f32) { + if weights.is_empty() { + return (vec![], 0.0); + } + + // Compute absmean scale: gamma = mean(|W|) + epsilon + let absmean: f32 = weights.iter().map(|&w| w.abs()).sum::() / weights.len() as f32; + let scale = absmean + EPSILON; + + // Quantize: W_ternary = RoundClip(W / scale, -1, 1) + let ternary: Vec = weights + .iter() + .map(|&w| { + let normalized = w / scale; + // Round and clip to {-1, 0, +1} + if normalized >= 0.5 { + 1 + } else if normalized <= -0.5 { + -1 + } else { + 0 + } + }) + .collect(); + + (ternary, scale) +} + +/// Helper to quantize weights (scale not needed) +fn quantize_absmean(weights: &[f32]) -> Vec { + let (ternary, _scale) = quantize_absmean_with_scale(weights); + ternary +} + +/// Helper to dequantize ternary values +fn dequantize_ternary(ternary: &[i8], scale: f32) -> Vec { + ternary.iter().map(|&t| (t as f32) * scale).collect() +} diff --git a/crates/ruvllm/src/gguf/quantization.rs b/crates/ruvllm/src/gguf/quantization.rs index ef15a3b31..de221ced8 100644 --- a/crates/ruvllm/src/gguf/quantization.rs +++ b/crates/ruvllm/src/gguf/quantization.rs @@ -29,6 +29,7 @@ //! | IQ4_NL | 4.5 | 32 | i-quant 4-bit non-linear | use crate::error::{Result, RuvLLMError}; +use crate::bitnet::dequantize_bitnet_t158; // ============================================================================ // Quantization Types @@ -100,6 +101,8 @@ pub enum GgufQuantType { F64 = 28, /// BF16 brain float Bf16 = 29, + /// BitNet b1.58 ternary quantization (2-bit packed) + BitnetT158 = 30, } impl TryFrom for GgufQuantType { @@ -137,6 +140,7 @@ impl TryFrom for GgufQuantType { 27 => Ok(Self::I64), 28 => Ok(Self::F64), 29 => Ok(Self::Bf16), + 30 => Ok(Self::BitnetT158), _ => Err(RuvLLMError::Model(format!( "Unknown GGUF quantization type: {}", value @@ -163,6 +167,7 @@ impl GgufQuantType { Self::IQ1_S => 256, Self::IQ4_NL => 32, Self::IQ4_XS => 256, + Self::BitnetT158 => 256, } } @@ -214,6 +219,8 @@ impl GgufQuantType { Self::IQ1_S => 50, Self::IQ4_NL => 18, Self::IQ4_XS => 136, + // BitNet b1.58: 256 elements -> 64 bytes (2-bit packed) + 2 bytes (FP16 scale) = 66 bytes + Self::BitnetT158 => 66, } } @@ -280,6 +287,7 @@ impl GgufQuantType { Self::IQ1_S => "IQ1_S", Self::IQ4_NL => "IQ4_NL", Self::IQ4_XS => "IQ4_XS", + Self::BitnetT158 => "BITNET_T158", } } } @@ -355,6 +363,14 @@ pub fn dequantize_tensor( GgufQuantType::Q5_K => dequantize_q5_k(data, &mut output), GgufQuantType::Q6_K => dequantize_q6_k(data, &mut output), GgufQuantType::IQ4_NL => dequantize_iq4_nl(data, &mut output), + GgufQuantType::BitnetT158 => dequantize_bitnet_t158_wrapper(data, &mut output), + GgufQuantType::IQ1_S => { + return Err(RuvLLMError::Model( + "IQ1_S dequantization requires codebook lookup tables (not yet implemented). \ + For BitNet ternary quantization, use BITNET_T158 type instead." + .to_string(), + )); + } _ => { return Err(RuvLLMError::Model(format!( "Dequantization not implemented for {:?}", diff --git a/crates/ruvllm/src/lib.rs b/crates/ruvllm/src/lib.rs index a66826c49..f6367f34e 100644 --- a/crates/ruvllm/src/lib.rs +++ b/crates/ruvllm/src/lib.rs @@ -44,6 +44,7 @@ pub mod adapter_manager; pub mod autodetect; pub mod backends; +pub mod bitnet; pub mod capabilities; pub mod claude_flow; pub mod context; diff --git a/docs/architecture/bitnet-quantizer-module-design.md b/docs/architecture/bitnet-quantizer-module-design.md new file mode 100644 index 000000000..7cbd5deb5 --- /dev/null +++ b/docs/architecture/bitnet-quantizer-module-design.md @@ -0,0 +1,999 @@ +# PT-BitNet Quantizer Module Architecture Design + +**Version:** 1.0 +**Date:** 2026-02-03 +**Status:** Design Specification +**Relates to:** ADR-017 (AD-1, AD-5, AD-18, AD-19), DDD Section 3.4/4.2/4.3 + +--- + +## Executive Summary + +This document specifies the architecture for the **PT-BitNet post-training quantizer** module that converts FP16/BF16 GLM-4.7-Flash weights to BitNet b1.58 ternary {-1, 0, +1} format via absmean quantization. This is a **design-only specification** — implementation follows in Phase 0. + +**Design Scope:** +- Module layout and file organization +- Complete struct definitions with field types +- Full function signatures (no implementations) +- GGUF integration points and format extensions +- Error handling strategy +- Testing approach + +**Out of Scope:** +- Actual implementation code +- Performance benchmarks +- Calibration dataset selection + +--- + +## A. Module Layout + +### Directory Structure + +``` +crates/ruvllm/src/ +├── bitnet/ # NEW module +│ ├── mod.rs # Module exports and public API +│ ├── quantizer.rs # PtBitnetQuantizer + absmean algorithm +│ ├── ternary_tensor.rs # TernaryTensor value object +│ ├── dequantize.rs # BITNET_T158 dequantization kernel +│ └── config.rs # PtBitnetConfig configuration +│ +├── gguf/ +│ ├── mod.rs # Add pub mod bitnet export +│ ├── quantization.rs # MODIFIED: Add BITNET_T158 enum variant +│ ├── parser.rs # Unchanged (reused as-is) +│ └── ... +│ +└── kernels/ + └── matmul.rs # Reference for dispatch patterns +``` + +### Modified Files + +#### `src/gguf/quantization.rs` + +**Changes:** +1. Add `BITNET_T158 = 30` variant to `GgufQuantType` enum (after `Bf16 = 29`) +2. Update `try_from()` impl to handle type 30 +3. Update `block_size()` to return 256 for `BITNET_T158` +4. Update `type_size()` to return 66 for `BITNET_T158` (64 bytes packed + 2 bytes FP16 scale) +5. Update `is_quantized()` to include `BITNET_T158` +6. Update `bits_per_weight()` to return 2.06 for `BITNET_T158` +7. Add new match arm in `dequantize_tensor()` → `BITNET_T158 => dequantize_bitnet_t158(data, output)` + +**Exact enum addition:** +```rust +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(u32)] +pub enum GgufQuantType { + // ... existing variants 0-29 ... + /// BitNet b1.58 ternary quantization (2-bit packed + FP16 scale per 256-element block) + BITNET_T158 = 30, +} +``` + +--- + +## B. Struct Definitions + +### 1. `PtBitnetConfig` (in `bitnet/config.rs`) + +**Purpose:** Configuration for PT-BitNet quantization process + +```rust +/// Configuration for PT-BitNet post-training quantization +#[derive(Debug, Clone)] +pub struct PtBitnetConfig { + /// Block size for absmean scale computation (default: 256) + pub block_size: usize, + + /// Epsilon for numerical stability in scale computation (default: 1e-8) + pub epsilon: f32, + + /// Whether to run calibration pass to optimize scale factors + pub use_calibration: bool, + + /// Number of calibration samples (if use_calibration = true) + pub calibration_samples: usize, + + /// Maximum sequence length for calibration (default: 2048) + pub calibration_max_seq_len: usize, + + /// Device for calibration pass ("cpu", "metal", "cuda:0") + pub calibration_device: String, + + /// Clipping threshold for normalized weights before rounding + /// (default: 1.0, range typically 0.95-1.05) + pub clip_threshold: f32, + + /// Sparsity target: if > 0.0, bias rounding toward zero to achieve target sparsity + pub target_sparsity: Option, +} + +impl Default for PtBitnetConfig { + fn default() -> Self { + Self { + block_size: 256, + epsilon: 1e-8, + use_calibration: false, + calibration_samples: 1000, + calibration_max_seq_len: 2048, + calibration_device: "metal".to_string(), + clip_threshold: 1.0, + target_sparsity: None, + } + } +} +``` + +### 2. `TernaryTensor` (in `bitnet/ternary_tensor.rs`) + +**Purpose:** Immutable value object for packed ternary weights + +```rust +/// Packed ternary tensor with per-block FP16 scales +#[derive(Debug, Clone)] +pub struct TernaryTensor { + /// Packed 2-bit ternary values (4 weights per byte) + /// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved + pub packed_data: Vec, + + /// Per-block FP16 scale factors (absmean values) + pub scales: Vec, + + /// Tensor shape [out_features, in_features] or [rows, cols] + pub shape: [usize; 2], + + /// Block size (always 256 for BitNet b1.58) + pub block_size: usize, + + /// Total number of weights + pub num_elements: usize, + + /// Number of blocks + pub num_blocks: usize, + + /// Measured sparsity (fraction of zero weights) + pub sparsity: f32, +} + +impl TernaryTensor { + /// Calculate total storage size in bytes + pub fn storage_size(&self) -> usize; + + /// Get expected packed_data size for validation + pub fn expected_packed_size(&self) -> usize; + + /// Validate internal consistency + pub fn validate(&self) -> Result<()>; +} +``` + +### 3. `TernaryBlock` (in `bitnet/ternary_tensor.rs`) + +**Purpose:** Single block of 256 ternary weights with scale + +```rust +/// A single 256-element block with ternary weights and FP16 scale +#[derive(Debug, Clone)] +pub struct TernaryBlock { + /// 64 bytes of packed 2-bit values (256 weights × 2 bits ÷ 8 bits/byte) + pub packed: [u8; 64], + + /// FP16 absmean scale factor + pub scale: f16, +} + +impl TernaryBlock { + /// Size in bytes when stored in GGUF (64 + 2 = 66) + pub const STORAGE_SIZE: usize = 66; + + /// Number of elements in a block + pub const BLOCK_SIZE: usize = 256; +} +``` + +### 4. `AbsmeanResult` (in `bitnet/quantizer.rs`) + +**Purpose:** Result of absmean quantization on a single block + +```rust +/// Result of absmean ternary quantization on a block +#[derive(Debug, Clone)] +pub struct AbsmeanResult { + /// Ternary values {-1, 0, +1} for each weight in the block + pub ternary_weights: Vec, + + /// Computed absmean scale factor (gamma = mean(|W|)) + pub scale: f32, + + /// Measured sparsity (fraction of zeros) + pub sparsity: f32, + + /// Mean squared error vs original FP16 values (for calibration) + pub mse: f32, +} +``` + +### 5. `QuantizationStats` (in `bitnet/quantizer.rs`) + +**Purpose:** Statistics collected during quantization + +```rust +/// Statistics from quantizing a single tensor +#[derive(Debug, Clone)] +pub struct QuantizationStats { + /// Tensor name + pub name: String, + + /// Mean of all block scales + pub mean_scale: f32, + + /// Std dev of block scales + pub std_scale: f32, + + /// Overall sparsity across all blocks + pub sparsity: f32, + + /// Mean MSE across all blocks + pub mean_mse: f32, + + /// Number of blocks + pub num_blocks: usize, +} +``` + +--- + +## C. Function Signatures + +### Core Quantization Functions (in `bitnet/quantizer.rs`) + +#### 1. Primary Quantization Entry Point + +```rust +/// Quantize an FP16/F32 tensor to ternary format using absmean quantization +/// +/// # Arguments +/// * `tensor` - Input FP16 or F32 tensor data (flat vector) +/// * `shape` - Tensor shape [out_features, in_features] +/// * `config` - Quantization configuration +/// +/// # Returns +/// * `TernaryTensor` - Packed ternary representation +/// * `QuantizationStats` - Statistics about the quantization process +/// +/// # Errors +/// * `RuvLLMError::Quantization` if tensor size is not divisible by block_size +/// * `RuvLLMError::Quantization` if shape product doesn't match tensor length +pub fn quantize_tensor( + tensor: &[f32], + shape: [usize; 2], + config: &PtBitnetConfig, +) -> Result<(TernaryTensor, QuantizationStats)>; +``` + +#### 2. Per-Block Quantization + +```rust +/// Apply absmean quantization to a single block of weights +/// +/// Algorithm: +/// 1. gamma = mean(|block|) + epsilon +/// 2. normalized = block / gamma +/// 3. ternary = round(clamp(normalized, -clip_threshold, +clip_threshold)) +/// 4. Map to {-1, 0, +1} +/// +/// # Arguments +/// * `block` - Block of FP16/F32 values (length = config.block_size) +/// * `config` - Configuration with epsilon and clip_threshold +/// +/// # Returns +/// * `AbsmeanResult` with ternary values, scale, sparsity, MSE +/// +/// # Panics +/// * If block.len() != config.block_size +pub fn absmean_ternary( + block: &[f32], + config: &PtBitnetConfig, +) -> AbsmeanResult; +``` + +#### 3. Packing Functions + +```rust +/// Pack ternary {-1, 0, +1} values into 2-bit representation +/// +/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved (unused) +/// 4 values packed per byte: [v3 v2 v1 v0] → byte +/// +/// # Arguments +/// * `values` - Ternary values (must be {-1, 0, +1} only) +/// +/// # Returns +/// * Packed bytes (length = ceil(values.len() / 4)) +/// +/// # Errors +/// * If any value is not in {-1, 0, +1} +pub fn pack_ternary(values: &[i8]) -> Result>; + +/// Unpack 2-bit representation to ternary {-1, 0, +1} values +/// +/// # Arguments +/// * `packed` - Packed 2-bit data +/// * `n` - Number of values to extract +/// +/// # Returns +/// * Vector of ternary values (length = n) +pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec; +``` + +#### 4. Calibration (Optional) + +```rust +/// Run calibration pass to optimize scale factors +/// +/// # Arguments +/// * `tensor` - Input FP16 tensor +/// * `shape` - Tensor shape +/// * `config` - Config with calibration settings +/// * `calibration_data` - Sample activations for this layer +/// +/// # Returns +/// * Optimized `TernaryTensor` with calibrated scales +/// +/// # Note +/// This is optional - if not used, falls back to plain absmean +pub fn quantize_with_calibration( + tensor: &[f32], + shape: [usize; 2], + config: &PtBitnetConfig, + calibration_data: &[Vec], +) -> Result<(TernaryTensor, QuantizationStats)>; +``` + +### Dequantization Functions (in `bitnet/dequantize.rs`) + +```rust +/// Dequantize BITNET_T158 tensor to FP32 +/// +/// # Arguments +/// * `data` - Raw GGUF tensor bytes (packed ternary + scales) +/// * `scales` - Per-block FP16 scales (extracted from data) +/// * `n` - Total number of elements to dequantize +/// +/// # Returns +/// * Vec of dequantized values +/// +/// # Format +/// Each block: [64 bytes packed ternary][2 bytes FP16 scale] +pub fn dequantize_bitnet_t158( + data: &[u8], + scales: &[f16], + n: usize, +) -> Vec; + +/// Dequantize a single BITNET_T158 block +/// +/// # Arguments +/// * `block_data` - 64 bytes of packed ternary data +/// * `scale` - FP16 scale factor +/// * `output` - Output buffer (must have capacity for 256 elements) +pub fn dequantize_bitnet_t158_block( + block_data: &[u8; 64], + scale: f16, + output: &mut [f32], +); +``` + +### Tensor Conversion (in `bitnet/ternary_tensor.rs`) + +```rust +impl TernaryTensor { + /// Convert from packed storage to FP32 (for validation/testing) + pub fn to_fp32(&self) -> Vec; + + /// Create from existing GGUF tensor data + pub fn from_gguf_data( + data: &[u8], + shape: [usize; 2], + block_size: usize, + ) -> Result; + + /// Serialize to GGUF tensor bytes + pub fn to_gguf_data(&self) -> Vec; +} +``` + +--- + +## D. GGUF Integration Points + +### 1. New Quantization Type Variant + +**File:** `crates/ruvllm/src/gguf/quantization.rs` + +**Changes to `GgufQuantType` enum:** + +```rust +#[repr(u32)] +pub enum GgufQuantType { + // ... existing 0-29 ... + + /// BitNet b1.58 ternary quantization + /// Block size: 256 elements + /// Storage: 64 bytes packed (2-bit) + 2 bytes FP16 scale = 66 bytes/block + /// Bits per weight: 2.06 bpw + BITNET_T158 = 30, +} + +impl GgufQuantType { + pub fn block_size(&self) -> usize { + match self { + // ... existing cases ... + Self::BITNET_T158 => 256, + } + } + + pub fn type_size(&self) -> usize { + match self { + // ... existing cases ... + Self::BITNET_T158 => 66, // 64 + 2 + } + } + + pub fn name(&self) -> &'static str { + match self { + // ... existing cases ... + Self::BITNET_T158 => "BITNET_T158", + } + } +} + +impl TryFrom for GgufQuantType { + fn try_from(value: u32) -> Result { + match value { + // ... existing 0-29 ... + 30 => Ok(Self::BITNET_T158), + _ => Err(/* ... */), + } + } +} +``` + +### 2. Dequantization Dispatch + +**File:** `crates/ruvllm/src/gguf/quantization.rs` + +**Modification to `dequantize_tensor()` function:** + +```rust +pub fn dequantize_tensor( + data: &[u8], + dtype: GgufQuantType, + num_elements: usize, +) -> Result> { + let mut output = vec![0.0f32; num_elements]; + + match dtype { + // ... existing cases ... + GgufQuantType::BITNET_T158 => { + // Extract scales and packed data + let num_blocks = (num_elements + 255) / 256; + let mut scales = Vec::with_capacity(num_blocks); + + for i in 0..num_blocks { + let block_offset = i * 66; + let scale_offset = block_offset + 64; + let scale_bytes = [data[scale_offset], data[scale_offset + 1]]; + scales.push(f16::from_le_bytes(scale_bytes)); + } + + crate::bitnet::dequantize::dequantize_bitnet_t158( + data, + &scales, + num_elements, + ); + } + _ => { + return Err(RuvLLMError::Model(format!( + "Dequantization not implemented for {:?}", + dtype + ))); + } + } + + Ok(output) +} +``` + +### 3. GGUF Metadata Keys + +**New metadata keys for BitNet models** (written during quantization, read during load): + +```rust +// In quantizer when exporting GGUF +pub const BITNET_METADATA_KEYS: &[(&str, &str)] = &[ + ("craftsman.bitnet.version", "1"), + ("craftsman.bitnet.weight_encoding", "absmean_ternary"), + ("craftsman.bitnet.activation_bits", "8"), + ("craftsman.bitnet.block_size", "256"), + ("craftsman.bitnet.kernel_hint", "tl1"), // or "tl2", "i2s" +]; +``` + +**Metadata reading in model loader:** + +```rust +// In backend when loading model +fn detect_bitnet_model(metadata: &HashMap) -> bool { + metadata.get("craftsman.bitnet.version") + .and_then(|v| v.as_str()) + .map(|v| v == "1") + .unwrap_or(false) +} +``` + +### 4. Tensor Info Extension + +**No changes needed** - existing `TensorInfo` struct in `parser.rs` already supports: +- `name: String` +- `shape: Vec` +- `dtype: GgufQuantType` ← Will now include `BITNET_T158` +- `offset: u64` + +--- + +## E. Error Handling Strategy + +### Error Types + +All errors use existing `RuvLLMError` enum from `crates/ruvllm/src/error.rs`: + +```rust +pub enum RuvLLMError { + // Existing variants... + + // Quantization-specific errors + Quantization(String), // Use this variant for all quantization errors + Model(String), // For GGUF format issues + Config(String), // For invalid configuration +} +``` + +### Error Scenarios and Handling + +| Scenario | Error Type | Recovery Strategy | +|----------|-----------|-------------------| +| Tensor size not divisible by block_size | `Quantization` | Pad last block with zeros | +| Invalid ternary value during packing | `Quantization` | Fail-fast - indicates bug | +| GGUF file has wrong BITNET_T158 block size | `Model` | Fail-fast - corrupted file | +| Calibration device unavailable | `Config` | Fall back to non-calibrated quantization | +| Out of memory during quantization | System panic | Let Rust OOM handler catch | +| Shape mismatch in tensor | `Quantization` | Fail-fast - validate before processing | +| FP16 scale is NaN/Inf | `Quantization` | Clamp to epsilon value | +| Empty tensor / zero elements | `Quantization` | Skip with warning | + +### Validation Functions + +```rust +/// Validate quantization config +pub fn validate_config(config: &PtBitnetConfig) -> Result<()> { + if config.block_size == 0 || config.block_size % 4 != 0 { + return Err(RuvLLMError::Config( + "block_size must be non-zero and divisible by 4".into() + )); + } + + if config.epsilon <= 0.0 { + return Err(RuvLLMError::Config( + "epsilon must be positive".into() + )); + } + + if config.clip_threshold <= 0.0 || config.clip_threshold > 2.0 { + return Err(RuvLLMError::Config( + "clip_threshold must be in range (0.0, 2.0]".into() + )); + } + + Ok(()) +} + +/// Validate tensor shape and size +pub fn validate_tensor( + tensor: &[f32], + shape: [usize; 2], + block_size: usize, +) -> Result<()> { + let expected_size = shape[0] * shape[1]; + + if tensor.len() != expected_size { + return Err(RuvLLMError::Quantization(format!( + "Tensor length {} doesn't match shape {:?} (expected {})", + tensor.len(), shape, expected_size + ))); + } + + if expected_size % block_size != 0 { + // Could pad, but for simplicity require exact multiple + return Err(RuvLLMError::Quantization(format!( + "Tensor size {} is not divisible by block_size {}", + expected_size, block_size + ))); + } + + Ok(()) +} +``` + +--- + +## F. Testing Strategy + +### Unit Tests + +#### 1. Absmean Quantization Correctness + +**File:** `crates/ruvllm/src/bitnet/tests/quantizer_tests.rs` + +```rust +#[test] +fn test_absmean_ternary_basic() { + // Test that absmean correctly quantizes known values + let config = PtBitnetConfig::default(); + + // Block with known mean(|x|) = 1.0 + let block: Vec = vec![ + 2.0, -2.0, 1.0, -1.0, // gamma = mean(2,2,1,1,...) ≈ 1.0 + 0.5, -0.5, 0.0, 0.0, + // ... (pad to 256 elements) + ]; + + let result = absmean_ternary(&block, &config); + + // After normalization: 2.0/1.0 = 2.0 → clamp to 1.0 → round to +1 + assert_eq!(result.ternary_weights[0], 1); // 2.0 → +1 + assert_eq!(result.ternary_weights[1], -1); // -2.0 → -1 + assert_eq!(result.ternary_weights[2], 1); // 1.0 → +1 + assert_eq!(result.ternary_weights[6], 0); // 0.0 → 0 + + assert!(result.scale > 0.9 && result.scale < 1.1); // gamma ≈ 1.0 +} + +#[test] +fn test_absmean_all_zeros() { + let config = PtBitnetConfig::default(); + let block = vec![0.0; 256]; + + let result = absmean_ternary(&block, &config); + + // All zeros → scale = epsilon, all ternary = 0 + assert_eq!(result.scale, config.epsilon); + assert!(result.ternary_weights.iter().all(|&x| x == 0)); + assert_eq!(result.sparsity, 1.0); +} +``` + +#### 2. Pack/Unpack Round-Trip + +```rust +#[test] +fn test_pack_unpack_roundtrip() { + let original = vec![1i8, -1, 0, 1, 0, -1, 1, 0]; + + let packed = pack_ternary(&original).unwrap(); + assert_eq!(packed.len(), 2); // 8 values → 2 bytes + + let unpacked = unpack_ternary(&packed, 8); + assert_eq!(unpacked, original); +} + +#[test] +fn test_pack_invalid_value() { + let invalid = vec![1i8, 2, 0]; // 2 is not ternary + + let result = pack_ternary(&invalid); + assert!(result.is_err()); +} +``` + +#### 3. Tensor Validation + +```rust +#[test] +fn test_validate_tensor_shape_mismatch() { + let tensor = vec![1.0; 100]; + let shape = [10, 11]; // 10*11 = 110 ≠ 100 + + let result = validate_tensor(&tensor, shape, 256); + assert!(result.is_err()); +} + +#[test] +fn test_validate_tensor_block_alignment() { + let tensor = vec![1.0; 257]; // Not divisible by 256 + let shape = [1, 257]; + + let result = validate_tensor(&tensor, shape, 256); + assert!(result.is_err()); +} +``` + +### Integration Tests + +#### 4. Full Quantization Pipeline + +```rust +#[test] +fn test_quantize_tensor_full_pipeline() { + let config = PtBitnetConfig::default(); + + // Create a 512-element tensor (2 blocks) + let tensor: Vec = (0..512).map(|i| (i as f32) / 512.0).collect(); + let shape = [2, 256]; + + let (ternary, stats) = quantize_tensor(&tensor, shape, &config).unwrap(); + + assert_eq!(ternary.num_blocks, 2); + assert_eq!(ternary.packed_data.len(), 2 * 64); // 2 blocks × 64 bytes + assert_eq!(ternary.scales.len(), 2); + assert_eq!(stats.num_blocks, 2); + + // Verify reconstruction quality + let reconstructed = ternary.to_fp32(); + assert_eq!(reconstructed.len(), 512); +} +``` + +#### 5. GGUF Round-Trip + +```rust +#[test] +fn test_gguf_serialization_roundtrip() { + let config = PtBitnetConfig::default(); + let tensor = vec![1.0; 256]; + let shape = [1, 256]; + + let (ternary, _) = quantize_tensor(&tensor, shape, &config).unwrap(); + + // Serialize to GGUF format + let gguf_data = ternary.to_gguf_data(); + assert_eq!(gguf_data.len(), 66); // 1 block = 66 bytes + + // Deserialize + let recovered = TernaryTensor::from_gguf_data(&gguf_data, shape, 256).unwrap(); + + assert_eq!(recovered.packed_data, ternary.packed_data); + assert_eq!(recovered.scales, ternary.scales); +} +``` + +### Benchmark Tests + +#### 6. Performance Regression + +```rust +#[bench] +fn bench_absmean_ternary_256(b: &mut Bencher) { + let config = PtBitnetConfig::default(); + let block: Vec = (0..256).map(|i| (i as f32) / 256.0).collect(); + + b.iter(|| { + let _ = absmean_ternary(&block, &config); + }); +} + +#[bench] +fn bench_pack_ternary_1024(b: &mut Bencher) { + let values = vec![1i8; 1024]; + + b.iter(|| { + let _ = pack_ternary(&values); + }); +} +``` + +### Correctness Validation Tests + +#### 7. Bit-Exact Validation Against Reference + +```rust +#[test] +fn test_dequantize_matches_reference() { + // Reference implementation (naive) + fn reference_dequant(ternary: &[i8], scale: f32) -> Vec { + ternary.iter().map(|&t| (t as f32) * scale).collect() + } + + let config = PtBitnetConfig::default(); + let tensor = vec![1.5, -2.3, 0.1, -0.4]; // Extend to 256 + let tensor_256 = /* pad to 256 */; + let shape = [1, 256]; + + let (ternary, _) = quantize_tensor(&tensor_256, shape, &config).unwrap(); + + // Unpack and dequantize + let unpacked = unpack_ternary(&ternary.packed_data, 256); + let reference = reference_dequant(&unpacked, ternary.scales[0].to_f32()); + let optimized = ternary.to_fp32(); + + // Allow small floating-point error + for (r, o) in reference.iter().zip(optimized.iter()) { + assert!((r - o).abs() < 1e-5); + } +} +``` + +### Test Organization + +``` +crates/ruvllm/src/bitnet/tests/ +├── quantizer_tests.rs # absmean, pack/unpack +├── tensor_tests.rs # TernaryTensor validation +├── dequantize_tests.rs # BITNET_T158 dequant +├── integration_tests.rs # Full pipeline, GGUF round-trip +└── benches.rs # Performance benchmarks +``` + +--- + +## G. Implementation Phases + +### Phase 0.1: Core Data Structures (~2-3 days) +1. `bitnet/mod.rs` - module structure +2. `bitnet/config.rs` - `PtBitnetConfig` +3. `bitnet/ternary_tensor.rs` - `TernaryTensor`, `TernaryBlock` +4. Unit tests for validation + +### Phase 0.2: Quantization Algorithm (~3-4 days) +1. `bitnet/quantizer.rs` - `absmean_ternary()` +2. Pack/unpack functions +3. `quantize_tensor()` main entry point +4. Unit tests for correctness + +### Phase 0.3: Dequantization (~2 days) +1. `bitnet/dequantize.rs` - block and tensor dequant +2. Integration with existing `quantization.rs` +3. Round-trip tests + +### Phase 0.4: GGUF Integration (~2-3 days) +1. Modify `gguf/quantization.rs` - add `BITNET_T158` enum variant +2. Add metadata keys +3. GGUF serialization/deserialization +4. Integration tests + +### Phase 0.5: Validation & Benchmarks (~2 days) +1. Full pipeline integration tests +2. Performance benchmarks +3. Bit-exact validation +4. Documentation + +**Total Estimated Effort:** ~13-16 days for clean, well-tested implementation + +--- + +## H. Open Design Questions + +| # | Question | Impact | Recommendation | +|---|----------|--------|----------------| +| 1 | Use `IQ1_S` (type 19) or new `BITNET_T158` (type 30)? | Compatibility | **New type 30** - cleaner separation, avoids confusion with IQ1_S's codebook format | +| 2 | Padding strategy for last block if not aligned? | Correctness | **Zero-pad** - simplest, matches BitNet spec | +| 3 | Should calibration be mandatory or optional? | Quality vs Speed | **Optional** - Phase 0 can work without it, add later if needed | +| 4 | F16 or F32 for internal scale computation? | Precision | **F32 internally, store as F16** - extra precision during compute | +| 5 | Handle NaN/Inf in input tensors? | Robustness | **Fail-fast** - corrupted weights should not be silently ignored | +| 6 | Support block sizes other than 256? | Flexibility | **No** - BitNet spec is 256, simplifies code | +| 7 | Multi-threading for per-block quantization? | Performance | **Not in Phase 0** - can add via rayon later | +| 8 | Store sparsity per-block in GGUF? | Kernel optimization | **No** - compute on-the-fly during dequant, saves space | + +--- + +## I. Dependencies and Prerequisites + +### Existing RuvLLM Components (Reused) +- `crates/ruvllm/src/error.rs` - `RuvLLMError` enum +- `crates/ruvllm/src/gguf/parser.rs` - GGUF parsing (unchanged) +- `crates/ruvllm/src/gguf/quantization.rs` - Enum + dispatch (modified) +- `half` crate - FP16 support (already in Cargo.toml) + +### New External Dependencies +None - uses only existing dependencies + +### Minimum Rust Version +Same as RuvLLM (likely 1.70+) + +--- + +## J. Non-Goals (Out of Scope) + +1. **Calibration implementation** - Deferred to future phase +2. **TL1/TL2 kernel implementation** - Separate ADR/DDD +3. **Model loader integration** - Separate backend implementation +4. **Performance optimization** - Phase 0 is correctness-first +5. **WASM support** - Desktop/server only for Phase 0 +6. **Dynamic quantization** - Only post-training static +7. **Mixed-precision strategies** - All-or-nothing ternary for Phase 0 + +--- + +## K. Success Criteria + +**This design is complete when:** + +1. All struct definitions have complete field specifications +2. All function signatures are documented with arguments, returns, errors +3. Module organization is clear and follows Rust conventions +4. GGUF integration points are precisely specified +5. Error handling covers all failure modes +6. Test plan covers correctness, integration, and performance +7. Implementation phases are realistic and sequenced +8. Open questions are documented with recommendations + +**Implementation is successful when:** + +1. All unit tests pass +2. Round-trip GGUF serialization is bit-exact +3. Dequantization produces correct FP32 output +4. Integration with existing GGUF pipeline works +5. Quantization of GLM-4.7-Flash completes without errors +6. Exported GGUF file is loadable by model loader + +--- + +## Appendix A: Code Size Estimates + +| File | Estimated Lines | Complexity | +|------|----------------|------------| +| `bitnet/mod.rs` | ~50 | Low | +| `bitnet/config.rs` | ~80 | Low | +| `bitnet/ternary_tensor.rs` | ~200 | Medium | +| `bitnet/quantizer.rs` | ~350 | High | +| `bitnet/dequantize.rs` | ~150 | Medium | +| `gguf/quantization.rs` (changes) | ~100 | Low | +| Tests | ~800 | Medium | +| **Total** | **~1,730 lines** | | + +**Comparison to ADR-018 estimate:** ~200-300 lines core quantizer → Actual ~350 lines (reasonable given struct overhead) + +--- + +## Appendix B: Memory Layout Examples + +### TernaryBlock Storage (66 bytes) + +``` +Byte Offset | Content +------------|-------- +0-63 | Packed 2-bit ternary (256 values) +64-65 | FP16 scale (little-endian) +``` + +### 2-Bit Packing Example + +``` +Values: [+1, -1, 0, +1] +Encoding: [10, 00, 01, 10] +Packed byte: 10_00_01_10 = 0x86 +``` + +### GGUF Tensor Data Layout + +``` +[TensorInfo] (in header) + name: "model.layers.0.mlp.gate_proj.weight" + shape: [4096, 11008] + dtype: BITNET_T158 (30) + offset: 0x1000 + +[Tensor Data] (at offset 0x1000) + Block 0: [64 bytes packed][2 bytes scale] + Block 1: [64 bytes packed][2 bytes scale] + ... + Block N: [64 bytes packed][2 bytes scale] +``` + +--- + +**End of Design Document** +