mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-24 22:15:18 +00:00
feat(sparse-mario): iter 5 — top-k + repetition penalty quality sweep
Adds `SamplingConfig` (temperature, top_k, repetition_penalty,
no_repeat_window) and rewires `MarioRetriever::generate` to take it.
A `SamplingConfig::quality()` constructor exposes the configuration
the iter-5 sweep landed on (top_k=5, rep_penalty=1.6, window=12).
Why this is the optimization step:
- Bare softmax over the retrieval logits saturates on the dominant
bigram (sky→sky, ground→ground), producing all-`-` or all-`X`
output even though the kernel is technically working correctly.
Top-k + repetition penalty break the steady state and let the
attention surface diverse Mario tiles (pipes, cannons, bricks,
coins, question blocks).
- Repetition penalty is HuggingFace-style: positive logits divided
by `pen`, negative multiplied — applied to every token in the
recent window so the demo doesn't bigram-lock.
- Top-k mask sets non-top-k logits to -inf before softmax so the
sampler only chooses among plausible candidates.
Why fp16 KV cache and FastGRNN aren't applied to this example:
- `KvCacheF16` is part of the autoregressive `decode_step` path
(causal). The retrieval workload uses non-causal `forward()`,
which is f32-only — fp16 would require a kernel patch beyond
iter-5 scope. Documented as a future direction.
- FastGRNN gate (`forward_gated_with_fastgrnn`) was benched in
iter 4: at our shape (heads=1, head_dim=64, seq≤2K) the gate's
scoring overhead dominates the savings. The gate pays back at
larger heads / longer sequences, where the iter-4 bench shows
no benefit at this scale.
- `parallel` feature is already on for both example and bench.
Three new tests (13 total, all passing):
- `quality_config_is_more_diverse` — quality config produces a
strictly larger unique-tile set than bare softmax, ≥5 tiles.
- `top_k_mask_restricts_sampling` — top_k=1 is greedy regardless
of sampler seed.
- `repetition_penalty_reduces_max_streak` — penalty shortens the
longest single-tile run.
Iter-plan progress:
✓ 1-3. corpus + retrieval LM + ASCII generation
✓ 4. dense vs sparse vs sparse+FastGRNN bench
✓ 5. quality sweep (top-k + repetition penalty) ← here
6. validation + final summary
Co-Authored-By: claude-flow <ruv@ruv.net>
This commit is contained in:
parent
03f8d08fd0
commit
5e1ce6722c
1 changed files with 221 additions and 14 deletions
|
|
@ -232,6 +232,51 @@ fn pos_encoding_into(i: usize, dim: usize, out: &mut [f32]) {
|
|||
|
||||
const POS_SCALE: f32 = 0.5;
|
||||
|
||||
/// Sampling controls for `MarioRetriever::generate`.
|
||||
///
|
||||
/// Bare softmax over the retrieval logits saturates on the dominant bigram
|
||||
/// (sky → sky, ground → ground), producing all-`-` or all-`X` levels. Top-k
|
||||
/// + repetition penalty + a small no-repeat window break the chain so the
|
||||
/// sparse attention kernel actually has room to surface diverse candidates.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SamplingConfig {
|
||||
/// Softmax temperature. >1 flattens, <1 sharpens. <=0 falls back to 1e-3.
|
||||
pub temperature: f32,
|
||||
/// Restrict sampling to the top-k logits. 0 disables (use full softmax).
|
||||
pub top_k: usize,
|
||||
/// Divide positive logits by this and multiply negative ones by it for
|
||||
/// every token that appears in the recent window. 1.0 disables.
|
||||
pub repetition_penalty: f32,
|
||||
/// Window size (in recent generated tokens) over which the repetition
|
||||
/// penalty applies. 0 disables.
|
||||
pub no_repeat_window: usize,
|
||||
}
|
||||
|
||||
impl Default for SamplingConfig {
|
||||
fn default() -> Self {
|
||||
// Plain temperature-only softmax (legacy iter-2 behaviour).
|
||||
Self {
|
||||
temperature: 1.0,
|
||||
top_k: 0,
|
||||
repetition_penalty: 1.0,
|
||||
no_repeat_window: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SamplingConfig {
|
||||
/// The configuration the iter-5 sweep selected for visual quality on
|
||||
/// the embedded SMB corpus. Trades exact bigram fidelity for variety.
|
||||
pub fn quality() -> Self {
|
||||
Self {
|
||||
temperature: 1.0,
|
||||
top_k: 5,
|
||||
repetition_penalty: 1.6,
|
||||
no_repeat_window: 12,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MarioRetriever {
|
||||
pub corpus: Vec<u8>,
|
||||
w: Vec<f32>,
|
||||
|
|
@ -312,26 +357,78 @@ impl MarioRetriever {
|
|||
logits
|
||||
}
|
||||
|
||||
pub fn generate(&self, prefix: &[u8], n: usize, temperature: f32, sampler_seed: u32) -> Vec<u8> {
|
||||
pub fn generate(
|
||||
&self,
|
||||
prefix: &[u8],
|
||||
n: usize,
|
||||
sampling: &SamplingConfig,
|
||||
sampler_seed: u32,
|
||||
) -> Vec<u8> {
|
||||
let mut state = sampler_seed.max(1);
|
||||
let mut out = prefix.to_vec();
|
||||
for _ in 0..n {
|
||||
let logits = self.next_token_logits(&out);
|
||||
let next = sample_logits(&logits, temperature, &mut state);
|
||||
let win = sampling.no_repeat_window.min(out.len());
|
||||
let recent = &out[out.len() - win..];
|
||||
let next = sample_logits(&logits, sampling, recent, &mut state);
|
||||
out.push(next);
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_logits(logits: &[f32; VOCAB_SIZE], temperature: f32, state: &mut u32) -> u8 {
|
||||
let temp = temperature.max(1e-3);
|
||||
let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
fn sample_logits(
|
||||
logits: &[f32; VOCAB_SIZE],
|
||||
cfg: &SamplingConfig,
|
||||
recent: &[u8],
|
||||
state: &mut u32,
|
||||
) -> u8 {
|
||||
let mut adjusted = *logits;
|
||||
|
||||
// Repetition penalty over the recent window — HuggingFace-style
|
||||
// (positive logits divided, negative multiplied).
|
||||
if cfg.repetition_penalty > 1.0 + f32::EPSILON && !recent.is_empty() {
|
||||
let pen = cfg.repetition_penalty;
|
||||
for &t in recent {
|
||||
let i = t as usize;
|
||||
if i < VOCAB_SIZE {
|
||||
adjusted[i] = if adjusted[i] > 0.0 {
|
||||
adjusted[i] / pen
|
||||
} else {
|
||||
adjusted[i] * pen
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Top-k mask — set the rest to -inf so the softmax ignores them.
|
||||
if cfg.top_k > 0 && cfg.top_k < VOCAB_SIZE {
|
||||
let mut idx: [usize; VOCAB_SIZE] = [0; VOCAB_SIZE];
|
||||
for i in 0..VOCAB_SIZE {
|
||||
idx[i] = i;
|
||||
}
|
||||
idx.sort_unstable_by(|&a, &b| {
|
||||
adjusted[b]
|
||||
.partial_cmp(&adjusted[a])
|
||||
.unwrap_or(core::cmp::Ordering::Equal)
|
||||
});
|
||||
let kth = adjusted[idx[cfg.top_k - 1]];
|
||||
for v in 0..VOCAB_SIZE {
|
||||
if adjusted[v] < kth {
|
||||
adjusted[v] = f32::NEG_INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let temp = cfg.temperature.max(1e-3);
|
||||
let max_l = adjusted.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
|
||||
let mut probs = [0.0f32; VOCAB_SIZE];
|
||||
let mut sum = 0.0f32;
|
||||
for i in 0..VOCAB_SIZE {
|
||||
probs[i] = ((logits[i] - max_l) / temp).exp();
|
||||
sum += probs[i];
|
||||
if adjusted[i].is_finite() {
|
||||
probs[i] = ((adjusted[i] - max_l) / temp).exp();
|
||||
sum += probs[i];
|
||||
}
|
||||
}
|
||||
if sum <= 0.0 {
|
||||
return 0; // fallback to sky
|
||||
|
|
@ -390,8 +487,9 @@ fn main() {
|
|||
.chars()
|
||||
.filter_map(encode_char)
|
||||
.collect();
|
||||
let sampling = SamplingConfig::quality();
|
||||
let t0 = std::time::Instant::now();
|
||||
let generated = retriever.generate(&seed_chars, n_gen, 1.2, 0xC0FF_EE42);
|
||||
let generated = retriever.generate(&seed_chars, n_gen, &sampling, 0xC0FF_EE42);
|
||||
let dt = t0.elapsed();
|
||||
let rendered = render_level(&generated);
|
||||
|
||||
|
|
@ -496,8 +594,12 @@ mod tests {
|
|||
let r1 = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let r2 = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "--".chars().filter_map(encode_char).collect();
|
||||
let a = r1.generate(&p, 64, 0.8, 0xDEAD_BEEF);
|
||||
let b = r2.generate(&p, 64, 0.8, 0xDEAD_BEEF);
|
||||
let cfg = SamplingConfig {
|
||||
temperature: 0.8,
|
||||
..SamplingConfig::default()
|
||||
};
|
||||
let a = r1.generate(&p, 64, &cfg, 0xDEAD_BEEF);
|
||||
let b = r2.generate(&p, 64, &cfg, 0xDEAD_BEEF);
|
||||
assert_eq!(a, b, "same seed should give same output");
|
||||
assert_eq!(a.len(), p.len() + 64);
|
||||
}
|
||||
|
|
@ -506,7 +608,7 @@ mod tests {
|
|||
fn generated_tiles_are_in_vocab() {
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "--".chars().filter_map(encode_char).collect();
|
||||
let out = r.generate(&p, 200, 1.0, 0x4242);
|
||||
let out = r.generate(&p, 200, &SamplingConfig::default(), 0x4242);
|
||||
for &t in &out {
|
||||
assert!((t as usize) < VOCAB.len(), "out-of-vocab token {}", t);
|
||||
}
|
||||
|
|
@ -514,11 +616,15 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn generated_distribution_is_corpus_like() {
|
||||
// With temperature=0.5 the retrieval should bias toward the corpus
|
||||
// distribution: most tiles should be sky or ground.
|
||||
// With low temperature and no top-k, retrieval biases hard toward the
|
||||
// dominant bigram — most tiles end up sky / ground / newline.
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "----".chars().filter_map(encode_char).collect();
|
||||
let out = r.generate(&p, 300, 0.5, 0x9001);
|
||||
let cfg = SamplingConfig {
|
||||
temperature: 0.5,
|
||||
..SamplingConfig::default()
|
||||
};
|
||||
let out = r.generate(&p, 300, &cfg, 0x9001);
|
||||
let gen = &out[p.len()..];
|
||||
let dist = tile_distribution(gen);
|
||||
let sky_or_ground = *dist.get(&'-').unwrap_or(&0)
|
||||
|
|
@ -531,4 +637,105 @@ mod tests {
|
|||
frac * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// ---------- iter 5 tests ----------
|
||||
|
||||
#[test]
|
||||
fn quality_config_is_more_diverse() {
|
||||
// The quality sampling config (top-k + repetition penalty) should
|
||||
// produce a strictly higher unique-tile count over a long generation
|
||||
// than bare softmax — the whole point of iter 5.
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "M-XXXXX\n".chars().filter_map(encode_char).collect();
|
||||
|
||||
let bare = r.generate(&p, 400, &SamplingConfig::default(), 0xBEEF);
|
||||
let qual = r.generate(&p, 400, &SamplingConfig::quality(), 0xBEEF);
|
||||
|
||||
let unique = |toks: &[u8]| -> usize {
|
||||
let mut s = std::collections::HashSet::new();
|
||||
for &t in toks {
|
||||
s.insert(t);
|
||||
}
|
||||
s.len()
|
||||
};
|
||||
let bare_unique = unique(&bare[p.len()..]);
|
||||
let qual_unique = unique(&qual[p.len()..]);
|
||||
assert!(
|
||||
qual_unique > bare_unique,
|
||||
"quality config should produce more distinct tiles than bare softmax \
|
||||
(bare={}, quality={})",
|
||||
bare_unique,
|
||||
qual_unique
|
||||
);
|
||||
assert!(
|
||||
qual_unique >= 5,
|
||||
"quality config should hit at least 5 distinct tiles, got {}",
|
||||
qual_unique
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn top_k_mask_restricts_sampling() {
|
||||
// With top_k=1 the sampler is greedy and deterministic across seeds.
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "X-".chars().filter_map(encode_char).collect();
|
||||
let cfg = SamplingConfig {
|
||||
temperature: 1.0,
|
||||
top_k: 1,
|
||||
..SamplingConfig::default()
|
||||
};
|
||||
let a = r.generate(&p, 32, &cfg, 0x1111);
|
||||
let b = r.generate(&p, 32, &cfg, 0x2222);
|
||||
assert_eq!(a, b, "top_k=1 should be greedy regardless of sampler seed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn repetition_penalty_reduces_max_streak() {
|
||||
// Repetition penalty should shorten the longest run of any single tile.
|
||||
let r = MarioRetriever::new(encode_corpus(), 0xABCD);
|
||||
let p: Vec<u8> = "M-XXXXX\n".chars().filter_map(encode_char).collect();
|
||||
let no_pen = SamplingConfig {
|
||||
temperature: 1.0,
|
||||
top_k: 4,
|
||||
repetition_penalty: 1.0,
|
||||
no_repeat_window: 0,
|
||||
};
|
||||
let with_pen = SamplingConfig {
|
||||
temperature: 1.0,
|
||||
top_k: 4,
|
||||
repetition_penalty: 1.8,
|
||||
no_repeat_window: 12,
|
||||
};
|
||||
|
||||
let max_streak = |toks: &[u8]| -> usize {
|
||||
let mut best = 0;
|
||||
let mut cur = 0;
|
||||
let mut prev: Option<u8> = None;
|
||||
for &t in toks {
|
||||
if Some(t) == prev {
|
||||
cur += 1;
|
||||
} else {
|
||||
cur = 1;
|
||||
}
|
||||
if cur > best {
|
||||
best = cur;
|
||||
}
|
||||
prev = Some(t);
|
||||
}
|
||||
best
|
||||
};
|
||||
|
||||
let a = r.generate(&p, 400, &no_pen, 0x3333);
|
||||
let b = r.generate(&p, 400, &with_pen, 0x3333);
|
||||
|
||||
let s_no = max_streak(&a[p.len()..]);
|
||||
let s_with = max_streak(&b[p.len()..]);
|
||||
assert!(
|
||||
s_with < s_no,
|
||||
"repetition penalty should shorten the longest streak \
|
||||
(no penalty: {}, with penalty: {})",
|
||||
s_no,
|
||||
s_with
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue