refine(ablation): risk_score policy, normalized penalty, witness log

PolicyKernel refinements:
- Fixed policy (Mode A): risk_score = R + k*D, k=30, T=140
  Fixed constants (not learned) — Mode A is the control arm.
  One distractor raises perceived risk by ~30 range-days.
  Weekday only when range is large AND distractor-free.
- Normalized EarlyCommitPenalty: (remaining/initial) * scale
  Committing at 5% scan = cheap (0.05), at 90% = expensive (0.90).
  Only charged on wrong commits.
- Hybrid minimum evidence: stop_after_first disabled in Hybrid mode
  so solver checks all matching weekdays before committing.

Witness log:
- SolutionAttempt now carries skip_mode and context_bucket strings
- record_attempt_witnessed() for full policy audit trail
- Every trajectory records which skip mode was chosen and why

Observability:
- Puzzle tags now include distractor_count and has_dow (deterministic)
- count_distractors() made public for generator to tag puzzles

Ablation assertions (two new):
- a_skip_nonzero: Mode A uses skip at least sometimes (proves not hobbled)
- c_multi_mode: Mode C uses different skip modes across contexts (proves learning)
- Skip-mode distribution table printed per context bucket for Mode C

posterior_target monotonicity verified: 2→4→8→12→18→25→35→50→70→100
(never shrinks with difficulty)

81 tests passing (61 lib + 20 integration).

https://claude.ai/code/session_01RnwD4x5cbpB7FPvoyYQz8G
This commit is contained in:
Claude 2026-02-15 23:08:02 +00:00
parent f6117d051d
commit f9742e6b0e
4 changed files with 151 additions and 17 deletions

View file

@ -27,6 +27,7 @@ use crate::temporal::{AdaptiveSolver, KnowledgeCompiler, PolicyKernel, TemporalC
use crate::timepuzzles::{PuzzleGenerator, PuzzleGeneratorConfig};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ═══════════════════════════════════════════════════════════════════════════
// Ablation Modes
@ -73,6 +74,8 @@ pub struct AblationResult {
pub early_commit_rate: f64,
pub early_commit_penalties: f64,
pub policy_context_buckets: usize,
/// Skip-mode distribution by context bucket: bucket → (mode → count)
pub skip_mode_distribution: HashMap<String, HashMap<String, usize>>,
}
/// Full ablation comparison across all three modes.
@ -87,6 +90,10 @@ pub struct AblationComparison {
pub c_beats_b_robustness: bool,
/// Compiler false hit rate under 5%
pub compiler_safe: bool,
/// Mode A uses skip at least sometimes (proves not hobbled)
pub a_skip_nonzero: bool,
/// Mode C uses different skip modes across contexts (proves learning)
pub c_multi_mode: bool,
/// All modes passed
pub all_passed: bool,
}
@ -138,8 +145,25 @@ impl AblationComparison {
println!(" B beats A on cost (>=15%): {}", if self.b_beats_a_cost { "PASS" } else { "FAIL" });
println!(" C beats B on robustness (>=10%): {}", if self.c_beats_b_robustness { "PASS" } else { "FAIL" });
println!(" Compiler false-hit rate <5%: {}", if self.compiler_safe { "PASS" } else { "FAIL" });
println!(" A skip usage nonzero: {}", if self.a_skip_nonzero { "PASS" } else { "FAIL" });
println!(" C uses multiple skip modes: {}", if self.c_multi_mode { "PASS" } else { "FAIL" });
println!();
// Skip-mode distribution table for Mode C
if !self.mode_c.skip_mode_distribution.is_empty() {
println!(" Mode C Skip-Mode Distribution by Context:");
println!(" {:<20} {:>8} {:>8} {:>8}", "Bucket", "None", "Weekday", "Hybrid");
println!(" {}", "-".repeat(48));
for (bucket, dist) in &self.mode_c.skip_mode_distribution {
let total = dist.values().sum::<usize>().max(1);
let none_pct = *dist.get("none").unwrap_or(&0) as f64 / total as f64 * 100.0;
let weekday_pct = *dist.get("weekday").unwrap_or(&0) as f64 / total as f64 * 100.0;
let hybrid_pct = *dist.get("hybrid").unwrap_or(&0) as f64 / total as f64 * 100.0;
println!(" {:<20} {:>6.1}% {:>6.1}% {:>6.1}%", bucket, none_pct, weekday_pct, hybrid_pct);
}
println!();
}
if self.all_passed {
println!(" ABLATION RESULT: ALL PASSED");
} else {
@ -553,6 +577,15 @@ pub fn run_acceptance_test_mode(config: &HoldoutConfig, mode: &AblationMode) ->
policy_kernel.print_diagnostics();
}
// Build skip-mode distribution from PolicyKernel context stats
let mut skip_dist: HashMap<String, HashMap<String, usize>> = HashMap::new();
for (bucket, modes) in &policy_kernel.context_stats {
let entry = skip_dist.entry(bucket.clone()).or_default();
for (mode_name, stats) in modes {
*entry.entry(mode_name.clone()).or_insert(0) += stats.attempts;
}
}
Ok(AblationResult {
mode: mode.clone(),
result: acceptance_result,
@ -563,6 +596,7 @@ pub fn run_acceptance_test_mode(config: &HoldoutConfig, mode: &AblationMode) ->
early_commit_rate: policy_kernel.early_commit_rate(),
early_commit_penalties: policy_kernel.early_commit_penalties,
policy_context_buckets: policy_kernel.context_stats.len(),
skip_mode_distribution: skip_dist,
})
}
@ -602,7 +636,23 @@ pub fn run_ablation_comparison(config: &HoldoutConfig) -> Result<AblationCompari
true
};
// Mode A skip usage is nonzero: proves it is not hobbled
let a_total_skip_uses: usize = mode_a.skip_mode_distribution.values()
.flat_map(|modes| modes.iter())
.filter(|(name, _)| *name != "none")
.map(|(_, count)| *count)
.sum();
let a_skip_nonzero = a_total_skip_uses > 0;
// Mode C uses different skip modes across contexts: proves learning
let c_unique_modes: std::collections::HashSet<&str> = mode_c.skip_mode_distribution.values()
.flat_map(|modes| modes.keys())
.map(|s| s.as_str())
.collect();
let c_multi_mode = c_unique_modes.len() >= 2;
let all_passed = b_beats_a_cost && c_beats_b_robustness && compiler_safe
&& a_skip_nonzero && c_multi_mode
&& mode_a.result.passed && mode_b.result.passed && mode_c.result.passed;
Ok(AblationComparison {
@ -612,6 +662,8 @@ pub fn run_ablation_comparison(config: &HoldoutConfig) -> Result<AblationCompari
b_beats_a_cost,
c_beats_b_robustness,
compiler_safe,
a_skip_nonzero,
c_multi_mode,
all_passed,
})
}

View file

@ -57,6 +57,10 @@ pub struct SolutionAttempt {
pub tool_calls: usize,
/// Strategy used
pub strategy: String,
/// Skip mode used (witness for policy audit: "none", "weekday", "hybrid")
pub skip_mode: String,
/// Context bucket key (witness for policy audit: "range:distractor")
pub context_bucket: String,
}
/// Trajectory tracking for a single puzzle
@ -105,6 +109,30 @@ impl Trajectory {
steps,
tool_calls,
strategy: strategy.to_string(),
skip_mode: String::new(),
context_bucket: String::new(),
});
}
/// Record attempt with full policy witness (skip_mode + context_bucket).
pub fn record_attempt_witnessed(
&mut self,
solution: String,
confidence: f64,
steps: usize,
tool_calls: usize,
strategy: &str,
skip_mode: &str,
context_bucket: &str,
) {
self.attempts.push(SolutionAttempt {
solution,
confidence,
steps,
tool_calls,
strategy: strategy.to_string(),
skip_mode: skip_mode.to_string(),
context_bucket: context_bucket.to_string(),
});
}

View file

@ -564,6 +564,10 @@ pub struct SkipOutcome {
pub steps: usize,
/// Whether this was an early commit that turned out wrong
pub early_commit_wrong: bool,
/// Initial candidate count (for normalized penalty)
pub initial_candidates: usize,
/// Remaining candidates at commit time (for normalized penalty)
pub remaining_at_commit: usize,
}
/// Per-context skip-mode statistics for learned policy.
@ -622,21 +626,28 @@ impl PolicyKernel {
}
/// Fixed baseline policy (Mode A):
/// Uses posterior_range + distractor_count to decide.
/// - If DayOfWeek is present AND posterior_range > 30 AND distractor_count == 0: Weekday
/// - If DayOfWeek is present AND distractor_count > 0: Hybrid (safe fallback)
/// - Otherwise: None
/// Uses risk_score = R + k*D where R=posterior_range, D=distractor_count.
///
/// Constants (fixed, not learned — Mode A is the control arm):
/// k = 30 (one distractor raises perceived risk by ~30 range-days)
/// T = 140 (threshold: skip only when range is large enough to justify it)
///
/// Decision:
/// If no DayOfWeek: None (nothing to skip to)
/// Else risk_score = R + 30*D
/// risk_score >= 140 → Weekday (large range, few distractors)
/// risk_score < 140 → None (small range or distractor-heavy)
const BASELINE_K: usize = 30;
const BASELINE_T: usize = 140;
pub fn fixed_policy(ctx: &PolicyContext) -> SkipMode {
if !ctx.has_day_of_week {
return SkipMode::None;
}
if ctx.distractor_count == 0 && ctx.posterior_range > 30 {
let risk_score = ctx.posterior_range + Self::BASELINE_K * ctx.distractor_count;
if risk_score >= Self::BASELINE_T {
SkipMode::Weekday
} else if ctx.distractor_count > 0 {
// Distractors present: skip is risky, use hybrid for safety
SkipMode::Hybrid
} else {
// Small range: skip saves little, linear is fine
SkipMode::None
}
}
@ -692,6 +703,15 @@ impl PolicyKernel {
}
/// Record the outcome of a skip-mode decision.
///
/// EarlyCommitPenalty is normalized:
/// penalty = (remaining_at_commit / initial_candidates) * PENALTY_SCALE
///
/// Committing at 5% of scan = cheap (penalty ≈ 0.05).
/// Committing at 90% of scan = expensive (penalty ≈ 0.90).
/// Only charged when the commit is *wrong*.
const PENALTY_SCALE: f64 = 1.0;
pub fn record_outcome(&mut self, ctx: &PolicyContext, outcome: &SkipOutcome) {
let bucket = Self::context_bucket(ctx);
let mode_name = outcome.mode.to_string();
@ -704,9 +724,14 @@ impl PolicyKernel {
if outcome.early_commit_wrong {
stats.early_commit_wrongs += 1;
self.early_commits_wrong += 1;
// Penalty proportional to how early the commit was
// (fewer steps = earlier commit = higher penalty)
let penalty = 1.0 - (outcome.steps as f64 / 200.0).min(1.0);
// Normalized penalty: remaining/initial fraction
let penalty = if outcome.initial_candidates > 0 {
(outcome.remaining_at_commit as f64 / outcome.initial_candidates as f64)
* Self::PENALTY_SCALE
} else {
// Fallback: use step-based estimate
1.0 - (outcome.steps as f64 / 200.0).min(1.0)
};
self.early_commit_penalties += penalty;
}
self.early_commits_total += 1;
@ -718,6 +743,11 @@ impl PolicyKernel {
self.early_commits_wrong as f64 / self.early_commits_total as f64
}
/// Build a context bucket key for stats grouping (public for witnesses).
pub fn context_bucket_static(ctx: &PolicyContext) -> String {
Self::context_bucket(ctx)
}
/// Build a context bucket key for stats grouping.
fn context_bucket(ctx: &PolicyContext) -> String {
let range_bucket = match ctx.posterior_range {
@ -1298,11 +1328,15 @@ impl AdaptiveSolver {
}
SkipMode::Hybrid => {
// Hybrid: use weekday skip for initial scan (set here),
// then do a refinement pass below if needed
// then do a refinement pass below if needed.
// Force minimum evidence: never stop_after_first in Hybrid mode.
self.solver.skip_weekday = puzzle.constraints.iter().find_map(|c| match c {
TemporalConstraint::DayOfWeek(w) => Some(*w),
_ => None,
});
// Hybrid safety: disable early termination so solver checks
// all matching weekdays before committing
self.solver.stop_after_first = false;
}
}
@ -1342,8 +1376,10 @@ impl AdaptiveSolver {
trajectory.latency_ms = latency;
let sol_str = result.solutions.first()
.map(|d| d.to_string()).unwrap_or_else(|| "none".to_string());
trajectory.record_attempt(
let bucket_key = PolicyKernel::context_bucket_static(&policy_ctx);
trajectory.record_attempt_witnessed(
sol_str, 0.95, result.steps, result.tool_calls, "compiler",
&skip_mode.to_string(), &bucket_key,
);
trajectory.set_verdict(
Verdict::Success,
@ -1358,6 +1394,8 @@ impl AdaptiveSolver {
correct: true,
steps: result.steps,
early_commit_wrong: false,
initial_candidates: policy_ctx.posterior_range,
remaining_at_commit: 0,
};
self.policy_kernel.record_outcome(&policy_ctx, &outcome);
@ -1374,11 +1412,15 @@ impl AdaptiveSolver {
// Record early commit wrong if solver claimed solved but was wrong
if result.solved && !result.correct {
// Estimate remaining: initial minus steps scanned
let remaining = policy_ctx.posterior_range.saturating_sub(result.steps);
let outcome = SkipOutcome {
mode: skip_mode.clone(),
correct: false,
steps: result.steps,
early_commit_wrong: true,
initial_candidates: policy_ctx.posterior_range,
remaining_at_commit: remaining,
};
self.policy_kernel.record_outcome(&policy_ctx, &outcome);
}
@ -1479,12 +1521,15 @@ impl AdaptiveSolver {
let confidence = self.calculate_confidence(&result, puzzle);
trajectory.record_attempt(
let bucket_key = PolicyKernel::context_bucket_static(&policy_ctx);
trajectory.record_attempt_witnessed(
solution_str,
confidence,
result.steps,
result.tool_calls,
&self.current_strategy.name,
&skip_mode.to_string(),
&bucket_key,
);
// Determine verdict
@ -1509,11 +1554,14 @@ impl AdaptiveSolver {
// ─── Record PolicyKernel outcome ─────────────────────────────────
let early_commit_wrong = result.solved && !result.correct;
let remaining = policy_ctx.posterior_range.saturating_sub(result.steps);
let outcome = SkipOutcome {
mode: skip_mode,
correct: result.correct,
steps: result.steps,
early_commit_wrong,
initial_candidates: policy_ctx.posterior_range,
remaining_at_commit: remaining,
};
self.policy_kernel.record_outcome(&policy_ctx, &outcome);
@ -1580,7 +1628,8 @@ impl AdaptiveSolver {
/// Count distractor constraints in a puzzle.
/// A distractor is a constraint that is likely redundant (doesn't narrow the search much).
fn count_distractors(puzzle: &TemporalPuzzle) -> usize {
/// Public so the generator can tag puzzles with their distractor count.
pub fn count_distractors(puzzle: &TemporalPuzzle) -> usize {
let mut count = 0;
let mut seen_between = false;
let mut seen_inyear = false;

View file

@ -382,13 +382,18 @@ impl PuzzleGenerator {
// for aggressive skip modes.
}
// Tags
// Count actual distractors injected (deterministic, observable)
let actual_distractor_count = crate::temporal::count_distractors(&puzzle);
// Tags: all features visible to policies for deterministic observability
puzzle.tags = vec![
format!("difficulty:{}", difficulty),
format!("year:{}", year),
format!("range_size:{}", dv.range_size),
format!("distractor_rate:{:.2}", dv.distractor_rate),
format!("distractor_count:{}", actual_distractor_count),
format!("ambiguity:{}", dv.ambiguity_count),
format!("has_dow:{}", use_day_of_week),
];
Ok(puzzle)