mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-24 22:15:18 +00:00
feat(adr-189/190): IncrementalLandmarks + decode_batch + parallel feature
- IncrementalLandmarks: Welford O(H×D) online mean update per append replaces O(T×H×D) Landmarks::from_kv rebuild in decode_step — O(1) amortised per token - KvCache: add block_size param, try_append (non-panicking), is_full, reset, append_all (bulk prefill load with landmark update) - decode_step: fix pre-append convention (i = cache.len-1, seq = cache.len); use cache.landmarks instead of per-step rebuild; empty-cache guard - decode_batch: speculative-decode support for q.seq >= 1; appends tokens incrementally, correct landmark state per draft token - parallel feature: optional rayon head-parallel forward() path (~4× prefill speedup on multi-core); serial path remains zero-dep by default - 21 tests pass (serial + parallel features), 4 new tests: incremental_landmarks_match_static, try_append_at_capacity_returns_error, kv_cache_reset_clears_state, decode_batch_shape_and_matches_sequential Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
259c289651
commit
4db35f2802
3 changed files with 486 additions and 119 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -10739,6 +10739,7 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"criterion 0.5.1",
|
||||
"rand 0.8.5",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,13 @@ edition = "2021"
|
|||
license = "MIT"
|
||||
description = "Subquadratic sparse attention kernel for ruvllm style inference"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Enable parallel head loops via rayon (~4× prefill speedup on multi-core).
|
||||
parallel = ["dep:rayon"]
|
||||
|
||||
[dependencies]
|
||||
rayon = { version = "1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
# rand is only used in tests/benchmarks — keep zero runtime dep footprint (ADR-183)
|
||||
|
|
|
|||
|
|
@ -139,98 +139,129 @@ impl AttentionBackend for SubquadraticSparseAttention {
|
|||
None
|
||||
};
|
||||
|
||||
let mut out = Tensor3::zeros(seq, heads, dim);
|
||||
let mut seen_tokens = vec![0usize; seq.max(1)];
|
||||
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), self.config.block_size)];
|
||||
let mut token_candidates = Vec::<usize>::with_capacity(self.config.window + 64);
|
||||
let mut block_candidates = Vec::<usize>::with_capacity(64);
|
||||
let mut acc = vec![0f32; dim];
|
||||
|
||||
for h in 0..heads {
|
||||
for i in 0..seq {
|
||||
// Stamp is unique per (head, token) pair so that seen_tokens[] and
|
||||
// seen_blocks[] — allocated once and shared across heads — correctly
|
||||
// reset deduplication state between heads. Formula: 1 + h*seq + i.
|
||||
// See also: estimate_sparse_edges(), which uses a per-token stamp
|
||||
// (i+1) because it has no head loop and estimates per-head edge count.
|
||||
let stamp = 1 + h * seq + i;
|
||||
token_candidates.clear();
|
||||
block_candidates.clear();
|
||||
|
||||
build_token_candidates(
|
||||
i,
|
||||
seq,
|
||||
&self.config,
|
||||
&mut seen_tokens,
|
||||
stamp,
|
||||
&mut token_candidates,
|
||||
);
|
||||
|
||||
if landmarks.is_some() {
|
||||
build_landmark_candidates(
|
||||
i,
|
||||
seq,
|
||||
&self.config,
|
||||
&mut seen_blocks,
|
||||
stamp,
|
||||
&mut block_candidates,
|
||||
);
|
||||
}
|
||||
|
||||
let q_row = q.row(i, h);
|
||||
|
||||
// One-pass online softmax (ADR-184): single traversal over candidates
|
||||
// using a running max + correction factor. Eliminates two-pass
|
||||
// dot-product redundancy (~2× FLOPs reduction on Pi 5 NEON paths).
|
||||
let mut running_max = f32::NEG_INFINITY;
|
||||
let mut denom = 0.0f32;
|
||||
acc.fill(0.0);
|
||||
|
||||
for &j in &token_candidates {
|
||||
let score = dot(q_row, k.row(j, h)) * scale;
|
||||
if score > running_max {
|
||||
let corr = (running_max - score).exp();
|
||||
for d in 0..dim {
|
||||
acc[d] *= corr;
|
||||
// Parallel path: each head gets its own dedup state — no shared mutation.
|
||||
#[cfg(feature = "parallel")]
|
||||
let out = {
|
||||
use rayon::prelude::*;
|
||||
let lm_ref = landmarks.as_ref();
|
||||
let config = &self.config;
|
||||
let head_vecs: Vec<Vec<f32>> = (0..heads).into_par_iter().map(|h| {
|
||||
let mut seen_tokens = vec![0usize; seq.max(1)];
|
||||
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), config.block_size)];
|
||||
let mut tok_c = Vec::<usize>::with_capacity(config.window + 64);
|
||||
let mut blk_c = Vec::<usize>::with_capacity(64);
|
||||
let mut acc = vec![0f32; dim];
|
||||
let mut hout = vec![0f32; seq * dim];
|
||||
for i in 0..seq {
|
||||
let stamp = 1 + h * seq + i;
|
||||
tok_c.clear(); blk_c.clear();
|
||||
build_token_candidates(i, seq, config, &mut seen_tokens, stamp, &mut tok_c);
|
||||
if lm_ref.is_some() {
|
||||
build_landmark_candidates(i, seq, config, &mut seen_blocks, stamp, &mut blk_c);
|
||||
}
|
||||
let q_row = q.row(i, h);
|
||||
let mut running_max = f32::NEG_INFINITY;
|
||||
let mut denom = 0.0f32;
|
||||
acc.fill(0.0);
|
||||
for &j in &tok_c {
|
||||
let score = dot(q_row, k.row(j, h)) * scale;
|
||||
if score > running_max {
|
||||
let c = (running_max - score).exp();
|
||||
for d in 0..dim { acc[d] *= c; }
|
||||
denom *= c; running_max = score;
|
||||
}
|
||||
denom *= corr;
|
||||
running_max = score;
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let vr = v.row(j, h);
|
||||
for d in 0..dim { acc[d] += w * vr[d]; }
|
||||
}
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let v_row = v.row(j, h);
|
||||
for d in 0..dim {
|
||||
acc[d] += w * v_row[d];
|
||||
if let Some(lm) = lm_ref {
|
||||
for &b in &blk_c {
|
||||
let score = dot(q_row, lm.keys.row(b, h)) * scale;
|
||||
if score > running_max {
|
||||
let c = (running_max - score).exp();
|
||||
for d in 0..dim { acc[d] *= c; }
|
||||
denom *= c; running_max = score;
|
||||
}
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let vr = lm.values.row(b, h);
|
||||
for d in 0..dim { acc[d] += w * vr[d]; }
|
||||
}
|
||||
}
|
||||
let inv = if denom > 0.0 { 1.0 / denom } else { 0.0 };
|
||||
let s = &mut hout[i * dim..(i + 1) * dim];
|
||||
for d in 0..dim { s[d] = acc[d] * inv; }
|
||||
}
|
||||
hout
|
||||
}).collect();
|
||||
let mut out = Tensor3::zeros(seq, heads, dim);
|
||||
for h in 0..heads {
|
||||
for i in 0..seq {
|
||||
out.row_mut(i, h).copy_from_slice(&head_vecs[h][i * dim..(i + 1) * dim]);
|
||||
}
|
||||
}
|
||||
out
|
||||
};
|
||||
|
||||
if let Some(lm) = landmarks.as_ref() {
|
||||
for &b in &block_candidates {
|
||||
let score = dot(q_row, lm.keys.row(b, h)) * scale;
|
||||
// Serial path (default — zero extra deps, works on no_std / WASM).
|
||||
#[cfg(not(feature = "parallel"))]
|
||||
let out = {
|
||||
let mut out = Tensor3::zeros(seq, heads, dim);
|
||||
let mut seen_tokens = vec![0usize; seq.max(1)];
|
||||
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), self.config.block_size)];
|
||||
let mut token_candidates = Vec::<usize>::with_capacity(self.config.window + 64);
|
||||
let mut block_candidates = Vec::<usize>::with_capacity(64);
|
||||
let mut acc = vec![0f32; dim];
|
||||
|
||||
for h in 0..heads {
|
||||
for i in 0..seq {
|
||||
let stamp = 1 + h * seq + i;
|
||||
token_candidates.clear();
|
||||
block_candidates.clear();
|
||||
build_token_candidates(i, seq, &self.config, &mut seen_tokens, stamp, &mut token_candidates);
|
||||
if landmarks.is_some() {
|
||||
build_landmark_candidates(i, seq, &self.config, &mut seen_blocks, stamp, &mut block_candidates);
|
||||
}
|
||||
let q_row = q.row(i, h);
|
||||
let mut running_max = f32::NEG_INFINITY;
|
||||
let mut denom = 0.0f32;
|
||||
acc.fill(0.0);
|
||||
for &j in &token_candidates {
|
||||
let score = dot(q_row, k.row(j, h)) * scale;
|
||||
if score > running_max {
|
||||
let corr = (running_max - score).exp();
|
||||
for d in 0..dim {
|
||||
acc[d] *= corr;
|
||||
}
|
||||
for d in 0..dim { acc[d] *= corr; }
|
||||
denom *= corr;
|
||||
running_max = score;
|
||||
}
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let v_row = lm.values.row(b, h);
|
||||
for d in 0..dim {
|
||||
acc[d] += w * v_row[d];
|
||||
let v_row = v.row(j, h);
|
||||
for d in 0..dim { acc[d] += w * v_row[d]; }
|
||||
}
|
||||
if let Some(lm) = landmarks.as_ref() {
|
||||
for &b in &block_candidates {
|
||||
let score = dot(q_row, lm.keys.row(b, h)) * scale;
|
||||
if score > running_max {
|
||||
let corr = (running_max - score).exp();
|
||||
for d in 0..dim { acc[d] *= corr; }
|
||||
denom *= corr;
|
||||
running_max = score;
|
||||
}
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let v_row = lm.values.row(b, h);
|
||||
for d in 0..dim { acc[d] += w * v_row[d]; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let out_row = out.row_mut(i, h);
|
||||
let inv_denom = if denom > 0.0 { 1.0 / denom } else { 0.0 };
|
||||
for d in 0..dim {
|
||||
out_row[d] = acc[d] * inv_denom;
|
||||
let out_row = out.row_mut(i, h);
|
||||
let inv_denom = if denom > 0.0 { 1.0 / denom } else { 0.0 };
|
||||
for d in 0..dim { out_row[d] = acc[d] * inv_denom; }
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
};
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
|
@ -499,6 +530,66 @@ fn div_ceil(a: usize, b: usize) -> usize {
|
|||
}
|
||||
}
|
||||
|
||||
/// Incrementally maintained landmark block-means for O(1) decode updates (ADR-189).
|
||||
///
|
||||
/// Updated via Welford running mean on each `KvCache::append`, eliminating the
|
||||
/// O(T × kv_heads × dim) full rebuild that `decode_step` previously did per token.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct IncrementalLandmarks {
|
||||
/// Running block-mean keys [max_blocks, kv_heads, dim]
|
||||
pub keys: Tensor3,
|
||||
/// Running block-mean values [max_blocks, kv_heads, dim]
|
||||
pub values: Tensor3,
|
||||
counts: Vec<usize>,
|
||||
pub block_size: usize,
|
||||
}
|
||||
|
||||
impl IncrementalLandmarks {
|
||||
pub fn new(capacity: usize, block_size: usize, kv_heads: usize, dim: usize) -> Self {
|
||||
let max_blocks = if block_size == 0 || capacity == 0 {
|
||||
1
|
||||
} else {
|
||||
div_ceil(capacity, block_size)
|
||||
};
|
||||
Self {
|
||||
keys: Tensor3::zeros(max_blocks, kv_heads, dim),
|
||||
values: Tensor3::zeros(max_blocks, kv_heads, dim),
|
||||
counts: vec![0; max_blocks],
|
||||
block_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Welford online mean update for the block containing token `t`. O(H × D).
|
||||
pub fn update(&mut self, t: usize, k: &Tensor3, v: &Tensor3) {
|
||||
if self.block_size == 0 {
|
||||
return;
|
||||
}
|
||||
let b = t / self.block_size;
|
||||
if b >= self.counts.len() {
|
||||
return;
|
||||
}
|
||||
self.counts[b] += 1;
|
||||
let count = self.counts[b] as f32;
|
||||
for h in 0..k.heads {
|
||||
let k_src = k.row(0, h);
|
||||
let v_src = v.row(0, h);
|
||||
let k_dst = self.keys.row_mut(b, h);
|
||||
let v_dst = self.values.row_mut(b, h);
|
||||
for d in 0..k.dim {
|
||||
k_dst[d] += (k_src[d] - k_dst[d]) / count;
|
||||
v_dst[d] += (v_src[d] - v_dst[d]) / count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset all block means (does not free memory).
|
||||
pub fn reset(&mut self) {
|
||||
self.keys.data.fill(0.0);
|
||||
self.values.data.fill(0.0);
|
||||
self.counts.fill(0);
|
||||
}
|
||||
}
|
||||
|
||||
/// KV cache for incremental decode. Stores keys/values for all previous tokens.
|
||||
/// For GQA/MQA set kv_heads = k.heads (8 for Mistral-7B, not 32).
|
||||
#[derive(Clone, Debug)]
|
||||
|
|
@ -507,37 +598,121 @@ pub struct KvCache {
|
|||
pub values: Tensor3,
|
||||
pub len: usize,
|
||||
pub capacity: usize,
|
||||
/// Incrementally maintained block-mean landmarks — updated O(H×D) per append.
|
||||
pub landmarks: IncrementalLandmarks,
|
||||
}
|
||||
|
||||
impl KvCache {
|
||||
pub fn new(capacity: usize, kv_heads: usize, dim: usize) -> Self {
|
||||
pub fn new(capacity: usize, kv_heads: usize, dim: usize, block_size: usize) -> Self {
|
||||
Self {
|
||||
keys: Tensor3::zeros(capacity, kv_heads, dim),
|
||||
values: Tensor3::zeros(capacity, kv_heads, dim),
|
||||
len: 0,
|
||||
capacity,
|
||||
landmarks: IncrementalLandmarks::new(capacity, block_size, kv_heads, dim),
|
||||
}
|
||||
}
|
||||
|
||||
/// Append one new token's K and V slices (shape [1, kv_heads, dim]).
|
||||
/// Append one new token's K/V (shape [1, kv_heads, dim]); updates landmarks.
|
||||
/// Panics on capacity overflow — use `try_append` for recoverable errors.
|
||||
pub fn append(&mut self, k: &Tensor3, v: &Tensor3) {
|
||||
assert_eq!(k.seq, 1, "append expects single-token tensors");
|
||||
assert_eq!(v.seq, 1);
|
||||
assert!(self.len < self.capacity, "KvCache capacity exceeded");
|
||||
for h in 0..k.heads {
|
||||
let dst_k = self.keys.row_mut(self.len, h);
|
||||
dst_k.copy_from_slice(k.row(0, h));
|
||||
let dst_v = self.values.row_mut(self.len, h);
|
||||
dst_v.copy_from_slice(v.row(0, h));
|
||||
self.try_append(k, v).expect("KvCache capacity exceeded");
|
||||
}
|
||||
|
||||
/// Non-panicking append. Returns `Err` on capacity overflow or shape mismatch.
|
||||
pub fn try_append(&mut self, k: &Tensor3, v: &Tensor3) -> Result<(), AttentionError> {
|
||||
if k.seq != 1 || v.seq != 1 {
|
||||
return Err(AttentionError::InvalidConfig(
|
||||
"try_append expects single-token tensors (seq == 1)".into(),
|
||||
));
|
||||
}
|
||||
if k.heads != self.keys.heads || v.heads != self.keys.heads {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"kv_heads mismatch: cache={}, k={}, v={}",
|
||||
self.keys.heads, k.heads, v.heads
|
||||
)));
|
||||
}
|
||||
if self.len >= self.capacity {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"KvCache capacity exceeded: capacity={}, len={}",
|
||||
self.capacity, self.len
|
||||
)));
|
||||
}
|
||||
for h in 0..k.heads {
|
||||
self.keys.row_mut(self.len, h).copy_from_slice(k.row(0, h));
|
||||
self.values.row_mut(self.len, h).copy_from_slice(v.row(0, h));
|
||||
}
|
||||
self.landmarks.update(self.len, k, v);
|
||||
self.len += 1;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.len >= self.capacity
|
||||
}
|
||||
|
||||
/// Bulk-load a multi-token K/V tensor (shape [n, kv_heads, dim]) into the cache.
|
||||
/// Used for the prefill pass to populate the cache from the full prompt.
|
||||
pub fn append_all(&mut self, k: &Tensor3, v: &Tensor3) -> Result<(), AttentionError> {
|
||||
let n = k.seq;
|
||||
if v.seq != n {
|
||||
return Err(AttentionError::InvalidConfig(
|
||||
"append_all: k.seq != v.seq".into(),
|
||||
));
|
||||
}
|
||||
if k.heads != self.keys.heads || v.heads != self.keys.heads {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"kv_heads mismatch: cache={}, k={}, v={}",
|
||||
self.keys.heads, k.heads, v.heads
|
||||
)));
|
||||
}
|
||||
if self.len + n > self.capacity {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"KvCache overflow: capacity={}, len={}, adding={}",
|
||||
self.capacity, self.len, n
|
||||
)));
|
||||
}
|
||||
let kv_heads = k.heads;
|
||||
let dim = k.dim;
|
||||
for t in 0..n {
|
||||
let pos = self.len + t;
|
||||
for h in 0..kv_heads {
|
||||
self.keys.row_mut(pos, h).copy_from_slice(k.row(t, h));
|
||||
self.values.row_mut(pos, h).copy_from_slice(v.row(t, h));
|
||||
}
|
||||
// Update incremental landmarks using a single-token view (avoids allocation
|
||||
// for the common case where landmarks are disabled in prefill).
|
||||
if self.landmarks.block_size > 0 {
|
||||
let k_t = Tensor3::from_vec(
|
||||
k.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
|
||||
1, kv_heads, dim,
|
||||
).unwrap();
|
||||
let v_t = Tensor3::from_vec(
|
||||
v.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
|
||||
1, kv_heads, dim,
|
||||
).unwrap();
|
||||
self.landmarks.update(pos, &k_t, &v_t);
|
||||
}
|
||||
}
|
||||
self.len += n;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reset to empty without freeing memory.
|
||||
pub fn reset(&mut self) {
|
||||
self.len = 0;
|
||||
self.landmarks.reset();
|
||||
}
|
||||
}
|
||||
|
||||
impl SubquadraticSparseAttention {
|
||||
/// Single-token decode against the KV cache (ADR-189).
|
||||
/// q: shape [1, q_heads, dim]. cache: keys/values for all prior tokens (kv_heads may differ for GQA).
|
||||
/// Returns shape [1, q_heads, dim].
|
||||
///
|
||||
/// **Caller must append the new token's K/V to `cache` before calling this.**
|
||||
/// The new token is at position `cache.len - 1`; landmarks are O(1)-updated
|
||||
/// by `KvCache::append` so no rebuild is needed here.
|
||||
///
|
||||
/// `q`: shape `[1, q_heads, dim]`. Returns shape `[1, q_heads, dim]`.
|
||||
pub fn decode_step(
|
||||
&self,
|
||||
q: &Tensor3,
|
||||
|
|
@ -553,6 +728,9 @@ impl SubquadraticSparseAttention {
|
|||
"head dimension must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
if cache.len == 0 {
|
||||
return Ok(Tensor3::zeros(1, q.heads, q.dim));
|
||||
}
|
||||
if cache.keys.heads == 0 || q.heads % cache.keys.heads != 0 {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"q_heads={} must be divisible by kv_heads={}",
|
||||
|
|
@ -565,8 +743,9 @@ impl SubquadraticSparseAttention {
|
|||
let group_size = q_heads / kv_heads;
|
||||
let dim = q.dim;
|
||||
let scale = 1.0f32 / (dim as f32).sqrt();
|
||||
let i = cache.len;
|
||||
let seq = i + 1;
|
||||
// New token was appended before this call; its position is cache.len - 1.
|
||||
let i = cache.len - 1;
|
||||
let seq = cache.len;
|
||||
|
||||
let mut seen_tokens = vec![0usize; seq.max(1)];
|
||||
let mut seen_blocks = vec![0usize; div_ceil(seq.max(1), self.config.block_size)];
|
||||
|
|
@ -574,13 +753,9 @@ impl SubquadraticSparseAttention {
|
|||
let mut block_candidates = Vec::with_capacity(64);
|
||||
|
||||
build_token_candidates(i, seq, &self.config, &mut seen_tokens, 1, &mut token_candidates);
|
||||
|
||||
let landmarks = if self.config.use_landmarks {
|
||||
if self.config.use_landmarks {
|
||||
build_landmark_candidates(i, seq, &self.config, &mut seen_blocks, 1, &mut block_candidates);
|
||||
Some(Landmarks::from_kv(&cache.keys, &cache.values, self.config.block_size))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
}
|
||||
|
||||
let mut out = Tensor3::zeros(1, q_heads, dim);
|
||||
let mut acc = vec![0f32; dim];
|
||||
|
|
@ -606,20 +781,19 @@ impl SubquadraticSparseAttention {
|
|||
for d in 0..dim { acc[d] += w * v_row[d]; }
|
||||
}
|
||||
|
||||
if let Some(ref lm) = landmarks {
|
||||
for &b in &block_candidates {
|
||||
let score = dot(q_row, lm.keys.row(b, kv_h)) * scale;
|
||||
if score > running_max {
|
||||
let corr = (running_max - score).exp();
|
||||
for d in 0..dim { acc[d] *= corr; }
|
||||
denom *= corr;
|
||||
running_max = score;
|
||||
}
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let v_row = lm.values.row(b, kv_h);
|
||||
for d in 0..dim { acc[d] += w * v_row[d]; }
|
||||
// Use O(1) incremental landmarks — no per-step rebuild.
|
||||
for &b in &block_candidates {
|
||||
let score = dot(q_row, cache.landmarks.keys.row(b, kv_h)) * scale;
|
||||
if score > running_max {
|
||||
let corr = (running_max - score).exp();
|
||||
for d in 0..dim { acc[d] *= corr; }
|
||||
denom *= corr;
|
||||
running_max = score;
|
||||
}
|
||||
let w = (score - running_max).exp();
|
||||
denom += w;
|
||||
let v_row = cache.landmarks.values.row(b, kv_h);
|
||||
for d in 0..dim { acc[d] += w * v_row[d]; }
|
||||
}
|
||||
|
||||
let out_row = out.row_mut(0, h);
|
||||
|
|
@ -629,6 +803,73 @@ impl SubquadraticSparseAttention {
|
|||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Decode a batch of draft tokens (speculative decoding).
|
||||
///
|
||||
/// Appends each draft token's K/V to `cache` and computes attention for the
|
||||
/// corresponding query token against the growing cache. Equivalent to calling
|
||||
/// `cache.try_append` + `decode_step` for each token in sequence, but shares
|
||||
/// the allocation overhead.
|
||||
///
|
||||
/// `q`, `new_k`, `new_v` must all have the same `seq` (the draft length).
|
||||
/// Returns shape `[draft_len, q_heads, dim]`.
|
||||
pub fn decode_batch(
|
||||
&self,
|
||||
q: &Tensor3,
|
||||
new_k: &Tensor3,
|
||||
new_v: &Tensor3,
|
||||
cache: &mut KvCache,
|
||||
) -> Result<Tensor3, AttentionError> {
|
||||
let draft_len = q.seq;
|
||||
if draft_len == 0 {
|
||||
return Ok(Tensor3::zeros(0, q.heads, q.dim));
|
||||
}
|
||||
if q.dim == 0 {
|
||||
return Err(AttentionError::InvalidConfig(
|
||||
"head dimension must be greater than zero".into(),
|
||||
));
|
||||
}
|
||||
if new_k.seq != draft_len || new_v.seq != draft_len {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"decode_batch: q.seq={draft_len} but new_k.seq={} new_v.seq={}",
|
||||
new_k.seq, new_v.seq
|
||||
)));
|
||||
}
|
||||
if cache.keys.heads == 0 || q.heads % cache.keys.heads != 0 {
|
||||
return Err(AttentionError::InvalidConfig(format!(
|
||||
"q_heads={} must be divisible by kv_heads={}",
|
||||
q.heads, cache.keys.heads
|
||||
)));
|
||||
}
|
||||
|
||||
let q_heads = q.heads;
|
||||
let kv_heads = new_k.heads;
|
||||
let dim = q.dim;
|
||||
let mut out = Tensor3::zeros(draft_len, q_heads, dim);
|
||||
|
||||
for t in 0..draft_len {
|
||||
// Extract single-token slices as owned Tensor3 (avoids unsafe borrow aliasing).
|
||||
let q_t = Tensor3::from_vec(
|
||||
q.data[t * q_heads * dim..(t + 1) * q_heads * dim].to_vec(),
|
||||
1, q_heads, dim,
|
||||
).unwrap();
|
||||
let k_t = Tensor3::from_vec(
|
||||
new_k.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
|
||||
1, kv_heads, dim,
|
||||
).unwrap();
|
||||
let v_t = Tensor3::from_vec(
|
||||
new_v.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
|
||||
1, kv_heads, dim,
|
||||
).unwrap();
|
||||
|
||||
cache.try_append(&k_t, &v_t)?;
|
||||
let out_t = self.decode_step(&q_t, cache)?;
|
||||
out.data[t * q_heads * dim..(t + 1) * q_heads * dim]
|
||||
.copy_from_slice(&out_t.data);
|
||||
}
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_gqa(q: &Tensor3, k: &Tensor3, v: &Tensor3) -> Result<(), AttentionError> {
|
||||
|
|
@ -1025,8 +1266,7 @@ mod tests {
|
|||
// --- ADR-189 KV cache tests ---
|
||||
|
||||
#[test]
|
||||
fn decode_step_single_token_matches_forward_on_single_seq() {
|
||||
// When seq=1, decode_step (cache empty) must equal forward on q/k/v of shape [1,h,d]
|
||||
fn decode_step_single_token_matches_forward() {
|
||||
let heads = 2;
|
||||
let dim = 8;
|
||||
let q = make_tensor(1, heads, dim);
|
||||
|
|
@ -1041,24 +1281,27 @@ mod tests {
|
|||
use_landmarks: false,
|
||||
})
|
||||
.unwrap();
|
||||
let _fwd = attn.forward(&q, &k, &v).unwrap();
|
||||
let fwd = attn.forward(&q, &k, &v).unwrap();
|
||||
|
||||
// Append k/v first, then decode — new convention.
|
||||
let mut cache = KvCache::new(256, heads, dim, 64);
|
||||
cache.try_append(&k, &v).unwrap();
|
||||
let out = attn.decode_step(&q, &cache).unwrap();
|
||||
|
||||
// Build empty cache then decode; produces output for the new token at position 0.
|
||||
let cache2 = KvCache::new(256, heads, dim);
|
||||
let out = attn.decode_step(&q, &cache2);
|
||||
assert!(out.is_ok());
|
||||
// Output shape must be [1, heads, dim]
|
||||
let out = out.unwrap();
|
||||
assert_eq!(out.seq, 1);
|
||||
assert_eq!(out.heads, heads);
|
||||
assert_eq!(out.dim, dim);
|
||||
// Values must match forward (single token, window covers entire seq).
|
||||
for (a, b) in out.data.iter().zip(fwd.data.iter()) {
|
||||
assert!((a - b).abs() < 1e-5, "decode_step vs forward: {a} vs {b}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kv_cache_append_and_len() {
|
||||
let heads = 4;
|
||||
let dim = 16;
|
||||
let mut cache = KvCache::new(64, heads, dim);
|
||||
let mut cache = KvCache::new(64, heads, dim, 8);
|
||||
assert_eq!(cache.len, 0);
|
||||
let k = make_tensor(1, heads, dim);
|
||||
let v = make_tensor(1, heads, dim);
|
||||
|
|
@ -1068,6 +1311,123 @@ mod tests {
|
|||
assert_eq!(cache.len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_append_at_capacity_returns_error() {
|
||||
let heads = 2;
|
||||
let dim = 4;
|
||||
let mut cache = KvCache::new(2, heads, dim, 1);
|
||||
let k = make_tensor(1, heads, dim);
|
||||
let v = make_tensor(1, heads, dim);
|
||||
assert!(cache.try_append(&k, &v).is_ok());
|
||||
assert!(cache.try_append(&k, &v).is_ok());
|
||||
assert!(cache.try_append(&k, &v).is_err(), "should error on overflow");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn kv_cache_reset_clears_state() {
|
||||
let heads = 2;
|
||||
let dim = 4;
|
||||
let mut cache = KvCache::new(8, heads, dim, 2);
|
||||
let k = make_tensor(1, heads, dim);
|
||||
let v = make_tensor(1, heads, dim);
|
||||
cache.append(&k, &v);
|
||||
cache.append(&k, &v);
|
||||
assert_eq!(cache.len, 2);
|
||||
cache.reset();
|
||||
assert_eq!(cache.len, 0);
|
||||
assert!(!cache.is_full());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn incremental_landmarks_match_static() {
|
||||
// After appending all tokens, IncrementalLandmarks means must equal
|
||||
// the static Landmarks::from_kv result (within fp rounding).
|
||||
let seq = 16;
|
||||
let heads = 2;
|
||||
let dim = 8;
|
||||
let block_size = 4;
|
||||
let k = make_tensor(seq, heads, dim);
|
||||
let v = make_tensor(seq, heads, dim);
|
||||
|
||||
let static_lm = Landmarks::from_kv(&k, &v, block_size);
|
||||
|
||||
let mut inc_lm = IncrementalLandmarks::new(seq, block_size, heads, dim);
|
||||
for t in 0..seq {
|
||||
let k_t = Tensor3::from_vec(
|
||||
k.data[t * heads * dim..(t + 1) * heads * dim].to_vec(),
|
||||
1, heads, dim,
|
||||
).unwrap();
|
||||
let v_t = Tensor3::from_vec(
|
||||
v.data[t * heads * dim..(t + 1) * heads * dim].to_vec(),
|
||||
1, heads, dim,
|
||||
).unwrap();
|
||||
inc_lm.update(t, &k_t, &v_t);
|
||||
}
|
||||
|
||||
for (a, b) in inc_lm.keys.data.iter().zip(static_lm.keys.data.iter()) {
|
||||
assert!((a - b).abs() < 1e-5, "landmark keys mismatch: {a} vs {b}");
|
||||
}
|
||||
for (a, b) in inc_lm.values.data.iter().zip(static_lm.values.data.iter()) {
|
||||
assert!((a - b).abs() < 1e-5, "landmark values mismatch: {a} vs {b}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_batch_shape_and_matches_sequential_decode_steps() {
|
||||
let q_heads = 4;
|
||||
let kv_heads = 2;
|
||||
let dim = 8;
|
||||
let draft_len = 4;
|
||||
let capacity = 32;
|
||||
let block_size = 4;
|
||||
|
||||
let attn = SubquadraticSparseAttention::new(SparseAttentionConfig {
|
||||
window: 16,
|
||||
block_size,
|
||||
global_tokens: vec![],
|
||||
causal: true,
|
||||
use_log_stride: false,
|
||||
use_landmarks: false,
|
||||
}).unwrap();
|
||||
|
||||
let q = make_tensor(draft_len, q_heads, dim);
|
||||
let new_k = make_tensor(draft_len, kv_heads, dim);
|
||||
let new_v = make_tensor(draft_len, kv_heads, dim);
|
||||
|
||||
// Batch path
|
||||
let mut cache_batch = KvCache::new(capacity, kv_heads, dim, block_size);
|
||||
let batch_out = attn.decode_batch(&q, &new_k, &new_v, &mut cache_batch).unwrap();
|
||||
assert_eq!(batch_out.seq, draft_len);
|
||||
assert_eq!(batch_out.heads, q_heads);
|
||||
assert_eq!(batch_out.dim, dim);
|
||||
|
||||
// Sequential path (reference)
|
||||
let mut cache_seq = KvCache::new(capacity, kv_heads, dim, block_size);
|
||||
let mut seq_out = Tensor3::zeros(draft_len, q_heads, dim);
|
||||
for t in 0..draft_len {
|
||||
let q_t = Tensor3::from_vec(
|
||||
q.data[t * q_heads * dim..(t + 1) * q_heads * dim].to_vec(),
|
||||
1, q_heads, dim,
|
||||
).unwrap();
|
||||
let k_t = Tensor3::from_vec(
|
||||
new_k.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
|
||||
1, kv_heads, dim,
|
||||
).unwrap();
|
||||
let v_t = Tensor3::from_vec(
|
||||
new_v.data[t * kv_heads * dim..(t + 1) * kv_heads * dim].to_vec(),
|
||||
1, kv_heads, dim,
|
||||
).unwrap();
|
||||
cache_seq.try_append(&k_t, &v_t).unwrap();
|
||||
let out_t = attn.decode_step(&q_t, &cache_seq).unwrap();
|
||||
seq_out.data[t * q_heads * dim..(t + 1) * q_heads * dim]
|
||||
.copy_from_slice(&out_t.data);
|
||||
}
|
||||
|
||||
for (a, b) in batch_out.data.iter().zip(seq_out.data.iter()) {
|
||||
assert!((a - b).abs() < 1e-5, "decode_batch vs sequential: {a} vs {b}");
|
||||
}
|
||||
}
|
||||
|
||||
// --- ADR-190 GQA/MQA tests ---
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue