mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-24 13:54:31 +00:00
feat(sparse-mario): iter 2-3 — retrieval LM + ASCII generation
Wires `SubquadraticSparseAttention` as an inference-only retrieval
language model over the embedded SMB corpus:
K[i] = embed(corpus[i]) + 0.5·pos(i)
V[i] = embed(corpus[i+1]) ← next-token supervision baked into V
Q[i] = K[i]
out = forward(Q, K, V)
logits[v] = out[last] · embed(v)
next = sample(softmax(logits / T))
- Unit-variance embedding matrix (vocab × 64), deterministic xorshift32
seed; combined with the kernel's 1/sqrt(d) scale this gives matched
embed dot-product ≈ sqrt(d) above the noise floor.
- Light positional encoding (POS_SCALE=0.5) — enough for level-depth
awareness without drowning the token signal.
- Non-causal attention with window=256 + log-stride + landmarks so the
last query position can reach the whole 2.8K-token combined sequence
through sparse hops.
- End-to-end `cargo run --release --example sparse_mario` produces a
full 14-row × 50-col ASCII level slice in ~25s on a 9950X.
5 new tests (10 total, all passing): embedding determinism, finite
logits, generation determinism for a fixed seed, in-vocab outputs,
and a corpus-shape distribution check.
Known limitation: pure bigram retrieval saturates on the most-common
next-token (sky → sky → ... or X → X → ...). Iter 5 will add top-k
sampling, repetition penalty, and KvCache-backed `decode_step` for
incremental O(log T) per-token cost.
Iter-plan progress:
✓ 1. corpus + tokenizer scaffold (3f5d13edf)
✓ 2. retrieval LM wired ← here
✓ 3. autoregressive ASCII generation ← here (folded in)
4. dense vs sparse vs sparse+FastGRNN bench
5. fp16 KV cache + FastGRNN gate + top-k optimization
6. validation + final summary
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
3f5d13edfc
commit
2962c104e3
1 changed files with 316 additions and 1 deletions
|
|
@ -139,6 +139,222 @@ pub fn tile_distribution(tokens: &[u8]) -> HashMap<char, usize> {
|
|||
m
|
||||
}
|
||||
|
||||
// =================================================================
|
||||
// Iter 2 — sparse-attention retrieval LM
|
||||
//
|
||||
// The crate is inference-only (no autograd), so instead of training a
|
||||
// transformer we use the sparse attention kernel as an associative
|
||||
// memory:
|
||||
//
|
||||
// K[i] = embed(corpus[i]) + pos(i)
|
||||
// V[i] = embed(corpus[i+1]) ← "supervision" baked in
|
||||
// Q[i] = embed(prefix[i]) + pos(i)
|
||||
// out = SubquadraticSparseAttention.forward(Q, K, V)
|
||||
// logits = out[last] · embedW^T
|
||||
// next = sample(softmax(logits / T))
|
||||
//
|
||||
// V is the corpus shifted by one position, so attention output is a
|
||||
// soft-pointer to the empirical next-token distribution. Embeddings are
|
||||
// random-normal with a fixed seed; ties between embed(t)·embed(t) are
|
||||
// strongest, so attention naturally retrieves "what tile usually follows
|
||||
// this tile in the corpus" — without any training.
|
||||
// =================================================================
|
||||
|
||||
use ruvllm_sparse_attention::{
|
||||
AttentionBackend, SparseAttentionConfig, SubquadraticSparseAttention, Tensor3,
|
||||
};
|
||||
|
||||
const HEAD_DIM: usize = 64;
|
||||
const N_HEADS: usize = 1;
|
||||
pub const VOCAB_SIZE: usize = 15;
|
||||
|
||||
const _: () = assert!(VOCAB.len() == VOCAB_SIZE, "VOCAB_SIZE drift vs VOCAB[]");
|
||||
|
||||
/// xorshift32 — deterministic PRNG, no external dep, no_std-friendly.
|
||||
fn xorshift32(state: &mut u32) -> u32 {
|
||||
let mut x = *state;
|
||||
if x == 0 {
|
||||
x = 0x9E37_79B9;
|
||||
}
|
||||
x ^= x.wrapping_shl(13);
|
||||
x ^= x.wrapping_shr(17);
|
||||
x ^= x.wrapping_shl(5);
|
||||
*state = x;
|
||||
x
|
||||
}
|
||||
|
||||
fn next_uniform(state: &mut u32) -> f32 {
|
||||
(xorshift32(state) as f32) / (u32::MAX as f32 + 1.0)
|
||||
}
|
||||
|
||||
fn next_normal(state: &mut u32) -> f32 {
|
||||
// Box-Muller — return one of the two samples.
|
||||
loop {
|
||||
let u1 = next_uniform(state);
|
||||
let u2 = next_uniform(state);
|
||||
if u1 > 1e-9 {
|
||||
let r = (-2.0 * u1.ln()).sqrt();
|
||||
let theta = 2.0 * std::f32::consts::PI * u2;
|
||||
return r * theta.cos();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_embedding_matrix(seed: u32) -> Vec<f32> {
|
||||
// Unit-variance per dimension. Combined with the kernel's 1/sqrt(d)
|
||||
// softmax scale, embed(t)·embed(t)/sqrt(d) ≈ sqrt(d) for matched tokens
|
||||
// and ≈ N(0,1) for unmatched — enough separation that exp() picks out
|
||||
// matches strongly. /sqrt(d)-scaled embeddings drown in the noise floor.
|
||||
let mut state = seed.max(1);
|
||||
let mut w = vec![0.0f32; VOCAB_SIZE * HEAD_DIM];
|
||||
for v in w.iter_mut() {
|
||||
*v = next_normal(&mut state);
|
||||
}
|
||||
w
|
||||
}
|
||||
|
||||
fn token_embedding(t: u8, w: &[f32]) -> &[f32] {
|
||||
let i = t as usize * HEAD_DIM;
|
||||
&w[i..i + HEAD_DIM]
|
||||
}
|
||||
|
||||
/// Sinusoidal positional encoding into `out` (length must equal `dim`).
|
||||
/// Used at scale 0.5 so token signal still dominates softmax (matched
|
||||
/// embed·embed = d ≫ 0.5²·pos·pos = d/8) but local context still nudges
|
||||
/// the retrieval toward positions in similar level-row offsets.
|
||||
fn pos_encoding_into(i: usize, dim: usize, out: &mut [f32]) {
|
||||
for d in 0..dim {
|
||||
let half = d / 2;
|
||||
let theta = (i as f32) / 10000_f32.powf((2 * half) as f32 / dim as f32);
|
||||
out[d] = if d % 2 == 0 { theta.sin() } else { theta.cos() };
|
||||
}
|
||||
}
|
||||
|
||||
const POS_SCALE: f32 = 0.5;
|
||||
|
||||
pub struct MarioRetriever {
|
||||
pub corpus: Vec<u8>,
|
||||
w: Vec<f32>,
|
||||
cfg: SparseAttentionConfig,
|
||||
}
|
||||
|
||||
impl MarioRetriever {
|
||||
pub fn new(corpus: Vec<u8>, embedding_seed: u32) -> Self {
|
||||
// Non-causal so the last query position can reach the whole corpus
|
||||
// through window + log-stride + landmark hops. window=256 + log-stride
|
||||
// + landmarks gives ≈ 14% sparse coverage of a 2.8K-token combined
|
||||
// sequence, which is enough to recover bigram-grade statistics for
|
||||
// 15-token tile vocab.
|
||||
let cfg = SparseAttentionConfig {
|
||||
window: 256,
|
||||
block_size: 64,
|
||||
global_tokens: vec![0],
|
||||
causal: false,
|
||||
use_log_stride: true,
|
||||
use_landmarks: true,
|
||||
sort_candidates: false,
|
||||
};
|
||||
Self {
|
||||
corpus,
|
||||
w: make_embedding_matrix(embedding_seed),
|
||||
cfg,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a [seq, 1, HEAD_DIM] tensor where row i = embed(token[i]) +
|
||||
/// POS_SCALE · pos(i). Token match dominates softmax (matched dot-product
|
||||
/// = d after /sqrt(d) → exp(sqrt(d))) but positional similarity nudges
|
||||
/// retrieval toward corpus positions at comparable level-depth.
|
||||
/// If `shift_for_value`, encodes token[i+1] for the V tensor (the
|
||||
/// empirical "next-token" supervision baked into V).
|
||||
fn make_row_tensor(&self, tokens: &[u8], shift_for_value: bool) -> Tensor3 {
|
||||
let seq = tokens.len();
|
||||
let mut t = Tensor3::zeros(seq, N_HEADS, HEAD_DIM);
|
||||
let mut pos = vec![0.0f32; HEAD_DIM];
|
||||
for i in 0..seq {
|
||||
let tok = if shift_for_value {
|
||||
if i + 1 < seq {
|
||||
tokens[i + 1]
|
||||
} else {
|
||||
tokens[i]
|
||||
}
|
||||
} else {
|
||||
tokens[i]
|
||||
};
|
||||
let emb = token_embedding(tok, &self.w);
|
||||
pos_encoding_into(i, HEAD_DIM, &mut pos);
|
||||
let row = t.row_mut(i, 0);
|
||||
for d in 0..HEAD_DIM {
|
||||
row[d] = emb[d] + POS_SCALE * pos[d];
|
||||
}
|
||||
}
|
||||
t
|
||||
}
|
||||
|
||||
/// Compute logits over VOCAB_SIZE for the next token after `prefix`.
|
||||
pub fn next_token_logits(&self, prefix: &[u8]) -> [f32; VOCAB_SIZE] {
|
||||
let mut combined = self.corpus.clone();
|
||||
combined.extend_from_slice(prefix);
|
||||
let q = self.make_row_tensor(&combined, false);
|
||||
let v = self.make_row_tensor(&combined, true);
|
||||
let attn = SubquadraticSparseAttention::new(self.cfg.clone()).expect("config");
|
||||
let out = attn.forward(&q, &q, &v).expect("attention");
|
||||
let last = combined.len() - 1;
|
||||
let mut logits = [0.0f32; VOCAB_SIZE];
|
||||
for v_idx in 0..VOCAB_SIZE {
|
||||
let emb = token_embedding(v_idx as u8, &self.w);
|
||||
let mut dot = 0.0f32;
|
||||
for d in 0..HEAD_DIM {
|
||||
dot += out.get(last, 0, d) * emb[d];
|
||||
}
|
||||
logits[v_idx] = dot;
|
||||
}
|
||||
logits
|
||||
}
|
||||
|
||||
pub fn generate(&self, prefix: &[u8], n: usize, temperature: f32, sampler_seed: u32) -> Vec<u8> {
|
||||
let mut state = sampler_seed.max(1);
|
||||
let mut out = prefix.to_vec();
|
||||
for _ in 0..n {
|
||||
let logits = self.next_token_logits(&out);
|
||||
let next = sample_logits(&logits, temperature, &mut state);
|
||||
out.push(next);
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_logits(logits: &[f32; VOCAB_SIZE], temperature: f32, state: &mut u32) -> u8 {
|
||||
let temp = temperature.max(1e-3);
|
||||
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut probs = [0.0f32; VOCAB_SIZE];
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..VOCAB_SIZE {
|
||||
probs[i] = ((logits[i] - max_l) / temp).exp();
|
||||
sum += probs[i];
|
||||
}
|
||||
if sum <= 0.0 {
|
||||
return 0; // fallback to sky
|
||||
}
|
||||
for p in probs.iter_mut() {
|
||||
*p /= sum;
|
||||
}
|
||||
let r = next_uniform(state);
|
||||
let mut acc = 0.0f32;
|
||||
for i in 0..VOCAB_SIZE {
|
||||
acc += probs[i];
|
||||
if r < acc {
|
||||
return i as u8;
|
||||
}
|
||||
}
|
||||
(VOCAB_SIZE - 1) as u8
|
||||
}
|
||||
|
||||
/// Render a flat token stream as ASCII (newline tokens already encode rows).
|
||||
pub fn render_level(tokens: &[u8]) -> String {
|
||||
tokens.iter().map(|&t| decode_token(t)).collect()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let tokens = encode_corpus();
|
||||
let dist = tile_distribution(&tokens);
|
||||
|
|
@ -158,8 +374,46 @@ fn main() {
|
|||
let label = if *c == '\n' { "\\n".to_string() } else { c.to_string() };
|
||||
println!(" {:>3} {:>5} {:>5.1}%", label, n, pct);
|
||||
}
|
||||
|
||||
// ---------- iter 2: retrieval generation ----------
|
||||
println!();
|
||||
println!("(iter 1 scaffold — model + generation land in iter 2-3)");
|
||||
println!("== Sparse-attention retrieval generation ==");
|
||||
let retriever = MarioRetriever::new(tokens.clone(), 0x4D41_5249); // "MARI"
|
||||
let row_w = 50 + 1; // 50 cols + newline
|
||||
let n_rows = 14;
|
||||
let n_gen = row_w * n_rows;
|
||||
|
||||
// Seed with a level-shaped fragment so the bigram chain has somewhere to
|
||||
// go besides "sky after sky → sky forever". Mario start + ground row +
|
||||
// newline + sky gives the retrieval bigrams from several distinct contexts.
|
||||
let seed_chars: Vec<u8> = "M-XXXXX\n--------\n"
|
||||
.chars()
|
||||
.filter_map(encode_char)
|
||||
.collect();
|
||||
let t0 = std::time::Instant::now();
|
||||
let generated = retriever.generate(&seed_chars, n_gen, 1.2, 0xC0FF_EE42);
|
||||
let dt = t0.elapsed();
|
||||
let rendered = render_level(&generated);
|
||||
|
||||
println!("seed prefix : {:?}", seed_chars);
|
||||
println!("generated : {} tokens in {:.2?}", n_gen, dt);
|
||||
println!();
|
||||
println!("{}", rendered);
|
||||
println!();
|
||||
|
||||
let gen_only = &generated[seed_chars.len()..];
|
||||
let gen_dist = tile_distribution(gen_only);
|
||||
let pct = |c: char| -> f64 {
|
||||
*gen_dist.get(&c).unwrap_or(&0) as f64 / gen_only.len() as f64 * 100.0
|
||||
};
|
||||
println!(
|
||||
"tile mix in generated: sky {:.1}% ground {:.1}% brick {:.1}% enemy {:.1}% newline {:.1}%",
|
||||
pct('-'),
|
||||
pct('X'),
|
||||
pct('S'),
|
||||
pct('E'),
|
||||
pct('\n')
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -216,4 +470,65 @@ mod tests {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- iter 2 tests ----------
|
||||
|
||||
#[test]
|
||||
fn embedding_matrix_deterministic() {
|
||||
let a = make_embedding_matrix(0x1234);
|
||||
let b = make_embedding_matrix(0x1234);
|
||||
assert_eq!(a, b);
|
||||
assert_eq!(a.len(), VOCAB_SIZE * HEAD_DIM);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_token_logits_finite() {
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let prefix: Vec<u8> = "----X".chars().filter_map(encode_char).collect();
|
||||
let logits = r.next_token_logits(&prefix);
|
||||
for (i, &l) in logits.iter().enumerate() {
|
||||
assert!(l.is_finite(), "non-finite logit at vocab idx {}: {}", i, l);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_is_deterministic() {
|
||||
let r1 = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let r2 = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "--".chars().filter_map(encode_char).collect();
|
||||
let a = r1.generate(&p, 64, 0.8, 0xDEAD_BEEF);
|
||||
let b = r2.generate(&p, 64, 0.8, 0xDEAD_BEEF);
|
||||
assert_eq!(a, b, "same seed should give same output");
|
||||
assert_eq!(a.len(), p.len() + 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generated_tiles_are_in_vocab() {
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "--".chars().filter_map(encode_char).collect();
|
||||
let out = r.generate(&p, 200, 1.0, 0x4242);
|
||||
for &t in &out {
|
||||
assert!((t as usize) < VOCAB.len(), "out-of-vocab token {}", t);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generated_distribution_is_corpus_like() {
|
||||
// With temperature=0.5 the retrieval should bias toward the corpus
|
||||
// distribution: most tiles should be sky or ground.
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "----".chars().filter_map(encode_char).collect();
|
||||
let out = r.generate(&p, 300, 0.5, 0x9001);
|
||||
let gen = &out[p.len()..];
|
||||
let dist = tile_distribution(gen);
|
||||
let sky_or_ground = *dist.get(&'-').unwrap_or(&0)
|
||||
+ *dist.get(&'X').unwrap_or(&0)
|
||||
+ *dist.get(&'\n').unwrap_or(&0);
|
||||
let frac = sky_or_ground as f64 / gen.len() as f64;
|
||||
assert!(
|
||||
frac > 0.7,
|
||||
"expected >70% sky/ground/newline, got {:.1}%",
|
||||
frac * 100.0
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue