feat(adr-189/190): IncrementalLandmarks + decode_batch + parallel feature

- IncrementalLandmarks: Welford O(H×D) online mean update per append replaces
  O(T×H×D) Landmarks::from_kv rebuild in decode_step — O(1) amortised per token
- KvCache: add block_size param, try_append (non-panicking), is_full, reset,
  append_all (bulk prefill load with landmark update)
- decode_step: fix pre-append convention (i = cache.len-1, seq = cache.len);
  use cache.landmarks instead of per-step rebuild; empty-cache guard
- decode_batch: speculative-decode support for q.seq >= 1; appends tokens
  incrementally, correct landmark state per draft token
- parallel feature: optional rayon head-parallel forward() path (~4× prefill
  speedup on multi-core); serial path remains zero-dep by default
- 21 tests pass (serial + parallel features), 4 new tests:
  incremental_landmarks_match_static, try_append_at_capacity_returns_error,
  kv_cache_reset_clears_state, decode_batch_shape_and_matches_sequential

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruvnet 2026-05-06 12:33:31 -04:00
parent 259c289651
commit 4db35f2802
3 changed files with 486 additions and 119 deletions

1
Cargo.lock generated
View file

@ -10739,6 +10739,7 @@ version = "0.1.0"
dependencies = [
"criterion 0.5.1",
"rand 0.8.5",
"rayon",
]
[[package]]

View file

@ -5,7 +5,13 @@ edition = "2021"
license = "MIT"
description = "Subquadratic sparse attention kernel for ruvllm style inference"
[features]
default = []
# Enable parallel head loops via rayon (~4× prefill speedup on multi-core).
parallel = ["dep:rayon"]
[dependencies]
rayon = { version = "1", optional = true }
[dev-dependencies]
# rand is only used in tests/benchmarks — keep zero runtime dep footprint (ADR-183)

View file

@ -139,98 +139,129 @@ impl AttentionBackend for SubquadraticSparseAttention {
None
};
let mut out = Tensor3::zeros(seq, heads, dim);
let mut seen_tokens = vec![0usize; seq.max(1)];
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), self.config.block_size)];
let mut token_candidates = Vec::<usize>::with_capacity(self.config.window + 64);
let mut block_candidates = Vec::<usize>::with_capacity(64);
let mut acc = vec![0f32; dim];
for h in 0..heads {
for i in 0..seq {
// Stamp is unique per (head, token) pair so that seen_tokens[] and
// seen_blocks[] — allocated once and shared across heads — correctly
// reset deduplication state between heads. Formula: 1 + h*seq + i.
// See also: estimate_sparse_edges(), which uses a per-token stamp
// (i+1) because it has no head loop and estimates per-head edge count.
let stamp = 1 + h * seq + i;
token_candidates.clear();
block_candidates.clear();
build_token_candidates(
i,
seq,
&self.config,
&mut seen_tokens,
stamp,
&mut token_candidates,
);
if landmarks.is_some() {
build_landmark_candidates(
i,
seq,
&self.config,
&mut seen_blocks,
stamp,
&mut block_candidates,
);
}
let q_row = q.row(i, h);
// One-pass online softmax (ADR-184): single traversal over candidates
// using a running max + correction factor. Eliminates two-pass
// dot-product redundancy (~2× FLOPs reduction on Pi 5 NEON paths).
let mut running_max = f32::NEG_INFINITY;
let mut denom = 0.0f32;
acc.fill(0.0);
for &j in &token_candidates {
let score = dot(q_row, k.row(j, h)) * scale;
if score > running_max {
let corr = (running_max - score).exp();
for d in 0..dim {
acc[d] *= corr;
// Parallel path: each head gets its own dedup state — no shared mutation.
#[cfg(feature = "parallel")]
let out = {
use rayon::prelude::*;
let lm_ref = landmarks.as_ref();
let config = &self.config;
let head_vecs: Vec<Vec<f32>> = (0..heads).into_par_iter().map(|h| {
let mut seen_tokens = vec![0usize; seq.max(1)];
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), config.block_size)];
let mut tok_c = Vec::<usize>::with_capacity(config.window + 64);
let mut blk_c = Vec::<usize>::with_capacity(64);
let mut acc = vec![0f32; dim];
let mut hout = vec![0f32; seq * dim];
for i in 0..seq {
let stamp = 1 + h * seq + i;
tok_c.clear(); blk_c.clear();
build_token_candidates(i, seq, config, &mut seen_tokens, stamp, &mut tok_c);
if lm_ref.is_some() {
build_landmark_candidates(i, seq, config, &mut seen_blocks, stamp, &mut blk_c);
}
let q_row = q.row(i, h);
let mut running_max = f32::NEG_INFINITY;
let mut denom = 0.0f32;
acc.fill(0.0);
for &j in &tok_c {
let score = dot(q_row, k.row(j, h)) * scale;
if score > running_max {
let c = (running_max - score).exp();
for d in 0..dim { acc[d] *= c; }
denom *= c; running_max = score;
}
denom *= corr;
running_max = score;
let w = (score - running_max).exp();
denom += w;
let vr = v.row(j, h);
for d in 0..dim { acc[d] += w * vr[d]; }
}
let w = (score - running_max).exp();
denom += w;
let v_row = v.row(j, h);
for d in 0..dim {
acc[d] += w * v_row[d];
if let Some(lm) = lm_ref {
for &b in &blk_c {
let score = dot(q_row, lm.keys.row(b, h)) * scale;
if score > running_max {
let c = (running_max - score).exp();
for d in 0..dim { acc[d] *= c; }
denom *= c; running_max = score;
}
let w = (score - running_max).exp();
denom += w;
let vr = lm.values.row(b, h);
for d in 0..dim { acc[d] += w * vr[d]; }
}
}
let inv = if denom > 0.0 { 1.0 / denom } else { 0.0 };
let s = &mut hout[i * dim..(i + 1) * dim];
for d in 0..dim { s[d] = acc[d] * inv; }
}
hout
}).collect();
let mut out = Tensor3::zeros(seq, heads, dim);
for h in 0..heads {
for i in 0..seq {
out.row_mut(i, h).copy_from_slice(&head_vecs[h][i * dim..(i + 1) * dim]);
}
}
out
};
if let Some(lm) = landmarks.as_ref() {
for &b in &block_candidates {
let score = dot(q_row, lm.keys.row(b, h)) * scale;
// Serial path (default — zero extra deps, works on no_std / WASM).
#[cfg(not(feature = "parallel"))]
let out = {
let mut out = Tensor3::zeros(seq, heads, dim);
let mut seen_tokens = vec![0usize; seq.max(1)];
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), self.config.block_size)];
let mut token_candidates = Vec::<usize>::with_capacity(self.config.window + 64);
let mut block_candidates = Vec::<usize>::with_capacity(64);
let mut acc = vec![0f32; dim];
for h in 0..heads {
for i in 0..seq {
let stamp = 1 + h * seq + i;
token_candidates.clear();
block_candidates.clear();
build_token_candidates(i, seq, &self.config, &mut seen_tokens, stamp, &mut token_candidates);
if landmarks.is_some() {
build_landmark_candidates(i, seq, &self.config, &mut seen_blocks, stamp, &mut block_candidates);
}
let q_row = q.row(i, h);
let mut running_max = f32::NEG_INFINITY;
let mut denom = 0.0f32;
acc.fill(0.0);
for &j in &token_candidates {
let score = dot(q_row, k.row(j, h)) * scale;
if score > running_max {
let corr = (running_max - score).exp();
for d in 0..dim {
acc[d] *= corr;
}
for d in 0..dim { acc[d] *= corr; }
denom *= corr;
running_max = score;
}
let w = (score - running_max).exp();
denom += w;
let v_row = lm.values.row(b, h);
for d in 0..dim {
acc[d] += w * v_row[d];
let v_row = v.row(j, h);
for d in 0..dim { acc[d] += w * v_row[d]; }
}
if let Some(lm) = landmarks.as_ref() {
for &b in &block_candidates {
let score = dot(q_row, lm.keys.row(b, h)) * scale;
if score > running_max {
let corr = (running_max - score).exp();
for d in 0..dim { acc[d] *= corr; }
denom *= corr;
running_max = score;
}
let w = (score - running_max).exp();
denom += w;
let v_row = lm.values.row(b, h);
for d in 0..dim { acc[d] += w * v_row[d]; }
}
}
}
let out_row = out.row_mut(i, h);
let inv_denom = if denom > 0.0 { 1.0 / denom } else { 0.0 };
for d in 0..dim {
out_row[d] = acc[d] * inv_denom;
let out_row = out.row_mut(i, h);
let inv_denom = if denom > 0.0 { 1.0 / denom } else { 0.0 };
for d in 0..dim { out_row[d] = acc[d] * inv_denom; }
}
}
}
out
};
Ok(out)
}
@ -499,6 +530,66 @@ fn div_ceil(a: usize, b: usize) -> usize {
}
}
/// Incrementally maintained landmark block-means for O(1) decode updates (ADR-189).
///
/// Updated via Welford running mean on each `KvCache::append`, eliminating the
/// O(T × kv_heads × dim) full rebuild that `decode_step` previously did per token.
#[derive(Clone, Debug)]
pub struct IncrementalLandmarks {
/// Running block-mean keys [max_blocks, kv_heads, dim]
pub keys: Tensor3,
/// Running block-mean values [max_blocks, kv_heads, dim]
pub values: Tensor3,
counts: Vec<usize>,
pub block_size: usize,
}
impl IncrementalLandmarks {
pub fn new(capacity: usize, block_size: usize, kv_heads: usize, dim: usize) -> Self {
let max_blocks = if block_size == 0 || capacity == 0 {
1
} else {
div_ceil(capacity, block_size)
};
Self {
keys: Tensor3::zeros(max_blocks, kv_heads, dim),
values: Tensor3::zeros(max_blocks, kv_heads, dim),
counts: vec![0; max_blocks],
block_size,
}
}
/// Welford online mean update for the block containing token `t`. O(H × D).
pub fn update(&mut self, t: usize, k: &Tensor3, v: &Tensor3) {
if self.block_size == 0 {
return;
}
let b = t / self.block_size;
if b >= self.counts.len() {
return;
}
self.counts[b] += 1;
let count = self.counts[b] as f32;
for h in 0..k.heads {
let k_src = k.row(0, h);
let v_src = v.row(0, h);
let k_dst = self.keys.row_mut(b, h);
let v_dst = self.values.row_mut(b, h);
for d in 0..k.dim {
k_dst[d] += (k_src[d] - k_dst[d]) / count;
v_dst[d] += (v_src[d] - v_dst[d]) / count;
}
}
}
/// Reset all block means (does not free memory).
pub fn reset(&mut self) {
self.keys.data.fill(0.0);
self.values.data.fill(0.0);
self.counts.fill(0);
}
}
/// KV cache for incremental decode. Stores keys/values for all previous tokens.
/// For GQA/MQA set kv_heads = k.heads (8 for Mistral-7B, not 32).
#[derive(Clone, Debug)]
@ -507,37 +598,121 @@ pub struct KvCache {
pub values: Tensor3,
pub len: usize,
pub capacity: usize,
/// Incrementally maintained block-mean landmarks — updated O(H×D) per append.
pub landmarks: IncrementalLandmarks,
}
impl KvCache {
pub fn new(capacity: usize, kv_heads: usize, dim: usize) -> Self {
pub fn new(capacity: usize, kv_heads: usize, dim: usize, block_size: usize) -> Self {
Self {
keys: Tensor3::zeros(capacity, kv_heads, dim),
values: Tensor3::zeros(capacity, kv_heads, dim),
len: 0,
capacity,
landmarks: IncrementalLandmarks::new(capacity, block_size, kv_heads, dim),
}
}
/// Append one new token's K and V slices (shape [1, kv_heads, dim]).
/// Append one new token's K/V (shape [1, kv_heads, dim]); updates landmarks.
/// Panics on capacity overflow — use `try_append` for recoverable errors.
pub fn append(&mut self, k: &Tensor3, v: &Tensor3) {
assert_eq!(k.seq, 1, "append expects single-token tensors");
assert_eq!(v.seq, 1);
assert!(self.len < self.capacity, "KvCache capacity exceeded");
for h in 0..k.heads {
let dst_k = self.keys.row_mut(self.len, h);
dst_k.copy_from_slice(k.row(0, h));
let dst_v = self.values.row_mut(self.len, h);
dst_v.copy_from_slice(v.row(0, h));
self.try_append(k, v).expect("KvCache capacity exceeded");
}
/// Non-panicking append. Returns `Err` on capacity overflow or shape mismatch.
pub fn try_append(&mut self, k: &Tensor3, v: &Tensor3) -> Result<(), AttentionError> {
if k.seq != 1 || v.seq != 1 {
return Err(AttentionError::InvalidConfig(
"try_append expects single-token tensors (seq == 1)".into(),
));
}
if k.heads != self.keys.heads || v.heads != self.keys.heads {
return Err(AttentionError::InvalidConfig(format!(
"kv_heads mismatch: cache={}, k={}, v={}",
self.keys.heads, k.heads, v.heads
)));
}
if self.len >= self.capacity {
return Err(AttentionError::InvalidConfig(format!(
"KvCache capacity exceeded: capacity={}, len={}",
self.capacity, self.len
)));
}
for h in 0..k.heads {
self.keys.row_mut(self.len, h).copy_from_slice(k.row(0, h));
self.values.row_mut(self.len, h).copy_from_slice(v.row(0, h));
}
self.landmarks.update(self.len, k, v);
self.len += 1;
Ok(())
}
pub fn is_full(&self) -> bool {
self.len >= self.capacity
}
/// Bulk-load a multi-token K/V tensor (shape [n, kv_heads, dim]) into the cache.
/// Used for the prefill pass to populate the cache from the full prompt.
pub fn append_all(&mut self, k: &Tensor3, v: &Tensor3) -> Result<(), AttentionError> {
let n = k.seq;
if v.seq != n {
return Err(AttentionError::InvalidConfig(
"append_all: k.seq != v.seq".into(),
));
}
if k.heads != self.keys.heads || v.heads != self.keys.heads {
return Err(AttentionError::InvalidConfig(format!(
"kv_heads mismatch: cache={}, k={}, v={}",
self.keys.heads, k.heads, v.heads
)));
}
if self.len + n > self.capacity {
return Err(AttentionError::InvalidConfig(format!(
"KvCache overflow: capacity={}, len={}, adding={}",
self.capacity, self.len, n
)));
}
let kv_heads = k.heads;
let dim = k.dim;
for t in 0..n {
let pos = self.len + t;
for h in 0..kv_heads {
self.keys.row_mut(pos, h).copy_from_slice(k.row(t, h));
self.values.row_mut(pos, h).copy_from_slice(v.row(t, h));
}
// Update incremental landmarks using a single-token view (avoids allocation
// for the common case where landmarks are disabled in prefill).
if self.landmarks.block_size > 0 {
let k_t = Tensor3::from_vec(
k.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
1, kv_heads, dim,
).unwrap();
let v_t = Tensor3::from_vec(
v.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
1, kv_heads, dim,
).unwrap();
self.landmarks.update(pos, &k_t, &v_t);
}
}
self.len += n;
Ok(())
}
/// Reset to empty without freeing memory.
pub fn reset(&mut self) {
self.len = 0;
self.landmarks.reset();
}
}
impl SubquadraticSparseAttention {
/// Single-token decode against the KV cache (ADR-189).
/// q: shape [1, q_heads, dim]. cache: keys/values for all prior tokens (kv_heads may differ for GQA).
/// Returns shape [1, q_heads, dim].
///
/// **Caller must append the new token's K/V to `cache` before calling this.**
/// The new token is at position `cache.len - 1`; landmarks are O(1)-updated
/// by `KvCache::append` so no rebuild is needed here.
///
/// `q`: shape `[1, q_heads, dim]`. Returns shape `[1, q_heads, dim]`.
pub fn decode_step(
&self,
q: &Tensor3,
@ -553,6 +728,9 @@ impl SubquadraticSparseAttention {
"head dimension must be greater than zero".to_string(),
));
}
if cache.len == 0 {
return Ok(Tensor3::zeros(1, q.heads, q.dim));
}
if cache.keys.heads == 0 || q.heads % cache.keys.heads != 0 {
return Err(AttentionError::InvalidConfig(format!(
"q_heads={} must be divisible by kv_heads={}",
@ -565,8 +743,9 @@ impl SubquadraticSparseAttention {
let group_size = q_heads / kv_heads;
let dim = q.dim;
let scale = 1.0f32 / (dim as f32).sqrt();
let i = cache.len;
let seq = i + 1;
// New token was appended before this call; its position is cache.len - 1.
let i = cache.len - 1;
let seq = cache.len;
let mut seen_tokens = vec![0usize; seq.max(1)];
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), self.config.block_size)];
@ -574,13 +753,9 @@ impl SubquadraticSparseAttention {
let mut block_candidates = Vec::with_capacity(64);
build_token_candidates(i, seq, &self.config, &mut seen_tokens, 1, &mut token_candidates);
let landmarks = if self.config.use_landmarks {
if self.config.use_landmarks {
build_landmark_candidates(i, seq, &self.config, &mut seen_blocks, 1, &mut block_candidates);
Some(Landmarks::from_kv(&cache.keys, &cache.values, self.config.block_size))
} else {
None
};
}
let mut out = Tensor3::zeros(1, q_heads, dim);
let mut acc = vec![0f32; dim];
@ -606,20 +781,19 @@ impl SubquadraticSparseAttention {
for d in 0..dim { acc[d] += w * v_row[d]; }
}
if let Some(ref lm) = landmarks {
for &b in &block_candidates {
let score = dot(q_row, lm.keys.row(b, kv_h)) * scale;
if score > running_max {
let corr = (running_max - score).exp();
for d in 0..dim { acc[d] *= corr; }
denom *= corr;
running_max = score;
}
let w = (score - running_max).exp();
denom += w;
let v_row = lm.values.row(b, kv_h);
for d in 0..dim { acc[d] += w * v_row[d]; }
// Use O(1) incremental landmarks — no per-step rebuild.
for &b in &block_candidates {
let score = dot(q_row, cache.landmarks.keys.row(b, kv_h)) * scale;
if score > running_max {
let corr = (running_max - score).exp();
for d in 0..dim { acc[d] *= corr; }
denom *= corr;
running_max = score;
}
let w = (score - running_max).exp();
denom += w;
let v_row = cache.landmarks.values.row(b, kv_h);
for d in 0..dim { acc[d] += w * v_row[d]; }
}
let out_row = out.row_mut(0, h);
@ -629,6 +803,73 @@ impl SubquadraticSparseAttention {
Ok(out)
}
/// Decode a batch of draft tokens (speculative decoding).
///
/// Appends each draft token's K/V to `cache` and computes attention for the
/// corresponding query token against the growing cache. Equivalent to calling
/// `cache.try_append` + `decode_step` for each token in sequence, but shares
/// the allocation overhead.
///
/// `q`, `new_k`, `new_v` must all have the same `seq` (the draft length).
/// Returns shape `[draft_len, q_heads, dim]`.
pub fn decode_batch(
&self,
q: &Tensor3,
new_k: &Tensor3,
new_v: &Tensor3,
cache: &mut KvCache,
) -> Result<Tensor3, AttentionError> {
let draft_len = q.seq;
if draft_len == 0 {
return Ok(Tensor3::zeros(0, q.heads, q.dim));
}
if q.dim == 0 {
return Err(AttentionError::InvalidConfig(
"head dimension must be greater than zero".into(),
));
}
if new_k.seq != draft_len || new_v.seq != draft_len {
return Err(AttentionError::InvalidConfig(format!(
"decode_batch: q.seq={draft_len} but new_k.seq={} new_v.seq={}",
new_k.seq, new_v.seq
)));
}
if cache.keys.heads == 0 || q.heads % cache.keys.heads != 0 {
return Err(AttentionError::InvalidConfig(format!(
"q_heads={} must be divisible by kv_heads={}",
q.heads, cache.keys.heads
)));
}
let q_heads = q.heads;
let kv_heads = new_k.heads;
let dim = q.dim;
let mut out = Tensor3::zeros(draft_len, q_heads, dim);
for t in 0..draft_len {
// Extract single-token slices as owned Tensor3 (avoids unsafe borrow aliasing).
let q_t = Tensor3::from_vec(
q.data[t * q_heads * dim..(t + 1) * q_heads * dim].to_vec(),
1, q_heads, dim,
).unwrap();
let k_t = Tensor3::from_vec(
new_k.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
1, kv_heads, dim,
).unwrap();
let v_t = Tensor3::from_vec(
new_v.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
1, kv_heads, dim,
).unwrap();
cache.try_append(&k_t, &v_t)?;
let out_t = self.decode_step(&q_t, cache)?;
out.data[t * q_heads * dim..(t + 1) * q_heads * dim]
.copy_from_slice(&out_t.data);
}
Ok(out)
}
}
fn validate_gqa(q: &Tensor3, k: &Tensor3, v: &Tensor3) -> Result<(), AttentionError> {
@ -1025,8 +1266,7 @@ mod tests {
// --- ADR-189 KV cache tests ---
#[test]
fn decode_step_single_token_matches_forward_on_single_seq() {
// When seq=1, decode_step (cache empty) must equal forward on q/k/v of shape [1,h,d]
fn decode_step_single_token_matches_forward() {
let heads = 2;
let dim = 8;
let q = make_tensor(1, heads, dim);
@ -1041,24 +1281,27 @@ mod tests {
use_landmarks: false,
})
.unwrap();
let _fwd = attn.forward(&q, &k, &v).unwrap();
let fwd = attn.forward(&q, &k, &v).unwrap();
// Append k/v first, then decode — new convention.
let mut cache = KvCache::new(256, heads, dim, 64);
cache.try_append(&k, &v).unwrap();
let out = attn.decode_step(&q, &cache).unwrap();
// Build empty cache then decode; produces output for the new token at position 0.
let cache2 = KvCache::new(256, heads, dim);
let out = attn.decode_step(&q, &cache2);
assert!(out.is_ok());
// Output shape must be [1, heads, dim]
let out = out.unwrap();
assert_eq!(out.seq, 1);
assert_eq!(out.heads, heads);
assert_eq!(out.dim, dim);
// Values must match forward (single token, window covers entire seq).
for (a, b) in out.data.iter().zip(fwd.data.iter()) {
assert!((a - b).abs() < 1e-5, "decode_step vs forward: {a} vs {b}");
}
}
#[test]
fn kv_cache_append_and_len() {
let heads = 4;
let dim = 16;
let mut cache = KvCache::new(64, heads, dim);
let mut cache = KvCache::new(64, heads, dim, 8);
assert_eq!(cache.len, 0);
let k = make_tensor(1, heads, dim);
let v = make_tensor(1, heads, dim);
@ -1068,6 +1311,123 @@ mod tests {
assert_eq!(cache.len, 2);
}
#[test]
fn try_append_at_capacity_returns_error() {
let heads = 2;
let dim = 4;
let mut cache = KvCache::new(2, heads, dim, 1);
let k = make_tensor(1, heads, dim);
let v = make_tensor(1, heads, dim);
assert!(cache.try_append(&k, &v).is_ok());
assert!(cache.try_append(&k, &v).is_ok());
assert!(cache.try_append(&k, &v).is_err(), "should error on overflow");
}
#[test]
fn kv_cache_reset_clears_state() {
let heads = 2;
let dim = 4;
let mut cache = KvCache::new(8, heads, dim, 2);
let k = make_tensor(1, heads, dim);
let v = make_tensor(1, heads, dim);
cache.append(&k, &v);
cache.append(&k, &v);
assert_eq!(cache.len, 2);
cache.reset();
assert_eq!(cache.len, 0);
assert!(!cache.is_full());
}
#[test]
fn incremental_landmarks_match_static() {
// After appending all tokens, IncrementalLandmarks means must equal
// the static Landmarks::from_kv result (within fp rounding).
let seq = 16;
let heads = 2;
let dim = 8;
let block_size = 4;
let k = make_tensor(seq, heads, dim);
let v = make_tensor(seq, heads, dim);
let static_lm = Landmarks::from_kv(&k, &v, block_size);
let mut inc_lm = IncrementalLandmarks::new(seq, block_size, heads, dim);
for t in 0..seq {
let k_t = Tensor3::from_vec(
k.data[t * heads * dim..(t + 1) * heads * dim].to_vec(),
1, heads, dim,
).unwrap();
let v_t = Tensor3::from_vec(
v.data[t * heads * dim..(t + 1) * heads * dim].to_vec(),
1, heads, dim,
).unwrap();
inc_lm.update(t, &k_t, &v_t);
}
for (a, b) in inc_lm.keys.data.iter().zip(static_lm.keys.data.iter()) {
assert!((a - b).abs() < 1e-5, "landmark keys mismatch: {a} vs {b}");
}
for (a, b) in inc_lm.values.data.iter().zip(static_lm.values.data.iter()) {
assert!((a - b).abs() < 1e-5, "landmark values mismatch: {a} vs {b}");
}
}
#[test]
fn decode_batch_shape_and_matches_sequential_decode_steps() {
let q_heads = 4;
let kv_heads = 2;
let dim = 8;
let draft_len = 4;
let capacity = 32;
let block_size = 4;
let attn = SubquadraticSparseAttention::new(SparseAttentionConfig {
window: 16,
block_size,
global_tokens: vec![],
causal: true,
use_log_stride: false,
use_landmarks: false,
}).unwrap();
let q = make_tensor(draft_len, q_heads, dim);
let new_k = make_tensor(draft_len, kv_heads, dim);
let new_v = make_tensor(draft_len, kv_heads, dim);
// Batch path
let mut cache_batch = KvCache::new(capacity, kv_heads, dim, block_size);
let batch_out = attn.decode_batch(&q, &new_k, &new_v, &mut cache_batch).unwrap();
assert_eq!(batch_out.seq, draft_len);
assert_eq!(batch_out.heads, q_heads);
assert_eq!(batch_out.dim, dim);
// Sequential path (reference)
let mut cache_seq = KvCache::new(capacity, kv_heads, dim, block_size);
let mut seq_out = Tensor3::zeros(draft_len, q_heads, dim);
for t in 0..draft_len {
let q_t = Tensor3::from_vec(
q.data[t * q_heads * dim..(t + 1) * q_heads * dim].to_vec(),
1, q_heads, dim,
).unwrap();
let k_t = Tensor3::from_vec(
new_k.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
1, kv_heads, dim,
).unwrap();
let v_t = Tensor3::from_vec(
new_v.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
1, kv_heads, dim,
).unwrap();
cache_seq.try_append(&k_t, &v_t).unwrap();
let out_t = attn.decode_step(&q_t, &cache_seq).unwrap();
seq_out.data[t * q_heads * dim..(t + 1) * q_heads * dim]
.copy_from_slice(&out_t.data);
}
for (a, b) in batch_out.data.iter().zip(seq_out.data.iter()) {
assert!((a - b).abs() < 1e-5, "decode_batch vs sequential: {a} vs {b}");
}
}
// --- ADR-190 GQA/MQA tests ---
#[test]