ruvector/docs/research/quantization-edge/02-quantization-aware-training-qat.md
rUv 3ed78842dd docs(research): add ultra-low-bit quantization & edge deployment research (#255)
* docs(research): add ultra-low-bit quantization & edge deployment research

Comprehensive research collection on 2-bit/3-bit quantization for ruvLLM:

- 01: Ultra-low-bit quantization survey (ICLR'26, QuIP, BitNet, I-quants)
- 02: Quantization-aware training (QAT) with reasoning preservation
- 03: QuIP 2-bit framework analysis (incoherence processing, E8 lattice)
- 04: MoE memory-aware routing for edge SRAM budgets
- 05: ruvLLM quantization architecture deep review and gap analysis
- 06: Rust implementation plan for 2-bit QAT pipeline (14-week roadmap)
- 07: Novel 3-int pi-constant quantization using irrational scaling

Key findings: ruvLLM has strong foundations (BitNet, K-quants, GGUF, KV cache)
but needs QAT training loop and differentiable quantization primitives.
Pi-constant scaling provides ~0.5 bit effective precision gain at 3-bit.

https://claude.ai/code/session_01E4pmfETYzknb1xq2dzCCaj

* docs(adr): add ADR-090 ultra-low-bit QAT & pi-quantization DDD architecture

Comprehensive architecture decision record for implementing 2-bit/3-bit
quantization-aware training in ruvLLM using Domain-Driven Design:

- 5 bounded contexts: Quantization Core, Training, MoE Routing, WASM Runtime, Observability
- Pi-constant quantization with irrational scaling (pi/k step sizes)
- QAT training loop with STE variants and LoRA-QAT lightweight path
- QuIP incoherence via fast Walsh-Hadamard (O(n log n))
- Memory-aware MoE routing with expert precision allocation
- WASM SIMD128 kernels reusing existing tl1_wasm.rs LUT pattern
- Security: weight integrity, GGUF validation, WASM sandbox
- Benchmarking: criterion suite with throughput/quality targets
- 14-week timeline, maps to 18 existing files for extension

Placed in docs/adr/ddd/ per DDD architectural pattern organization.

https://claude.ai/code/session_01E4pmfETYzknb1xq2dzCCaj

---------

Co-authored-by: Claude <noreply@anthropic.com>
2026-03-12 10:21:30 -04:00

524 lines
16 KiB
Markdown

# Quantization-Aware Training (QAT) for Ultra-Low-Bit LLMs
## Abstract
Quantization-aware training (QAT) is the process of training or fine-tuning a
neural network while simulating the effects of quantization in the forward pass.
Unlike post-training quantization (PTQ), QAT allows the model to adapt its weights
to compensate for quantization error, making it essential for ultra-low-bit
(2-bit) precision where PTQ alone degrades reasoning by 30-60%.
This document covers the theory, state-of-the-art implementations, and practical
considerations for integrating QAT into ruvLLM's Rust-based training pipeline.
## 1. QAT Fundamentals
### 1.1 The Core Idea
In standard training, weights are FP32:
```
Forward: y = W * x (FP32 multiply)
Backward: dW = dL/dW (FP32 gradients)
Update: W = W - lr * dW (FP32 update)
```
In QAT, we simulate quantization during training:
```
Forward: W_q = Quantize(W) (round to low-bit)
y = W_q * x (quantized multiply)
Backward: dW = dL/dW via STE (straight-through estimator)
Update: W = W - lr * dW (update LATENT FP32 weights)
```
The key insight: we maintain full-precision "latent" weights that accumulate
gradients, but the model only ever sees the quantized version during forward
passes. This lets the model learn weight configurations that are robust to
quantization.
### 1.2 Straight-Through Estimator (STE)
The quantization function `Q(w)` is a step function -- non-differentiable.
The straight-through estimator simply passes gradients through unchanged:
```
Forward: w_q = Q(w) (round to nearest quantized value)
Backward: dw = dw_q (gradient passes through Q unchanged)
```
With clipping (recommended for stability):
```
STE(w) = dw_q * 1_{|w| <= clip_val}
```
This works because even though the gradient is biased, it still points the
latent weights toward configurations that minimize loss under quantization.
### 1.3 Quantization Functions for 2-Bit
**Uniform 2-bit quantization (4 levels):**
```
Q(w) = clip(round(w / s), -2, 1) * s
where s = max(|W|) / 2 (per-channel or per-block scale)
Levels: {-2s, -s, 0, s}
```
**Non-uniform 2-bit (learned centroids):**
```
Q(w) = argmin_c ||w - c_i|| for c_i in {c_0, c_1, c_2, c_3}
Centroids are learned during training.
```
**Ternary (BitNet b1.58 style):**
```
Q(w) = RoundClip(w / (mean(|W|) + eps), -1, 1)
Levels: {-1, 0, +1} with per-block scale = mean(|W|)
```
## 2. Two-Stage Reasoning-Oriented QAT (ICLR'26)
### 2.1 Why Standard QAT Fails at 2-Bit for Reasoning
Standard QAT with language modeling loss preserves perplexity but not reasoning:
```
Perplexity (WikiText-2):
FP16: 5.47
QAT-2bit: 6.12 (+12% -- acceptable)
GSM8K accuracy:
FP16: 56.8%
QAT-2bit: 34.2% (-40% -- unacceptable)
```
The problem: language modeling loss optimizes for next-token prediction on
fluent text, but reasoning requires preserving structured multi-step computation
that occupies a small fraction of the weight space.
### 2.2 Stage 1: Mixed-Domain Calibration
**Goal**: Initialize the quantization grid so it preserves reasoning-critical
weight regions.
**Algorithm**:
```
Input: Pre-trained model M, calibration datasets D = {D_math, D_code, D_nl, D_reasoning}
Output: Per-layer quantization parameters (scales, zero-points, centroids)
For each layer L:
activations = []
for domain in D:
# Collect 1024 samples per domain
for batch in domain:
a = L.forward(batch)
activations.append(a)
# Compute per-channel Fisher information
for channel c in L.weight:
F_c = E[||d log p(y|x) / dW_c||^2]
# Sensitivity-weighted centroid initialization
weights_flat = L.weight.flatten()
importance = F_c.expand_as(weights_flat)
# K-means with importance weighting
centroids = weighted_kmeans(
weights_flat,
k=4, # 2-bit = 4 centroids
weights=importance,
n_iter=100
)
L.quant_params = QuantParams(centroids=centroids, scales=per_channel_scales)
```
**Key design decisions:**
1. **Mixed-domain data** -- math and code calibration data ensures reasoning
weight regions are properly characterized.
2. **Fisher information weighting** -- high-gradient weights get more influence
on centroid placement.
3. **Per-channel granularity** -- different channels have different distributions;
shared quantization grids lose precision.
### 2.3 Stage 2: Teacher-Guided Fine-Tuning
**Goal**: Fine-tune the quantized model using a full-precision teacher to
preserve reasoning behavior.
**Loss function:**
```
L_total = alpha * L_task + beta * L_KD + gamma * L_reasoning
where:
L_task = CrossEntropy(student_logits, targets)
L_KD = KL(softmax(student/T), softmax(teacher/T)) * T^2
L_reasoning = sum_i KL(P_s(step_i | context), P_t(step_i | context))
```
**Hyperparameters (from ICLR'26 paper):**
```
alpha = 1.0 (task loss weight)
beta = 0.5 (general distillation weight)
gamma = 2.0 (reasoning distillation weight -- intentionally high)
T = 4.0 (distillation temperature)
lr = 1e-5 (learning rate -- low for fine-tuning)
epochs = 3 (sufficient with good initialization from Stage 1)
```
**Training loop pseudocode:**
```
for epoch in 1..=3:
for batch in training_data:
# Teacher forward (no grad)
with no_grad():
teacher_logits = teacher.forward(batch)
teacher_reasoning = teacher.forward(reasoning_prompts)
# Student forward (quantized)
student.apply_quantization()
student_logits = student.forward(batch)
student_reasoning = student.forward(reasoning_prompts)
# Composite loss
loss = compute_composite_loss(
student_logits, teacher_logits,
student_reasoning, teacher_reasoning,
targets, alpha, beta, gamma, T
)
# Backward through STE
loss.backward() # STE handles quantization gradient
# Update latent weights
optimizer.step()
optimizer.zero_grad()
```
### 2.4 Calibration Dataset Composition
The ICLR'26 paper uses:
| Domain | Dataset | Samples | Purpose |
|--------|---------|---------|---------|
| Math | GSM8K train + MATH train | 15K | Arithmetic reasoning |
| Code | HumanEval + MBPP | 5K | Structured reasoning |
| Language | C4 / RedPajama subset | 20K | General fluency |
| Reasoning | ARC + HellaSwag | 10K | Common-sense reasoning |
Total: ~50K calibration samples for Stage 1, ~100K for Stage 2 fine-tuning.
## 3. Meta's LLM-QAT Framework
### 3.1 Architecture
Meta's LLM-QAT provides a reusable training loop with three quantization targets:
1. **Weight quantization**: Standard per-channel or per-group quantization
2. **Activation quantization**: Per-tensor dynamic quantization
3. **KV-cache quantization**: Unique contribution -- quantize cached keys/values
### 3.2 KV-Cache Quantization
This is particularly relevant for long-context edge inference:
```
Standard KV cache (FP16):
Memory per token = 2 * n_layers * n_kv_heads * head_dim * 2 bytes
LLaMA-7B: 2 * 32 * 32 * 128 * 2 = 524 KB per token
4K context: 2 GB just for KV cache
QAT-trained KV cache (INT4):
Memory per token = 2 * n_layers * n_kv_heads * head_dim * 0.5 bytes
LLaMA-7B: 2 * 32 * 32 * 128 * 0.5 = 131 KB per token
4K context: 512 MB for KV cache (4x reduction)
```
ruvLLM's two-tier KV cache (`kv_cache.rs`) already does FP16+Q4 tiering.
LLM-QAT shows that training with Q4 KV from the start yields better quality
than post-hoc compression.
### 3.3 Training Configuration
```python
# From LLM-QAT repository
config = QATConfig(
weight_bits=4, # or 2 for ultra-low-bit
activation_bits=8, # keep activations higher precision
kv_cache_bits=4, # or 2 for aggressive compression
weight_quant="per_channel",
act_quant="per_tensor_dynamic",
kv_quant="per_head",
use_ste=True,
clip_ratio=1.0,
num_calibration_batches=128,
)
```
## 4. ParetoQ: Multi-Objective Ultra-Low-Bit
### 4.1 Core Idea
ParetoQ treats bit-width allocation as a multi-objective optimization problem:
- **Objective 1**: Minimize model size (total bits)
- **Objective 2**: Minimize task loss (quality)
- **Objective 3**: Minimize inference latency
Different layers have different sensitivity to quantization. ParetoQ
finds Pareto-optimal configurations:
```
Layer Type Sensitivity Recommended Bits
------------------------------------------------
Embedding Low 2-3 bits
Attention Q/K High 3-4 bits
Attention V/O Medium 2-3 bits
FFN Gate/Up Medium 2-3 bits
FFN Down High 3-4 bits
LM Head Very High 4-8 bits
```
### 4.2 Mixed-Precision Assignment
ParetoQ uses reinforcement learning to search the bit-width space:
```
State: Current bit-width assignment for all layers
Action: Increase/decrease bits for a specific layer
Reward: -alpha * size_increase + beta * quality_improvement
```
This produces mixed-precision models where critical layers get more bits
while less sensitive layers are aggressively compressed.
### 4.3 Results
```
Model: LLaMA-7B, Target: 2.5 average bits
Method Avg Bits MMLU GSM8K Size
-------------------------------------------------
Uniform 2-bit 2.0 28.7 21.3 1.75 GB
Uniform 3-bit 3.0 41.5 48.2 2.63 GB
ParetoQ mixed 2.5 40.8 45.1 2.19 GB (best tradeoff)
```
## 5. Straight-Through Estimator Variants
### 5.1 Standard STE
```
Forward: q = round(w / s) * s
Backward: dw = dq (identity)
```
**Problem**: Gradient is biased -- ignores quantization error.
### 5.2 Clipped STE
```
Forward: q = clip(round(w / s), min_q, max_q) * s
Backward: dw = dq * (1 if min_q*s <= w <= max_q*s else 0)
```
**Benefit**: Prevents latent weights from drifting far outside quantization range.
### 5.3 Learned Step Size Quantization (LSQ)
```
Forward: q = round(clip(w / s, -Q_N, Q_P)) * s
Backward: ds = dq * (round(w/s) - w/s) if in range, else (-Q_N or Q_P)
dw = dq / s (if in range)
```
**Benefit**: Scale factor `s` is learned, adapting to weight distribution.
### 5.4 EWGS (Elastic Weight Gradient Scaling)
```
Backward: dw = dq * (1 + lambda * |w - q|)
```
**Benefit**: Weights far from their quantized value get stronger gradients,
pushing them toward stable quantization points.
## 6. Practical Implementation Considerations
### 6.1 Memory Requirements
QAT requires more memory than standard training:
```
Standard inference: Model weights (quantized)
PTQ: Model weights (FP16) + calibration activations
QAT: Latent weights (FP32) + quantized weights + gradients + optimizer state
Memory for QAT on 7B model:
Latent FP32 weights: 28 GB
Quantized weights: 1.75 GB (2-bit)
Gradients: 28 GB (FP32)
Adam optimizer state: 56 GB (2x FP32 for m, v)
Total: ~114 GB
```
**Mitigation strategies:**
1. **LoRA-QAT**: Only train low-rank adapters, not full weights (~1% of params)
2. **Gradient checkpointing**: Trade compute for memory
3. **Mixed-precision training**: FP16 gradients where possible
4. **Layer-wise QAT**: Quantize and fine-tune one layer at a time
### 6.2 LoRA-QAT (Recommended for ruvLLM)
Instead of full QAT, train LoRA adapters on top of quantized base weights:
```
Forward:
W_q = Quantize(W_base) # Quantize frozen base
W_effective = W_q + B @ A * (a/r) # Add LoRA delta
y = W_effective @ x
Backward:
dA, dB = compute_lora_gradients() # Only LoRA params get gradients
# W_base stays frozen and quantized
```
**Advantages:**
- Memory: only LoRA params (A, B) need optimizer state (~50 MB for rank-16)
- Speed: much faster convergence (1 epoch vs 3)
- Flexibility: different LoRA adapters for different tasks on same quantized base
- Fits ruvLLM's existing MicroLoRA infrastructure
### 6.3 Training Data Quality
For reasoning-preserving QAT, training data must include:
1. **Chain-of-thought examples**: Step-by-step reasoning traces
2. **Multi-turn dialogues**: Context-dependent reasoning
3. **Code generation**: Structured output with syntax constraints
4. **Mathematical proofs**: Formal logical sequences
ruvLLM's `training/` directory already has dataset generators for
tool use (`tool_dataset.rs`) and Claude-style data (`claude_dataset.rs`).
A reasoning-focused dataset generator is needed.
## 7. Rust Implementation Strategy
### 7.1 Core QAT Types
```rust
/// Configuration for quantization-aware training
pub struct QatConfig {
/// Target bit-width for weights
pub weight_bits: u8, // 2, 3, 4, or 8
/// Target bit-width for KV cache
pub kv_cache_bits: u8, // 2, 4, or 8
/// Quantization granularity
pub granularity: QuantGranularity, // PerTensor, PerChannel, PerGroup(n)
/// STE variant
pub ste_variant: SteVariant, // Standard, Clipped, LSQ, EWGS
/// Clip ratio for clipped STE
pub clip_ratio: f32, // default 1.0
/// Whether to use mixed-precision (ParetoQ-style)
pub mixed_precision: bool,
/// Teacher model path (for distillation)
pub teacher_model: Option<PathBuf>,
/// Distillation temperature
pub temperature: f32, // default 4.0
/// Loss weights
pub alpha_task: f32, // default 1.0
pub beta_kd: f32, // default 0.5
pub gamma_reasoning: f32, // default 2.0
}
/// STE variant for backward pass
pub enum SteVariant {
Standard,
Clipped { clip_val: f32 },
LearnedStepSize,
ElasticWeightGradient { lambda: f32 },
}
/// Quantization granularity
pub enum QuantGranularity {
PerTensor,
PerChannel,
PerGroup(usize), // group size (e.g., 128 for GPTQ-style)
PerBlock(usize), // block size (e.g., 256 for K-quants)
}
```
### 7.2 Integration Points with ruvLLM
```
Existing module QAT integration needed
---------------------------------------------------------
quantize/ruvltra_quant Add differentiable quantize/dequantize
bitnet/quantizer Add STE backward pass
training/real_trainer Add QAT training loop
training/grpo Support quantized policy model
lora/micro_lora LoRA-QAT: adapt on quantized base
sona/integration Post-deployment continual QAT
kv_cache Trainable KV quantization parameters
backends/candle_backend Forward pass with simulated quantization
```
### 7.3 Gradient Computation through Quantization
The critical Rust implementation -- differentiable quantization:
```rust
/// Differentiable 2-bit quantization with STE
pub fn quantize_2bit_ste(
weights: &[f32], // latent FP32 weights
scale: f32, // quantization scale
grad_output: &[f32], // upstream gradient
) -> (Vec<u8>, Vec<f32>) { // (quantized, weight_gradients)
let mut quantized = Vec::with_capacity(weights.len() / 4);
let mut grad_input = Vec::with_capacity(weights.len());
for (w, g) in weights.iter().zip(grad_output.iter()) {
// Forward: quantize to {-2, -1, 0, 1} * scale
let w_scaled = w / scale;
let q = w_scaled.round().clamp(-2.0, 1.0);
// Backward: STE with clipping
let in_range = w_scaled.abs() <= 2.0;
let grad = if in_range { *g } else { 0.0 };
grad_input.push(grad);
// Pack 4 values into 1 byte
// (packing logic omitted for clarity)
}
(quantized, grad_input)
}
```
## 8. Open Questions
1. **STE bias at 2-bit**: The STE gradient bias is larger at lower bit-widths.
Does EWGS or LSQ compensate sufficiently, or do we need custom estimators?
2. **Calibration data size**: Is 50K samples enough for mixed-domain calibration?
Larger calibration may help but increases Stage 1 cost.
3. **LoRA rank for QAT**: What LoRA rank preserves reasoning at 2-bit?
ruvLLM's MicroLoRA uses rank-1; QAT may need rank-8 to rank-16.
4. **Continual QAT**: Can SONA's three-tier learning loop perform incremental
QAT, adjusting quantization parameters as the model adapts?
5. **Hardware-specific grids**: Should quantization grids be optimized for
specific hardware (ANE tile sizes, NEON SIMD widths)?