feat: Implement Phase 0 PT-BitNet quantizer module

Add bitnet/ module with absmean ternary quantizer, TernaryTensor type,
BITNET_T158 dequantization, and comprehensive test suite (~1600 lines).

Components:
- quantizer.rs: PtBitnetConfig, absmean_ternary(), quantize_tensor()
- ternary_tensor.rs: TernaryTensor, pack/unpack 2-bit ternary encoding
- dequantize.rs: dequantize_bitnet_t158(), block dequant, error metrics
- tests.rs: Packing roundtrips, quantization correctness, edge cases
- gguf/quantization.rs: BitnetT158 = 30 enum variant, block_size, bytes

Implements AD-1 (weight representation), AD-5 (GGUF extension), AD-18
(PT-BitNet quantization) from ADR-017.

https://claude.ai/code/session_011nTcGcn49b8YKJRVoh4TaK
This commit is contained in:
Claude 2026-02-03 12:40:18 +00:00
parent 80171fb8c1
commit 2bb04a64ed
8 changed files with 2619 additions and 0 deletions

View file

@ -0,0 +1,269 @@
//! BitNet Ternary Dequantization
//!
//! Converts packed 2-bit ternary weights back to FP32 for validation and testing.
use super::ternary_tensor::unpack_ternary;
/// Dequantize BITNET_T158 packed ternary data to FP32.
///
/// This function unpacks 2-bit ternary values and applies per-block scale factors
/// to reconstruct approximate FP32 weights. Used for validation and testing, not
/// for production inference (which should use native ternary kernels).
///
/// # Data Layout
///
/// The input data is organized as:
/// ```text
/// [packed_block_0][scale_0][packed_block_1][scale_1]...
/// ```
///
/// Where each block contains:
/// - 64 bytes of packed 2-bit ternary data (256 values)
/// - 2 bytes of FP16 scale factor
///
/// Total: 66 bytes per 256-element block
///
/// # Arguments
///
/// * `packed` - Raw GGUF tensor data with interleaved ternary and scales
/// * `scales` - Per-block FP32 scale factors
/// * `num_elements` - Total number of output elements
///
/// # Returns
///
/// Vector of FP32 weights approximating the original quantized tensor
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::dequantize_bitnet_t158;
///
/// // Load from GGUF
/// let packed_data = gguf_tensor.data; // Raw bytes
/// let scales = vec![0.542, 0.381, ...]; // One per block
/// let num_elements = 512;
///
/// let fp32_weights = dequantize_bitnet_t158(&packed_data, &scales, num_elements);
/// ```
pub fn dequantize_bitnet_t158(packed: &[u8], scales: &[f32], num_elements: usize) -> Vec<f32> {
// Unpack ternary values
let ternary = unpack_ternary(packed, num_elements);
// Apply per-block scales
let block_size = 256; // Standard BitNet block size
let mut output = Vec::with_capacity(num_elements);
for (block_idx, chunk) in ternary.chunks(block_size).enumerate() {
let scale = scales.get(block_idx).copied().unwrap_or(1.0);
for &ternary_val in chunk {
let fp32_val = (ternary_val as f32) * scale;
output.push(fp32_val);
}
}
output
}
/// Dequantize a single BITNET_T158 block.
///
/// Helper function for block-wise dequantization in streaming scenarios.
///
/// # Arguments
///
/// * `packed_block` - 64 bytes of packed 2-bit ternary data
/// * `scale` - FP32 scale factor for this block
/// * `output` - Output buffer (must have capacity for 256 FP32 values)
///
/// # Panics
///
/// Panics if output buffer is smaller than 256 elements.
pub fn dequantize_bitnet_block(packed_block: &[u8], scale: f32, output: &mut [f32]) {
assert!(
output.len() >= 256,
"Output buffer must hold at least 256 elements"
);
assert_eq!(
packed_block.len(),
64,
"Packed block must be exactly 64 bytes"
);
let ternary = unpack_ternary(packed_block, 256);
for (i, &ternary_val) in ternary.iter().enumerate() {
output[i] = (ternary_val as f32) * scale;
}
}
/// Compute dequantization error metrics.
///
/// Compares dequantized weights against original FP32 weights to measure
/// quantization quality.
///
/// # Arguments
///
/// * `original` - Original FP32 weights
/// * `dequantized` - Dequantized weights from ternary
///
/// # Returns
///
/// Tuple of (mean_absolute_error, mean_squared_error, max_error)
pub fn compute_dequant_error(original: &[f32], dequantized: &[f32]) -> (f32, f32, f32) {
assert_eq!(
original.len(),
dequantized.len(),
"Arrays must have same length"
);
let mut sum_abs_error = 0.0;
let mut sum_sq_error = 0.0;
let mut max_error = 0.0;
for (orig, dequant) in original.iter().zip(dequantized.iter()) {
let error = (orig - dequant).abs();
sum_abs_error += error;
sum_sq_error += error * error;
max_error = max_error.max(error);
}
let n = original.len() as f32;
let mae = sum_abs_error / n;
let mse = sum_sq_error / n;
(mae, mse, max_error)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bitnet::{absmean_ternary, pack_ternary};
#[test]
fn test_dequantize_bitnet_t158_simple() {
// Create simple ternary data
let ternary = vec![-1i8, 0, 1, -1, 1, 0, 0, 1];
let packed = pack_ternary(&ternary);
let scales = vec![0.5f32];
let result = dequantize_bitnet_t158(&packed, &scales, 8);
assert_eq!(result.len(), 8);
// Check values: ternary * scale
assert_eq!(result[0], -0.5); // -1 * 0.5
assert_eq!(result[1], 0.0); // 0 * 0.5
assert_eq!(result[2], 0.5); // 1 * 0.5
assert_eq!(result[3], -0.5); // -1 * 0.5
}
#[test]
fn test_dequantize_bitnet_block() {
// Create a full 256-element block
let ternary = vec![1i8; 256];
let packed = pack_ternary(&ternary);
let scale = 2.0;
let mut output = vec![0.0f32; 256];
dequantize_bitnet_block(&packed, scale, &mut output);
// All values should be 1 * 2.0 = 2.0
assert!(output.iter().all(|&v| (v - 2.0).abs() < 1e-6));
}
#[test]
fn test_dequantize_multiple_blocks() {
// Two blocks with different scales
let ternary1 = vec![1i8; 256];
let ternary2 = vec![-1i8; 256];
let mut all_ternary = ternary1.clone();
all_ternary.extend_from_slice(&ternary2);
let packed = pack_ternary(&all_ternary);
let scales = vec![1.0, 2.0];
let result = dequantize_bitnet_t158(&packed, &scales, 512);
// First 256 should be 1.0 * 1.0 = 1.0
assert!(result[..256].iter().all(|&v| (v - 1.0).abs() < 1e-6));
// Next 256 should be -1.0 * 2.0 = -2.0
assert!(result[256..512]
.iter()
.all(|&v| (v - (-2.0)).abs() < 1e-6));
}
#[test]
fn test_roundtrip_quantize_dequantize() {
// Original weights
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4, 0.2, -0.6];
// Quantize
let (ternary, scale) = absmean_ternary(&original);
let packed = pack_ternary(&ternary);
// Dequantize
let dequantized = dequantize_bitnet_t158(&packed, &[scale], original.len());
// Check that we got 8 values back
assert_eq!(dequantized.len(), 8);
// Values should be approximate (quantization loses precision)
// But should be close for values near the scale
for (orig, dequant) in original.iter().zip(dequantized.iter()) {
let error = (orig - dequant).abs();
// Error should be bounded by the quantization step (~scale)
assert!(error < scale * 2.0);
}
}
#[test]
fn test_compute_dequant_error() {
let original = vec![1.0, 2.0, 3.0, 4.0];
let dequantized = vec![1.1, 1.9, 3.2, 3.8];
let (mae, mse, max_error) = compute_dequant_error(&original, &dequantized);
// MAE should be (0.1 + 0.1 + 0.2 + 0.2) / 4 = 0.15
assert!((mae - 0.15).abs() < 1e-6);
// MSE should be (0.01 + 0.01 + 0.04 + 0.04) / 4 = 0.025
assert!((mse - 0.025).abs() < 1e-6);
// Max error should be 0.2
assert!((max_error - 0.2).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "Output buffer must hold at least 256 elements")]
fn test_dequantize_block_small_buffer() {
let packed = vec![0u8; 64];
let mut output = vec![0.0f32; 128]; // Too small
dequantize_bitnet_block(&packed, 1.0, &mut output);
}
#[test]
#[should_panic(expected = "Packed block must be exactly 64 bytes")]
fn test_dequantize_block_wrong_size() {
let packed = vec![0u8; 32]; // Wrong size
let mut output = vec![0.0f32; 256];
dequantize_bitnet_block(&packed, 1.0, &mut output);
}
#[test]
fn test_dequantize_with_missing_scales() {
// More elements than scales (should use default 1.0)
let ternary = vec![1i8; 512];
let packed = pack_ternary(&ternary);
let scales = vec![2.0]; // Only one scale for two blocks
let result = dequantize_bitnet_t158(&packed, &scales, 512);
// First 256 use scale 2.0
assert!(result[..256].iter().all(|&v| (v - 2.0).abs() < 1e-6));
// Next 256 use default 1.0
assert!(result[256..512].iter().all(|&v| (v - 1.0).abs() < 1e-6));
}
}

View file

@ -0,0 +1,58 @@
//! BitNet b1.58 Ternary Quantization for RuvLLM
//!
//! This module implements Microsoft Research's BitNet b1.58 ternary weight quantization
//! for the Craftsman Ultra 30b 1bit model. It provides post-training quantization (PTQ)
//! of FP16 weights to ternary {-1, 0, +1} using absmean quantization.
//!
//! ## Overview
//!
//! BitNet b1.58 enables multiplication-free inference by quantizing weights to three values:
//! -1, 0, +1. This reduces memory footprint to ~2 bits per weight and eliminates floating-point
//! multiplication in matrix operations.
//!
//! ## Key Components
//!
//! - [`TernaryTensor`]: Container for ternary weights with 2-bit packing
//! - [`quantize_tensor`]: Convert FP32 weights to ternary using absmean algorithm
//! - [`dequantize_bitnet_t158`]: Convert packed ternary back to FP32 for validation
//! - [`PtBitnetConfig`]: Configuration for post-training quantization
//!
//! ## Example
//!
//! ```rust,ignore
//! use ruvllm::bitnet::{quantize_tensor, PtBitnetConfig};
//!
//! // Configure quantization
//! let config = PtBitnetConfig {
//! block_size: 256,
//! optimize_scales: true,
//! ..Default::default()
//! };
//!
//! // Quantize a weight tensor
//! let fp32_weights = vec![0.5, -0.3, 0.0, 0.8, /* ... */];
//! let ternary = quantize_tensor(&fp32_weights, (128, 256), &config)?;
//!
//! println!("Sparsity: {:.2}%", ternary.sparsity() * 100.0);
//! println!("Memory: {} bytes", ternary.memory_bytes());
//! ```
//!
//! ## Architecture Details
//!
//! From ADR-017 (AD-1, AD-5, AD-18):
//!
//! - **Absmean quantization**: `W_ternary = RoundClip(W / (mean(|W|) + ε), -1, 1)`
//! - **2-bit packing**: 00=-1, 01=0, 10=+1 (4 values per byte)
//! - **Block size**: 256 elements per scale factor
//! - **Storage**: 66 bytes per block (64 bytes ternary + 2 bytes FP16 scale)
//! - **Compression**: 2.06 bits/weight (30B model → ~7.7 GB)
pub mod dequantize;
pub mod quantizer;
pub mod ternary_tensor;
pub use dequantize::dequantize_bitnet_t158;
pub use quantizer::{
absmean_ternary, quantize_tensor, LayerMask, Precision, PtBitnetConfig, TernaryFormat,
};
pub use ternary_tensor::{pack_ternary, unpack_ternary, TernaryTensor};

View file

@ -0,0 +1,338 @@
//! PT-BitNet Post-Training Quantization
//!
//! Core absmean ternary quantization algorithm for converting FP32 weights
//! to BitNet b1.58 ternary format.
use crate::error::{Result, RuvLLMError};
use super::ternary_tensor::{pack_ternary, TernaryTensor};
/// Configuration for PT-BitNet post-training quantization.
///
/// Controls the quantization process behavior, including block size,
/// calibration, and layer selection.
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::PtBitnetConfig;
///
/// let config = PtBitnetConfig {
/// calibration_samples: 1000,
/// block_size: 256,
/// optimize_scales: true,
/// layers_to_quantize: LayerMask::ExpertsOnly,
/// export_format: TernaryFormat::BitnetT158,
/// ..Default::default()
/// };
/// ```
#[derive(Debug, Clone)]
pub struct PtBitnetConfig {
/// Number of calibration samples for scale optimization
pub calibration_samples: usize,
/// Elements per quantization block
pub block_size: usize,
/// Enable scale factor optimization via calibration
pub optimize_scales: bool,
/// Which layers to quantize
pub layers_to_quantize: LayerMask,
/// Export format for GGUF serialization
pub export_format: TernaryFormat,
/// Precision for router and shared layers
pub router_precision: Precision,
/// Use memory-mapped I/O for weight loading
pub use_mmap: bool,
/// Use Metal GPU for calibration (Mac Studio only)
pub use_metal_calibration: bool,
/// Maximum memory budget in GB
pub max_memory_gb: usize,
}
impl Default for PtBitnetConfig {
fn default() -> Self {
Self {
calibration_samples: 1000,
block_size: 256,
optimize_scales: true,
layers_to_quantize: LayerMask::ExpertsOnly,
export_format: TernaryFormat::BitnetT158,
router_precision: Precision::FP16,
use_mmap: true,
use_metal_calibration: cfg!(all(target_os = "macos", feature = "metal-compute")),
max_memory_gb: 64,
}
}
}
/// Layer selection mask for quantization.
///
/// Determines which model layers to convert to ternary. Per ADR-017 (AD-2),
/// the MoE router, embeddings, and LM head must remain in higher precision.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LayerMask {
/// Only MoE expert FFN layers (recommended for Phase 1)
ExpertsOnly,
/// All linear layers except router/embeddings/head
All,
/// Custom layer selection by name pattern
Custom(Vec<String>),
}
/// Ternary tensor export format.
///
/// Determines the GGUF quantization type used for serialization.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TernaryFormat {
/// BitNet b1.58 native format (type 30)
BitnetT158,
/// IQ1_S compatible format (type 19)
IQ1S,
}
/// Precision for non-quantized layers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision {
/// 16-bit floating point
FP16,
/// Brain floating point 16
BF16,
/// 32-bit floating point
FP32,
}
/// Core absmean ternary quantization algorithm.
///
/// Implements the BitNet b1.58 quantization formula:
/// ```text
/// gamma = mean(|block|) + epsilon
/// normalized = block / gamma
/// ternary = round(clamp(normalized, -1, 1))
/// ```
///
/// # Arguments
///
/// * `block` - FP32 weight block (typically 256 elements)
///
/// # Returns
///
/// Tuple of (ternary values, scale factor):
/// - `Vec<i8>`: Ternary weights in {-1, 0, +1}
/// - `f32`: Absmean scale factor (gamma)
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::absmean_ternary;
///
/// let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4];
/// let (ternary, scale) = absmean_ternary(&weights);
///
/// println!("Scale: {}", scale);
/// println!("Ternary: {:?}", ternary); // e.g., [1, -1, 1, 0, 0, 1]
/// ```
pub fn absmean_ternary(block: &[f32]) -> (Vec<i8>, f32) {
// Compute absmean scale: gamma = mean(|W|)
let sum_abs: f32 = block.iter().map(|&w| w.abs()).sum();
let gamma = (sum_abs / block.len() as f32) + 1e-8;
// Normalize and quantize to {-1, 0, +1}
let ternary: Vec<i8> = block
.iter()
.map(|&w| {
let normalized = w / gamma;
let clamped = normalized.clamp(-1.0, 1.0);
clamped.round() as i8
})
.collect();
(ternary, gamma)
}
/// Quantize a full FP32 tensor to ternary representation.
///
/// Processes the input tensor in blocks of `config.block_size`, applying
/// absmean quantization to each block independently.
///
/// # Arguments
///
/// * `weights` - FP32 weight tensor (flattened)
/// * `shape` - Tensor shape (rows, cols)
/// * `config` - Quantization configuration
///
/// # Returns
///
/// `TernaryTensor` with packed 2-bit data and per-block scales
///
/// # Errors
///
/// Returns an error if the weight dimensions are invalid.
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::{quantize_tensor, PtBitnetConfig};
///
/// let weights = vec![0.5; 512]; // 512 FP32 weights
/// let shape = (2, 256);
/// let config = PtBitnetConfig::default();
///
/// let ternary = quantize_tensor(&weights, shape, &config)?;
/// println!("Compressed to {} bytes", ternary.memory_bytes());
/// ```
pub fn quantize_tensor(
weights: &[f32],
shape: (usize, usize),
config: &PtBitnetConfig,
) -> Result<TernaryTensor> {
let (rows, cols) = shape;
let total_elements = rows * cols;
if weights.len() != total_elements {
return Err(RuvLLMError::Model(format!(
"Weight size mismatch: expected {} elements for shape {:?}, got {}",
total_elements,
shape,
weights.len()
)));
}
let block_size = config.block_size;
let num_blocks = (total_elements + block_size - 1) / block_size;
let mut all_ternary = Vec::with_capacity(total_elements);
let mut scales = Vec::with_capacity(num_blocks);
// Process each block
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(total_elements);
let block = &weights[start..end];
let (ternary, scale) = absmean_ternary(block);
all_ternary.extend_from_slice(&ternary);
scales.push(scale);
}
// Pack ternary values into 2-bit representation
let packed_data = pack_ternary(&all_ternary);
Ok(TernaryTensor {
packed_data,
scales,
shape,
block_size,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_absmean_ternary_simple() {
// Simple block with known values
let block = vec![0.5, -0.5, 0.0, 1.0, -1.0, 0.25];
let (ternary, scale) = absmean_ternary(&block);
// All values should be in {-1, 0, +1}
assert!(ternary.iter().all(|&v| v >= -1 && v <= 1));
// Scale should be positive
assert!(scale > 0.0);
// Check specific values
// gamma ≈ (0.5 + 0.5 + 0.0 + 1.0 + 1.0 + 0.25) / 6 ≈ 0.542
// 0.5 / 0.542 ≈ 0.92 → round(0.92) = 1
// -0.5 / 0.542 ≈ -0.92 → round(-0.92) = -1
// 0.0 / 0.542 = 0 → round(0) = 0
assert_eq!(ternary[0], 1);
assert_eq!(ternary[1], -1);
assert_eq!(ternary[2], 0);
}
#[test]
fn test_absmean_ternary_all_zeros() {
let block = vec![0.0; 256];
let (ternary, scale) = absmean_ternary(&block);
// All should quantize to 0
assert!(ternary.iter().all(|&v| v == 0));
// Scale should be epsilon (1e-8)
assert!(scale < 1e-7 && scale > 0.0);
}
#[test]
fn test_absmean_ternary_large_values() {
let block = vec![10.0, -10.0, 5.0, -5.0];
let (ternary, _scale) = absmean_ternary(&block);
// All should saturate to ±1
assert!(ternary[0] == 1 || ternary[0] == -1);
assert!(ternary[1] == 1 || ternary[1] == -1);
}
#[test]
fn test_quantize_tensor_simple() {
let weights = vec![0.5; 512]; // 512 identical weights
let shape = (2, 256);
let config = PtBitnetConfig::default();
let ternary = quantize_tensor(&weights, shape, &config).unwrap();
assert_eq!(ternary.shape, shape);
assert_eq!(ternary.block_size, 256);
assert_eq!(ternary.num_blocks(), 2); // 512 / 256 = 2 blocks
assert_eq!(ternary.scales.len(), 2);
// 512 elements packed in 2 bits each = 128 bytes
assert_eq!(ternary.packed_data.len(), 128);
}
#[test]
fn test_quantize_tensor_size_mismatch() {
let weights = vec![0.5; 100]; // Wrong size
let shape = (2, 256); // Expects 512
let config = PtBitnetConfig::default();
let result = quantize_tensor(&weights, shape, &config);
assert!(result.is_err());
}
#[test]
fn test_quantize_tensor_memory_savings() {
// Quantize a 1MB FP32 tensor (256K elements)
let weights = vec![0.5; 256 * 1024];
let shape = (512, 512);
let config = PtBitnetConfig::default();
let ternary = quantize_tensor(&weights, shape, &config).unwrap();
let original_bytes = weights.len() * 4; // FP32
let compressed_bytes = ternary.memory_bytes();
// Should be ~16x compression (32 bits → 2 bits + scale overhead)
let compression_ratio = original_bytes as f32 / compressed_bytes as f32;
assert!(compression_ratio > 10.0); // At least 10x compression
assert!(compression_ratio < 20.0); // Less than 20x (due to scales)
}
#[test]
fn test_config_default() {
let config = PtBitnetConfig::default();
assert_eq!(config.block_size, 256);
assert_eq!(config.calibration_samples, 1000);
assert!(config.optimize_scales);
assert_eq!(config.layers_to_quantize, LayerMask::ExpertsOnly);
}
#[test]
fn test_layer_mask_variants() {
let experts = LayerMask::ExpertsOnly;
let all = LayerMask::All;
let custom = LayerMask::Custom(vec!["layer.0".to_string()]);
assert_ne!(experts, all);
assert_ne!(all, custom);
assert_ne!(experts, custom);
}
}

View file

@ -0,0 +1,276 @@
//! Ternary Tensor Data Structure
//!
//! This module provides the `TernaryTensor` container for BitNet b1.58 ternary weights,
//! along with efficient 2-bit packing/unpacking functions.
/// Ternary tensor with 2-bit packed representation.
///
/// Stores ternary weights {-1, 0, +1} in a compact 2-bit format:
/// - 00 = -1
/// - 01 = 0
/// - 10 = +1
/// - 11 = reserved (unused)
///
/// Each block of `block_size` elements shares a single FP32 scale factor
/// derived from the absmean quantization process.
///
/// # Memory Layout
///
/// For a tensor with shape (m, n) and block_size B:
/// - `packed_data`: ceil(m*n / 4) bytes (4 ternary values per byte)
/// - `scales`: ceil(m*n / B) * 4 bytes (one FP32 scale per block)
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::TernaryTensor;
///
/// let tensor = TernaryTensor {
/// packed_data: vec![0b10010100], // [+1, 0, +1, 0]
/// scales: vec![0.5],
/// shape: (2, 2),
/// block_size: 256,
/// };
///
/// println!("Sparsity: {:.2}%", tensor.sparsity() * 100.0);
/// println!("Memory: {} bytes", tensor.memory_bytes());
/// ```
#[derive(Debug, Clone)]
pub struct TernaryTensor {
/// Packed 2-bit ternary data (4 values per byte)
pub packed_data: Vec<u8>,
/// Per-block scale factors (FP32)
pub scales: Vec<f32>,
/// Tensor shape (rows, cols)
pub shape: (usize, usize),
/// Elements per quantization block
pub block_size: usize,
}
impl TernaryTensor {
/// Calculate the fraction of zero weights (sparsity).
///
/// Zero weights enable feature filtering and reduce computation
/// in ternary matrix multiplication.
///
/// # Returns
///
/// Fraction of weights that are exactly 0, in range [0.0, 1.0]
pub fn sparsity(&self) -> f32 {
let total_elements = self.shape.0 * self.shape.1;
let unpacked = unpack_ternary(&self.packed_data, total_elements);
let zero_count = unpacked.iter().filter(|&&x| x == 0).count();
zero_count as f32 / total_elements as f32
}
/// Calculate total memory footprint in bytes.
///
/// Includes both packed ternary data and per-block scales.
///
/// # Returns
///
/// Total bytes: packed_data.len() + scales.len() * 4
pub fn memory_bytes(&self) -> usize {
self.packed_data.len() + self.scales.len() * 4
}
/// Get the number of quantization blocks.
pub fn num_blocks(&self) -> usize {
let total_elements = self.shape.0 * self.shape.1;
(total_elements + self.block_size - 1) / self.block_size
}
}
/// Pack ternary values {-1, 0, +1} into 2-bit representation.
///
/// Encoding:
/// - -1 → 00
/// - 0 → 01
/// - +1 → 10
/// - (unused) → 11
///
/// Four values are packed into each byte in LSB-first order:
/// ```text
/// byte = [v3:v2:v1:v0]
/// ```
///
/// # Arguments
///
/// * `values` - Slice of i8 values, must be in {-1, 0, +1}
///
/// # Returns
///
/// Vector of bytes, length = ceil(values.len() / 4)
///
/// # Panics
///
/// Panics if any value is not in {-1, 0, +1}
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::pack_ternary;
///
/// let values = vec![-1, 0, 1, -1];
/// let packed = pack_ternary(&values);
/// assert_eq!(packed.len(), 1); // 4 values in 1 byte
/// ```
pub fn pack_ternary(values: &[i8]) -> Vec<u8> {
let num_bytes = (values.len() + 3) / 4;
let mut packed = vec![0u8; num_bytes];
for (i, &val) in values.iter().enumerate() {
let byte_idx = i / 4;
let bit_offset = (i % 4) * 2;
let encoded: u8 = match val {
-1 => 0b00,
0 => 0b01,
1 => 0b10,
_ => panic!("Invalid ternary value: {} (must be -1, 0, or +1)", val),
};
packed[byte_idx] |= encoded << bit_offset;
}
packed
}
/// Unpack 2-bit ternary values to i8.
///
/// Decoding:
/// - 00 → -1
/// - 01 → 0
/// - 10 → +1
/// - 11 → 0 (reserved, treated as zero)
///
/// # Arguments
///
/// * `packed` - Packed 2-bit data
/// * `n` - Number of elements to unpack
///
/// # Returns
///
/// Vector of i8 values in {-1, 0, +1}
///
/// # Example
///
/// ```rust,ignore
/// use ruvllm::bitnet::{pack_ternary, unpack_ternary};
///
/// let original = vec![-1, 0, 1, -1];
/// let packed = pack_ternary(&original);
/// let unpacked = unpack_ternary(&packed, 4);
/// assert_eq!(original, unpacked);
/// ```
pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec<i8> {
let mut values = Vec::with_capacity(n);
for i in 0..n {
let byte_idx = i / 4;
let bit_offset = (i % 4) * 2;
if byte_idx >= packed.len() {
break;
}
let encoded = (packed[byte_idx] >> bit_offset) & 0b11;
let val = match encoded {
0b00 => -1,
0b01 => 0,
0b10 => 1,
0b11 => 0, // Reserved, treat as zero
_ => unreachable!(),
};
values.push(val);
}
values
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_unpack_ternary() {
let values = vec![-1, 0, 1, -1, 1, 0, 0, 1];
let packed = pack_ternary(&values);
let unpacked = unpack_ternary(&packed, values.len());
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_ternary_single_byte() {
// 4 values fit in 1 byte
let values = vec![-1, 0, 1, -1];
let packed = pack_ternary(&values);
assert_eq!(packed.len(), 1);
// Manually verify encoding
// -1=00, 0=01, 1=10, -1=00
// byte = [00:10:01:00] = 0b00_10_01_00 = 0x08
assert_eq!(packed[0], 0b00_10_01_00);
}
#[test]
fn test_pack_ternary_partial_byte() {
// 5 values need 2 bytes
let values = vec![-1, 0, 1, -1, 1];
let packed = pack_ternary(&values);
assert_eq!(packed.len(), 2);
}
#[test]
#[should_panic(expected = "Invalid ternary value")]
fn test_pack_invalid_value() {
let values = vec![-1, 0, 2]; // 2 is invalid
pack_ternary(&values);
}
#[test]
fn test_ternary_tensor_sparsity() {
let values = vec![0, 1, 0, -1, 0, 0, 1, 0]; // 5 zeros out of 8
let packed = pack_ternary(&values);
let tensor = TernaryTensor {
packed_data: packed,
scales: vec![1.0],
shape: (2, 4),
block_size: 256,
};
let sparsity = tensor.sparsity();
assert!((sparsity - 0.625).abs() < 0.001); // 5/8 = 0.625
}
#[test]
fn test_ternary_tensor_memory() {
let packed = vec![0u8; 64]; // 64 bytes of packed data
let scales = vec![0.5f32; 16]; // 16 scales * 4 bytes = 64 bytes
let tensor = TernaryTensor {
packed_data: packed,
scales,
shape: (128, 256),
block_size: 256,
};
assert_eq!(tensor.memory_bytes(), 64 + 64); // 128 bytes total
}
#[test]
fn test_ternary_tensor_num_blocks() {
let tensor = TernaryTensor {
packed_data: vec![],
scales: vec![],
shape: (256, 256), // 65536 elements
block_size: 256, // 256 elements per block
};
assert_eq!(tensor.num_blocks(), 256); // 65536 / 256 = 256 blocks
}
}

View file

@ -0,0 +1,662 @@
//! Comprehensive tests for PT-BitNet Phase 0 ternary quantization
//!
//! Test coverage based on ADR-017 (AD-1, AD-18):
//! - Ternary packing/unpacking roundtrips
//! - Absmean quantization correctness
//! - Dequantization accuracy
//! - Full tensor quantization
//! - Edge cases and error conditions
use super::{
dequantize_bitnet_t158, pack_ternary, quantize_tensor, unpack_ternary, PtBitnetConfig,
TernaryTensor,
};
// ============================================================================
// Test Constants
// ============================================================================
const EPSILON: f32 = 1e-6;
const BLOCK_SIZE: usize = 256;
// ============================================================================
// 1. Ternary Packing Roundtrip Tests
// ============================================================================
#[test]
fn test_pack_unpack_simple_roundtrip() {
// Simple 4-element ternary array
let ternary = vec![1i8, 0, -1, 1];
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 4);
assert_eq!(ternary, unpacked, "Packing roundtrip failed for [1, 0, -1, 1]");
}
#[test]
fn test_pack_all_zeros() {
let ternary = vec![0i8; 256];
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 256);
assert_eq!(ternary, unpacked);
assert!(unpacked.iter().all(|&x| x == 0), "All zeros should remain all zeros");
}
#[test]
fn test_pack_all_ones() {
let ternary = vec![1i8; 256];
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 256);
assert_eq!(ternary, unpacked);
assert!(unpacked.iter().all(|&x| x == 1), "All +1 should remain all +1");
}
#[test]
fn test_pack_all_neg_ones() {
let ternary = vec![-1i8; 256];
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 256);
assert_eq!(ternary, unpacked);
assert!(unpacked.iter().all(|&x| x == -1), "All -1 should remain all -1");
}
#[test]
fn test_pack_one_block_256_elements() {
// One full block (256 elements) with alternating pattern
let mut ternary = Vec::with_capacity(256);
for i in 0..256 {
ternary.push(match i % 3 {
0 => 1,
1 => 0,
2 => -1,
_ => unreachable!(),
});
}
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 256);
assert_eq!(ternary, unpacked, "256-element block roundtrip failed");
// Verify storage size: 256 elements * 2 bits = 64 bytes
assert_eq!(packed.len(), 64, "Packed size should be 64 bytes for 256 elements");
}
#[test]
fn test_pack_non_aligned_size() {
// 100 elements (not divisible by 128, the typical packing boundary)
let mut ternary = Vec::with_capacity(100);
for i in 0..100 {
ternary.push(if i % 2 == 0 { 1 } else { -1 });
}
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 100);
assert_eq!(
ternary.len(),
unpacked.len(),
"Unpacked length should match original"
);
assert_eq!(ternary, unpacked, "Non-aligned size roundtrip failed");
}
#[test]
fn test_pack_large_tensor() {
// Multiple blocks (1024 elements = 4 blocks)
let ternary: Vec<i8> = (0..1024)
.map(|i| match i % 5 {
0 | 1 => 1,
2 | 3 => -1,
4 => 0,
_ => unreachable!(),
})
.collect();
let packed = pack_ternary(&ternary);
let unpacked = unpack_ternary(&packed, 1024);
assert_eq!(ternary, unpacked, "Large tensor roundtrip failed");
}
// ============================================================================
// 2. Absmean Quantization Correctness Tests
// ============================================================================
#[test]
fn test_quantize_uniform_random() {
// Uniform random weights in [-1, 1] should produce all ternary values
let weights = vec![0.5, -0.3, 0.1, -0.7, 0.9, -0.1, 0.0, 0.4];
let ternary = quantize_absmean(&weights);
// All outputs must be in {-1, 0, +1}
for &t in &ternary {
assert!(
t == -1 || t == 0 || t == 1,
"Quantized value {} not in ternary set",
t
);
}
}
#[test]
fn test_quantize_all_zeros() {
let weights = vec![0.0; 256];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// All ternary values should be zero
assert!(
ternary.iter().all(|&x| x == 0),
"All-zero input should produce all-zero ternary"
);
// Scale should be near epsilon (avoiding division by zero)
assert!(
scale < 1e-5,
"Scale for all-zero weights should be near epsilon, got {}",
scale
);
}
#[test]
fn test_quantize_large_positive() {
// Large positive weights should quantize to all +1
let weights = vec![10.0; 256];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// All should be +1
assert!(
ternary.iter().all(|&x| x == 1),
"Large positive weights should quantize to +1"
);
// Scale should be approximately 10.0 (mean absolute value)
assert!(
(scale - 10.0).abs() < 0.1,
"Scale should be ~10.0, got {}",
scale
);
}
#[test]
fn test_quantize_large_negative() {
// Large negative weights should quantize to all -1
let weights = vec![-10.0; 256];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// All should be -1
assert!(
ternary.iter().all(|&x| x == -1),
"Large negative weights should quantize to -1"
);
// Scale should be approximately 10.0 (mean absolute value)
assert!(
(scale - 10.0).abs() < 0.1,
"Scale should be ~10.0, got {}",
scale
);
}
#[test]
fn test_quantize_known_example() {
// From ADR: W_ternary = RoundClip(W / (mean(|W|) + epsilon), -1, 1)
// Example: weights = [0.5, -0.3, 0.1, -0.7]
// gamma = mean(|W|) = (0.5 + 0.3 + 0.1 + 0.7) / 4 = 0.4
// normalized = [1.25, -0.75, 0.25, -1.75]
// ternary = [1, -1, 0, -1] (after clamp and round)
let weights = vec![0.5, -0.3, 0.1, -0.7];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// Verify scale is approximately 0.4
assert!(
(scale - 0.4).abs() < 0.01,
"Expected scale ~0.4, got {}",
scale
);
// Verify ternary values
// 1.25 -> 1, -0.75 -> -1, 0.25 -> 0, -1.75 -> -1
assert_eq!(ternary[0], 1, "0.5/0.4 = 1.25 should round to 1");
assert_eq!(ternary[1], -1, "-0.3/0.4 = -0.75 should round to -1");
assert_eq!(ternary[2], 0, "0.1/0.4 = 0.25 should round to 0");
assert_eq!(ternary[3], -1, "-0.7/0.4 = -1.75 should clamp to -1");
}
#[test]
fn test_quantize_scale_calculation() {
// Verify scale = mean(|weights|)
let weights = vec![1.0, -2.0, 3.0, -4.0];
let (_, scale) = quantize_absmean_with_scale(&weights);
let expected_scale = (1.0 + 2.0 + 3.0 + 4.0) / 4.0; // = 2.5
assert!(
(scale - expected_scale).abs() < EPSILON,
"Scale should be mean of absolute values: expected {}, got {}",
expected_scale,
scale
);
}
// ============================================================================
// 3. Dequantization Correctness Tests
// ============================================================================
#[test]
fn test_dequantize_simple() {
let ternary = vec![1i8, 0, -1];
let scale = 2.0;
let dequantized = dequantize_ternary(&ternary, scale);
assert_eq!(dequantized.len(), 3);
assert!((dequantized[0] - 2.0).abs() < EPSILON, "1 * 2.0 = 2.0");
assert!((dequantized[1] - 0.0).abs() < EPSILON, "0 * 2.0 = 0.0");
assert!((dequantized[2] - (-2.0)).abs() < EPSILON, "-1 * 2.0 = -2.0");
}
#[test]
fn test_dequantize_packed_data() {
// Pack known ternary data, then dequantize
let ternary = vec![1i8, 0, -1, 1];
let packed = pack_ternary(&ternary);
let scale = 3.5;
let unpacked = unpack_ternary(&packed, 4);
let dequantized = dequantize_ternary(&unpacked, scale);
assert_eq!(dequantized.len(), 4);
assert!((dequantized[0] - 3.5).abs() < EPSILON);
assert!((dequantized[1] - 0.0).abs() < EPSILON);
assert!((dequantized[2] - (-3.5)).abs() < EPSILON);
assert!((dequantized[3] - 3.5).abs() < EPSILON);
}
#[test]
fn test_quantize_dequantize_roundtrip_mse() {
// Quantize -> Dequantize should have bounded MSE
let weights = vec![0.5, -0.3, 0.1, -0.7, 0.9, -0.1, 0.4, -0.5];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
let dequantized = dequantize_ternary(&ternary, scale);
// Compute MSE
let mse: f32 = weights
.iter()
.zip(dequantized.iter())
.map(|(&w, &d)| (w - d).powi(2))
.sum::<f32>()
/ weights.len() as f32;
// MSE should be reasonable (ternary quantization is lossy)
// For absmean, expect MSE < 0.5 for normalized weights
assert!(
mse < 0.5,
"MSE too high: {} (weights may not reconstruct well)",
mse
);
}
#[test]
fn test_dequantize_full_block() {
// Dequantize a full 256-element block
let ternary: Vec<i8> = (0..256).map(|i| if i % 2 == 0 { 1 } else { -1 }).collect();
let scale = 1.5;
let dequantized = dequantize_ternary(&ternary, scale);
assert_eq!(dequantized.len(), 256);
for (i, &val) in dequantized.iter().enumerate() {
let expected = if i % 2 == 0 { 1.5 } else { -1.5 };
assert!(
(val - expected).abs() < EPSILON,
"Element {} incorrect: expected {}, got {}",
i,
expected,
val
);
}
}
// ============================================================================
// 4. Full Tensor Quantization Tests
// ============================================================================
#[test]
fn test_tensor_quantize_256x256() {
// 256x256 random tensor (65536 elements)
let mut weights = Vec::with_capacity(65536);
for i in 0..65536 {
let val = ((i as f32) * 0.001).sin(); // Pseudo-random in [-1, 1]
weights.push(val);
}
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
// Verify shape preserved
assert_eq!(
tensor.num_elements(),
65536,
"Tensor should preserve element count"
);
// Verify sparsity is in valid range
let sparsity = tensor.sparsity();
assert!(
sparsity >= 0.0 && sparsity <= 1.0,
"Sparsity {} out of range [0, 1]",
sparsity
);
// For uniform random, expect ~1/3 zeros (rough heuristic)
assert!(
sparsity > 0.15 && sparsity < 0.5,
"Sparsity {} seems unrealistic for uniform random input",
sparsity
);
}
#[test]
fn test_tensor_memory_bytes() {
let weights = vec![0.5; 256];
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
// Expected memory:
// - Packed data: 256 elements * 2 bits / 8 = 64 bytes
// - Scales: 1 block * 4 bytes (f32) = 4 bytes
// Total: 68 bytes
let expected_bytes = 64 + 4;
assert_eq!(
tensor.memory_bytes(),
expected_bytes,
"Memory calculation incorrect"
);
}
#[test]
fn test_tensor_sparsity_calculation() {
// Known sparsity: 50% zeros
let weights: Vec<f32> = (0..256)
.map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
.collect();
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
let sparsity = tensor.sparsity();
// Should be close to 0.5 (half zeros)
assert!(
(sparsity - 0.5).abs() < 0.1,
"Expected sparsity ~0.5, got {}",
sparsity
);
}
#[test]
fn test_tensor_block_alignment() {
// 512 elements = 2 blocks of 256
let weights = vec![1.0; 512];
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
// Should have 2 scale factors (one per block)
assert_eq!(
tensor.num_blocks(),
2,
"Expected 2 blocks for 512 elements"
);
}
#[test]
fn test_tensor_non_aligned_padding() {
// 300 elements (256 + 44) should create 2 blocks with padding
let weights = vec![0.5; 300];
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
// Should pad to 2 full blocks (512 elements)
let num_blocks = (300 + BLOCK_SIZE - 1) / BLOCK_SIZE;
assert_eq!(
tensor.num_blocks(),
num_blocks,
"Non-aligned tensor should pad to full blocks"
);
// Original element count should be preserved
assert_eq!(tensor.num_elements(), 300);
}
// ============================================================================
// 5. TernaryTensor Properties Tests
// ============================================================================
#[test]
fn test_ternary_tensor_properties() {
let weights: Vec<f32> = (0..512).map(|i| (i as f32) * 0.01).collect();
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
// Memory bytes should match calculation
let num_blocks = (512 + BLOCK_SIZE - 1) / BLOCK_SIZE;
let packed_bytes = num_blocks * BLOCK_SIZE * 2 / 8; // 2 bits per element
let scale_bytes = num_blocks * 4; // f32 scales
let expected = packed_bytes + scale_bytes;
assert_eq!(tensor.memory_bytes(), expected);
// Sparsity should be in valid range
assert!(tensor.sparsity() >= 0.0 && tensor.sparsity() <= 1.0);
}
#[test]
fn test_ternary_tensor_uniform_random_sparsity() {
// Uniform random should have ~1/3 sparsity
let mut weights = Vec::with_capacity(2048);
for i in 0..2048 {
weights.push(((i as f32) * 1.234).sin());
}
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
let sparsity = tensor.sparsity();
// Rough heuristic: 20-45% zeros for uniform random
assert!(
sparsity > 0.2 && sparsity < 0.45,
"Uniform random sparsity {} outside expected range [0.2, 0.45]",
sparsity
);
}
// ============================================================================
// 6. Config Validation Tests
// ============================================================================
#[test]
fn test_config_default_values() {
let config = PtBitnetConfig::default();
assert_eq!(config.block_size, 256, "Default block size should be 256");
assert!(
config.calibration_samples > 0,
"Calibration samples must be > 0"
);
}
#[test]
#[should_panic(expected = "block_size must be > 0")]
fn test_config_invalid_block_size() {
let _config = PtBitnetConfig {
block_size: 0,
..Default::default()
};
}
#[test]
#[should_panic(expected = "calibration_samples must be > 0")]
fn test_config_invalid_calibration_samples() {
let _config = PtBitnetConfig {
calibration_samples: 0,
..Default::default()
};
}
// ============================================================================
// 7. Edge Case Tests
// ============================================================================
#[test]
fn test_empty_input() {
let weights: Vec<f32> = vec![];
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
assert_eq!(tensor.num_elements(), 0);
assert_eq!(tensor.num_blocks(), 0);
assert_eq!(tensor.sparsity(), 0.0);
}
#[test]
fn test_single_element() {
let weights = vec![0.5];
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
assert_eq!(tensor.num_elements(), 1);
// Should create 1 block (padded)
assert_eq!(tensor.num_blocks(), 1);
}
#[test]
fn test_very_large_values() {
let weights = vec![f32::MAX, f32::MAX, f32::MAX, f32::MAX];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// Should all quantize to +1
assert!(ternary.iter().all(|&x| x == 1), "f32::MAX should quantize to +1");
// Scale should be approximately f32::MAX
assert!(scale > 1e30, "Scale should be very large");
// Dequantization should not produce NaN
let dequantized = dequantize_ternary(&ternary, scale);
assert!(
dequantized.iter().all(|&x| !x.is_nan()),
"Dequantization should not produce NaN"
);
}
#[test]
fn test_subnormal_floats() {
// Very small positive values (subnormal range)
let weights = vec![1e-40, -1e-40, 1e-39, -1e-39];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// Should quantize reasonably (may be all zeros or small values)
assert!(ternary.iter().all(|&x| x >= -1 && x <= 1));
// Scale should be tiny but not zero
assert!(scale > 0.0, "Scale should be > 0 even for subnormal inputs");
}
#[test]
fn test_nan_handling() {
// NaN should not crash, but behavior is implementation-defined
let weights = vec![f32::NAN, 1.0, -1.0, 0.0];
let result = std::panic::catch_unwind(|| {
quantize_absmean_with_scale(&weights)
});
// Should either panic or handle gracefully
// At minimum, should not produce infinite loop or segfault
if let Ok((ternary, scale)) = result {
// If it succeeds, output should not contain NaN
assert!(
!scale.is_nan() || scale == 0.0,
"Scale should not be NaN unless handled explicitly"
);
assert!(
ternary.iter().all(|&x| x >= -1 && x <= 1),
"Ternary values must be in valid range"
);
}
}
#[test]
fn test_infinity_handling() {
let weights = vec![f32::INFINITY, f32::NEG_INFINITY, 1.0, -1.0];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// Infinities should quantize to ±1
assert_eq!(ternary[0], 1, "INFINITY should quantize to +1");
assert_eq!(ternary[1], -1, "NEG_INFINITY should quantize to -1");
// Scale should be finite (or handled gracefully)
// Implementation may cap scale to avoid overflow
assert!(
scale.is_finite() || scale > 1e30,
"Scale should be finite or very large"
);
}
#[test]
fn test_mixed_magnitudes() {
// Mix of very large and very small values
let weights = vec![1000.0, 0.001, -1000.0, -0.001, 0.0];
let (ternary, scale) = quantize_absmean_with_scale(&weights);
// Should produce valid ternary values
assert!(ternary.iter().all(|&x| x >= -1 && x <= 1));
// Scale should be dominated by large values
assert!(scale > 100.0, "Scale should reflect large values");
// Small values should quantize to 0
assert_eq!(
ternary[1], 0,
"0.001 compared to scale ~500 should be 0"
);
assert_eq!(ternary[3], 0, "-0.001 should be 0");
}
// ============================================================================
// Helper Functions
// ============================================================================
/// Helper to quantize weights using absmean method
/// Returns both ternary values and scale factor
fn quantize_absmean_with_scale(weights: &[f32]) -> (Vec<i8>, f32) {
if weights.is_empty() {
return (vec![], 0.0);
}
// Compute absmean scale: gamma = mean(|W|) + epsilon
let absmean: f32 = weights.iter().map(|&w| w.abs()).sum::<f32>() / weights.len() as f32;
let scale = absmean + EPSILON;
// Quantize: W_ternary = RoundClip(W / scale, -1, 1)
let ternary: Vec<i8> = weights
.iter()
.map(|&w| {
let normalized = w / scale;
// Round and clip to {-1, 0, +1}
if normalized >= 0.5 {
1
} else if normalized <= -0.5 {
-1
} else {
0
}
})
.collect();
(ternary, scale)
}
/// Helper to quantize weights (scale not needed)
fn quantize_absmean(weights: &[f32]) -> Vec<i8> {
let (ternary, _scale) = quantize_absmean_with_scale(weights);
ternary
}
/// Helper to dequantize ternary values
fn dequantize_ternary(ternary: &[i8], scale: f32) -> Vec<f32> {
ternary.iter().map(|&t| (t as f32) * scale).collect()
}

View file

@ -29,6 +29,7 @@
//! | IQ4_NL | 4.5 | 32 | i-quant 4-bit non-linear |
use crate::error::{Result, RuvLLMError};
use crate::bitnet::dequantize_bitnet_t158;
// ============================================================================
// Quantization Types
@ -100,6 +101,8 @@ pub enum GgufQuantType {
F64 = 28,
/// BF16 brain float
Bf16 = 29,
/// BitNet b1.58 ternary quantization (2-bit packed)
BitnetT158 = 30,
}
impl TryFrom<u32> for GgufQuantType {
@ -137,6 +140,7 @@ impl TryFrom<u32> for GgufQuantType {
27 => Ok(Self::I64),
28 => Ok(Self::F64),
29 => Ok(Self::Bf16),
30 => Ok(Self::BitnetT158),
_ => Err(RuvLLMError::Model(format!(
"Unknown GGUF quantization type: {}",
value
@ -163,6 +167,7 @@ impl GgufQuantType {
Self::IQ1_S => 256,
Self::IQ4_NL => 32,
Self::IQ4_XS => 256,
Self::BitnetT158 => 256,
}
}
@ -214,6 +219,8 @@ impl GgufQuantType {
Self::IQ1_S => 50,
Self::IQ4_NL => 18,
Self::IQ4_XS => 136,
// BitNet b1.58: 256 elements -> 64 bytes (2-bit packed) + 2 bytes (FP16 scale) = 66 bytes
Self::BitnetT158 => 66,
}
}
@ -280,6 +287,7 @@ impl GgufQuantType {
Self::IQ1_S => "IQ1_S",
Self::IQ4_NL => "IQ4_NL",
Self::IQ4_XS => "IQ4_XS",
Self::BitnetT158 => "BITNET_T158",
}
}
}
@ -355,6 +363,14 @@ pub fn dequantize_tensor(
GgufQuantType::Q5_K => dequantize_q5_k(data, &mut output),
GgufQuantType::Q6_K => dequantize_q6_k(data, &mut output),
GgufQuantType::IQ4_NL => dequantize_iq4_nl(data, &mut output),
GgufQuantType::BitnetT158 => dequantize_bitnet_t158_wrapper(data, &mut output),
GgufQuantType::IQ1_S => {
return Err(RuvLLMError::Model(
"IQ1_S dequantization requires codebook lookup tables (not yet implemented). \
For BitNet ternary quantization, use BITNET_T158 type instead."
.to_string(),
));
}
_ => {
return Err(RuvLLMError::Model(format!(
"Dequantization not implemented for {:?}",

View file

@ -44,6 +44,7 @@
pub mod adapter_manager;
pub mod autodetect;
pub mod backends;
pub mod bitnet;
pub mod capabilities;
pub mod claude_flow;
pub mod context;

View file

@ -0,0 +1,999 @@
# PT-BitNet Quantizer Module Architecture Design
**Version:** 1.0
**Date:** 2026-02-03
**Status:** Design Specification
**Relates to:** ADR-017 (AD-1, AD-5, AD-18, AD-19), DDD Section 3.4/4.2/4.3
---
## Executive Summary
This document specifies the architecture for the **PT-BitNet post-training quantizer** module that converts FP16/BF16 GLM-4.7-Flash weights to BitNet b1.58 ternary {-1, 0, +1} format via absmean quantization. This is a **design-only specification** — implementation follows in Phase 0.
**Design Scope:**
- Module layout and file organization
- Complete struct definitions with field types
- Full function signatures (no implementations)
- GGUF integration points and format extensions
- Error handling strategy
- Testing approach
**Out of Scope:**
- Actual implementation code
- Performance benchmarks
- Calibration dataset selection
---
## A. Module Layout
### Directory Structure
```
crates/ruvllm/src/
├── bitnet/ # NEW module
│ ├── mod.rs # Module exports and public API
│ ├── quantizer.rs # PtBitnetQuantizer + absmean algorithm
│ ├── ternary_tensor.rs # TernaryTensor value object
│ ├── dequantize.rs # BITNET_T158 dequantization kernel
│ └── config.rs # PtBitnetConfig configuration
├── gguf/
│ ├── mod.rs # Add pub mod bitnet export
│ ├── quantization.rs # MODIFIED: Add BITNET_T158 enum variant
│ ├── parser.rs # Unchanged (reused as-is)
│ └── ...
└── kernels/
└── matmul.rs # Reference for dispatch patterns
```
### Modified Files
#### `src/gguf/quantization.rs`
**Changes:**
1. Add `BITNET_T158 = 30` variant to `GgufQuantType` enum (after `Bf16 = 29`)
2. Update `try_from()` impl to handle type 30
3. Update `block_size()` to return 256 for `BITNET_T158`
4. Update `type_size()` to return 66 for `BITNET_T158` (64 bytes packed + 2 bytes FP16 scale)
5. Update `is_quantized()` to include `BITNET_T158`
6. Update `bits_per_weight()` to return 2.06 for `BITNET_T158`
7. Add new match arm in `dequantize_tensor()``BITNET_T158 => dequantize_bitnet_t158(data, output)`
**Exact enum addition:**
```rust
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum GgufQuantType {
// ... existing variants 0-29 ...
/// BitNet b1.58 ternary quantization (2-bit packed + FP16 scale per 256-element block)
BITNET_T158 = 30,
}
```
---
## B. Struct Definitions
### 1. `PtBitnetConfig` (in `bitnet/config.rs`)
**Purpose:** Configuration for PT-BitNet quantization process
```rust
/// Configuration for PT-BitNet post-training quantization
#[derive(Debug, Clone)]
pub struct PtBitnetConfig {
/// Block size for absmean scale computation (default: 256)
pub block_size: usize,
/// Epsilon for numerical stability in scale computation (default: 1e-8)
pub epsilon: f32,
/// Whether to run calibration pass to optimize scale factors
pub use_calibration: bool,
/// Number of calibration samples (if use_calibration = true)
pub calibration_samples: usize,
/// Maximum sequence length for calibration (default: 2048)
pub calibration_max_seq_len: usize,
/// Device for calibration pass ("cpu", "metal", "cuda:0")
pub calibration_device: String,
/// Clipping threshold for normalized weights before rounding
/// (default: 1.0, range typically 0.95-1.05)
pub clip_threshold: f32,
/// Sparsity target: if > 0.0, bias rounding toward zero to achieve target sparsity
pub target_sparsity: Option<f32>,
}
impl Default for PtBitnetConfig {
fn default() -> Self {
Self {
block_size: 256,
epsilon: 1e-8,
use_calibration: false,
calibration_samples: 1000,
calibration_max_seq_len: 2048,
calibration_device: "metal".to_string(),
clip_threshold: 1.0,
target_sparsity: None,
}
}
}
```
### 2. `TernaryTensor` (in `bitnet/ternary_tensor.rs`)
**Purpose:** Immutable value object for packed ternary weights
```rust
/// Packed ternary tensor with per-block FP16 scales
#[derive(Debug, Clone)]
pub struct TernaryTensor {
/// Packed 2-bit ternary values (4 weights per byte)
/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved
pub packed_data: Vec<u8>,
/// Per-block FP16 scale factors (absmean values)
pub scales: Vec<f16>,
/// Tensor shape [out_features, in_features] or [rows, cols]
pub shape: [usize; 2],
/// Block size (always 256 for BitNet b1.58)
pub block_size: usize,
/// Total number of weights
pub num_elements: usize,
/// Number of blocks
pub num_blocks: usize,
/// Measured sparsity (fraction of zero weights)
pub sparsity: f32,
}
impl TernaryTensor {
/// Calculate total storage size in bytes
pub fn storage_size(&self) -> usize;
/// Get expected packed_data size for validation
pub fn expected_packed_size(&self) -> usize;
/// Validate internal consistency
pub fn validate(&self) -> Result<()>;
}
```
### 3. `TernaryBlock` (in `bitnet/ternary_tensor.rs`)
**Purpose:** Single block of 256 ternary weights with scale
```rust
/// A single 256-element block with ternary weights and FP16 scale
#[derive(Debug, Clone)]
pub struct TernaryBlock {
/// 64 bytes of packed 2-bit values (256 weights × 2 bits ÷ 8 bits/byte)
pub packed: [u8; 64],
/// FP16 absmean scale factor
pub scale: f16,
}
impl TernaryBlock {
/// Size in bytes when stored in GGUF (64 + 2 = 66)
pub const STORAGE_SIZE: usize = 66;
/// Number of elements in a block
pub const BLOCK_SIZE: usize = 256;
}
```
### 4. `AbsmeanResult` (in `bitnet/quantizer.rs`)
**Purpose:** Result of absmean quantization on a single block
```rust
/// Result of absmean ternary quantization on a block
#[derive(Debug, Clone)]
pub struct AbsmeanResult {
/// Ternary values {-1, 0, +1} for each weight in the block
pub ternary_weights: Vec<i8>,
/// Computed absmean scale factor (gamma = mean(|W|))
pub scale: f32,
/// Measured sparsity (fraction of zeros)
pub sparsity: f32,
/// Mean squared error vs original FP16 values (for calibration)
pub mse: f32,
}
```
### 5. `QuantizationStats` (in `bitnet/quantizer.rs`)
**Purpose:** Statistics collected during quantization
```rust
/// Statistics from quantizing a single tensor
#[derive(Debug, Clone)]
pub struct QuantizationStats {
/// Tensor name
pub name: String,
/// Mean of all block scales
pub mean_scale: f32,
/// Std dev of block scales
pub std_scale: f32,
/// Overall sparsity across all blocks
pub sparsity: f32,
/// Mean MSE across all blocks
pub mean_mse: f32,
/// Number of blocks
pub num_blocks: usize,
}
```
---
## C. Function Signatures
### Core Quantization Functions (in `bitnet/quantizer.rs`)
#### 1. Primary Quantization Entry Point
```rust
/// Quantize an FP16/F32 tensor to ternary format using absmean quantization
///
/// # Arguments
/// * `tensor` - Input FP16 or F32 tensor data (flat vector)
/// * `shape` - Tensor shape [out_features, in_features]
/// * `config` - Quantization configuration
///
/// # Returns
/// * `TernaryTensor` - Packed ternary representation
/// * `QuantizationStats` - Statistics about the quantization process
///
/// # Errors
/// * `RuvLLMError::Quantization` if tensor size is not divisible by block_size
/// * `RuvLLMError::Quantization` if shape product doesn't match tensor length
pub fn quantize_tensor(
tensor: &[f32],
shape: [usize; 2],
config: &PtBitnetConfig,
) -> Result<(TernaryTensor, QuantizationStats)>;
```
#### 2. Per-Block Quantization
```rust
/// Apply absmean quantization to a single block of weights
///
/// Algorithm:
/// 1. gamma = mean(|block|) + epsilon
/// 2. normalized = block / gamma
/// 3. ternary = round(clamp(normalized, -clip_threshold, +clip_threshold))
/// 4. Map to {-1, 0, +1}
///
/// # Arguments
/// * `block` - Block of FP16/F32 values (length = config.block_size)
/// * `config` - Configuration with epsilon and clip_threshold
///
/// # Returns
/// * `AbsmeanResult` with ternary values, scale, sparsity, MSE
///
/// # Panics
/// * If block.len() != config.block_size
pub fn absmean_ternary(
block: &[f32],
config: &PtBitnetConfig,
) -> AbsmeanResult;
```
#### 3. Packing Functions
```rust
/// Pack ternary {-1, 0, +1} values into 2-bit representation
///
/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved (unused)
/// 4 values packed per byte: [v3 v2 v1 v0] → byte
///
/// # Arguments
/// * `values` - Ternary values (must be {-1, 0, +1} only)
///
/// # Returns
/// * Packed bytes (length = ceil(values.len() / 4))
///
/// # Errors
/// * If any value is not in {-1, 0, +1}
pub fn pack_ternary(values: &[i8]) -> Result<Vec<u8>>;
/// Unpack 2-bit representation to ternary {-1, 0, +1} values
///
/// # Arguments
/// * `packed` - Packed 2-bit data
/// * `n` - Number of values to extract
///
/// # Returns
/// * Vector of ternary values (length = n)
pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec<i8>;
```
#### 4. Calibration (Optional)
```rust
/// Run calibration pass to optimize scale factors
///
/// # Arguments
/// * `tensor` - Input FP16 tensor
/// * `shape` - Tensor shape
/// * `config` - Config with calibration settings
/// * `calibration_data` - Sample activations for this layer
///
/// # Returns
/// * Optimized `TernaryTensor` with calibrated scales
///
/// # Note
/// This is optional - if not used, falls back to plain absmean
pub fn quantize_with_calibration(
tensor: &[f32],
shape: [usize; 2],
config: &PtBitnetConfig,
calibration_data: &[Vec<f32>],
) -> Result<(TernaryTensor, QuantizationStats)>;
```
### Dequantization Functions (in `bitnet/dequantize.rs`)
```rust
/// Dequantize BITNET_T158 tensor to FP32
///
/// # Arguments
/// * `data` - Raw GGUF tensor bytes (packed ternary + scales)
/// * `scales` - Per-block FP16 scales (extracted from data)
/// * `n` - Total number of elements to dequantize
///
/// # Returns
/// * Vec<f32> of dequantized values
///
/// # Format
/// Each block: [64 bytes packed ternary][2 bytes FP16 scale]
pub fn dequantize_bitnet_t158(
data: &[u8],
scales: &[f16],
n: usize,
) -> Vec<f32>;
/// Dequantize a single BITNET_T158 block
///
/// # Arguments
/// * `block_data` - 64 bytes of packed ternary data
/// * `scale` - FP16 scale factor
/// * `output` - Output buffer (must have capacity for 256 elements)
pub fn dequantize_bitnet_t158_block(
block_data: &[u8; 64],
scale: f16,
output: &mut [f32],
);
```
### Tensor Conversion (in `bitnet/ternary_tensor.rs`)
```rust
impl TernaryTensor {
/// Convert from packed storage to FP32 (for validation/testing)
pub fn to_fp32(&self) -> Vec<f32>;
/// Create from existing GGUF tensor data
pub fn from_gguf_data(
data: &[u8],
shape: [usize; 2],
block_size: usize,
) -> Result<Self>;
/// Serialize to GGUF tensor bytes
pub fn to_gguf_data(&self) -> Vec<u8>;
}
```
---
## D. GGUF Integration Points
### 1. New Quantization Type Variant
**File:** `crates/ruvllm/src/gguf/quantization.rs`
**Changes to `GgufQuantType` enum:**
```rust
#[repr(u32)]
pub enum GgufQuantType {
// ... existing 0-29 ...
/// BitNet b1.58 ternary quantization
/// Block size: 256 elements
/// Storage: 64 bytes packed (2-bit) + 2 bytes FP16 scale = 66 bytes/block
/// Bits per weight: 2.06 bpw
BITNET_T158 = 30,
}
impl GgufQuantType {
pub fn block_size(&self) -> usize {
match self {
// ... existing cases ...
Self::BITNET_T158 => 256,
}
}
pub fn type_size(&self) -> usize {
match self {
// ... existing cases ...
Self::BITNET_T158 => 66, // 64 + 2
}
}
pub fn name(&self) -> &'static str {
match self {
// ... existing cases ...
Self::BITNET_T158 => "BITNET_T158",
}
}
}
impl TryFrom<u32> for GgufQuantType {
fn try_from(value: u32) -> Result<Self> {
match value {
// ... existing 0-29 ...
30 => Ok(Self::BITNET_T158),
_ => Err(/* ... */),
}
}
}
```
### 2. Dequantization Dispatch
**File:** `crates/ruvllm/src/gguf/quantization.rs`
**Modification to `dequantize_tensor()` function:**
```rust
pub fn dequantize_tensor(
data: &[u8],
dtype: GgufQuantType,
num_elements: usize,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; num_elements];
match dtype {
// ... existing cases ...
GgufQuantType::BITNET_T158 => {
// Extract scales and packed data
let num_blocks = (num_elements + 255) / 256;
let mut scales = Vec::with_capacity(num_blocks);
for i in 0..num_blocks {
let block_offset = i * 66;
let scale_offset = block_offset + 64;
let scale_bytes = [data[scale_offset], data[scale_offset + 1]];
scales.push(f16::from_le_bytes(scale_bytes));
}
crate::bitnet::dequantize::dequantize_bitnet_t158(
data,
&scales,
num_elements,
);
}
_ => {
return Err(RuvLLMError::Model(format!(
"Dequantization not implemented for {:?}",
dtype
)));
}
}
Ok(output)
}
```
### 3. GGUF Metadata Keys
**New metadata keys for BitNet models** (written during quantization, read during load):
```rust
// In quantizer when exporting GGUF
pub const BITNET_METADATA_KEYS: &[(&str, &str)] = &[
("craftsman.bitnet.version", "1"),
("craftsman.bitnet.weight_encoding", "absmean_ternary"),
("craftsman.bitnet.activation_bits", "8"),
("craftsman.bitnet.block_size", "256"),
("craftsman.bitnet.kernel_hint", "tl1"), // or "tl2", "i2s"
];
```
**Metadata reading in model loader:**
```rust
// In backend when loading model
fn detect_bitnet_model(metadata: &HashMap<String, GgufValue>) -> bool {
metadata.get("craftsman.bitnet.version")
.and_then(|v| v.as_str())
.map(|v| v == "1")
.unwrap_or(false)
}
```
### 4. Tensor Info Extension
**No changes needed** - existing `TensorInfo` struct in `parser.rs` already supports:
- `name: String`
- `shape: Vec<usize>`
- `dtype: GgufQuantType` ← Will now include `BITNET_T158`
- `offset: u64`
---
## E. Error Handling Strategy
### Error Types
All errors use existing `RuvLLMError` enum from `crates/ruvllm/src/error.rs`:
```rust
pub enum RuvLLMError {
// Existing variants...
// Quantization-specific errors
Quantization(String), // Use this variant for all quantization errors
Model(String), // For GGUF format issues
Config(String), // For invalid configuration
}
```
### Error Scenarios and Handling
| Scenario | Error Type | Recovery Strategy |
|----------|-----------|-------------------|
| Tensor size not divisible by block_size | `Quantization` | Pad last block with zeros |
| Invalid ternary value during packing | `Quantization` | Fail-fast - indicates bug |
| GGUF file has wrong BITNET_T158 block size | `Model` | Fail-fast - corrupted file |
| Calibration device unavailable | `Config` | Fall back to non-calibrated quantization |
| Out of memory during quantization | System panic | Let Rust OOM handler catch |
| Shape mismatch in tensor | `Quantization` | Fail-fast - validate before processing |
| FP16 scale is NaN/Inf | `Quantization` | Clamp to epsilon value |
| Empty tensor / zero elements | `Quantization` | Skip with warning |
### Validation Functions
```rust
/// Validate quantization config
pub fn validate_config(config: &PtBitnetConfig) -> Result<()> {
if config.block_size == 0 || config.block_size % 4 != 0 {
return Err(RuvLLMError::Config(
"block_size must be non-zero and divisible by 4".into()
));
}
if config.epsilon <= 0.0 {
return Err(RuvLLMError::Config(
"epsilon must be positive".into()
));
}
if config.clip_threshold <= 0.0 || config.clip_threshold > 2.0 {
return Err(RuvLLMError::Config(
"clip_threshold must be in range (0.0, 2.0]".into()
));
}
Ok(())
}
/// Validate tensor shape and size
pub fn validate_tensor(
tensor: &[f32],
shape: [usize; 2],
block_size: usize,
) -> Result<()> {
let expected_size = shape[0] * shape[1];
if tensor.len() != expected_size {
return Err(RuvLLMError::Quantization(format!(
"Tensor length {} doesn't match shape {:?} (expected {})",
tensor.len(), shape, expected_size
)));
}
if expected_size % block_size != 0 {
// Could pad, but for simplicity require exact multiple
return Err(RuvLLMError::Quantization(format!(
"Tensor size {} is not divisible by block_size {}",
expected_size, block_size
)));
}
Ok(())
}
```
---
## F. Testing Strategy
### Unit Tests
#### 1. Absmean Quantization Correctness
**File:** `crates/ruvllm/src/bitnet/tests/quantizer_tests.rs`
```rust
#[test]
fn test_absmean_ternary_basic() {
// Test that absmean correctly quantizes known values
let config = PtBitnetConfig::default();
// Block with known mean(|x|) = 1.0
let block: Vec<f32> = vec![
2.0, -2.0, 1.0, -1.0, // gamma = mean(2,2,1,1,...) ≈ 1.0
0.5, -0.5, 0.0, 0.0,
// ... (pad to 256 elements)
];
let result = absmean_ternary(&block, &config);
// After normalization: 2.0/1.0 = 2.0 → clamp to 1.0 → round to +1
assert_eq!(result.ternary_weights[0], 1); // 2.0 → +1
assert_eq!(result.ternary_weights[1], -1); // -2.0 → -1
assert_eq!(result.ternary_weights[2], 1); // 1.0 → +1
assert_eq!(result.ternary_weights[6], 0); // 0.0 → 0
assert!(result.scale > 0.9 && result.scale < 1.1); // gamma 1.0
}
#[test]
fn test_absmean_all_zeros() {
let config = PtBitnetConfig::default();
let block = vec![0.0; 256];
let result = absmean_ternary(&block, &config);
// All zeros → scale = epsilon, all ternary = 0
assert_eq!(result.scale, config.epsilon);
assert!(result.ternary_weights.iter().all(|&x| x == 0));
assert_eq!(result.sparsity, 1.0);
}
```
#### 2. Pack/Unpack Round-Trip
```rust
#[test]
fn test_pack_unpack_roundtrip() {
let original = vec![1i8, -1, 0, 1, 0, -1, 1, 0];
let packed = pack_ternary(&original).unwrap();
assert_eq!(packed.len(), 2); // 8 values → 2 bytes
let unpacked = unpack_ternary(&packed, 8);
assert_eq!(unpacked, original);
}
#[test]
fn test_pack_invalid_value() {
let invalid = vec![1i8, 2, 0]; // 2 is not ternary
let result = pack_ternary(&invalid);
assert!(result.is_err());
}
```
#### 3. Tensor Validation
```rust
#[test]
fn test_validate_tensor_shape_mismatch() {
let tensor = vec![1.0; 100];
let shape = [10, 11]; // 10*11 = 110 ≠ 100
let result = validate_tensor(&tensor, shape, 256);
assert!(result.is_err());
}
#[test]
fn test_validate_tensor_block_alignment() {
let tensor = vec![1.0; 257]; // Not divisible by 256
let shape = [1, 257];
let result = validate_tensor(&tensor, shape, 256);
assert!(result.is_err());
}
```
### Integration Tests
#### 4. Full Quantization Pipeline
```rust
#[test]
fn test_quantize_tensor_full_pipeline() {
let config = PtBitnetConfig::default();
// Create a 512-element tensor (2 blocks)
let tensor: Vec<f32> = (0..512).map(|i| (i as f32) / 512.0).collect();
let shape = [2, 256];
let (ternary, stats) = quantize_tensor(&tensor, shape, &config).unwrap();
assert_eq!(ternary.num_blocks, 2);
assert_eq!(ternary.packed_data.len(), 2 * 64); // 2 blocks × 64 bytes
assert_eq!(ternary.scales.len(), 2);
assert_eq!(stats.num_blocks, 2);
// Verify reconstruction quality
let reconstructed = ternary.to_fp32();
assert_eq!(reconstructed.len(), 512);
}
```
#### 5. GGUF Round-Trip
```rust
#[test]
fn test_gguf_serialization_roundtrip() {
let config = PtBitnetConfig::default();
let tensor = vec![1.0; 256];
let shape = [1, 256];
let (ternary, _) = quantize_tensor(&tensor, shape, &config).unwrap();
// Serialize to GGUF format
let gguf_data = ternary.to_gguf_data();
assert_eq!(gguf_data.len(), 66); // 1 block = 66 bytes
// Deserialize
let recovered = TernaryTensor::from_gguf_data(&gguf_data, shape, 256).unwrap();
assert_eq!(recovered.packed_data, ternary.packed_data);
assert_eq!(recovered.scales, ternary.scales);
}
```
### Benchmark Tests
#### 6. Performance Regression
```rust
#[bench]
fn bench_absmean_ternary_256(b: &mut Bencher) {
let config = PtBitnetConfig::default();
let block: Vec<f32> = (0..256).map(|i| (i as f32) / 256.0).collect();
b.iter(|| {
let _ = absmean_ternary(&block, &config);
});
}
#[bench]
fn bench_pack_ternary_1024(b: &mut Bencher) {
let values = vec![1i8; 1024];
b.iter(|| {
let _ = pack_ternary(&values);
});
}
```
### Correctness Validation Tests
#### 7. Bit-Exact Validation Against Reference
```rust
#[test]
fn test_dequantize_matches_reference() {
// Reference implementation (naive)
fn reference_dequant(ternary: &[i8], scale: f32) -> Vec<f32> {
ternary.iter().map(|&t| (t as f32) * scale).collect()
}
let config = PtBitnetConfig::default();
let tensor = vec![1.5, -2.3, 0.1, -0.4]; // Extend to 256
let tensor_256 = /* pad to 256 */;
let shape = [1, 256];
let (ternary, _) = quantize_tensor(&tensor_256, shape, &config).unwrap();
// Unpack and dequantize
let unpacked = unpack_ternary(&ternary.packed_data, 256);
let reference = reference_dequant(&unpacked, ternary.scales[0].to_f32());
let optimized = ternary.to_fp32();
// Allow small floating-point error
for (r, o) in reference.iter().zip(optimized.iter()) {
assert!((r - o).abs() < 1e-5);
}
}
```
### Test Organization
```
crates/ruvllm/src/bitnet/tests/
├── quantizer_tests.rs # absmean, pack/unpack
├── tensor_tests.rs # TernaryTensor validation
├── dequantize_tests.rs # BITNET_T158 dequant
├── integration_tests.rs # Full pipeline, GGUF round-trip
└── benches.rs # Performance benchmarks
```
---
## G. Implementation Phases
### Phase 0.1: Core Data Structures (~2-3 days)
1. `bitnet/mod.rs` - module structure
2. `bitnet/config.rs` - `PtBitnetConfig`
3. `bitnet/ternary_tensor.rs` - `TernaryTensor`, `TernaryBlock`
4. Unit tests for validation
### Phase 0.2: Quantization Algorithm (~3-4 days)
1. `bitnet/quantizer.rs` - `absmean_ternary()`
2. Pack/unpack functions
3. `quantize_tensor()` main entry point
4. Unit tests for correctness
### Phase 0.3: Dequantization (~2 days)
1. `bitnet/dequantize.rs` - block and tensor dequant
2. Integration with existing `quantization.rs`
3. Round-trip tests
### Phase 0.4: GGUF Integration (~2-3 days)
1. Modify `gguf/quantization.rs` - add `BITNET_T158` enum variant
2. Add metadata keys
3. GGUF serialization/deserialization
4. Integration tests
### Phase 0.5: Validation & Benchmarks (~2 days)
1. Full pipeline integration tests
2. Performance benchmarks
3. Bit-exact validation
4. Documentation
**Total Estimated Effort:** ~13-16 days for clean, well-tested implementation
---
## H. Open Design Questions
| # | Question | Impact | Recommendation |
|---|----------|--------|----------------|
| 1 | Use `IQ1_S` (type 19) or new `BITNET_T158` (type 30)? | Compatibility | **New type 30** - cleaner separation, avoids confusion with IQ1_S's codebook format |
| 2 | Padding strategy for last block if not aligned? | Correctness | **Zero-pad** - simplest, matches BitNet spec |
| 3 | Should calibration be mandatory or optional? | Quality vs Speed | **Optional** - Phase 0 can work without it, add later if needed |
| 4 | F16 or F32 for internal scale computation? | Precision | **F32 internally, store as F16** - extra precision during compute |
| 5 | Handle NaN/Inf in input tensors? | Robustness | **Fail-fast** - corrupted weights should not be silently ignored |
| 6 | Support block sizes other than 256? | Flexibility | **No** - BitNet spec is 256, simplifies code |
| 7 | Multi-threading for per-block quantization? | Performance | **Not in Phase 0** - can add via rayon later |
| 8 | Store sparsity per-block in GGUF? | Kernel optimization | **No** - compute on-the-fly during dequant, saves space |
---
## I. Dependencies and Prerequisites
### Existing RuvLLM Components (Reused)
- `crates/ruvllm/src/error.rs` - `RuvLLMError` enum
- `crates/ruvllm/src/gguf/parser.rs` - GGUF parsing (unchanged)
- `crates/ruvllm/src/gguf/quantization.rs` - Enum + dispatch (modified)
- `half` crate - FP16 support (already in Cargo.toml)
### New External Dependencies
None - uses only existing dependencies
### Minimum Rust Version
Same as RuvLLM (likely 1.70+)
---
## J. Non-Goals (Out of Scope)
1. **Calibration implementation** - Deferred to future phase
2. **TL1/TL2 kernel implementation** - Separate ADR/DDD
3. **Model loader integration** - Separate backend implementation
4. **Performance optimization** - Phase 0 is correctness-first
5. **WASM support** - Desktop/server only for Phase 0
6. **Dynamic quantization** - Only post-training static
7. **Mixed-precision strategies** - All-or-nothing ternary for Phase 0
---
## K. Success Criteria
**This design is complete when:**
1. All struct definitions have complete field specifications
2. All function signatures are documented with arguments, returns, errors
3. Module organization is clear and follows Rust conventions
4. GGUF integration points are precisely specified
5. Error handling covers all failure modes
6. Test plan covers correctness, integration, and performance
7. Implementation phases are realistic and sequenced
8. Open questions are documented with recommendations
**Implementation is successful when:**
1. All unit tests pass
2. Round-trip GGUF serialization is bit-exact
3. Dequantization produces correct FP32 output
4. Integration with existing GGUF pipeline works
5. Quantization of GLM-4.7-Flash completes without errors
6. Exported GGUF file is loadable by model loader
---
## Appendix A: Code Size Estimates
| File | Estimated Lines | Complexity |
|------|----------------|------------|
| `bitnet/mod.rs` | ~50 | Low |
| `bitnet/config.rs` | ~80 | Low |
| `bitnet/ternary_tensor.rs` | ~200 | Medium |
| `bitnet/quantizer.rs` | ~350 | High |
| `bitnet/dequantize.rs` | ~150 | Medium |
| `gguf/quantization.rs` (changes) | ~100 | Low |
| Tests | ~800 | Medium |
| **Total** | **~1,730 lines** | |
**Comparison to ADR-018 estimate:** ~200-300 lines core quantizer → Actual ~350 lines (reasonable given struct overhead)
---
## Appendix B: Memory Layout Examples
### TernaryBlock Storage (66 bytes)
```
Byte Offset | Content
------------|--------
0-63 | Packed 2-bit ternary (256 values)
64-65 | FP16 scale (little-endian)
```
### 2-Bit Packing Example
```
Values: [+1, -1, 0, +1]
Encoding: [10, 00, 01, 10]
Packed byte: 10_00_01_10 = 0x86
```
### GGUF Tensor Data Layout
```
[TensorInfo] (in header)
name: "model.layers.0.mlp.gate_proj.weight"
shape: [4096, 11008]
dtype: BITNET_T158 (30)
offset: 0x1000
[Tensor Data] (at offset 0x1000)
Block 0: [64 bytes packed][2 bytes scale]
Block 1: [64 bytes packed][2 bytes scale]
...
Block N: [64 bytes packed][2 bytes scale]
```
---
**End of Design Document**