mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-30 20:43:38 +00:00
feat: Implement Phase 0 PT-BitNet quantizer module
Add bitnet/ module with absmean ternary quantizer, TernaryTensor type, BITNET_T158 dequantization, and comprehensive test suite (~1600 lines). Components: - quantizer.rs: PtBitnetConfig, absmean_ternary(), quantize_tensor() - ternary_tensor.rs: TernaryTensor, pack/unpack 2-bit ternary encoding - dequantize.rs: dequantize_bitnet_t158(), block dequant, error metrics - tests.rs: Packing roundtrips, quantization correctness, edge cases - gguf/quantization.rs: BitnetT158 = 30 enum variant, block_size, bytes Implements AD-1 (weight representation), AD-5 (GGUF extension), AD-18 (PT-BitNet quantization) from ADR-017. https://claude.ai/code/session_011nTcGcn49b8YKJRVoh4TaK
This commit is contained in:
parent
80171fb8c1
commit
2bb04a64ed
8 changed files with 2619 additions and 0 deletions
269
crates/ruvllm/src/bitnet/dequantize.rs
Normal file
269
crates/ruvllm/src/bitnet/dequantize.rs
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
//! BitNet Ternary Dequantization
|
||||
//!
|
||||
//! Converts packed 2-bit ternary weights back to FP32 for validation and testing.
|
||||
|
||||
use super::ternary_tensor::unpack_ternary;
|
||||
|
||||
/// Dequantize BITNET_T158 packed ternary data to FP32.
|
||||
///
|
||||
/// This function unpacks 2-bit ternary values and applies per-block scale factors
|
||||
/// to reconstruct approximate FP32 weights. Used for validation and testing, not
|
||||
/// for production inference (which should use native ternary kernels).
|
||||
///
|
||||
/// # Data Layout
|
||||
///
|
||||
/// The input data is organized as:
|
||||
/// ```text
|
||||
/// [packed_block_0][scale_0][packed_block_1][scale_1]...
|
||||
/// ```
|
||||
///
|
||||
/// Where each block contains:
|
||||
/// - 64 bytes of packed 2-bit ternary data (256 values)
|
||||
/// - 2 bytes of FP16 scale factor
|
||||
///
|
||||
/// Total: 66 bytes per 256-element block
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `packed` - Raw GGUF tensor data with interleaved ternary and scales
|
||||
/// * `scales` - Per-block FP32 scale factors
|
||||
/// * `num_elements` - Total number of output elements
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of FP32 weights approximating the original quantized tensor
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::dequantize_bitnet_t158;
|
||||
///
|
||||
/// // Load from GGUF
|
||||
/// let packed_data = gguf_tensor.data; // Raw bytes
|
||||
/// let scales = vec![0.542, 0.381, ...]; // One per block
|
||||
/// let num_elements = 512;
|
||||
///
|
||||
/// let fp32_weights = dequantize_bitnet_t158(&packed_data, &scales, num_elements);
|
||||
/// ```
|
||||
pub fn dequantize_bitnet_t158(packed: &[u8], scales: &[f32], num_elements: usize) -> Vec<f32> {
|
||||
// Unpack ternary values
|
||||
let ternary = unpack_ternary(packed, num_elements);
|
||||
|
||||
// Apply per-block scales
|
||||
let block_size = 256; // Standard BitNet block size
|
||||
let mut output = Vec::with_capacity(num_elements);
|
||||
|
||||
for (block_idx, chunk) in ternary.chunks(block_size).enumerate() {
|
||||
let scale = scales.get(block_idx).copied().unwrap_or(1.0);
|
||||
|
||||
for &ternary_val in chunk {
|
||||
let fp32_val = (ternary_val as f32) * scale;
|
||||
output.push(fp32_val);
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Dequantize a single BITNET_T158 block.
|
||||
///
|
||||
/// Helper function for block-wise dequantization in streaming scenarios.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `packed_block` - 64 bytes of packed 2-bit ternary data
|
||||
/// * `scale` - FP32 scale factor for this block
|
||||
/// * `output` - Output buffer (must have capacity for 256 FP32 values)
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if output buffer is smaller than 256 elements.
|
||||
pub fn dequantize_bitnet_block(packed_block: &[u8], scale: f32, output: &mut [f32]) {
|
||||
assert!(
|
||||
output.len() >= 256,
|
||||
"Output buffer must hold at least 256 elements"
|
||||
);
|
||||
assert_eq!(
|
||||
packed_block.len(),
|
||||
64,
|
||||
"Packed block must be exactly 64 bytes"
|
||||
);
|
||||
|
||||
let ternary = unpack_ternary(packed_block, 256);
|
||||
|
||||
for (i, &ternary_val) in ternary.iter().enumerate() {
|
||||
output[i] = (ternary_val as f32) * scale;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute dequantization error metrics.
|
||||
///
|
||||
/// Compares dequantized weights against original FP32 weights to measure
|
||||
/// quantization quality.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `original` - Original FP32 weights
|
||||
/// * `dequantized` - Dequantized weights from ternary
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Tuple of (mean_absolute_error, mean_squared_error, max_error)
|
||||
pub fn compute_dequant_error(original: &[f32], dequantized: &[f32]) -> (f32, f32, f32) {
|
||||
assert_eq!(
|
||||
original.len(),
|
||||
dequantized.len(),
|
||||
"Arrays must have same length"
|
||||
);
|
||||
|
||||
let mut sum_abs_error = 0.0;
|
||||
let mut sum_sq_error = 0.0;
|
||||
let mut max_error = 0.0;
|
||||
|
||||
for (orig, dequant) in original.iter().zip(dequantized.iter()) {
|
||||
let error = (orig - dequant).abs();
|
||||
sum_abs_error += error;
|
||||
sum_sq_error += error * error;
|
||||
max_error = max_error.max(error);
|
||||
}
|
||||
|
||||
let n = original.len() as f32;
|
||||
let mae = sum_abs_error / n;
|
||||
let mse = sum_sq_error / n;
|
||||
|
||||
(mae, mse, max_error)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::bitnet::{absmean_ternary, pack_ternary};
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_bitnet_t158_simple() {
|
||||
// Create simple ternary data
|
||||
let ternary = vec![-1i8, 0, 1, -1, 1, 0, 0, 1];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let scales = vec![0.5f32];
|
||||
|
||||
let result = dequantize_bitnet_t158(&packed, &scales, 8);
|
||||
|
||||
assert_eq!(result.len(), 8);
|
||||
|
||||
// Check values: ternary * scale
|
||||
assert_eq!(result[0], -0.5); // -1 * 0.5
|
||||
assert_eq!(result[1], 0.0); // 0 * 0.5
|
||||
assert_eq!(result[2], 0.5); // 1 * 0.5
|
||||
assert_eq!(result[3], -0.5); // -1 * 0.5
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_bitnet_block() {
|
||||
// Create a full 256-element block
|
||||
let ternary = vec![1i8; 256];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let scale = 2.0;
|
||||
|
||||
let mut output = vec![0.0f32; 256];
|
||||
dequantize_bitnet_block(&packed, scale, &mut output);
|
||||
|
||||
// All values should be 1 * 2.0 = 2.0
|
||||
assert!(output.iter().all(|&v| (v - 2.0).abs() < 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_multiple_blocks() {
|
||||
// Two blocks with different scales
|
||||
let ternary1 = vec![1i8; 256];
|
||||
let ternary2 = vec![-1i8; 256];
|
||||
|
||||
let mut all_ternary = ternary1.clone();
|
||||
all_ternary.extend_from_slice(&ternary2);
|
||||
|
||||
let packed = pack_ternary(&all_ternary);
|
||||
let scales = vec![1.0, 2.0];
|
||||
|
||||
let result = dequantize_bitnet_t158(&packed, &scales, 512);
|
||||
|
||||
// First 256 should be 1.0 * 1.0 = 1.0
|
||||
assert!(result[..256].iter().all(|&v| (v - 1.0).abs() < 1e-6));
|
||||
|
||||
// Next 256 should be -1.0 * 2.0 = -2.0
|
||||
assert!(result[256..512]
|
||||
.iter()
|
||||
.all(|&v| (v - (-2.0)).abs() < 1e-6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_quantize_dequantize() {
|
||||
// Original weights
|
||||
let original = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4, 0.2, -0.6];
|
||||
|
||||
// Quantize
|
||||
let (ternary, scale) = absmean_ternary(&original);
|
||||
let packed = pack_ternary(&ternary);
|
||||
|
||||
// Dequantize
|
||||
let dequantized = dequantize_bitnet_t158(&packed, &[scale], original.len());
|
||||
|
||||
// Check that we got 8 values back
|
||||
assert_eq!(dequantized.len(), 8);
|
||||
|
||||
// Values should be approximate (quantization loses precision)
|
||||
// But should be close for values near the scale
|
||||
for (orig, dequant) in original.iter().zip(dequantized.iter()) {
|
||||
let error = (orig - dequant).abs();
|
||||
// Error should be bounded by the quantization step (~scale)
|
||||
assert!(error < scale * 2.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_dequant_error() {
|
||||
let original = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let dequantized = vec![1.1, 1.9, 3.2, 3.8];
|
||||
|
||||
let (mae, mse, max_error) = compute_dequant_error(&original, &dequantized);
|
||||
|
||||
// MAE should be (0.1 + 0.1 + 0.2 + 0.2) / 4 = 0.15
|
||||
assert!((mae - 0.15).abs() < 1e-6);
|
||||
|
||||
// MSE should be (0.01 + 0.01 + 0.04 + 0.04) / 4 = 0.025
|
||||
assert!((mse - 0.025).abs() < 1e-6);
|
||||
|
||||
// Max error should be 0.2
|
||||
assert!((max_error - 0.2).abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Output buffer must hold at least 256 elements")]
|
||||
fn test_dequantize_block_small_buffer() {
|
||||
let packed = vec![0u8; 64];
|
||||
let mut output = vec![0.0f32; 128]; // Too small
|
||||
dequantize_bitnet_block(&packed, 1.0, &mut output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Packed block must be exactly 64 bytes")]
|
||||
fn test_dequantize_block_wrong_size() {
|
||||
let packed = vec![0u8; 32]; // Wrong size
|
||||
let mut output = vec![0.0f32; 256];
|
||||
dequantize_bitnet_block(&packed, 1.0, &mut output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_with_missing_scales() {
|
||||
// More elements than scales (should use default 1.0)
|
||||
let ternary = vec![1i8; 512];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let scales = vec![2.0]; // Only one scale for two blocks
|
||||
|
||||
let result = dequantize_bitnet_t158(&packed, &scales, 512);
|
||||
|
||||
// First 256 use scale 2.0
|
||||
assert!(result[..256].iter().all(|&v| (v - 2.0).abs() < 1e-6));
|
||||
|
||||
// Next 256 use default 1.0
|
||||
assert!(result[256..512].iter().all(|&v| (v - 1.0).abs() < 1e-6));
|
||||
}
|
||||
}
|
||||
58
crates/ruvllm/src/bitnet/mod.rs
Normal file
58
crates/ruvllm/src/bitnet/mod.rs
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
//! BitNet b1.58 Ternary Quantization for RuvLLM
|
||||
//!
|
||||
//! This module implements Microsoft Research's BitNet b1.58 ternary weight quantization
|
||||
//! for the Craftsman Ultra 30b 1bit model. It provides post-training quantization (PTQ)
|
||||
//! of FP16 weights to ternary {-1, 0, +1} using absmean quantization.
|
||||
//!
|
||||
//! ## Overview
|
||||
//!
|
||||
//! BitNet b1.58 enables multiplication-free inference by quantizing weights to three values:
|
||||
//! -1, 0, +1. This reduces memory footprint to ~2 bits per weight and eliminates floating-point
|
||||
//! multiplication in matrix operations.
|
||||
//!
|
||||
//! ## Key Components
|
||||
//!
|
||||
//! - [`TernaryTensor`]: Container for ternary weights with 2-bit packing
|
||||
//! - [`quantize_tensor`]: Convert FP32 weights to ternary using absmean algorithm
|
||||
//! - [`dequantize_bitnet_t158`]: Convert packed ternary back to FP32 for validation
|
||||
//! - [`PtBitnetConfig`]: Configuration for post-training quantization
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvllm::bitnet::{quantize_tensor, PtBitnetConfig};
|
||||
//!
|
||||
//! // Configure quantization
|
||||
//! let config = PtBitnetConfig {
|
||||
//! block_size: 256,
|
||||
//! optimize_scales: true,
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! // Quantize a weight tensor
|
||||
//! let fp32_weights = vec![0.5, -0.3, 0.0, 0.8, /* ... */];
|
||||
//! let ternary = quantize_tensor(&fp32_weights, (128, 256), &config)?;
|
||||
//!
|
||||
//! println!("Sparsity: {:.2}%", ternary.sparsity() * 100.0);
|
||||
//! println!("Memory: {} bytes", ternary.memory_bytes());
|
||||
//! ```
|
||||
//!
|
||||
//! ## Architecture Details
|
||||
//!
|
||||
//! From ADR-017 (AD-1, AD-5, AD-18):
|
||||
//!
|
||||
//! - **Absmean quantization**: `W_ternary = RoundClip(W / (mean(|W|) + ε), -1, 1)`
|
||||
//! - **2-bit packing**: 00=-1, 01=0, 10=+1 (4 values per byte)
|
||||
//! - **Block size**: 256 elements per scale factor
|
||||
//! - **Storage**: 66 bytes per block (64 bytes ternary + 2 bytes FP16 scale)
|
||||
//! - **Compression**: 2.06 bits/weight (30B model → ~7.7 GB)
|
||||
|
||||
pub mod dequantize;
|
||||
pub mod quantizer;
|
||||
pub mod ternary_tensor;
|
||||
|
||||
pub use dequantize::dequantize_bitnet_t158;
|
||||
pub use quantizer::{
|
||||
absmean_ternary, quantize_tensor, LayerMask, Precision, PtBitnetConfig, TernaryFormat,
|
||||
};
|
||||
pub use ternary_tensor::{pack_ternary, unpack_ternary, TernaryTensor};
|
||||
338
crates/ruvllm/src/bitnet/quantizer.rs
Normal file
338
crates/ruvllm/src/bitnet/quantizer.rs
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
//! PT-BitNet Post-Training Quantization
|
||||
//!
|
||||
//! Core absmean ternary quantization algorithm for converting FP32 weights
|
||||
//! to BitNet b1.58 ternary format.
|
||||
|
||||
use crate::error::{Result, RuvLLMError};
|
||||
use super::ternary_tensor::{pack_ternary, TernaryTensor};
|
||||
|
||||
/// Configuration for PT-BitNet post-training quantization.
|
||||
///
|
||||
/// Controls the quantization process behavior, including block size,
|
||||
/// calibration, and layer selection.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::PtBitnetConfig;
|
||||
///
|
||||
/// let config = PtBitnetConfig {
|
||||
/// calibration_samples: 1000,
|
||||
/// block_size: 256,
|
||||
/// optimize_scales: true,
|
||||
/// layers_to_quantize: LayerMask::ExpertsOnly,
|
||||
/// export_format: TernaryFormat::BitnetT158,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PtBitnetConfig {
|
||||
/// Number of calibration samples for scale optimization
|
||||
pub calibration_samples: usize,
|
||||
/// Elements per quantization block
|
||||
pub block_size: usize,
|
||||
/// Enable scale factor optimization via calibration
|
||||
pub optimize_scales: bool,
|
||||
/// Which layers to quantize
|
||||
pub layers_to_quantize: LayerMask,
|
||||
/// Export format for GGUF serialization
|
||||
pub export_format: TernaryFormat,
|
||||
/// Precision for router and shared layers
|
||||
pub router_precision: Precision,
|
||||
/// Use memory-mapped I/O for weight loading
|
||||
pub use_mmap: bool,
|
||||
/// Use Metal GPU for calibration (Mac Studio only)
|
||||
pub use_metal_calibration: bool,
|
||||
/// Maximum memory budget in GB
|
||||
pub max_memory_gb: usize,
|
||||
}
|
||||
|
||||
impl Default for PtBitnetConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
calibration_samples: 1000,
|
||||
block_size: 256,
|
||||
optimize_scales: true,
|
||||
layers_to_quantize: LayerMask::ExpertsOnly,
|
||||
export_format: TernaryFormat::BitnetT158,
|
||||
router_precision: Precision::FP16,
|
||||
use_mmap: true,
|
||||
use_metal_calibration: cfg!(all(target_os = "macos", feature = "metal-compute")),
|
||||
max_memory_gb: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Layer selection mask for quantization.
|
||||
///
|
||||
/// Determines which model layers to convert to ternary. Per ADR-017 (AD-2),
|
||||
/// the MoE router, embeddings, and LM head must remain in higher precision.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum LayerMask {
|
||||
/// Only MoE expert FFN layers (recommended for Phase 1)
|
||||
ExpertsOnly,
|
||||
/// All linear layers except router/embeddings/head
|
||||
All,
|
||||
/// Custom layer selection by name pattern
|
||||
Custom(Vec<String>),
|
||||
}
|
||||
|
||||
/// Ternary tensor export format.
|
||||
///
|
||||
/// Determines the GGUF quantization type used for serialization.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TernaryFormat {
|
||||
/// BitNet b1.58 native format (type 30)
|
||||
BitnetT158,
|
||||
/// IQ1_S compatible format (type 19)
|
||||
IQ1S,
|
||||
}
|
||||
|
||||
/// Precision for non-quantized layers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Precision {
|
||||
/// 16-bit floating point
|
||||
FP16,
|
||||
/// Brain floating point 16
|
||||
BF16,
|
||||
/// 32-bit floating point
|
||||
FP32,
|
||||
}
|
||||
|
||||
/// Core absmean ternary quantization algorithm.
|
||||
///
|
||||
/// Implements the BitNet b1.58 quantization formula:
|
||||
/// ```text
|
||||
/// gamma = mean(|block|) + epsilon
|
||||
/// normalized = block / gamma
|
||||
/// ternary = round(clamp(normalized, -1, 1))
|
||||
/// ```
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `block` - FP32 weight block (typically 256 elements)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Tuple of (ternary values, scale factor):
|
||||
/// - `Vec<i8>`: Ternary weights in {-1, 0, +1}
|
||||
/// - `f32`: Absmean scale factor (gamma)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::absmean_ternary;
|
||||
///
|
||||
/// let weights = vec![0.5, -0.3, 0.8, -0.1, 0.0, 0.4];
|
||||
/// let (ternary, scale) = absmean_ternary(&weights);
|
||||
///
|
||||
/// println!("Scale: {}", scale);
|
||||
/// println!("Ternary: {:?}", ternary); // e.g., [1, -1, 1, 0, 0, 1]
|
||||
/// ```
|
||||
pub fn absmean_ternary(block: &[f32]) -> (Vec<i8>, f32) {
|
||||
// Compute absmean scale: gamma = mean(|W|)
|
||||
let sum_abs: f32 = block.iter().map(|&w| w.abs()).sum();
|
||||
let gamma = (sum_abs / block.len() as f32) + 1e-8;
|
||||
|
||||
// Normalize and quantize to {-1, 0, +1}
|
||||
let ternary: Vec<i8> = block
|
||||
.iter()
|
||||
.map(|&w| {
|
||||
let normalized = w / gamma;
|
||||
let clamped = normalized.clamp(-1.0, 1.0);
|
||||
clamped.round() as i8
|
||||
})
|
||||
.collect();
|
||||
|
||||
(ternary, gamma)
|
||||
}
|
||||
|
||||
/// Quantize a full FP32 tensor to ternary representation.
|
||||
///
|
||||
/// Processes the input tensor in blocks of `config.block_size`, applying
|
||||
/// absmean quantization to each block independently.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `weights` - FP32 weight tensor (flattened)
|
||||
/// * `shape` - Tensor shape (rows, cols)
|
||||
/// * `config` - Quantization configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `TernaryTensor` with packed 2-bit data and per-block scales
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the weight dimensions are invalid.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::{quantize_tensor, PtBitnetConfig};
|
||||
///
|
||||
/// let weights = vec![0.5; 512]; // 512 FP32 weights
|
||||
/// let shape = (2, 256);
|
||||
/// let config = PtBitnetConfig::default();
|
||||
///
|
||||
/// let ternary = quantize_tensor(&weights, shape, &config)?;
|
||||
/// println!("Compressed to {} bytes", ternary.memory_bytes());
|
||||
/// ```
|
||||
pub fn quantize_tensor(
|
||||
weights: &[f32],
|
||||
shape: (usize, usize),
|
||||
config: &PtBitnetConfig,
|
||||
) -> Result<TernaryTensor> {
|
||||
let (rows, cols) = shape;
|
||||
let total_elements = rows * cols;
|
||||
|
||||
if weights.len() != total_elements {
|
||||
return Err(RuvLLMError::Model(format!(
|
||||
"Weight size mismatch: expected {} elements for shape {:?}, got {}",
|
||||
total_elements,
|
||||
shape,
|
||||
weights.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let block_size = config.block_size;
|
||||
let num_blocks = (total_elements + block_size - 1) / block_size;
|
||||
|
||||
let mut all_ternary = Vec::with_capacity(total_elements);
|
||||
let mut scales = Vec::with_capacity(num_blocks);
|
||||
|
||||
// Process each block
|
||||
for block_idx in 0..num_blocks {
|
||||
let start = block_idx * block_size;
|
||||
let end = (start + block_size).min(total_elements);
|
||||
let block = &weights[start..end];
|
||||
|
||||
let (ternary, scale) = absmean_ternary(block);
|
||||
all_ternary.extend_from_slice(&ternary);
|
||||
scales.push(scale);
|
||||
}
|
||||
|
||||
// Pack ternary values into 2-bit representation
|
||||
let packed_data = pack_ternary(&all_ternary);
|
||||
|
||||
Ok(TernaryTensor {
|
||||
packed_data,
|
||||
scales,
|
||||
shape,
|
||||
block_size,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_absmean_ternary_simple() {
|
||||
// Simple block with known values
|
||||
let block = vec![0.5, -0.5, 0.0, 1.0, -1.0, 0.25];
|
||||
let (ternary, scale) = absmean_ternary(&block);
|
||||
|
||||
// All values should be in {-1, 0, +1}
|
||||
assert!(ternary.iter().all(|&v| v >= -1 && v <= 1));
|
||||
|
||||
// Scale should be positive
|
||||
assert!(scale > 0.0);
|
||||
|
||||
// Check specific values
|
||||
// gamma ≈ (0.5 + 0.5 + 0.0 + 1.0 + 1.0 + 0.25) / 6 ≈ 0.542
|
||||
// 0.5 / 0.542 ≈ 0.92 → round(0.92) = 1
|
||||
// -0.5 / 0.542 ≈ -0.92 → round(-0.92) = -1
|
||||
// 0.0 / 0.542 = 0 → round(0) = 0
|
||||
assert_eq!(ternary[0], 1);
|
||||
assert_eq!(ternary[1], -1);
|
||||
assert_eq!(ternary[2], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absmean_ternary_all_zeros() {
|
||||
let block = vec![0.0; 256];
|
||||
let (ternary, scale) = absmean_ternary(&block);
|
||||
|
||||
// All should quantize to 0
|
||||
assert!(ternary.iter().all(|&v| v == 0));
|
||||
|
||||
// Scale should be epsilon (1e-8)
|
||||
assert!(scale < 1e-7 && scale > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absmean_ternary_large_values() {
|
||||
let block = vec![10.0, -10.0, 5.0, -5.0];
|
||||
let (ternary, _scale) = absmean_ternary(&block);
|
||||
|
||||
// All should saturate to ±1
|
||||
assert!(ternary[0] == 1 || ternary[0] == -1);
|
||||
assert!(ternary[1] == 1 || ternary[1] == -1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_tensor_simple() {
|
||||
let weights = vec![0.5; 512]; // 512 identical weights
|
||||
let shape = (2, 256);
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
let ternary = quantize_tensor(&weights, shape, &config).unwrap();
|
||||
|
||||
assert_eq!(ternary.shape, shape);
|
||||
assert_eq!(ternary.block_size, 256);
|
||||
assert_eq!(ternary.num_blocks(), 2); // 512 / 256 = 2 blocks
|
||||
assert_eq!(ternary.scales.len(), 2);
|
||||
|
||||
// 512 elements packed in 2 bits each = 128 bytes
|
||||
assert_eq!(ternary.packed_data.len(), 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_tensor_size_mismatch() {
|
||||
let weights = vec![0.5; 100]; // Wrong size
|
||||
let shape = (2, 256); // Expects 512
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
let result = quantize_tensor(&weights, shape, &config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_tensor_memory_savings() {
|
||||
// Quantize a 1MB FP32 tensor (256K elements)
|
||||
let weights = vec![0.5; 256 * 1024];
|
||||
let shape = (512, 512);
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
let ternary = quantize_tensor(&weights, shape, &config).unwrap();
|
||||
|
||||
let original_bytes = weights.len() * 4; // FP32
|
||||
let compressed_bytes = ternary.memory_bytes();
|
||||
|
||||
// Should be ~16x compression (32 bits → 2 bits + scale overhead)
|
||||
let compression_ratio = original_bytes as f32 / compressed_bytes as f32;
|
||||
assert!(compression_ratio > 10.0); // At least 10x compression
|
||||
assert!(compression_ratio < 20.0); // Less than 20x (due to scales)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_default() {
|
||||
let config = PtBitnetConfig::default();
|
||||
assert_eq!(config.block_size, 256);
|
||||
assert_eq!(config.calibration_samples, 1000);
|
||||
assert!(config.optimize_scales);
|
||||
assert_eq!(config.layers_to_quantize, LayerMask::ExpertsOnly);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layer_mask_variants() {
|
||||
let experts = LayerMask::ExpertsOnly;
|
||||
let all = LayerMask::All;
|
||||
let custom = LayerMask::Custom(vec!["layer.0".to_string()]);
|
||||
|
||||
assert_ne!(experts, all);
|
||||
assert_ne!(all, custom);
|
||||
assert_ne!(experts, custom);
|
||||
}
|
||||
}
|
||||
276
crates/ruvllm/src/bitnet/ternary_tensor.rs
Normal file
276
crates/ruvllm/src/bitnet/ternary_tensor.rs
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
//! Ternary Tensor Data Structure
|
||||
//!
|
||||
//! This module provides the `TernaryTensor` container for BitNet b1.58 ternary weights,
|
||||
//! along with efficient 2-bit packing/unpacking functions.
|
||||
|
||||
/// Ternary tensor with 2-bit packed representation.
|
||||
///
|
||||
/// Stores ternary weights {-1, 0, +1} in a compact 2-bit format:
|
||||
/// - 00 = -1
|
||||
/// - 01 = 0
|
||||
/// - 10 = +1
|
||||
/// - 11 = reserved (unused)
|
||||
///
|
||||
/// Each block of `block_size` elements shares a single FP32 scale factor
|
||||
/// derived from the absmean quantization process.
|
||||
///
|
||||
/// # Memory Layout
|
||||
///
|
||||
/// For a tensor with shape (m, n) and block_size B:
|
||||
/// - `packed_data`: ceil(m*n / 4) bytes (4 ternary values per byte)
|
||||
/// - `scales`: ceil(m*n / B) * 4 bytes (one FP32 scale per block)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::TernaryTensor;
|
||||
///
|
||||
/// let tensor = TernaryTensor {
|
||||
/// packed_data: vec![0b10010100], // [+1, 0, +1, 0]
|
||||
/// scales: vec![0.5],
|
||||
/// shape: (2, 2),
|
||||
/// block_size: 256,
|
||||
/// };
|
||||
///
|
||||
/// println!("Sparsity: {:.2}%", tensor.sparsity() * 100.0);
|
||||
/// println!("Memory: {} bytes", tensor.memory_bytes());
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TernaryTensor {
|
||||
/// Packed 2-bit ternary data (4 values per byte)
|
||||
pub packed_data: Vec<u8>,
|
||||
/// Per-block scale factors (FP32)
|
||||
pub scales: Vec<f32>,
|
||||
/// Tensor shape (rows, cols)
|
||||
pub shape: (usize, usize),
|
||||
/// Elements per quantization block
|
||||
pub block_size: usize,
|
||||
}
|
||||
|
||||
impl TernaryTensor {
|
||||
/// Calculate the fraction of zero weights (sparsity).
|
||||
///
|
||||
/// Zero weights enable feature filtering and reduce computation
|
||||
/// in ternary matrix multiplication.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Fraction of weights that are exactly 0, in range [0.0, 1.0]
|
||||
pub fn sparsity(&self) -> f32 {
|
||||
let total_elements = self.shape.0 * self.shape.1;
|
||||
let unpacked = unpack_ternary(&self.packed_data, total_elements);
|
||||
|
||||
let zero_count = unpacked.iter().filter(|&&x| x == 0).count();
|
||||
zero_count as f32 / total_elements as f32
|
||||
}
|
||||
|
||||
/// Calculate total memory footprint in bytes.
|
||||
///
|
||||
/// Includes both packed ternary data and per-block scales.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Total bytes: packed_data.len() + scales.len() * 4
|
||||
pub fn memory_bytes(&self) -> usize {
|
||||
self.packed_data.len() + self.scales.len() * 4
|
||||
}
|
||||
|
||||
/// Get the number of quantization blocks.
|
||||
pub fn num_blocks(&self) -> usize {
|
||||
let total_elements = self.shape.0 * self.shape.1;
|
||||
(total_elements + self.block_size - 1) / self.block_size
|
||||
}
|
||||
}
|
||||
|
||||
/// Pack ternary values {-1, 0, +1} into 2-bit representation.
|
||||
///
|
||||
/// Encoding:
|
||||
/// - -1 → 00
|
||||
/// - 0 → 01
|
||||
/// - +1 → 10
|
||||
/// - (unused) → 11
|
||||
///
|
||||
/// Four values are packed into each byte in LSB-first order:
|
||||
/// ```text
|
||||
/// byte = [v3:v2:v1:v0]
|
||||
/// ```
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `values` - Slice of i8 values, must be in {-1, 0, +1}
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of bytes, length = ceil(values.len() / 4)
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if any value is not in {-1, 0, +1}
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::pack_ternary;
|
||||
///
|
||||
/// let values = vec![-1, 0, 1, -1];
|
||||
/// let packed = pack_ternary(&values);
|
||||
/// assert_eq!(packed.len(), 1); // 4 values in 1 byte
|
||||
/// ```
|
||||
pub fn pack_ternary(values: &[i8]) -> Vec<u8> {
|
||||
let num_bytes = (values.len() + 3) / 4;
|
||||
let mut packed = vec![0u8; num_bytes];
|
||||
|
||||
for (i, &val) in values.iter().enumerate() {
|
||||
let byte_idx = i / 4;
|
||||
let bit_offset = (i % 4) * 2;
|
||||
|
||||
let encoded: u8 = match val {
|
||||
-1 => 0b00,
|
||||
0 => 0b01,
|
||||
1 => 0b10,
|
||||
_ => panic!("Invalid ternary value: {} (must be -1, 0, or +1)", val),
|
||||
};
|
||||
|
||||
packed[byte_idx] |= encoded << bit_offset;
|
||||
}
|
||||
|
||||
packed
|
||||
}
|
||||
|
||||
/// Unpack 2-bit ternary values to i8.
|
||||
///
|
||||
/// Decoding:
|
||||
/// - 00 → -1
|
||||
/// - 01 → 0
|
||||
/// - 10 → +1
|
||||
/// - 11 → 0 (reserved, treated as zero)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `packed` - Packed 2-bit data
|
||||
/// * `n` - Number of elements to unpack
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of i8 values in {-1, 0, +1}
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use ruvllm::bitnet::{pack_ternary, unpack_ternary};
|
||||
///
|
||||
/// let original = vec![-1, 0, 1, -1];
|
||||
/// let packed = pack_ternary(&original);
|
||||
/// let unpacked = unpack_ternary(&packed, 4);
|
||||
/// assert_eq!(original, unpacked);
|
||||
/// ```
|
||||
pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec<i8> {
|
||||
let mut values = Vec::with_capacity(n);
|
||||
|
||||
for i in 0..n {
|
||||
let byte_idx = i / 4;
|
||||
let bit_offset = (i % 4) * 2;
|
||||
|
||||
if byte_idx >= packed.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let encoded = (packed[byte_idx] >> bit_offset) & 0b11;
|
||||
|
||||
let val = match encoded {
|
||||
0b00 => -1,
|
||||
0b01 => 0,
|
||||
0b10 => 1,
|
||||
0b11 => 0, // Reserved, treat as zero
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
values.push(val);
|
||||
}
|
||||
|
||||
values
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pack_unpack_ternary() {
|
||||
let values = vec![-1, 0, 1, -1, 1, 0, 0, 1];
|
||||
let packed = pack_ternary(&values);
|
||||
let unpacked = unpack_ternary(&packed, values.len());
|
||||
assert_eq!(values, unpacked);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_ternary_single_byte() {
|
||||
// 4 values fit in 1 byte
|
||||
let values = vec![-1, 0, 1, -1];
|
||||
let packed = pack_ternary(&values);
|
||||
assert_eq!(packed.len(), 1);
|
||||
|
||||
// Manually verify encoding
|
||||
// -1=00, 0=01, 1=10, -1=00
|
||||
// byte = [00:10:01:00] = 0b00_10_01_00 = 0x08
|
||||
assert_eq!(packed[0], 0b00_10_01_00);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_ternary_partial_byte() {
|
||||
// 5 values need 2 bytes
|
||||
let values = vec![-1, 0, 1, -1, 1];
|
||||
let packed = pack_ternary(&values);
|
||||
assert_eq!(packed.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Invalid ternary value")]
|
||||
fn test_pack_invalid_value() {
|
||||
let values = vec![-1, 0, 2]; // 2 is invalid
|
||||
pack_ternary(&values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ternary_tensor_sparsity() {
|
||||
let values = vec![0, 1, 0, -1, 0, 0, 1, 0]; // 5 zeros out of 8
|
||||
let packed = pack_ternary(&values);
|
||||
|
||||
let tensor = TernaryTensor {
|
||||
packed_data: packed,
|
||||
scales: vec![1.0],
|
||||
shape: (2, 4),
|
||||
block_size: 256,
|
||||
};
|
||||
|
||||
let sparsity = tensor.sparsity();
|
||||
assert!((sparsity - 0.625).abs() < 0.001); // 5/8 = 0.625
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ternary_tensor_memory() {
|
||||
let packed = vec![0u8; 64]; // 64 bytes of packed data
|
||||
let scales = vec![0.5f32; 16]; // 16 scales * 4 bytes = 64 bytes
|
||||
|
||||
let tensor = TernaryTensor {
|
||||
packed_data: packed,
|
||||
scales,
|
||||
shape: (128, 256),
|
||||
block_size: 256,
|
||||
};
|
||||
|
||||
assert_eq!(tensor.memory_bytes(), 64 + 64); // 128 bytes total
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ternary_tensor_num_blocks() {
|
||||
let tensor = TernaryTensor {
|
||||
packed_data: vec![],
|
||||
scales: vec![],
|
||||
shape: (256, 256), // 65536 elements
|
||||
block_size: 256, // 256 elements per block
|
||||
};
|
||||
|
||||
assert_eq!(tensor.num_blocks(), 256); // 65536 / 256 = 256 blocks
|
||||
}
|
||||
}
|
||||
662
crates/ruvllm/src/bitnet/tests.rs
Normal file
662
crates/ruvllm/src/bitnet/tests.rs
Normal file
|
|
@ -0,0 +1,662 @@
|
|||
//! Comprehensive tests for PT-BitNet Phase 0 ternary quantization
|
||||
//!
|
||||
//! Test coverage based on ADR-017 (AD-1, AD-18):
|
||||
//! - Ternary packing/unpacking roundtrips
|
||||
//! - Absmean quantization correctness
|
||||
//! - Dequantization accuracy
|
||||
//! - Full tensor quantization
|
||||
//! - Edge cases and error conditions
|
||||
|
||||
use super::{
|
||||
dequantize_bitnet_t158, pack_ternary, quantize_tensor, unpack_ternary, PtBitnetConfig,
|
||||
TernaryTensor,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Test Constants
|
||||
// ============================================================================
|
||||
|
||||
const EPSILON: f32 = 1e-6;
|
||||
const BLOCK_SIZE: usize = 256;
|
||||
|
||||
// ============================================================================
|
||||
// 1. Ternary Packing Roundtrip Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pack_unpack_simple_roundtrip() {
|
||||
// Simple 4-element ternary array
|
||||
let ternary = vec![1i8, 0, -1, 1];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 4);
|
||||
|
||||
assert_eq!(ternary, unpacked, "Packing roundtrip failed for [1, 0, -1, 1]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_all_zeros() {
|
||||
let ternary = vec![0i8; 256];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 256);
|
||||
|
||||
assert_eq!(ternary, unpacked);
|
||||
assert!(unpacked.iter().all(|&x| x == 0), "All zeros should remain all zeros");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_all_ones() {
|
||||
let ternary = vec![1i8; 256];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 256);
|
||||
|
||||
assert_eq!(ternary, unpacked);
|
||||
assert!(unpacked.iter().all(|&x| x == 1), "All +1 should remain all +1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_all_neg_ones() {
|
||||
let ternary = vec![-1i8; 256];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 256);
|
||||
|
||||
assert_eq!(ternary, unpacked);
|
||||
assert!(unpacked.iter().all(|&x| x == -1), "All -1 should remain all -1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_one_block_256_elements() {
|
||||
// One full block (256 elements) with alternating pattern
|
||||
let mut ternary = Vec::with_capacity(256);
|
||||
for i in 0..256 {
|
||||
ternary.push(match i % 3 {
|
||||
0 => 1,
|
||||
1 => 0,
|
||||
2 => -1,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
}
|
||||
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 256);
|
||||
|
||||
assert_eq!(ternary, unpacked, "256-element block roundtrip failed");
|
||||
|
||||
// Verify storage size: 256 elements * 2 bits = 64 bytes
|
||||
assert_eq!(packed.len(), 64, "Packed size should be 64 bytes for 256 elements");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_non_aligned_size() {
|
||||
// 100 elements (not divisible by 128, the typical packing boundary)
|
||||
let mut ternary = Vec::with_capacity(100);
|
||||
for i in 0..100 {
|
||||
ternary.push(if i % 2 == 0 { 1 } else { -1 });
|
||||
}
|
||||
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 100);
|
||||
|
||||
assert_eq!(
|
||||
ternary.len(),
|
||||
unpacked.len(),
|
||||
"Unpacked length should match original"
|
||||
);
|
||||
assert_eq!(ternary, unpacked, "Non-aligned size roundtrip failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_large_tensor() {
|
||||
// Multiple blocks (1024 elements = 4 blocks)
|
||||
let ternary: Vec<i8> = (0..1024)
|
||||
.map(|i| match i % 5 {
|
||||
0 | 1 => 1,
|
||||
2 | 3 => -1,
|
||||
4 => 0,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let packed = pack_ternary(&ternary);
|
||||
let unpacked = unpack_ternary(&packed, 1024);
|
||||
|
||||
assert_eq!(ternary, unpacked, "Large tensor roundtrip failed");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 2. Absmean Quantization Correctness Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_quantize_uniform_random() {
|
||||
// Uniform random weights in [-1, 1] should produce all ternary values
|
||||
let weights = vec![0.5, -0.3, 0.1, -0.7, 0.9, -0.1, 0.0, 0.4];
|
||||
let ternary = quantize_absmean(&weights);
|
||||
|
||||
// All outputs must be in {-1, 0, +1}
|
||||
for &t in &ternary {
|
||||
assert!(
|
||||
t == -1 || t == 0 || t == 1,
|
||||
"Quantized value {} not in ternary set",
|
||||
t
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_all_zeros() {
|
||||
let weights = vec![0.0; 256];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// All ternary values should be zero
|
||||
assert!(
|
||||
ternary.iter().all(|&x| x == 0),
|
||||
"All-zero input should produce all-zero ternary"
|
||||
);
|
||||
|
||||
// Scale should be near epsilon (avoiding division by zero)
|
||||
assert!(
|
||||
scale < 1e-5,
|
||||
"Scale for all-zero weights should be near epsilon, got {}",
|
||||
scale
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_large_positive() {
|
||||
// Large positive weights should quantize to all +1
|
||||
let weights = vec![10.0; 256];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// All should be +1
|
||||
assert!(
|
||||
ternary.iter().all(|&x| x == 1),
|
||||
"Large positive weights should quantize to +1"
|
||||
);
|
||||
|
||||
// Scale should be approximately 10.0 (mean absolute value)
|
||||
assert!(
|
||||
(scale - 10.0).abs() < 0.1,
|
||||
"Scale should be ~10.0, got {}",
|
||||
scale
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_large_negative() {
|
||||
// Large negative weights should quantize to all -1
|
||||
let weights = vec![-10.0; 256];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// All should be -1
|
||||
assert!(
|
||||
ternary.iter().all(|&x| x == -1),
|
||||
"Large negative weights should quantize to -1"
|
||||
);
|
||||
|
||||
// Scale should be approximately 10.0 (mean absolute value)
|
||||
assert!(
|
||||
(scale - 10.0).abs() < 0.1,
|
||||
"Scale should be ~10.0, got {}",
|
||||
scale
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_known_example() {
|
||||
// From ADR: W_ternary = RoundClip(W / (mean(|W|) + epsilon), -1, 1)
|
||||
// Example: weights = [0.5, -0.3, 0.1, -0.7]
|
||||
// gamma = mean(|W|) = (0.5 + 0.3 + 0.1 + 0.7) / 4 = 0.4
|
||||
// normalized = [1.25, -0.75, 0.25, -1.75]
|
||||
// ternary = [1, -1, 0, -1] (after clamp and round)
|
||||
|
||||
let weights = vec![0.5, -0.3, 0.1, -0.7];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// Verify scale is approximately 0.4
|
||||
assert!(
|
||||
(scale - 0.4).abs() < 0.01,
|
||||
"Expected scale ~0.4, got {}",
|
||||
scale
|
||||
);
|
||||
|
||||
// Verify ternary values
|
||||
// 1.25 -> 1, -0.75 -> -1, 0.25 -> 0, -1.75 -> -1
|
||||
assert_eq!(ternary[0], 1, "0.5/0.4 = 1.25 should round to 1");
|
||||
assert_eq!(ternary[1], -1, "-0.3/0.4 = -0.75 should round to -1");
|
||||
assert_eq!(ternary[2], 0, "0.1/0.4 = 0.25 should round to 0");
|
||||
assert_eq!(ternary[3], -1, "-0.7/0.4 = -1.75 should clamp to -1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_scale_calculation() {
|
||||
// Verify scale = mean(|weights|)
|
||||
let weights = vec![1.0, -2.0, 3.0, -4.0];
|
||||
let (_, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
let expected_scale = (1.0 + 2.0 + 3.0 + 4.0) / 4.0; // = 2.5
|
||||
assert!(
|
||||
(scale - expected_scale).abs() < EPSILON,
|
||||
"Scale should be mean of absolute values: expected {}, got {}",
|
||||
expected_scale,
|
||||
scale
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 3. Dequantization Correctness Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_simple() {
|
||||
let ternary = vec![1i8, 0, -1];
|
||||
let scale = 2.0;
|
||||
|
||||
let dequantized = dequantize_ternary(&ternary, scale);
|
||||
|
||||
assert_eq!(dequantized.len(), 3);
|
||||
assert!((dequantized[0] - 2.0).abs() < EPSILON, "1 * 2.0 = 2.0");
|
||||
assert!((dequantized[1] - 0.0).abs() < EPSILON, "0 * 2.0 = 0.0");
|
||||
assert!((dequantized[2] - (-2.0)).abs() < EPSILON, "-1 * 2.0 = -2.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_packed_data() {
|
||||
// Pack known ternary data, then dequantize
|
||||
let ternary = vec![1i8, 0, -1, 1];
|
||||
let packed = pack_ternary(&ternary);
|
||||
let scale = 3.5;
|
||||
|
||||
let unpacked = unpack_ternary(&packed, 4);
|
||||
let dequantized = dequantize_ternary(&unpacked, scale);
|
||||
|
||||
assert_eq!(dequantized.len(), 4);
|
||||
assert!((dequantized[0] - 3.5).abs() < EPSILON);
|
||||
assert!((dequantized[1] - 0.0).abs() < EPSILON);
|
||||
assert!((dequantized[2] - (-3.5)).abs() < EPSILON);
|
||||
assert!((dequantized[3] - 3.5).abs() < EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantize_dequantize_roundtrip_mse() {
|
||||
// Quantize -> Dequantize should have bounded MSE
|
||||
let weights = vec![0.5, -0.3, 0.1, -0.7, 0.9, -0.1, 0.4, -0.5];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
let dequantized = dequantize_ternary(&ternary, scale);
|
||||
|
||||
// Compute MSE
|
||||
let mse: f32 = weights
|
||||
.iter()
|
||||
.zip(dequantized.iter())
|
||||
.map(|(&w, &d)| (w - d).powi(2))
|
||||
.sum::<f32>()
|
||||
/ weights.len() as f32;
|
||||
|
||||
// MSE should be reasonable (ternary quantization is lossy)
|
||||
// For absmean, expect MSE < 0.5 for normalized weights
|
||||
assert!(
|
||||
mse < 0.5,
|
||||
"MSE too high: {} (weights may not reconstruct well)",
|
||||
mse
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dequantize_full_block() {
|
||||
// Dequantize a full 256-element block
|
||||
let ternary: Vec<i8> = (0..256).map(|i| if i % 2 == 0 { 1 } else { -1 }).collect();
|
||||
let scale = 1.5;
|
||||
|
||||
let dequantized = dequantize_ternary(&ternary, scale);
|
||||
|
||||
assert_eq!(dequantized.len(), 256);
|
||||
for (i, &val) in dequantized.iter().enumerate() {
|
||||
let expected = if i % 2 == 0 { 1.5 } else { -1.5 };
|
||||
assert!(
|
||||
(val - expected).abs() < EPSILON,
|
||||
"Element {} incorrect: expected {}, got {}",
|
||||
i,
|
||||
expected,
|
||||
val
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 4. Full Tensor Quantization Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_tensor_quantize_256x256() {
|
||||
// 256x256 random tensor (65536 elements)
|
||||
let mut weights = Vec::with_capacity(65536);
|
||||
for i in 0..65536 {
|
||||
let val = ((i as f32) * 0.001).sin(); // Pseudo-random in [-1, 1]
|
||||
weights.push(val);
|
||||
}
|
||||
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
// Verify shape preserved
|
||||
assert_eq!(
|
||||
tensor.num_elements(),
|
||||
65536,
|
||||
"Tensor should preserve element count"
|
||||
);
|
||||
|
||||
// Verify sparsity is in valid range
|
||||
let sparsity = tensor.sparsity();
|
||||
assert!(
|
||||
sparsity >= 0.0 && sparsity <= 1.0,
|
||||
"Sparsity {} out of range [0, 1]",
|
||||
sparsity
|
||||
);
|
||||
|
||||
// For uniform random, expect ~1/3 zeros (rough heuristic)
|
||||
assert!(
|
||||
sparsity > 0.15 && sparsity < 0.5,
|
||||
"Sparsity {} seems unrealistic for uniform random input",
|
||||
sparsity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_memory_bytes() {
|
||||
let weights = vec![0.5; 256];
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
// Expected memory:
|
||||
// - Packed data: 256 elements * 2 bits / 8 = 64 bytes
|
||||
// - Scales: 1 block * 4 bytes (f32) = 4 bytes
|
||||
// Total: 68 bytes
|
||||
let expected_bytes = 64 + 4;
|
||||
|
||||
assert_eq!(
|
||||
tensor.memory_bytes(),
|
||||
expected_bytes,
|
||||
"Memory calculation incorrect"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_sparsity_calculation() {
|
||||
// Known sparsity: 50% zeros
|
||||
let weights: Vec<f32> = (0..256)
|
||||
.map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
|
||||
.collect();
|
||||
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
let sparsity = tensor.sparsity();
|
||||
|
||||
// Should be close to 0.5 (half zeros)
|
||||
assert!(
|
||||
(sparsity - 0.5).abs() < 0.1,
|
||||
"Expected sparsity ~0.5, got {}",
|
||||
sparsity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_block_alignment() {
|
||||
// 512 elements = 2 blocks of 256
|
||||
let weights = vec![1.0; 512];
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
// Should have 2 scale factors (one per block)
|
||||
assert_eq!(
|
||||
tensor.num_blocks(),
|
||||
2,
|
||||
"Expected 2 blocks for 512 elements"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_non_aligned_padding() {
|
||||
// 300 elements (256 + 44) should create 2 blocks with padding
|
||||
let weights = vec![0.5; 300];
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
// Should pad to 2 full blocks (512 elements)
|
||||
let num_blocks = (300 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
assert_eq!(
|
||||
tensor.num_blocks(),
|
||||
num_blocks,
|
||||
"Non-aligned tensor should pad to full blocks"
|
||||
);
|
||||
|
||||
// Original element count should be preserved
|
||||
assert_eq!(tensor.num_elements(), 300);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 5. TernaryTensor Properties Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_ternary_tensor_properties() {
|
||||
let weights: Vec<f32> = (0..512).map(|i| (i as f32) * 0.01).collect();
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
// Memory bytes should match calculation
|
||||
let num_blocks = (512 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
let packed_bytes = num_blocks * BLOCK_SIZE * 2 / 8; // 2 bits per element
|
||||
let scale_bytes = num_blocks * 4; // f32 scales
|
||||
let expected = packed_bytes + scale_bytes;
|
||||
|
||||
assert_eq!(tensor.memory_bytes(), expected);
|
||||
|
||||
// Sparsity should be in valid range
|
||||
assert!(tensor.sparsity() >= 0.0 && tensor.sparsity() <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ternary_tensor_uniform_random_sparsity() {
|
||||
// Uniform random should have ~1/3 sparsity
|
||||
let mut weights = Vec::with_capacity(2048);
|
||||
for i in 0..2048 {
|
||||
weights.push(((i as f32) * 1.234).sin());
|
||||
}
|
||||
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
let sparsity = tensor.sparsity();
|
||||
|
||||
// Rough heuristic: 20-45% zeros for uniform random
|
||||
assert!(
|
||||
sparsity > 0.2 && sparsity < 0.45,
|
||||
"Uniform random sparsity {} outside expected range [0.2, 0.45]",
|
||||
sparsity
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 6. Config Validation Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_config_default_values() {
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
assert_eq!(config.block_size, 256, "Default block size should be 256");
|
||||
assert!(
|
||||
config.calibration_samples > 0,
|
||||
"Calibration samples must be > 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "block_size must be > 0")]
|
||||
fn test_config_invalid_block_size() {
|
||||
let _config = PtBitnetConfig {
|
||||
block_size: 0,
|
||||
..Default::default()
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "calibration_samples must be > 0")]
|
||||
fn test_config_invalid_calibration_samples() {
|
||||
let _config = PtBitnetConfig {
|
||||
calibration_samples: 0,
|
||||
..Default::default()
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 7. Edge Case Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_empty_input() {
|
||||
let weights: Vec<f32> = vec![];
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
assert_eq!(tensor.num_elements(), 0);
|
||||
assert_eq!(tensor.num_blocks(), 0);
|
||||
assert_eq!(tensor.sparsity(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_element() {
|
||||
let weights = vec![0.5];
|
||||
let tensor = TernaryTensor::quantize(&weights, BLOCK_SIZE);
|
||||
|
||||
assert_eq!(tensor.num_elements(), 1);
|
||||
// Should create 1 block (padded)
|
||||
assert_eq!(tensor.num_blocks(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_very_large_values() {
|
||||
let weights = vec![f32::MAX, f32::MAX, f32::MAX, f32::MAX];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// Should all quantize to +1
|
||||
assert!(ternary.iter().all(|&x| x == 1), "f32::MAX should quantize to +1");
|
||||
|
||||
// Scale should be approximately f32::MAX
|
||||
assert!(scale > 1e30, "Scale should be very large");
|
||||
|
||||
// Dequantization should not produce NaN
|
||||
let dequantized = dequantize_ternary(&ternary, scale);
|
||||
assert!(
|
||||
dequantized.iter().all(|&x| !x.is_nan()),
|
||||
"Dequantization should not produce NaN"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subnormal_floats() {
|
||||
// Very small positive values (subnormal range)
|
||||
let weights = vec![1e-40, -1e-40, 1e-39, -1e-39];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// Should quantize reasonably (may be all zeros or small values)
|
||||
assert!(ternary.iter().all(|&x| x >= -1 && x <= 1));
|
||||
|
||||
// Scale should be tiny but not zero
|
||||
assert!(scale > 0.0, "Scale should be > 0 even for subnormal inputs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nan_handling() {
|
||||
// NaN should not crash, but behavior is implementation-defined
|
||||
let weights = vec![f32::NAN, 1.0, -1.0, 0.0];
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
quantize_absmean_with_scale(&weights)
|
||||
});
|
||||
|
||||
// Should either panic or handle gracefully
|
||||
// At minimum, should not produce infinite loop or segfault
|
||||
if let Ok((ternary, scale)) = result {
|
||||
// If it succeeds, output should not contain NaN
|
||||
assert!(
|
||||
!scale.is_nan() || scale == 0.0,
|
||||
"Scale should not be NaN unless handled explicitly"
|
||||
);
|
||||
assert!(
|
||||
ternary.iter().all(|&x| x >= -1 && x <= 1),
|
||||
"Ternary values must be in valid range"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_infinity_handling() {
|
||||
let weights = vec![f32::INFINITY, f32::NEG_INFINITY, 1.0, -1.0];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// Infinities should quantize to ±1
|
||||
assert_eq!(ternary[0], 1, "INFINITY should quantize to +1");
|
||||
assert_eq!(ternary[1], -1, "NEG_INFINITY should quantize to -1");
|
||||
|
||||
// Scale should be finite (or handled gracefully)
|
||||
// Implementation may cap scale to avoid overflow
|
||||
assert!(
|
||||
scale.is_finite() || scale > 1e30,
|
||||
"Scale should be finite or very large"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_magnitudes() {
|
||||
// Mix of very large and very small values
|
||||
let weights = vec![1000.0, 0.001, -1000.0, -0.001, 0.0];
|
||||
let (ternary, scale) = quantize_absmean_with_scale(&weights);
|
||||
|
||||
// Should produce valid ternary values
|
||||
assert!(ternary.iter().all(|&x| x >= -1 && x <= 1));
|
||||
|
||||
// Scale should be dominated by large values
|
||||
assert!(scale > 100.0, "Scale should reflect large values");
|
||||
|
||||
// Small values should quantize to 0
|
||||
assert_eq!(
|
||||
ternary[1], 0,
|
||||
"0.001 compared to scale ~500 should be 0"
|
||||
);
|
||||
assert_eq!(ternary[3], 0, "-0.001 should be 0");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Helper to quantize weights using absmean method
|
||||
/// Returns both ternary values and scale factor
|
||||
fn quantize_absmean_with_scale(weights: &[f32]) -> (Vec<i8>, f32) {
|
||||
if weights.is_empty() {
|
||||
return (vec![], 0.0);
|
||||
}
|
||||
|
||||
// Compute absmean scale: gamma = mean(|W|) + epsilon
|
||||
let absmean: f32 = weights.iter().map(|&w| w.abs()).sum::<f32>() / weights.len() as f32;
|
||||
let scale = absmean + EPSILON;
|
||||
|
||||
// Quantize: W_ternary = RoundClip(W / scale, -1, 1)
|
||||
let ternary: Vec<i8> = weights
|
||||
.iter()
|
||||
.map(|&w| {
|
||||
let normalized = w / scale;
|
||||
// Round and clip to {-1, 0, +1}
|
||||
if normalized >= 0.5 {
|
||||
1
|
||||
} else if normalized <= -0.5 {
|
||||
-1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
(ternary, scale)
|
||||
}
|
||||
|
||||
/// Helper to quantize weights (scale not needed)
|
||||
fn quantize_absmean(weights: &[f32]) -> Vec<i8> {
|
||||
let (ternary, _scale) = quantize_absmean_with_scale(weights);
|
||||
ternary
|
||||
}
|
||||
|
||||
/// Helper to dequantize ternary values
|
||||
fn dequantize_ternary(ternary: &[i8], scale: f32) -> Vec<f32> {
|
||||
ternary.iter().map(|&t| (t as f32) * scale).collect()
|
||||
}
|
||||
|
|
@ -29,6 +29,7 @@
|
|||
//! | IQ4_NL | 4.5 | 32 | i-quant 4-bit non-linear |
|
||||
|
||||
use crate::error::{Result, RuvLLMError};
|
||||
use crate::bitnet::dequantize_bitnet_t158;
|
||||
|
||||
// ============================================================================
|
||||
// Quantization Types
|
||||
|
|
@ -100,6 +101,8 @@ pub enum GgufQuantType {
|
|||
F64 = 28,
|
||||
/// BF16 brain float
|
||||
Bf16 = 29,
|
||||
/// BitNet b1.58 ternary quantization (2-bit packed)
|
||||
BitnetT158 = 30,
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for GgufQuantType {
|
||||
|
|
@ -137,6 +140,7 @@ impl TryFrom<u32> for GgufQuantType {
|
|||
27 => Ok(Self::I64),
|
||||
28 => Ok(Self::F64),
|
||||
29 => Ok(Self::Bf16),
|
||||
30 => Ok(Self::BitnetT158),
|
||||
_ => Err(RuvLLMError::Model(format!(
|
||||
"Unknown GGUF quantization type: {}",
|
||||
value
|
||||
|
|
@ -163,6 +167,7 @@ impl GgufQuantType {
|
|||
Self::IQ1_S => 256,
|
||||
Self::IQ4_NL => 32,
|
||||
Self::IQ4_XS => 256,
|
||||
Self::BitnetT158 => 256,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -214,6 +219,8 @@ impl GgufQuantType {
|
|||
Self::IQ1_S => 50,
|
||||
Self::IQ4_NL => 18,
|
||||
Self::IQ4_XS => 136,
|
||||
// BitNet b1.58: 256 elements -> 64 bytes (2-bit packed) + 2 bytes (FP16 scale) = 66 bytes
|
||||
Self::BitnetT158 => 66,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -280,6 +287,7 @@ impl GgufQuantType {
|
|||
Self::IQ1_S => "IQ1_S",
|
||||
Self::IQ4_NL => "IQ4_NL",
|
||||
Self::IQ4_XS => "IQ4_XS",
|
||||
Self::BitnetT158 => "BITNET_T158",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -355,6 +363,14 @@ pub fn dequantize_tensor(
|
|||
GgufQuantType::Q5_K => dequantize_q5_k(data, &mut output),
|
||||
GgufQuantType::Q6_K => dequantize_q6_k(data, &mut output),
|
||||
GgufQuantType::IQ4_NL => dequantize_iq4_nl(data, &mut output),
|
||||
GgufQuantType::BitnetT158 => dequantize_bitnet_t158_wrapper(data, &mut output),
|
||||
GgufQuantType::IQ1_S => {
|
||||
return Err(RuvLLMError::Model(
|
||||
"IQ1_S dequantization requires codebook lookup tables (not yet implemented). \
|
||||
For BitNet ternary quantization, use BITNET_T158 type instead."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
return Err(RuvLLMError::Model(format!(
|
||||
"Dequantization not implemented for {:?}",
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@
|
|||
pub mod adapter_manager;
|
||||
pub mod autodetect;
|
||||
pub mod backends;
|
||||
pub mod bitnet;
|
||||
pub mod capabilities;
|
||||
pub mod claude_flow;
|
||||
pub mod context;
|
||||
|
|
|
|||
999
docs/architecture/bitnet-quantizer-module-design.md
Normal file
999
docs/architecture/bitnet-quantizer-module-design.md
Normal file
|
|
@ -0,0 +1,999 @@
|
|||
# PT-BitNet Quantizer Module Architecture Design
|
||||
|
||||
**Version:** 1.0
|
||||
**Date:** 2026-02-03
|
||||
**Status:** Design Specification
|
||||
**Relates to:** ADR-017 (AD-1, AD-5, AD-18, AD-19), DDD Section 3.4/4.2/4.3
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This document specifies the architecture for the **PT-BitNet post-training quantizer** module that converts FP16/BF16 GLM-4.7-Flash weights to BitNet b1.58 ternary {-1, 0, +1} format via absmean quantization. This is a **design-only specification** — implementation follows in Phase 0.
|
||||
|
||||
**Design Scope:**
|
||||
- Module layout and file organization
|
||||
- Complete struct definitions with field types
|
||||
- Full function signatures (no implementations)
|
||||
- GGUF integration points and format extensions
|
||||
- Error handling strategy
|
||||
- Testing approach
|
||||
|
||||
**Out of Scope:**
|
||||
- Actual implementation code
|
||||
- Performance benchmarks
|
||||
- Calibration dataset selection
|
||||
|
||||
---
|
||||
|
||||
## A. Module Layout
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
crates/ruvllm/src/
|
||||
├── bitnet/ # NEW module
|
||||
│ ├── mod.rs # Module exports and public API
|
||||
│ ├── quantizer.rs # PtBitnetQuantizer + absmean algorithm
|
||||
│ ├── ternary_tensor.rs # TernaryTensor value object
|
||||
│ ├── dequantize.rs # BITNET_T158 dequantization kernel
|
||||
│ └── config.rs # PtBitnetConfig configuration
|
||||
│
|
||||
├── gguf/
|
||||
│ ├── mod.rs # Add pub mod bitnet export
|
||||
│ ├── quantization.rs # MODIFIED: Add BITNET_T158 enum variant
|
||||
│ ├── parser.rs # Unchanged (reused as-is)
|
||||
│ └── ...
|
||||
│
|
||||
└── kernels/
|
||||
└── matmul.rs # Reference for dispatch patterns
|
||||
```
|
||||
|
||||
### Modified Files
|
||||
|
||||
#### `src/gguf/quantization.rs`
|
||||
|
||||
**Changes:**
|
||||
1. Add `BITNET_T158 = 30` variant to `GgufQuantType` enum (after `Bf16 = 29`)
|
||||
2. Update `try_from()` impl to handle type 30
|
||||
3. Update `block_size()` to return 256 for `BITNET_T158`
|
||||
4. Update `type_size()` to return 66 for `BITNET_T158` (64 bytes packed + 2 bytes FP16 scale)
|
||||
5. Update `is_quantized()` to include `BITNET_T158`
|
||||
6. Update `bits_per_weight()` to return 2.06 for `BITNET_T158`
|
||||
7. Add new match arm in `dequantize_tensor()` → `BITNET_T158 => dequantize_bitnet_t158(data, output)`
|
||||
|
||||
**Exact enum addition:**
|
||||
```rust
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[repr(u32)]
|
||||
pub enum GgufQuantType {
|
||||
// ... existing variants 0-29 ...
|
||||
/// BitNet b1.58 ternary quantization (2-bit packed + FP16 scale per 256-element block)
|
||||
BITNET_T158 = 30,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## B. Struct Definitions
|
||||
|
||||
### 1. `PtBitnetConfig` (in `bitnet/config.rs`)
|
||||
|
||||
**Purpose:** Configuration for PT-BitNet quantization process
|
||||
|
||||
```rust
|
||||
/// Configuration for PT-BitNet post-training quantization
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PtBitnetConfig {
|
||||
/// Block size for absmean scale computation (default: 256)
|
||||
pub block_size: usize,
|
||||
|
||||
/// Epsilon for numerical stability in scale computation (default: 1e-8)
|
||||
pub epsilon: f32,
|
||||
|
||||
/// Whether to run calibration pass to optimize scale factors
|
||||
pub use_calibration: bool,
|
||||
|
||||
/// Number of calibration samples (if use_calibration = true)
|
||||
pub calibration_samples: usize,
|
||||
|
||||
/// Maximum sequence length for calibration (default: 2048)
|
||||
pub calibration_max_seq_len: usize,
|
||||
|
||||
/// Device for calibration pass ("cpu", "metal", "cuda:0")
|
||||
pub calibration_device: String,
|
||||
|
||||
/// Clipping threshold for normalized weights before rounding
|
||||
/// (default: 1.0, range typically 0.95-1.05)
|
||||
pub clip_threshold: f32,
|
||||
|
||||
/// Sparsity target: if > 0.0, bias rounding toward zero to achieve target sparsity
|
||||
pub target_sparsity: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for PtBitnetConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
block_size: 256,
|
||||
epsilon: 1e-8,
|
||||
use_calibration: false,
|
||||
calibration_samples: 1000,
|
||||
calibration_max_seq_len: 2048,
|
||||
calibration_device: "metal".to_string(),
|
||||
clip_threshold: 1.0,
|
||||
target_sparsity: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. `TernaryTensor` (in `bitnet/ternary_tensor.rs`)
|
||||
|
||||
**Purpose:** Immutable value object for packed ternary weights
|
||||
|
||||
```rust
|
||||
/// Packed ternary tensor with per-block FP16 scales
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TernaryTensor {
|
||||
/// Packed 2-bit ternary values (4 weights per byte)
|
||||
/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved
|
||||
pub packed_data: Vec<u8>,
|
||||
|
||||
/// Per-block FP16 scale factors (absmean values)
|
||||
pub scales: Vec<f16>,
|
||||
|
||||
/// Tensor shape [out_features, in_features] or [rows, cols]
|
||||
pub shape: [usize; 2],
|
||||
|
||||
/// Block size (always 256 for BitNet b1.58)
|
||||
pub block_size: usize,
|
||||
|
||||
/// Total number of weights
|
||||
pub num_elements: usize,
|
||||
|
||||
/// Number of blocks
|
||||
pub num_blocks: usize,
|
||||
|
||||
/// Measured sparsity (fraction of zero weights)
|
||||
pub sparsity: f32,
|
||||
}
|
||||
|
||||
impl TernaryTensor {
|
||||
/// Calculate total storage size in bytes
|
||||
pub fn storage_size(&self) -> usize;
|
||||
|
||||
/// Get expected packed_data size for validation
|
||||
pub fn expected_packed_size(&self) -> usize;
|
||||
|
||||
/// Validate internal consistency
|
||||
pub fn validate(&self) -> Result<()>;
|
||||
}
|
||||
```
|
||||
|
||||
### 3. `TernaryBlock` (in `bitnet/ternary_tensor.rs`)
|
||||
|
||||
**Purpose:** Single block of 256 ternary weights with scale
|
||||
|
||||
```rust
|
||||
/// A single 256-element block with ternary weights and FP16 scale
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TernaryBlock {
|
||||
/// 64 bytes of packed 2-bit values (256 weights × 2 bits ÷ 8 bits/byte)
|
||||
pub packed: [u8; 64],
|
||||
|
||||
/// FP16 absmean scale factor
|
||||
pub scale: f16,
|
||||
}
|
||||
|
||||
impl TernaryBlock {
|
||||
/// Size in bytes when stored in GGUF (64 + 2 = 66)
|
||||
pub const STORAGE_SIZE: usize = 66;
|
||||
|
||||
/// Number of elements in a block
|
||||
pub const BLOCK_SIZE: usize = 256;
|
||||
}
|
||||
```
|
||||
|
||||
### 4. `AbsmeanResult` (in `bitnet/quantizer.rs`)
|
||||
|
||||
**Purpose:** Result of absmean quantization on a single block
|
||||
|
||||
```rust
|
||||
/// Result of absmean ternary quantization on a block
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AbsmeanResult {
|
||||
/// Ternary values {-1, 0, +1} for each weight in the block
|
||||
pub ternary_weights: Vec<i8>,
|
||||
|
||||
/// Computed absmean scale factor (gamma = mean(|W|))
|
||||
pub scale: f32,
|
||||
|
||||
/// Measured sparsity (fraction of zeros)
|
||||
pub sparsity: f32,
|
||||
|
||||
/// Mean squared error vs original FP16 values (for calibration)
|
||||
pub mse: f32,
|
||||
}
|
||||
```
|
||||
|
||||
### 5. `QuantizationStats` (in `bitnet/quantizer.rs`)
|
||||
|
||||
**Purpose:** Statistics collected during quantization
|
||||
|
||||
```rust
|
||||
/// Statistics from quantizing a single tensor
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QuantizationStats {
|
||||
/// Tensor name
|
||||
pub name: String,
|
||||
|
||||
/// Mean of all block scales
|
||||
pub mean_scale: f32,
|
||||
|
||||
/// Std dev of block scales
|
||||
pub std_scale: f32,
|
||||
|
||||
/// Overall sparsity across all blocks
|
||||
pub sparsity: f32,
|
||||
|
||||
/// Mean MSE across all blocks
|
||||
pub mean_mse: f32,
|
||||
|
||||
/// Number of blocks
|
||||
pub num_blocks: usize,
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## C. Function Signatures
|
||||
|
||||
### Core Quantization Functions (in `bitnet/quantizer.rs`)
|
||||
|
||||
#### 1. Primary Quantization Entry Point
|
||||
|
||||
```rust
|
||||
/// Quantize an FP16/F32 tensor to ternary format using absmean quantization
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensor` - Input FP16 or F32 tensor data (flat vector)
|
||||
/// * `shape` - Tensor shape [out_features, in_features]
|
||||
/// * `config` - Quantization configuration
|
||||
///
|
||||
/// # Returns
|
||||
/// * `TernaryTensor` - Packed ternary representation
|
||||
/// * `QuantizationStats` - Statistics about the quantization process
|
||||
///
|
||||
/// # Errors
|
||||
/// * `RuvLLMError::Quantization` if tensor size is not divisible by block_size
|
||||
/// * `RuvLLMError::Quantization` if shape product doesn't match tensor length
|
||||
pub fn quantize_tensor(
|
||||
tensor: &[f32],
|
||||
shape: [usize; 2],
|
||||
config: &PtBitnetConfig,
|
||||
) -> Result<(TernaryTensor, QuantizationStats)>;
|
||||
```
|
||||
|
||||
#### 2. Per-Block Quantization
|
||||
|
||||
```rust
|
||||
/// Apply absmean quantization to a single block of weights
|
||||
///
|
||||
/// Algorithm:
|
||||
/// 1. gamma = mean(|block|) + epsilon
|
||||
/// 2. normalized = block / gamma
|
||||
/// 3. ternary = round(clamp(normalized, -clip_threshold, +clip_threshold))
|
||||
/// 4. Map to {-1, 0, +1}
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `block` - Block of FP16/F32 values (length = config.block_size)
|
||||
/// * `config` - Configuration with epsilon and clip_threshold
|
||||
///
|
||||
/// # Returns
|
||||
/// * `AbsmeanResult` with ternary values, scale, sparsity, MSE
|
||||
///
|
||||
/// # Panics
|
||||
/// * If block.len() != config.block_size
|
||||
pub fn absmean_ternary(
|
||||
block: &[f32],
|
||||
config: &PtBitnetConfig,
|
||||
) -> AbsmeanResult;
|
||||
```
|
||||
|
||||
#### 3. Packing Functions
|
||||
|
||||
```rust
|
||||
/// Pack ternary {-1, 0, +1} values into 2-bit representation
|
||||
///
|
||||
/// Encoding: 00 = -1, 01 = 0, 10 = +1, 11 = reserved (unused)
|
||||
/// 4 values packed per byte: [v3 v2 v1 v0] → byte
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `values` - Ternary values (must be {-1, 0, +1} only)
|
||||
///
|
||||
/// # Returns
|
||||
/// * Packed bytes (length = ceil(values.len() / 4))
|
||||
///
|
||||
/// # Errors
|
||||
/// * If any value is not in {-1, 0, +1}
|
||||
pub fn pack_ternary(values: &[i8]) -> Result<Vec<u8>>;
|
||||
|
||||
/// Unpack 2-bit representation to ternary {-1, 0, +1} values
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `packed` - Packed 2-bit data
|
||||
/// * `n` - Number of values to extract
|
||||
///
|
||||
/// # Returns
|
||||
/// * Vector of ternary values (length = n)
|
||||
pub fn unpack_ternary(packed: &[u8], n: usize) -> Vec<i8>;
|
||||
```
|
||||
|
||||
#### 4. Calibration (Optional)
|
||||
|
||||
```rust
|
||||
/// Run calibration pass to optimize scale factors
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tensor` - Input FP16 tensor
|
||||
/// * `shape` - Tensor shape
|
||||
/// * `config` - Config with calibration settings
|
||||
/// * `calibration_data` - Sample activations for this layer
|
||||
///
|
||||
/// # Returns
|
||||
/// * Optimized `TernaryTensor` with calibrated scales
|
||||
///
|
||||
/// # Note
|
||||
/// This is optional - if not used, falls back to plain absmean
|
||||
pub fn quantize_with_calibration(
|
||||
tensor: &[f32],
|
||||
shape: [usize; 2],
|
||||
config: &PtBitnetConfig,
|
||||
calibration_data: &[Vec<f32>],
|
||||
) -> Result<(TernaryTensor, QuantizationStats)>;
|
||||
```
|
||||
|
||||
### Dequantization Functions (in `bitnet/dequantize.rs`)
|
||||
|
||||
```rust
|
||||
/// Dequantize BITNET_T158 tensor to FP32
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - Raw GGUF tensor bytes (packed ternary + scales)
|
||||
/// * `scales` - Per-block FP16 scales (extracted from data)
|
||||
/// * `n` - Total number of elements to dequantize
|
||||
///
|
||||
/// # Returns
|
||||
/// * Vec<f32> of dequantized values
|
||||
///
|
||||
/// # Format
|
||||
/// Each block: [64 bytes packed ternary][2 bytes FP16 scale]
|
||||
pub fn dequantize_bitnet_t158(
|
||||
data: &[u8],
|
||||
scales: &[f16],
|
||||
n: usize,
|
||||
) -> Vec<f32>;
|
||||
|
||||
/// Dequantize a single BITNET_T158 block
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `block_data` - 64 bytes of packed ternary data
|
||||
/// * `scale` - FP16 scale factor
|
||||
/// * `output` - Output buffer (must have capacity for 256 elements)
|
||||
pub fn dequantize_bitnet_t158_block(
|
||||
block_data: &[u8; 64],
|
||||
scale: f16,
|
||||
output: &mut [f32],
|
||||
);
|
||||
```
|
||||
|
||||
### Tensor Conversion (in `bitnet/ternary_tensor.rs`)
|
||||
|
||||
```rust
|
||||
impl TernaryTensor {
|
||||
/// Convert from packed storage to FP32 (for validation/testing)
|
||||
pub fn to_fp32(&self) -> Vec<f32>;
|
||||
|
||||
/// Create from existing GGUF tensor data
|
||||
pub fn from_gguf_data(
|
||||
data: &[u8],
|
||||
shape: [usize; 2],
|
||||
block_size: usize,
|
||||
) -> Result<Self>;
|
||||
|
||||
/// Serialize to GGUF tensor bytes
|
||||
pub fn to_gguf_data(&self) -> Vec<u8>;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## D. GGUF Integration Points
|
||||
|
||||
### 1. New Quantization Type Variant
|
||||
|
||||
**File:** `crates/ruvllm/src/gguf/quantization.rs`
|
||||
|
||||
**Changes to `GgufQuantType` enum:**
|
||||
|
||||
```rust
|
||||
#[repr(u32)]
|
||||
pub enum GgufQuantType {
|
||||
// ... existing 0-29 ...
|
||||
|
||||
/// BitNet b1.58 ternary quantization
|
||||
/// Block size: 256 elements
|
||||
/// Storage: 64 bytes packed (2-bit) + 2 bytes FP16 scale = 66 bytes/block
|
||||
/// Bits per weight: 2.06 bpw
|
||||
BITNET_T158 = 30,
|
||||
}
|
||||
|
||||
impl GgufQuantType {
|
||||
pub fn block_size(&self) -> usize {
|
||||
match self {
|
||||
// ... existing cases ...
|
||||
Self::BITNET_T158 => 256,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn type_size(&self) -> usize {
|
||||
match self {
|
||||
// ... existing cases ...
|
||||
Self::BITNET_T158 => 66, // 64 + 2
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
// ... existing cases ...
|
||||
Self::BITNET_T158 => "BITNET_T158",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<u32> for GgufQuantType {
|
||||
fn try_from(value: u32) -> Result<Self> {
|
||||
match value {
|
||||
// ... existing 0-29 ...
|
||||
30 => Ok(Self::BITNET_T158),
|
||||
_ => Err(/* ... */),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Dequantization Dispatch
|
||||
|
||||
**File:** `crates/ruvllm/src/gguf/quantization.rs`
|
||||
|
||||
**Modification to `dequantize_tensor()` function:**
|
||||
|
||||
```rust
|
||||
pub fn dequantize_tensor(
|
||||
data: &[u8],
|
||||
dtype: GgufQuantType,
|
||||
num_elements: usize,
|
||||
) -> Result<Vec<f32>> {
|
||||
let mut output = vec![0.0f32; num_elements];
|
||||
|
||||
match dtype {
|
||||
// ... existing cases ...
|
||||
GgufQuantType::BITNET_T158 => {
|
||||
// Extract scales and packed data
|
||||
let num_blocks = (num_elements + 255) / 256;
|
||||
let mut scales = Vec::with_capacity(num_blocks);
|
||||
|
||||
for i in 0..num_blocks {
|
||||
let block_offset = i * 66;
|
||||
let scale_offset = block_offset + 64;
|
||||
let scale_bytes = [data[scale_offset], data[scale_offset + 1]];
|
||||
scales.push(f16::from_le_bytes(scale_bytes));
|
||||
}
|
||||
|
||||
crate::bitnet::dequantize::dequantize_bitnet_t158(
|
||||
data,
|
||||
&scales,
|
||||
num_elements,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
return Err(RuvLLMError::Model(format!(
|
||||
"Dequantization not implemented for {:?}",
|
||||
dtype
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
```
|
||||
|
||||
### 3. GGUF Metadata Keys
|
||||
|
||||
**New metadata keys for BitNet models** (written during quantization, read during load):
|
||||
|
||||
```rust
|
||||
// In quantizer when exporting GGUF
|
||||
pub const BITNET_METADATA_KEYS: &[(&str, &str)] = &[
|
||||
("craftsman.bitnet.version", "1"),
|
||||
("craftsman.bitnet.weight_encoding", "absmean_ternary"),
|
||||
("craftsman.bitnet.activation_bits", "8"),
|
||||
("craftsman.bitnet.block_size", "256"),
|
||||
("craftsman.bitnet.kernel_hint", "tl1"), // or "tl2", "i2s"
|
||||
];
|
||||
```
|
||||
|
||||
**Metadata reading in model loader:**
|
||||
|
||||
```rust
|
||||
// In backend when loading model
|
||||
fn detect_bitnet_model(metadata: &HashMap<String, GgufValue>) -> bool {
|
||||
metadata.get("craftsman.bitnet.version")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|v| v == "1")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Tensor Info Extension
|
||||
|
||||
**No changes needed** - existing `TensorInfo` struct in `parser.rs` already supports:
|
||||
- `name: String`
|
||||
- `shape: Vec<usize>`
|
||||
- `dtype: GgufQuantType` ← Will now include `BITNET_T158`
|
||||
- `offset: u64`
|
||||
|
||||
---
|
||||
|
||||
## E. Error Handling Strategy
|
||||
|
||||
### Error Types
|
||||
|
||||
All errors use existing `RuvLLMError` enum from `crates/ruvllm/src/error.rs`:
|
||||
|
||||
```rust
|
||||
pub enum RuvLLMError {
|
||||
// Existing variants...
|
||||
|
||||
// Quantization-specific errors
|
||||
Quantization(String), // Use this variant for all quantization errors
|
||||
Model(String), // For GGUF format issues
|
||||
Config(String), // For invalid configuration
|
||||
}
|
||||
```
|
||||
|
||||
### Error Scenarios and Handling
|
||||
|
||||
| Scenario | Error Type | Recovery Strategy |
|
||||
|----------|-----------|-------------------|
|
||||
| Tensor size not divisible by block_size | `Quantization` | Pad last block with zeros |
|
||||
| Invalid ternary value during packing | `Quantization` | Fail-fast - indicates bug |
|
||||
| GGUF file has wrong BITNET_T158 block size | `Model` | Fail-fast - corrupted file |
|
||||
| Calibration device unavailable | `Config` | Fall back to non-calibrated quantization |
|
||||
| Out of memory during quantization | System panic | Let Rust OOM handler catch |
|
||||
| Shape mismatch in tensor | `Quantization` | Fail-fast - validate before processing |
|
||||
| FP16 scale is NaN/Inf | `Quantization` | Clamp to epsilon value |
|
||||
| Empty tensor / zero elements | `Quantization` | Skip with warning |
|
||||
|
||||
### Validation Functions
|
||||
|
||||
```rust
|
||||
/// Validate quantization config
|
||||
pub fn validate_config(config: &PtBitnetConfig) -> Result<()> {
|
||||
if config.block_size == 0 || config.block_size % 4 != 0 {
|
||||
return Err(RuvLLMError::Config(
|
||||
"block_size must be non-zero and divisible by 4".into()
|
||||
));
|
||||
}
|
||||
|
||||
if config.epsilon <= 0.0 {
|
||||
return Err(RuvLLMError::Config(
|
||||
"epsilon must be positive".into()
|
||||
));
|
||||
}
|
||||
|
||||
if config.clip_threshold <= 0.0 || config.clip_threshold > 2.0 {
|
||||
return Err(RuvLLMError::Config(
|
||||
"clip_threshold must be in range (0.0, 2.0]".into()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate tensor shape and size
|
||||
pub fn validate_tensor(
|
||||
tensor: &[f32],
|
||||
shape: [usize; 2],
|
||||
block_size: usize,
|
||||
) -> Result<()> {
|
||||
let expected_size = shape[0] * shape[1];
|
||||
|
||||
if tensor.len() != expected_size {
|
||||
return Err(RuvLLMError::Quantization(format!(
|
||||
"Tensor length {} doesn't match shape {:?} (expected {})",
|
||||
tensor.len(), shape, expected_size
|
||||
)));
|
||||
}
|
||||
|
||||
if expected_size % block_size != 0 {
|
||||
// Could pad, but for simplicity require exact multiple
|
||||
return Err(RuvLLMError::Quantization(format!(
|
||||
"Tensor size {} is not divisible by block_size {}",
|
||||
expected_size, block_size
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## F. Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
|
||||
#### 1. Absmean Quantization Correctness
|
||||
|
||||
**File:** `crates/ruvllm/src/bitnet/tests/quantizer_tests.rs`
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_absmean_ternary_basic() {
|
||||
// Test that absmean correctly quantizes known values
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
// Block with known mean(|x|) = 1.0
|
||||
let block: Vec<f32> = vec![
|
||||
2.0, -2.0, 1.0, -1.0, // gamma = mean(2,2,1,1,...) ≈ 1.0
|
||||
0.5, -0.5, 0.0, 0.0,
|
||||
// ... (pad to 256 elements)
|
||||
];
|
||||
|
||||
let result = absmean_ternary(&block, &config);
|
||||
|
||||
// After normalization: 2.0/1.0 = 2.0 → clamp to 1.0 → round to +1
|
||||
assert_eq!(result.ternary_weights[0], 1); // 2.0 → +1
|
||||
assert_eq!(result.ternary_weights[1], -1); // -2.0 → -1
|
||||
assert_eq!(result.ternary_weights[2], 1); // 1.0 → +1
|
||||
assert_eq!(result.ternary_weights[6], 0); // 0.0 → 0
|
||||
|
||||
assert!(result.scale > 0.9 && result.scale < 1.1); // gamma ≈ 1.0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absmean_all_zeros() {
|
||||
let config = PtBitnetConfig::default();
|
||||
let block = vec![0.0; 256];
|
||||
|
||||
let result = absmean_ternary(&block, &config);
|
||||
|
||||
// All zeros → scale = epsilon, all ternary = 0
|
||||
assert_eq!(result.scale, config.epsilon);
|
||||
assert!(result.ternary_weights.iter().all(|&x| x == 0));
|
||||
assert_eq!(result.sparsity, 1.0);
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Pack/Unpack Round-Trip
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_pack_unpack_roundtrip() {
|
||||
let original = vec![1i8, -1, 0, 1, 0, -1, 1, 0];
|
||||
|
||||
let packed = pack_ternary(&original).unwrap();
|
||||
assert_eq!(packed.len(), 2); // 8 values → 2 bytes
|
||||
|
||||
let unpacked = unpack_ternary(&packed, 8);
|
||||
assert_eq!(unpacked, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pack_invalid_value() {
|
||||
let invalid = vec![1i8, 2, 0]; // 2 is not ternary
|
||||
|
||||
let result = pack_ternary(&invalid);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
```
|
||||
|
||||
#### 3. Tensor Validation
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_validate_tensor_shape_mismatch() {
|
||||
let tensor = vec![1.0; 100];
|
||||
let shape = [10, 11]; // 10*11 = 110 ≠ 100
|
||||
|
||||
let result = validate_tensor(&tensor, shape, 256);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tensor_block_alignment() {
|
||||
let tensor = vec![1.0; 257]; // Not divisible by 256
|
||||
let shape = [1, 257];
|
||||
|
||||
let result = validate_tensor(&tensor, shape, 256);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
```
|
||||
|
||||
### Integration Tests
|
||||
|
||||
#### 4. Full Quantization Pipeline
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_quantize_tensor_full_pipeline() {
|
||||
let config = PtBitnetConfig::default();
|
||||
|
||||
// Create a 512-element tensor (2 blocks)
|
||||
let tensor: Vec<f32> = (0..512).map(|i| (i as f32) / 512.0).collect();
|
||||
let shape = [2, 256];
|
||||
|
||||
let (ternary, stats) = quantize_tensor(&tensor, shape, &config).unwrap();
|
||||
|
||||
assert_eq!(ternary.num_blocks, 2);
|
||||
assert_eq!(ternary.packed_data.len(), 2 * 64); // 2 blocks × 64 bytes
|
||||
assert_eq!(ternary.scales.len(), 2);
|
||||
assert_eq!(stats.num_blocks, 2);
|
||||
|
||||
// Verify reconstruction quality
|
||||
let reconstructed = ternary.to_fp32();
|
||||
assert_eq!(reconstructed.len(), 512);
|
||||
}
|
||||
```
|
||||
|
||||
#### 5. GGUF Round-Trip
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_gguf_serialization_roundtrip() {
|
||||
let config = PtBitnetConfig::default();
|
||||
let tensor = vec![1.0; 256];
|
||||
let shape = [1, 256];
|
||||
|
||||
let (ternary, _) = quantize_tensor(&tensor, shape, &config).unwrap();
|
||||
|
||||
// Serialize to GGUF format
|
||||
let gguf_data = ternary.to_gguf_data();
|
||||
assert_eq!(gguf_data.len(), 66); // 1 block = 66 bytes
|
||||
|
||||
// Deserialize
|
||||
let recovered = TernaryTensor::from_gguf_data(&gguf_data, shape, 256).unwrap();
|
||||
|
||||
assert_eq!(recovered.packed_data, ternary.packed_data);
|
||||
assert_eq!(recovered.scales, ternary.scales);
|
||||
}
|
||||
```
|
||||
|
||||
### Benchmark Tests
|
||||
|
||||
#### 6. Performance Regression
|
||||
|
||||
```rust
|
||||
#[bench]
|
||||
fn bench_absmean_ternary_256(b: &mut Bencher) {
|
||||
let config = PtBitnetConfig::default();
|
||||
let block: Vec<f32> = (0..256).map(|i| (i as f32) / 256.0).collect();
|
||||
|
||||
b.iter(|| {
|
||||
let _ = absmean_ternary(&block, &config);
|
||||
});
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn bench_pack_ternary_1024(b: &mut Bencher) {
|
||||
let values = vec![1i8; 1024];
|
||||
|
||||
b.iter(|| {
|
||||
let _ = pack_ternary(&values);
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### Correctness Validation Tests
|
||||
|
||||
#### 7. Bit-Exact Validation Against Reference
|
||||
|
||||
```rust
|
||||
#[test]
|
||||
fn test_dequantize_matches_reference() {
|
||||
// Reference implementation (naive)
|
||||
fn reference_dequant(ternary: &[i8], scale: f32) -> Vec<f32> {
|
||||
ternary.iter().map(|&t| (t as f32) * scale).collect()
|
||||
}
|
||||
|
||||
let config = PtBitnetConfig::default();
|
||||
let tensor = vec![1.5, -2.3, 0.1, -0.4]; // Extend to 256
|
||||
let tensor_256 = /* pad to 256 */;
|
||||
let shape = [1, 256];
|
||||
|
||||
let (ternary, _) = quantize_tensor(&tensor_256, shape, &config).unwrap();
|
||||
|
||||
// Unpack and dequantize
|
||||
let unpacked = unpack_ternary(&ternary.packed_data, 256);
|
||||
let reference = reference_dequant(&unpacked, ternary.scales[0].to_f32());
|
||||
let optimized = ternary.to_fp32();
|
||||
|
||||
// Allow small floating-point error
|
||||
for (r, o) in reference.iter().zip(optimized.iter()) {
|
||||
assert!((r - o).abs() < 1e-5);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Test Organization
|
||||
|
||||
```
|
||||
crates/ruvllm/src/bitnet/tests/
|
||||
├── quantizer_tests.rs # absmean, pack/unpack
|
||||
├── tensor_tests.rs # TernaryTensor validation
|
||||
├── dequantize_tests.rs # BITNET_T158 dequant
|
||||
├── integration_tests.rs # Full pipeline, GGUF round-trip
|
||||
└── benches.rs # Performance benchmarks
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## G. Implementation Phases
|
||||
|
||||
### Phase 0.1: Core Data Structures (~2-3 days)
|
||||
1. `bitnet/mod.rs` - module structure
|
||||
2. `bitnet/config.rs` - `PtBitnetConfig`
|
||||
3. `bitnet/ternary_tensor.rs` - `TernaryTensor`, `TernaryBlock`
|
||||
4. Unit tests for validation
|
||||
|
||||
### Phase 0.2: Quantization Algorithm (~3-4 days)
|
||||
1. `bitnet/quantizer.rs` - `absmean_ternary()`
|
||||
2. Pack/unpack functions
|
||||
3. `quantize_tensor()` main entry point
|
||||
4. Unit tests for correctness
|
||||
|
||||
### Phase 0.3: Dequantization (~2 days)
|
||||
1. `bitnet/dequantize.rs` - block and tensor dequant
|
||||
2. Integration with existing `quantization.rs`
|
||||
3. Round-trip tests
|
||||
|
||||
### Phase 0.4: GGUF Integration (~2-3 days)
|
||||
1. Modify `gguf/quantization.rs` - add `BITNET_T158` enum variant
|
||||
2. Add metadata keys
|
||||
3. GGUF serialization/deserialization
|
||||
4. Integration tests
|
||||
|
||||
### Phase 0.5: Validation & Benchmarks (~2 days)
|
||||
1. Full pipeline integration tests
|
||||
2. Performance benchmarks
|
||||
3. Bit-exact validation
|
||||
4. Documentation
|
||||
|
||||
**Total Estimated Effort:** ~13-16 days for clean, well-tested implementation
|
||||
|
||||
---
|
||||
|
||||
## H. Open Design Questions
|
||||
|
||||
| # | Question | Impact | Recommendation |
|
||||
|---|----------|--------|----------------|
|
||||
| 1 | Use `IQ1_S` (type 19) or new `BITNET_T158` (type 30)? | Compatibility | **New type 30** - cleaner separation, avoids confusion with IQ1_S's codebook format |
|
||||
| 2 | Padding strategy for last block if not aligned? | Correctness | **Zero-pad** - simplest, matches BitNet spec |
|
||||
| 3 | Should calibration be mandatory or optional? | Quality vs Speed | **Optional** - Phase 0 can work without it, add later if needed |
|
||||
| 4 | F16 or F32 for internal scale computation? | Precision | **F32 internally, store as F16** - extra precision during compute |
|
||||
| 5 | Handle NaN/Inf in input tensors? | Robustness | **Fail-fast** - corrupted weights should not be silently ignored |
|
||||
| 6 | Support block sizes other than 256? | Flexibility | **No** - BitNet spec is 256, simplifies code |
|
||||
| 7 | Multi-threading for per-block quantization? | Performance | **Not in Phase 0** - can add via rayon later |
|
||||
| 8 | Store sparsity per-block in GGUF? | Kernel optimization | **No** - compute on-the-fly during dequant, saves space |
|
||||
|
||||
---
|
||||
|
||||
## I. Dependencies and Prerequisites
|
||||
|
||||
### Existing RuvLLM Components (Reused)
|
||||
- `crates/ruvllm/src/error.rs` - `RuvLLMError` enum
|
||||
- `crates/ruvllm/src/gguf/parser.rs` - GGUF parsing (unchanged)
|
||||
- `crates/ruvllm/src/gguf/quantization.rs` - Enum + dispatch (modified)
|
||||
- `half` crate - FP16 support (already in Cargo.toml)
|
||||
|
||||
### New External Dependencies
|
||||
None - uses only existing dependencies
|
||||
|
||||
### Minimum Rust Version
|
||||
Same as RuvLLM (likely 1.70+)
|
||||
|
||||
---
|
||||
|
||||
## J. Non-Goals (Out of Scope)
|
||||
|
||||
1. **Calibration implementation** - Deferred to future phase
|
||||
2. **TL1/TL2 kernel implementation** - Separate ADR/DDD
|
||||
3. **Model loader integration** - Separate backend implementation
|
||||
4. **Performance optimization** - Phase 0 is correctness-first
|
||||
5. **WASM support** - Desktop/server only for Phase 0
|
||||
6. **Dynamic quantization** - Only post-training static
|
||||
7. **Mixed-precision strategies** - All-or-nothing ternary for Phase 0
|
||||
|
||||
---
|
||||
|
||||
## K. Success Criteria
|
||||
|
||||
**This design is complete when:**
|
||||
|
||||
1. All struct definitions have complete field specifications
|
||||
2. All function signatures are documented with arguments, returns, errors
|
||||
3. Module organization is clear and follows Rust conventions
|
||||
4. GGUF integration points are precisely specified
|
||||
5. Error handling covers all failure modes
|
||||
6. Test plan covers correctness, integration, and performance
|
||||
7. Implementation phases are realistic and sequenced
|
||||
8. Open questions are documented with recommendations
|
||||
|
||||
**Implementation is successful when:**
|
||||
|
||||
1. All unit tests pass
|
||||
2. Round-trip GGUF serialization is bit-exact
|
||||
3. Dequantization produces correct FP32 output
|
||||
4. Integration with existing GGUF pipeline works
|
||||
5. Quantization of GLM-4.7-Flash completes without errors
|
||||
6. Exported GGUF file is loadable by model loader
|
||||
|
||||
---
|
||||
|
||||
## Appendix A: Code Size Estimates
|
||||
|
||||
| File | Estimated Lines | Complexity |
|
||||
|------|----------------|------------|
|
||||
| `bitnet/mod.rs` | ~50 | Low |
|
||||
| `bitnet/config.rs` | ~80 | Low |
|
||||
| `bitnet/ternary_tensor.rs` | ~200 | Medium |
|
||||
| `bitnet/quantizer.rs` | ~350 | High |
|
||||
| `bitnet/dequantize.rs` | ~150 | Medium |
|
||||
| `gguf/quantization.rs` (changes) | ~100 | Low |
|
||||
| Tests | ~800 | Medium |
|
||||
| **Total** | **~1,730 lines** | |
|
||||
|
||||
**Comparison to ADR-018 estimate:** ~200-300 lines core quantizer → Actual ~350 lines (reasonable given struct overhead)
|
||||
|
||||
---
|
||||
|
||||
## Appendix B: Memory Layout Examples
|
||||
|
||||
### TernaryBlock Storage (66 bytes)
|
||||
|
||||
```
|
||||
Byte Offset | Content
|
||||
------------|--------
|
||||
0-63 | Packed 2-bit ternary (256 values)
|
||||
64-65 | FP16 scale (little-endian)
|
||||
```
|
||||
|
||||
### 2-Bit Packing Example
|
||||
|
||||
```
|
||||
Values: [+1, -1, 0, +1]
|
||||
Encoding: [10, 00, 01, 10]
|
||||
Packed byte: 10_00_01_10 = 0x86
|
||||
```
|
||||
|
||||
### GGUF Tensor Data Layout
|
||||
|
||||
```
|
||||
[TensorInfo] (in header)
|
||||
name: "model.layers.0.mlp.gate_proj.weight"
|
||||
shape: [4096, 11008]
|
||||
dtype: BITNET_T158 (30)
|
||||
offset: 0x1000
|
||||
|
||||
[Tensor Data] (at offset 0x1000)
|
||||
Block 0: [64 bytes packed][2 bytes scale]
|
||||
Block 1: [64 bytes packed][2 bytes scale]
|
||||
...
|
||||
Block N: [64 bytes packed][2 bytes scale]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**End of Design Document**
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue