From 5e1ce6722cfbf9f33b04a99c2f24a327701c690f Mon Sep 17 00:00:00 2001 From: ruvnet Date: Fri, 8 May 2026 12:58:10 -0400 Subject: [PATCH] =?UTF-8?q?feat(sparse-mario):=20iter=205=20=E2=80=94=20to?= =?UTF-8?q?p-k=20+=20repetition=20penalty=20quality=20sweep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds `SamplingConfig` (temperature, top_k, repetition_penalty, no_repeat_window) and rewires `MarioRetriever::generate` to take it. A `SamplingConfig::quality()` constructor exposes the configuration the iter-5 sweep landed on (top_k=5, rep_penalty=1.6, window=12). Why this is the optimization step: - Bare softmax over the retrieval logits saturates on the dominant bigram (sky→sky, ground→ground), producing all-`-` or all-`X` output even though the kernel is technically working correctly. Top-k + repetition penalty break the steady state and let the attention surface diverse Mario tiles (pipes, cannons, bricks, coins, question blocks). - Repetition penalty is HuggingFace-style: positive logits divided by `pen`, negative multiplied — applied to every token in the recent window so the demo doesn't bigram-lock. - Top-k mask sets non-top-k logits to -inf before softmax so the sampler only chooses among plausible candidates. Why fp16 KV cache and FastGRNN aren't applied to this example: - `KvCacheF16` is part of the autoregressive `decode_step` path (causal). The retrieval workload uses non-causal `forward()`, which is f32-only — fp16 would require a kernel patch beyond iter-5 scope. Documented as a future direction. - FastGRNN gate (`forward_gated_with_fastgrnn`) was benched in iter 4: at our shape (heads=1, head_dim=64, seq≤2K) the gate's scoring overhead dominates the savings. The gate pays back at larger heads / longer sequences, where the iter-4 bench shows no benefit at this scale. - `parallel` feature is already on for both example and bench. Three new tests (13 total, all passing): - `quality_config_is_more_diverse` — quality config produces a strictly larger unique-tile set than bare softmax, ≥5 tiles. - `top_k_mask_restricts_sampling` — top_k=1 is greedy regardless of sampler seed. - `repetition_penalty_reduces_max_streak` — penalty shortens the longest single-tile run. Iter-plan progress: ✓ 1-3. corpus + retrieval LM + ASCII generation ✓ 4. dense vs sparse vs sparse+FastGRNN bench ✓ 5. quality sweep (top-k + repetition penalty) ← here 6. validation + final summary Co-Authored-By: claude-flow --- .../examples/sparse_mario.rs | 235 ++++++++++++++++-- 1 file changed, 221 insertions(+), 14 deletions(-) diff --git a/crates/ruvllm_sparse_attention/examples/sparse_mario.rs b/crates/ruvllm_sparse_attention/examples/sparse_mario.rs index ead07b3d..f39dad26 100644 --- a/crates/ruvllm_sparse_attention/examples/sparse_mario.rs +++ b/crates/ruvllm_sparse_attention/examples/sparse_mario.rs @@ -232,6 +232,51 @@ fn pos_encoding_into(i: usize, dim: usize, out: &mut [f32]) { const POS_SCALE: f32 = 0.5; +/// Sampling controls for `MarioRetriever::generate`. +/// +/// Bare softmax over the retrieval logits saturates on the dominant bigram +/// (sky → sky, ground → ground), producing all-`-` or all-`X` levels. Top-k +/// + repetition penalty + a small no-repeat window break the chain so the +/// sparse attention kernel actually has room to surface diverse candidates. +#[derive(Clone, Debug)] +pub struct SamplingConfig { + /// Softmax temperature. >1 flattens, <1 sharpens. <=0 falls back to 1e-3. + pub temperature: f32, + /// Restrict sampling to the top-k logits. 0 disables (use full softmax). + pub top_k: usize, + /// Divide positive logits by this and multiply negative ones by it for + /// every token that appears in the recent window. 1.0 disables. + pub repetition_penalty: f32, + /// Window size (in recent generated tokens) over which the repetition + /// penalty applies. 0 disables. + pub no_repeat_window: usize, +} + +impl Default for SamplingConfig { + fn default() -> Self { + // Plain temperature-only softmax (legacy iter-2 behaviour). + Self { + temperature: 1.0, + top_k: 0, + repetition_penalty: 1.0, + no_repeat_window: 0, + } + } +} + +impl SamplingConfig { + /// The configuration the iter-5 sweep selected for visual quality on + /// the embedded SMB corpus. Trades exact bigram fidelity for variety. + pub fn quality() -> Self { + Self { + temperature: 1.0, + top_k: 5, + repetition_penalty: 1.6, + no_repeat_window: 12, + } + } +} + pub struct MarioRetriever { pub corpus: Vec, w: Vec, @@ -312,26 +357,78 @@ impl MarioRetriever { logits } - pub fn generate(&self, prefix: &[u8], n: usize, temperature: f32, sampler_seed: u32) -> Vec { + pub fn generate( + &self, + prefix: &[u8], + n: usize, + sampling: &SamplingConfig, + sampler_seed: u32, + ) -> Vec { let mut state = sampler_seed.max(1); let mut out = prefix.to_vec(); for _ in 0..n { let logits = self.next_token_logits(&out); - let next = sample_logits(&logits, temperature, &mut state); + let win = sampling.no_repeat_window.min(out.len()); + let recent = &out[out.len() - win..]; + let next = sample_logits(&logits, sampling, recent, &mut state); out.push(next); } out } } -fn sample_logits(logits: &[f32; VOCAB_SIZE], temperature: f32, state: &mut u32) -> u8 { - let temp = temperature.max(1e-3); - let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); +fn sample_logits( + logits: &[f32; VOCAB_SIZE], + cfg: &SamplingConfig, + recent: &[u8], + state: &mut u32, +) -> u8 { + let mut adjusted = *logits; + + // Repetition penalty over the recent window — HuggingFace-style + // (positive logits divided, negative multiplied). + if cfg.repetition_penalty > 1.0 + f32::EPSILON && !recent.is_empty() { + let pen = cfg.repetition_penalty; + for &t in recent { + let i = t as usize; + if i < VOCAB_SIZE { + adjusted[i] = if adjusted[i] > 0.0 { + adjusted[i] / pen + } else { + adjusted[i] * pen + }; + } + } + } + + // Top-k mask — set the rest to -inf so the softmax ignores them. + if cfg.top_k > 0 && cfg.top_k < VOCAB_SIZE { + let mut idx: [usize; VOCAB_SIZE] = [0; VOCAB_SIZE]; + for i in 0..VOCAB_SIZE { + idx[i] = i; + } + idx.sort_unstable_by(|&a, &b| { + adjusted[b] + .partial_cmp(&adjusted[a]) + .unwrap_or(core::cmp::Ordering::Equal) + }); + let kth = adjusted[idx[cfg.top_k - 1]]; + for v in 0..VOCAB_SIZE { + if adjusted[v] < kth { + adjusted[v] = f32::NEG_INFINITY; + } + } + } + + let temp = cfg.temperature.max(1e-3); + let max_l = adjusted.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let mut probs = [0.0f32; VOCAB_SIZE]; let mut sum = 0.0f32; for i in 0..VOCAB_SIZE { - probs[i] = ((logits[i] - max_l) / temp).exp(); - sum += probs[i]; + if adjusted[i].is_finite() { + probs[i] = ((adjusted[i] - max_l) / temp).exp(); + sum += probs[i]; + } } if sum <= 0.0 { return 0; // fallback to sky @@ -390,8 +487,9 @@ fn main() { .chars() .filter_map(encode_char) .collect(); + let sampling = SamplingConfig::quality(); let t0 = std::time::Instant::now(); - let generated = retriever.generate(&seed_chars, n_gen, 1.2, 0xC0FF_EE42); + let generated = retriever.generate(&seed_chars, n_gen, &sampling, 0xC0FF_EE42); let dt = t0.elapsed(); let rendered = render_level(&generated); @@ -496,8 +594,12 @@ mod tests { let r1 = MarioRetriever::new(encode_corpus(), 0xABCD); let r2 = MarioRetriever::new(encode_corpus(), 0xABCD); let p: Vec = "--".chars().filter_map(encode_char).collect(); - let a = r1.generate(&p, 64, 0.8, 0xDEAD_BEEF); - let b = r2.generate(&p, 64, 0.8, 0xDEAD_BEEF); + let cfg = SamplingConfig { + temperature: 0.8, + ..SamplingConfig::default() + }; + let a = r1.generate(&p, 64, &cfg, 0xDEAD_BEEF); + let b = r2.generate(&p, 64, &cfg, 0xDEAD_BEEF); assert_eq!(a, b, "same seed should give same output"); assert_eq!(a.len(), p.len() + 64); } @@ -506,7 +608,7 @@ mod tests { fn generated_tiles_are_in_vocab() { let r = MarioRetriever::new(encode_corpus(), 0xABCD); let p: Vec = "--".chars().filter_map(encode_char).collect(); - let out = r.generate(&p, 200, 1.0, 0x4242); + let out = r.generate(&p, 200, &SamplingConfig::default(), 0x4242); for &t in &out { assert!((t as usize) < VOCAB.len(), "out-of-vocab token {}", t); } @@ -514,11 +616,15 @@ mod tests { #[test] fn generated_distribution_is_corpus_like() { - // With temperature=0.5 the retrieval should bias toward the corpus - // distribution: most tiles should be sky or ground. + // With low temperature and no top-k, retrieval biases hard toward the + // dominant bigram — most tiles end up sky / ground / newline. let r = MarioRetriever::new(encode_corpus(), 0xABCD); let p: Vec = "----".chars().filter_map(encode_char).collect(); - let out = r.generate(&p, 300, 0.5, 0x9001); + let cfg = SamplingConfig { + temperature: 0.5, + ..SamplingConfig::default() + }; + let out = r.generate(&p, 300, &cfg, 0x9001); let gen = &out[p.len()..]; let dist = tile_distribution(gen); let sky_or_ground = *dist.get(&'-').unwrap_or(&0) @@ -531,4 +637,105 @@ mod tests { frac * 100.0 ); } + + // ---------- iter 5 tests ---------- + + #[test] + fn quality_config_is_more_diverse() { + // The quality sampling config (top-k + repetition penalty) should + // produce a strictly higher unique-tile count over a long generation + // than bare softmax — the whole point of iter 5. + let r = MarioRetriever::new(encode_corpus(), 0xABCD); + let p: Vec = "M-XXXXX\n".chars().filter_map(encode_char).collect(); + + let bare = r.generate(&p, 400, &SamplingConfig::default(), 0xBEEF); + let qual = r.generate(&p, 400, &SamplingConfig::quality(), 0xBEEF); + + let unique = |toks: &[u8]| -> usize { + let mut s = std::collections::HashSet::new(); + for &t in toks { + s.insert(t); + } + s.len() + }; + let bare_unique = unique(&bare[p.len()..]); + let qual_unique = unique(&qual[p.len()..]); + assert!( + qual_unique > bare_unique, + "quality config should produce more distinct tiles than bare softmax \ + (bare={}, quality={})", + bare_unique, + qual_unique + ); + assert!( + qual_unique >= 5, + "quality config should hit at least 5 distinct tiles, got {}", + qual_unique + ); + } + + #[test] + fn top_k_mask_restricts_sampling() { + // With top_k=1 the sampler is greedy and deterministic across seeds. + let r = MarioRetriever::new(encode_corpus(), 0xABCD); + let p: Vec = "X-".chars().filter_map(encode_char).collect(); + let cfg = SamplingConfig { + temperature: 1.0, + top_k: 1, + ..SamplingConfig::default() + }; + let a = r.generate(&p, 32, &cfg, 0x1111); + let b = r.generate(&p, 32, &cfg, 0x2222); + assert_eq!(a, b, "top_k=1 should be greedy regardless of sampler seed"); + } + + #[test] + fn repetition_penalty_reduces_max_streak() { + // Repetition penalty should shorten the longest run of any single tile. + let r = MarioRetriever::new(encode_corpus(), 0xABCD); + let p: Vec = "M-XXXXX\n".chars().filter_map(encode_char).collect(); + let no_pen = SamplingConfig { + temperature: 1.0, + top_k: 4, + repetition_penalty: 1.0, + no_repeat_window: 0, + }; + let with_pen = SamplingConfig { + temperature: 1.0, + top_k: 4, + repetition_penalty: 1.8, + no_repeat_window: 12, + }; + + let max_streak = |toks: &[u8]| -> usize { + let mut best = 0; + let mut cur = 0; + let mut prev: Option = None; + for &t in toks { + if Some(t) == prev { + cur += 1; + } else { + cur = 1; + } + if cur > best { + best = cur; + } + prev = Some(t); + } + best + }; + + let a = r.generate(&p, 400, &no_pen, 0x3333); + let b = r.generate(&p, 400, &with_pen, 0x3333); + + let s_no = max_streak(&a[p.len()..]); + let s_with = max_streak(&b[p.len()..]); + assert!( + s_with < s_no, + "repetition penalty should shorten the longest streak \ + (no penalty: {}, with penalty: {})", + s_no, + s_with + ); + } }