feat(sparse-mario): iter 5 — top-k + repetition penalty quality sweep

Adds `SamplingConfig` (temperature, top_k, repetition_penalty,
no_repeat_window) and rewires `MarioRetriever::generate` to take it.
A `SamplingConfig::quality()` constructor exposes the configuration
the iter-5 sweep landed on (top_k=5, rep_penalty=1.6, window=12).

Why this is the optimization step:

- Bare softmax over the retrieval logits saturates on the dominant
  bigram (sky→sky, ground→ground), producing all-`-` or all-`X`
  output even though the kernel is technically working correctly.
  Top-k + repetition penalty break the steady state and let the
  attention surface diverse Mario tiles (pipes, cannons, bricks,
  coins, question blocks).
- Repetition penalty is HuggingFace-style: positive logits divided
  by `pen`, negative multiplied — applied to every token in the
  recent window so the demo doesn't bigram-lock.
- Top-k mask sets non-top-k logits to -inf before softmax so the
  sampler only chooses among plausible candidates.

Why fp16 KV cache and FastGRNN aren't applied to this example:

- `KvCacheF16` is part of the autoregressive `decode_step` path
  (causal). The retrieval workload uses non-causal `forward()`,
  which is f32-only — fp16 would require a kernel patch beyond
  iter-5 scope. Documented as a future direction.
- FastGRNN gate (`forward_gated_with_fastgrnn`) was benched in
  iter 4: at our shape (heads=1, head_dim=64, seq≤2K) the gate's
  scoring overhead dominates the savings. The gate pays back at
  larger heads / longer sequences, where the iter-4 bench shows
  no benefit at this scale.
- `parallel` feature is already on for both example and bench.

Three new tests (13 total, all passing):
- `quality_config_is_more_diverse` — quality config produces a
  strictly larger unique-tile set than bare softmax, ≥5 tiles.
- `top_k_mask_restricts_sampling` — top_k=1 is greedy regardless
  of sampler seed.
- `repetition_penalty_reduces_max_streak` — penalty shortens the
  longest single-tile run.

Iter-plan progress:
  ✓ 1-3. corpus + retrieval LM + ASCII generation
  ✓ 4. dense vs sparse vs sparse+FastGRNN bench
  ✓ 5. quality sweep (top-k + repetition penalty)   ← here
    6. validation + final summary

Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
ruvnet 2026-05-08 12:58:10 -04:00
parent 03f8d08fd0
commit 5e1ce6722c

View file

@ -232,6 +232,51 @@ fn pos_encoding_into(i: usize, dim: usize, out: &mut [f32]) {
const POS_SCALE: f32 = 0.5;
/// Sampling controls for `MarioRetriever::generate`.
///
/// Bare softmax over the retrieval logits saturates on the dominant bigram
/// (sky → sky, ground → ground), producing all-`-` or all-`X` levels. Top-k
/// + repetition penalty + a small no-repeat window break the chain so the
/// sparse attention kernel actually has room to surface diverse candidates.
#[derive(Clone, Debug)]
pub struct SamplingConfig {
/// Softmax temperature. >1 flattens, <1 sharpens. <=0 falls back to 1e-3.
pub temperature: f32,
/// Restrict sampling to the top-k logits. 0 disables (use full softmax).
pub top_k: usize,
/// Divide positive logits by this and multiply negative ones by it for
/// every token that appears in the recent window. 1.0 disables.
pub repetition_penalty: f32,
/// Window size (in recent generated tokens) over which the repetition
/// penalty applies. 0 disables.
pub no_repeat_window: usize,
}
impl Default for SamplingConfig {
fn default() -> Self {
// Plain temperature-only softmax (legacy iter-2 behaviour).
Self {
temperature: 1.0,
top_k: 0,
repetition_penalty: 1.0,
no_repeat_window: 0,
}
}
}
impl SamplingConfig {
/// The configuration the iter-5 sweep selected for visual quality on
/// the embedded SMB corpus. Trades exact bigram fidelity for variety.
pub fn quality() -> Self {
Self {
temperature: 1.0,
top_k: 5,
repetition_penalty: 1.6,
no_repeat_window: 12,
}
}
}
pub struct MarioRetriever {
pub corpus: Vec<u8>,
w: Vec<f32>,
@ -312,26 +357,78 @@ impl MarioRetriever {
logits
}
pub fn generate(&self, prefix: &[u8], n: usize, temperature: f32, sampler_seed: u32) -> Vec<u8> {
pub fn generate(
&self,
prefix: &[u8],
n: usize,
sampling: &SamplingConfig,
sampler_seed: u32,
) -> Vec<u8> {
let mut state = sampler_seed.max(1);
let mut out = prefix.to_vec();
for _ in 0..n {
let logits = self.next_token_logits(&out);
let next = sample_logits(&logits, temperature, &mut state);
let win = sampling.no_repeat_window.min(out.len());
let recent = &out[out.len() - win..];
let next = sample_logits(&logits, sampling, recent, &mut state);
out.push(next);
}
out
}
}
fn sample_logits(logits: &[f32; VOCAB_SIZE], temperature: f32, state: &mut u32) -> u8 {
let temp = temperature.max(1e-3);
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
fn sample_logits(
logits: &[f32; VOCAB_SIZE],
cfg: &SamplingConfig,
recent: &[u8],
state: &mut u32,
) -> u8 {
let mut adjusted = *logits;
// Repetition penalty over the recent window — HuggingFace-style
// (positive logits divided, negative multiplied).
if cfg.repetition_penalty > 1.0 + f32::EPSILON && !recent.is_empty() {
let pen = cfg.repetition_penalty;
for &t in recent {
let i = t as usize;
if i < VOCAB_SIZE {
adjusted[i] = if adjusted[i] > 0.0 {
adjusted[i] / pen
} else {
adjusted[i] * pen
};
}
}
}
// Top-k mask — set the rest to -inf so the softmax ignores them.
if cfg.top_k > 0 && cfg.top_k < VOCAB_SIZE {
let mut idx: [usize; VOCAB_SIZE] = [0; VOCAB_SIZE];
for i in 0..VOCAB_SIZE {
idx[i] = i;
}
idx.sort_unstable_by(|&a, &b| {
adjusted[b]
.partial_cmp(&adjusted[a])
.unwrap_or(core::cmp::Ordering::Equal)
});
let kth = adjusted[idx[cfg.top_k - 1]];
for v in 0..VOCAB_SIZE {
if adjusted[v] < kth {
adjusted[v] = f32::NEG_INFINITY;
}
}
}
let temp = cfg.temperature.max(1e-3);
let max_l = adjusted.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs = [0.0f32; VOCAB_SIZE];
let mut sum = 0.0f32;
for i in 0..VOCAB_SIZE {
probs[i] = ((logits[i] - max_l) / temp).exp();
sum += probs[i];
if adjusted[i].is_finite() {
probs[i] = ((adjusted[i] - max_l) / temp).exp();
sum += probs[i];
}
}
if sum <= 0.0 {
return 0; // fallback to sky
@ -390,8 +487,9 @@ fn main() {
.chars()
.filter_map(encode_char)
.collect();
let sampling = SamplingConfig::quality();
let t0 = std::time::Instant::now();
let generated = retriever.generate(&seed_chars, n_gen, 1.2, 0xC0FF_EE42);
let generated = retriever.generate(&seed_chars, n_gen, &sampling, 0xC0FF_EE42);
let dt = t0.elapsed();
let rendered = render_level(&generated);
@ -496,8 +594,12 @@ mod tests {
let r1 = MarioRetriever::new(encode_corpus(), 0xABCD);
let r2 = MarioRetriever::new(encode_corpus(), 0xABCD);
let p: Vec<u8> = "--".chars().filter_map(encode_char).collect();
let a = r1.generate(&p, 64, 0.8, 0xDEAD_BEEF);
let b = r2.generate(&p, 64, 0.8, 0xDEAD_BEEF);
let cfg = SamplingConfig {
temperature: 0.8,
..SamplingConfig::default()
};
let a = r1.generate(&p, 64, &cfg, 0xDEAD_BEEF);
let b = r2.generate(&p, 64, &cfg, 0xDEAD_BEEF);
assert_eq!(a, b, "same seed should give same output");
assert_eq!(a.len(), p.len() + 64);
}
@ -506,7 +608,7 @@ mod tests {
fn generated_tiles_are_in_vocab() {
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
let p: Vec<u8> = "--".chars().filter_map(encode_char).collect();
let out = r.generate(&p, 200, 1.0, 0x4242);
let out = r.generate(&p, 200, &SamplingConfig::default(), 0x4242);
for &t in &out {
assert!((t as usize) < VOCAB.len(), "out-of-vocab token {}", t);
}
@ -514,11 +616,15 @@ mod tests {
#[test]
fn generated_distribution_is_corpus_like() {
// With temperature=0.5 the retrieval should bias toward the corpus
// distribution: most tiles should be sky or ground.
// With low temperature and no top-k, retrieval biases hard toward the
// dominant bigram — most tiles end up sky / ground / newline.
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
let p: Vec<u8> = "----".chars().filter_map(encode_char).collect();
let out = r.generate(&p, 300, 0.5, 0x9001);
let cfg = SamplingConfig {
temperature: 0.5,
..SamplingConfig::default()
};
let out = r.generate(&p, 300, &cfg, 0x9001);
let gen = &out[p.len()..];
let dist = tile_distribution(gen);
let sky_or_ground = *dist.get(&'-').unwrap_or(&0)
@ -531,4 +637,105 @@ mod tests {
frac * 100.0
);
}
// ---------- iter 5 tests ----------
#[test]
fn quality_config_is_more_diverse() {
// The quality sampling config (top-k + repetition penalty) should
// produce a strictly higher unique-tile count over a long generation
// than bare softmax — the whole point of iter 5.
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
let p: Vec<u8> = "M-XXXXX\n".chars().filter_map(encode_char).collect();
let bare = r.generate(&p, 400, &SamplingConfig::default(), 0xBEEF);
let qual = r.generate(&p, 400, &SamplingConfig::quality(), 0xBEEF);
let unique = |toks: &[u8]| -> usize {
let mut s = std::collections::HashSet::new();
for &t in toks {
s.insert(t);
}
s.len()
};
let bare_unique = unique(&bare[p.len()..]);
let qual_unique = unique(&qual[p.len()..]);
assert!(
qual_unique > bare_unique,
"quality config should produce more distinct tiles than bare softmax \
(bare={}, quality={})",
bare_unique,
qual_unique
);
assert!(
qual_unique >= 5,
"quality config should hit at least 5 distinct tiles, got {}",
qual_unique
);
}
#[test]
fn top_k_mask_restricts_sampling() {
// With top_k=1 the sampler is greedy and deterministic across seeds.
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
let p: Vec<u8> = "X-".chars().filter_map(encode_char).collect();
let cfg = SamplingConfig {
temperature: 1.0,
top_k: 1,
..SamplingConfig::default()
};
let a = r.generate(&p, 32, &cfg, 0x1111);
let b = r.generate(&p, 32, &cfg, 0x2222);
assert_eq!(a, b, "top_k=1 should be greedy regardless of sampler seed");
}
#[test]
fn repetition_penalty_reduces_max_streak() {
// Repetition penalty should shorten the longest run of any single tile.
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
let p: Vec<u8> = "M-XXXXX\n".chars().filter_map(encode_char).collect();
let no_pen = SamplingConfig {
temperature: 1.0,
top_k: 4,
repetition_penalty: 1.0,
no_repeat_window: 0,
};
let with_pen = SamplingConfig {
temperature: 1.0,
top_k: 4,
repetition_penalty: 1.8,
no_repeat_window: 12,
};
let max_streak = |toks: &[u8]| -> usize {
let mut best = 0;
let mut cur = 0;
let mut prev: Option<u8> = None;
for &t in toks {
if Some(t) == prev {
cur += 1;
} else {
cur = 1;
}
if cur > best {
best = cur;
}
prev = Some(t);
}
best
};
let a = r.generate(&p, 400, &no_pen, 0x3333);
let b = r.generate(&p, 400, &with_pen, 0x3333);
let s_no = max_streak(&a[p.len()..]);
let s_with = max_streak(&b[p.len()..]);
assert!(
s_with < s_no,
"repetition penalty should shorten the longest streak \
(no penalty: {}, with penalty: {})",
s_no,
s_with
);
}
}