diff --git a/crates/ruvllm_sparse_attention/examples/sparse_mario.rs b/crates/ruvllm_sparse_attention/examples/sparse_mario.rs index ead07b3d..f39dad26 100644 --- a/crates/ruvllm_sparse_attention/examples/sparse_mario.rs +++ b/crates/ruvllm_sparse_attention/examples/sparse_mario.rs @@ -232,6 +232,51 @@ fn pos_encoding_into(i: usize, dim: usize, out: &mut [f32]) { const POS_SCALE: f32 = 0.5; +/// Sampling controls for `MarioRetriever::generate`. +/// +/// Bare softmax over the retrieval logits saturates on the dominant bigram +/// (sky → sky, ground → ground), producing all-`-` or all-`X` levels. Top-k +/// + repetition penalty + a small no-repeat window break the chain so the +/// sparse attention kernel actually has room to surface diverse candidates. +#[derive(Clone, Debug)] +pub struct SamplingConfig { + /// Softmax temperature. >1 flattens, <1 sharpens. <=0 falls back to 1e-3. + pub temperature: f32, + /// Restrict sampling to the top-k logits. 0 disables (use full softmax). + pub top_k: usize, + /// Divide positive logits by this and multiply negative ones by it for + /// every token that appears in the recent window. 1.0 disables. + pub repetition_penalty: f32, + /// Window size (in recent generated tokens) over which the repetition + /// penalty applies. 0 disables. + pub no_repeat_window: usize, +} + +impl Default for SamplingConfig { + fn default() -> Self { + // Plain temperature-only softmax (legacy iter-2 behaviour). + Self { + temperature: 1.0, + top_k: 0, + repetition_penalty: 1.0, + no_repeat_window: 0, + } + } +} + +impl SamplingConfig { + /// The configuration the iter-5 sweep selected for visual quality on + /// the embedded SMB corpus. Trades exact bigram fidelity for variety. + pub fn quality() -> Self { + Self { + temperature: 1.0, + top_k: 5, + repetition_penalty: 1.6, + no_repeat_window: 12, + } + } +} + pub struct MarioRetriever { pub corpus: Vec, w: Vec, @@ -312,26 +357,78 @@ impl MarioRetriever { logits } - pub fn generate(&self, prefix: &[u8], n: usize, temperature: f32, sampler_seed: u32) -> Vec { + pub fn generate( + &self, + prefix: &[u8], + n: usize, + sampling: &SamplingConfig, + sampler_seed: u32, + ) -> Vec { let mut state = sampler_seed.max(1); let mut out = prefix.to_vec(); for _ in 0..n { let logits = self.next_token_logits(&out); - let next = sample_logits(&logits, temperature, &mut state); + let win = sampling.no_repeat_window.min(out.len()); + let recent = &out[out.len() - win..]; + let next = sample_logits(&logits, sampling, recent, &mut state); out.push(next); } out } } -fn sample_logits(logits: &[f32; VOCAB_SIZE], temperature: f32, state: &mut u32) -> u8 { - let temp = temperature.max(1e-3); - let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max); +fn sample_logits( + logits: &[f32; VOCAB_SIZE], + cfg: &SamplingConfig, + recent: &[u8], + state: &mut u32, +) -> u8 { + let mut adjusted = *logits; + + // Repetition penalty over the recent window — HuggingFace-style + // (positive logits divided, negative multiplied). + if cfg.repetition_penalty > 1.0 + f32::EPSILON && !recent.is_empty() { + let pen = cfg.repetition_penalty; + for &t in recent { + let i = t as usize; + if i < VOCAB_SIZE { + adjusted[i] = if adjusted[i] > 0.0 { + adjusted[i] / pen + } else { + adjusted[i] * pen + }; + } + } + } + + // Top-k mask — set the rest to -inf so the softmax ignores them. + if cfg.top_k > 0 && cfg.top_k < VOCAB_SIZE { + let mut idx: [usize; VOCAB_SIZE] = [0; VOCAB_SIZE]; + for i in 0..VOCAB_SIZE { + idx[i] = i; + } + idx.sort_unstable_by(|&a, &b| { + adjusted[b] + .partial_cmp(&adjusted[a]) + .unwrap_or(core::cmp::Ordering::Equal) + }); + let kth = adjusted[idx[cfg.top_k - 1]]; + for v in 0..VOCAB_SIZE { + if adjusted[v] < kth { + adjusted[v] = f32::NEG_INFINITY; + } + } + } + + let temp = cfg.temperature.max(1e-3); + let max_l = adjusted.iter().cloned().fold(f32::NEG_INFINITY, f32::max); let mut probs = [0.0f32; VOCAB_SIZE]; let mut sum = 0.0f32; for i in 0..VOCAB_SIZE { - probs[i] = ((logits[i] - max_l) / temp).exp(); - sum += probs[i]; + if adjusted[i].is_finite() { + probs[i] = ((adjusted[i] - max_l) / temp).exp(); + sum += probs[i]; + } } if sum <= 0.0 { return 0; // fallback to sky @@ -390,8 +487,9 @@ fn main() { .chars() .filter_map(encode_char) .collect(); + let sampling = SamplingConfig::quality(); let t0 = std::time::Instant::now(); - let generated = retriever.generate(&seed_chars, n_gen, 1.2, 0xC0FF_EE42); + let generated = retriever.generate(&seed_chars, n_gen, &sampling, 0xC0FF_EE42); let dt = t0.elapsed(); let rendered = render_level(&generated); @@ -496,8 +594,12 @@ mod tests { let r1 = MarioRetriever::new(encode_corpus(), 0xABCD); let r2 = MarioRetriever::new(encode_corpus(), 0xABCD); let p: Vec = "--".chars().filter_map(encode_char).collect(); - let a = r1.generate(&p, 64, 0.8, 0xDEAD_BEEF); - let b = r2.generate(&p, 64, 0.8, 0xDEAD_BEEF); + let cfg = SamplingConfig { + temperature: 0.8, + ..SamplingConfig::default() + }; + let a = r1.generate(&p, 64, &cfg, 0xDEAD_BEEF); + let b = r2.generate(&p, 64, &cfg, 0xDEAD_BEEF); assert_eq!(a, b, "same seed should give same output"); assert_eq!(a.len(), p.len() + 64); } @@ -506,7 +608,7 @@ mod tests { fn generated_tiles_are_in_vocab() { let r = MarioRetriever::new(encode_corpus(), 0xABCD); let p: Vec = "--".chars().filter_map(encode_char).collect(); - let out = r.generate(&p, 200, 1.0, 0x4242); + let out = r.generate(&p, 200, &SamplingConfig::default(), 0x4242); for &t in &out { assert!((t as usize) < VOCAB.len(), "out-of-vocab token {}", t); } @@ -514,11 +616,15 @@ mod tests { #[test] fn generated_distribution_is_corpus_like() { - // With temperature=0.5 the retrieval should bias toward the corpus - // distribution: most tiles should be sky or ground. + // With low temperature and no top-k, retrieval biases hard toward the + // dominant bigram — most tiles end up sky / ground / newline. let r = MarioRetriever::new(encode_corpus(), 0xABCD); let p: Vec = "----".chars().filter_map(encode_char).collect(); - let out = r.generate(&p, 300, 0.5, 0x9001); + let cfg = SamplingConfig { + temperature: 0.5, + ..SamplingConfig::default() + }; + let out = r.generate(&p, 300, &cfg, 0x9001); let gen = &out[p.len()..]; let dist = tile_distribution(gen); let sky_or_ground = *dist.get(&'-').unwrap_or(&0) @@ -531,4 +637,105 @@ mod tests { frac * 100.0 ); } + + // ---------- iter 5 tests ---------- + + #[test] + fn quality_config_is_more_diverse() { + // The quality sampling config (top-k + repetition penalty) should + // produce a strictly higher unique-tile count over a long generation + // than bare softmax — the whole point of iter 5. + let r = MarioRetriever::new(encode_corpus(), 0xABCD); + let p: Vec = "M-XXXXX\n".chars().filter_map(encode_char).collect(); + + let bare = r.generate(&p, 400, &SamplingConfig::default(), 0xBEEF); + let qual = r.generate(&p, 400, &SamplingConfig::quality(), 0xBEEF); + + let unique = |toks: &[u8]| -> usize { + let mut s = std::collections::HashSet::new(); + for &t in toks { + s.insert(t); + } + s.len() + }; + let bare_unique = unique(&bare[p.len()..]); + let qual_unique = unique(&qual[p.len()..]); + assert!( + qual_unique > bare_unique, + "quality config should produce more distinct tiles than bare softmax \ + (bare={}, quality={})", + bare_unique, + qual_unique + ); + assert!( + qual_unique >= 5, + "quality config should hit at least 5 distinct tiles, got {}", + qual_unique + ); + } + + #[test] + fn top_k_mask_restricts_sampling() { + // With top_k=1 the sampler is greedy and deterministic across seeds. + let r = MarioRetriever::new(encode_corpus(), 0xABCD); + let p: Vec = "X-".chars().filter_map(encode_char).collect(); + let cfg = SamplingConfig { + temperature: 1.0, + top_k: 1, + ..SamplingConfig::default() + }; + let a = r.generate(&p, 32, &cfg, 0x1111); + let b = r.generate(&p, 32, &cfg, 0x2222); + assert_eq!(a, b, "top_k=1 should be greedy regardless of sampler seed"); + } + + #[test] + fn repetition_penalty_reduces_max_streak() { + // Repetition penalty should shorten the longest run of any single tile. + let r = MarioRetriever::new(encode_corpus(), 0xABCD); + let p: Vec = "M-XXXXX\n".chars().filter_map(encode_char).collect(); + let no_pen = SamplingConfig { + temperature: 1.0, + top_k: 4, + repetition_penalty: 1.0, + no_repeat_window: 0, + }; + let with_pen = SamplingConfig { + temperature: 1.0, + top_k: 4, + repetition_penalty: 1.8, + no_repeat_window: 12, + }; + + let max_streak = |toks: &[u8]| -> usize { + let mut best = 0; + let mut cur = 0; + let mut prev: Option = None; + for &t in toks { + if Some(t) == prev { + cur += 1; + } else { + cur = 1; + } + if cur > best { + best = cur; + } + prev = Some(t); + } + best + }; + + let a = r.generate(&p, 400, &no_pen, 0x3333); + let b = r.generate(&p, 400, &with_pen, 0x3333); + + let s_no = max_streak(&a[p.len()..]); + let s_with = max_streak(&b[p.len()..]); + assert!( + s_with < s_no, + "repetition penalty should shorten the longest streak \ + (no penalty: {}, with penalty: {})", + s_no, + s_with + ); + } }