refine: improve exomoon graph cut pipeline detection quality

Key improvements to the exomoon detection pipeline:

PSPL Fitting:
- Extract pspl_chi2_at() helper for reuse
- Add fine refinement pass (±1 unit, 0.2 step) around coarse grid best
- Better parameter recovery for all geometric parameters

Lambda Computation:
- Three complementary statistics: excess chi2, runs test coherence, Gaussian bump fit
- Excess chi2 normalized against event's global reduced chi2 (not theoretical)
- Differential lambda: compare each window to its tau-neighbors, producing
  z-scores that are ~0 for uniform fit quality and positive for localized anomalies
- This key change prevents the cut from labeling entire peak regions as moon

Detection Criteria:
- J-score from lambda_sum with per-window penalty (replacing BIC formalism)
- Fragility bootstrap for support stability
- Support fraction bounded (2-50%) for localization

Embeddings:
- Fixed residual computation to use fitted F_s * A(u) + F_b model
- Injection bank labels based on positive local evidence (not just geometry)
- Bank size increased to 60 events for better prior calibration

Current metrics: P=25%, R=25%, F1=0.25 on 30 synthetic events.
Detection quality is limited by the perturbative Chang-Refsdal
approximation — production requires a full polynomial lens solver,
as noted in the user's formulation.

https://claude.ai/code/session_01UWE22wnsZRSHKhT4h4Axby
This commit is contained in:
Claude 2026-03-14 23:04:13 +00:00 committed by Reuven
parent 8ae369312b
commit dbebc19680

View file

@ -322,6 +322,46 @@ struct PSPLFit {
n_obs: usize,
}
/// Evaluate PSPL chi2 at given geometric params, solving F_s/F_b linearly.
fn pspl_chi2_at(lc: &LightCurve, t0: f64, u0: f64, t_e: f64, sigma_sys: f64) -> Option<PSPLFit> {
let mut sum_a = 0.0;
let mut sum_a2 = 0.0;
let mut sum_f = 0.0;
let mut sum_af = 0.0;
let mut sum_1 = 0.0;
for obs in &lc.observations {
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
let w = 1.0 / sig2;
let u = impact_parameter(obs.time, t0, t_e, u0);
let a = pspl_magnification(u);
sum_a += w * a;
sum_a2 += w * a * a;
sum_f += w * obs.flux;
sum_af += w * a * obs.flux;
sum_1 += w;
}
let det = sum_a2 * sum_1 - sum_a * sum_a;
if det.abs() < 1e-15 { return None; }
let f_s = (sum_af * sum_1 - sum_a * sum_f) / det;
let f_b = (sum_a2 * sum_f - sum_a * sum_af) / det;
if f_s < 0.01 { return None; }
let mut chi2 = 0.0;
for obs in &lc.observations {
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
let u = impact_parameter(obs.time, t0, t_e, u0);
let model = f_s * pspl_magnification(u) + f_b;
let diff = obs.flux - model;
chi2 += diff * diff / sig2;
}
Some(PSPLFit { t0, u0, t_e, f_s, f_b, chi2, n_obs: lc.observations.len() })
}
fn fit_pspl(lc: &LightCurve) -> PSPLFit {
let sigma_sys = lc.survey.sigma_sys();
let mut best = PSPLFit {
@ -329,60 +369,34 @@ fn fit_pspl(lc: &LightCurve) -> PSPLFit {
chi2: f64::MAX, n_obs: lc.observations.len(),
};
// Coarse grid search over geometric params; solve F_s, F_b linearly
// Phase 1: Coarse grid search
for t_e_i in (2..=70).step_by(2) {
let t_e = t_e_i as f64;
for t0_i in (40..=80).step_by(2) {
let t0 = t0_i as f64;
for u0_i in 1..=12 {
let u0 = u0_i as f64 * 0.05;
// Linear regression: flux = F_s * A(u) + F_b
// Solve for (F_s, F_b) analytically given geometric params
let mut sum_a = 0.0;
let mut sum_a2 = 0.0;
let mut sum_f = 0.0;
let mut sum_af = 0.0;
let mut sum_1 = 0.0;
let mut sum_w = 0.0;
for obs in &lc.observations {
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
let w = 1.0 / sig2;
let u = impact_parameter(obs.time, t0, t_e, u0);
let a = pspl_magnification(u);
sum_a += w * a;
sum_a2 += w * a * a;
sum_f += w * obs.flux;
sum_af += w * a * obs.flux;
sum_1 += w;
sum_w += 1.0;
if let Some(fit) = pspl_chi2_at(lc, t0, u0, t_e, sigma_sys) {
if fit.chi2 < best.chi2 { best = fit; }
}
}
}
}
let det = sum_a2 * sum_1 - sum_a * sum_a;
if det.abs() < 1e-15 { continue; }
let f_s = (sum_af * sum_1 - sum_a * sum_f) / det;
let f_b = (sum_a2 * sum_f - sum_a * sum_af) / det;
if f_s < 0.01 { continue; } // source flux must be positive
let mut chi2 = 0.0;
for obs in &lc.observations {
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
let u = impact_parameter(obs.time, t0, t_e, u0);
let model = f_s * pspl_magnification(u) + f_b;
let diff = obs.flux - model;
chi2 += diff * diff / sig2;
}
if chi2 < best.chi2 {
best.chi2 = chi2;
best.t0 = t0;
best.u0 = u0;
best.t_e = t_e;
best.f_s = f_s;
best.f_b = f_b;
// Phase 2: Fine refinement around coarse best
let dt_e = 1.0;
let dt0 = 1.0;
let du0 = 0.02;
for dt_e_i in -5..=5 {
let t_e = best.t_e + dt_e_i as f64 * dt_e * 0.2;
if t_e < 1.0 { continue; }
for dt0_i in -5..=5 {
let t0 = best.t0 + dt0_i as f64 * dt0 * 0.2;
for du0_i in -5..=5 {
let u0 = best.u0 + du0_i as f64 * du0 * 0.2;
if u0 < 0.01 { continue; }
if let Some(fit) = pspl_chi2_at(lc, t0, u0, t_e, sigma_sys) {
if fit.chi2 < best.chi2 { best = fit; }
}
}
}
@ -415,13 +429,15 @@ struct Window {
fn build_windows(lc: &LightCurve, fit: &PSPLFit, window_half_width_tau: f64, stride_tau: f64) -> Vec<Window> {
let sigma_sys = lc.survey.sigma_sys();
// Global reduced chi2: baseline for "normal" PSPL fit quality
let global_rchi2 = (fit.chi2 / fit.n_obs as f64).max(1.0);
let mut windows = Vec::new();
// Sweep in normalized time tau from -3 to +3
let mut tau = -3.0;
let mut win_id = 0;
while tau <= 3.0 {
let t_center = fit.t0 + tau * fit.t_e;
let _t_center = fit.t0 + tau * fit.t_e;
// Collect observations in this window
let obs_indices: Vec<usize> = lc.observations.iter().enumerate()
@ -437,43 +453,24 @@ fn build_windows(lc: &LightCurve, fit: &PSPLFit, window_half_width_tau: f64, str
continue;
}
// Compute local log-likelihood ratio: moon model vs null (PSPL).
// Compute local log-likelihood ratio using three complementary statistics:
//
// The moon model is penalized with an Occam factor for extra parameters.
// Only windows where residuals significantly exceed noise expectations
// get positive lambda.
// 1. Excess chi2: does PSPL fit poorly in this window?
// Under null, chi2 ~ N with std ~ sqrt(2N). Excess = (chi2 - N) / sqrt(2N).
//
// lambda_i = (chi2_null - chi2_moon) / 2 - penalty
// 2. Coherent structure: do residuals show correlated pattern?
// Runs test: fewer sign-change runs than expected → coherent signal.
//
// where penalty ~ (p_moon - p_null)/2 * log(N_window) accounts for
// the binary lens's extra degrees of freedom.
// 3. Gaussian bump fit: can a localized perturbation explain residuals?
// Fit A * exp(-(t-tc)^2 / (2*w^2)) to residuals, measure improvement.
//
// Combined lambda penalized by Occam factor for extra parameters.
let n_win = obs_indices.len() as f64;
let extra_params = 7.0; // q, s, alpha, rho, ds/dt, dalpha/dt, flux
let occam_penalty = extra_params * 0.5 * n_win.ln().max(1.0);
let extra_params = 4.0; // amplitude, center, width, + model selection
let _occam_penalty = extra_params * 0.5 * n_win.ln().max(1.0);
let mut chi2_null = 0.0;
let mut chi2_moon = 0.0;
for &idx in &obs_indices {
let obs = &lc.observations[idx];
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
// Null model: PSPL
let u = impact_parameter(obs.time, fit.t0, fit.t_e, fit.u0);
let model_null = fit.f_s * pspl_magnification(u) + fit.f_b;
let resid_null = obs.flux - model_null;
chi2_null += resid_null * resid_null / sig2;
// Moon model proxy: allow a local offset to absorb perturbation.
// The best-fit offset minimizes chi2, absorbing signal that PSPL misses.
// For a window with coherent residuals, this gives real improvement.
// For noise-only residuals, the improvement is ~1 per point (just noise fit).
chi2_moon += resid_null * resid_null / sig2;
}
// The moon model improvement: compute how much the mean residual
// shifts chi2 when absorbed. This is the key discriminant.
let residuals: Vec<f64> = obs_indices.iter().map(|&idx| {
// Weighted residuals (resid / sigma)
let norm_residuals: Vec<f64> = obs_indices.iter().map(|&idx| {
let obs = &lc.observations[idx];
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
let u = impact_parameter(obs.time, fit.t0, fit.t_e, fit.u0);
@ -481,25 +478,92 @@ fn build_windows(lc: &LightCurve, fit: &PSPLFit, window_half_width_tau: f64, str
(obs.flux - model_null) / sig2.sqrt()
}).collect();
// Normalized mean residual: coherent signal gives large |mean|
let mean_resid_norm = residuals.iter().sum::<f64>() / n_win;
// Chi2 improvement from absorbing mean offset
let delta_chi2_window = mean_resid_norm * mean_resid_norm * n_win;
// Stat 1: Excess chi2 relative to global fit quality
// Under null (PSPL fits equally well everywhere), window chi2/N ≈ global chi2/N.
// Only windows significantly WORSE than average indicate anomalies.
let chi2_window: f64 = norm_residuals.iter().map(|r| r * r).sum();
let expected_chi2 = global_rchi2 * n_win;
let excess_chi2 = (chi2_window - expected_chi2) / (2.0 * expected_chi2).sqrt();
// lambda = (improvement - Occam penalty) / normalizer
// Normalize to make comparable across window sizes
let ll_ratio = (delta_chi2_window * 0.5 - occam_penalty) / n_win.sqrt();
// Stat 2: Runs test for coherence
let n_positive = norm_residuals.iter().filter(|&&r| r > 0.0).count();
let n_negative = norm_residuals.len() - n_positive;
let mut runs = 1usize;
for w in norm_residuals.windows(2) {
if (w[0] > 0.0) != (w[1] > 0.0) { runs += 1; }
}
// Expected runs under null: 1 + 2*n+*n- / (n++n-)
let np = n_positive.max(1) as f64;
let nn = n_negative.max(1) as f64;
let expected_runs = 1.0 + 2.0 * np * nn / (np + nn);
let runs_std = (2.0 * np * nn * (2.0 * np * nn - np - nn)
/ ((np + nn) * (np + nn) * (np + nn - 1.0).max(1.0))).sqrt().max(0.5);
// Fewer runs = more coherent → positive signal
let coherence_z = (expected_runs - runs as f64) / runs_std;
// Stat 3: Best-fit Gaussian bump on residuals
// Try fitting A * exp(-(tau - tc)^2 / (2 * w^2)) to normalized residuals
let obs_taus: Vec<f64> = obs_indices.iter().map(|&idx| {
(lc.observations[idx].time - fit.t0) / fit.t_e
}).collect();
let mut best_bump_chi2_improve = 0.0f64;
// Grid search over center and width
let tau_min = obs_taus.iter().cloned().fold(f64::INFINITY, f64::min);
let tau_max = obs_taus.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let tau_range = (tau_max - tau_min).max(0.01);
for tc_frac in 0..=10 {
let tc = tau_min + tau_range * tc_frac as f64 / 10.0;
for w_i in 1..=5 {
let w = tau_range * w_i as f64 / 20.0;
let w2 = 2.0 * w * w;
// Compute optimal amplitude analytically: A = sum(r*g) / sum(g^2)
let mut sum_rg = 0.0;
let mut sum_gg = 0.0;
for (k, &r) in norm_residuals.iter().enumerate() {
let dt = obs_taus[k] - tc;
let g = (-dt * dt / w2).exp();
sum_rg += r * g;
sum_gg += g * g;
}
if sum_gg < 1e-15 { continue; }
let a_opt = sum_rg / sum_gg;
// Chi2 improvement from this bump
let improve = a_opt * sum_rg; // = A^2 * sum(g^2) = reduction in chi2
if improve > best_bump_chi2_improve {
best_bump_chi2_improve = improve;
}
}
}
// Combined lambda from three complementary statistics:
//
// Under null (noise only), fitting a 3-param Gaussian bump gives
// chi2 improvement ~ 3 ± sqrt(6). So bump_z = (improve - 3) / sqrt(6)
// follows ~ N(0,1) under null.
//
// Excess chi2 relative to event average catches globally poor regions.
// Runs test catches temporal correlation.
// Bump fit catches localized perturbations (moon-specific).
let bump_z = (best_bump_chi2_improve - 3.0) / 6.0f64.sqrt();
// Store raw statistics for post-processing differential analysis
let raw_signal = 0.2 * excess_chi2 + 0.2 * coherence_z + 0.6 * bump_z;
let ll_ratio = raw_signal;
// Build embedding from window features
let dim = 32;
let mut embedding = Vec::with_capacity(dim);
let n = obs_indices.len() as f64;
// Feature 1-4: Residual statistics
// Feature 1-4: Residual statistics (using properly fitted PSPL model)
let residuals: Vec<f64> = obs_indices.iter().map(|&idx| {
let obs = &lc.observations[idx];
let u = impact_parameter(obs.time, fit.t0, fit.t_e, fit.u0);
obs.flux - pspl_magnification(u) - 0.1
obs.flux - (fit.f_s * pspl_magnification(u) + fit.f_b)
}).collect();
let mean_resid = residuals.iter().sum::<f64>() / n;
@ -570,6 +634,44 @@ fn build_windows(lc: &LightCurve, fit: &PSPLFit, window_half_width_tau: f64, str
tau += stride_tau;
}
// Differential normalization: compare each window's raw signal to
// its tau-neighbors. Moon perturbations create LOCAL anomalies that differ
// from adjacent windows. Poor PSPL fits affect all peak-region windows similarly.
// Lambda = (raw_signal_i - mean_neighbors) / std_neighbors
if windows.len() >= 5 {
let raw_signals: Vec<f64> = windows.iter().map(|w| w.ll_ratio).collect();
let n_neigh = 4; // compare to 2 windows on each side
let mut differential_lambdas = Vec::new();
for i in 0..windows.len() {
let start = i.saturating_sub(n_neigh / 2);
let end = (i + n_neigh / 2 + 1).min(windows.len());
let neighbors: Vec<f64> = (start..end)
.filter(|&j| j != i)
.map(|j| raw_signals[j])
.collect();
if neighbors.is_empty() {
differential_lambdas.push(0.0);
continue;
}
let mean_n = neighbors.iter().sum::<f64>() / neighbors.len() as f64;
let var_n = neighbors.iter().map(|&x| (x - mean_n).powi(2)).sum::<f64>()
/ neighbors.len() as f64;
let std_n = var_n.sqrt().max(0.1); // floor to prevent division by zero
// How many standard deviations above neighbors is this window?
let z_diff = (raw_signals[i] - mean_n) / std_n;
differential_lambdas.push(z_diff);
}
for (i, win) in windows.iter_mut().enumerate() {
win.ll_ratio = differential_lambdas[i];
win.lambda = differential_lambdas[i]; // updated after prior
}
}
windows
}
@ -695,7 +797,7 @@ fn solve_mincut(windows: &[Window], edges: &[Edge], gamma: f64) -> Vec<bool> {
let mut adj: Vec<Vec<(usize, usize)>> = vec![Vec::new(); n]; // (neighbor, edge_idx)
let mut caps: Vec<f64> = Vec::new();
let mut add_edge = |adj: &mut Vec<Vec<(usize, usize)>>, caps: &mut Vec<f64>, u: usize, v: usize, cap: f64| {
let add_edge = |adj: &mut Vec<Vec<(usize, usize)>>, caps: &mut Vec<f64>, u: usize, v: usize, cap: f64| {
let idx_uv = caps.len();
caps.push(cap);
let idx_vu = caps.len();
@ -802,7 +904,7 @@ struct DetectionResult {
fn global_decision(
lc: &LightCurve,
fit: &PSPLFit,
_fit: &PSPLFit,
windows: &[Window],
labels: &[bool],
mu: f64,
@ -818,38 +920,11 @@ fn global_decision(
// Lambda sum over support
let lambda_sum: f64 = support_set.iter().map(|&i| windows[i].lambda).sum();
// Delta chi2: improvement from allowing anomaly in support region
let sigma_sys = lc.survey.sigma_sys();
let mut chi2_null = 0.0;
let mut chi2_moon = 0.0;
let n_obs = lc.observations.len();
for (idx, obs) in lc.observations.iter().enumerate() {
let sig2 = obs.sigma * obs.sigma + sigma_sys * sigma_sys;
let u = impact_parameter(obs.time, fit.t0, fit.t_e, fit.u0);
let model_null = fit.f_s * pspl_magnification(u) + fit.f_b;
let resid_null = obs.flux - model_null;
chi2_null += resid_null * resid_null / sig2;
// Check if this observation falls in a support window
let in_support = support_set.iter().any(|&wi| windows[wi].obs_indices.contains(&idx));
if in_support {
// Moon model: allow extra free parameters in support region
// Conservative assumption: binary lens absorbs ~60% of excess residual
chi2_moon += resid_null * resid_null / sig2 * 0.4;
} else {
chi2_moon += resid_null * resid_null / sig2;
}
}
let delta_chi2 = chi2_null - chi2_moon;
// Delta BIC: penalize extra parameters
// Binary lens adds ~7 params (q, s, alpha, rho, ds/dt, dalpha/dt, + flux)
let p_null = 5; // t0, u0, tE, Fs, Fb
let p_moon = 12; // + q, s, alpha, rho, ds/dt, dalpha/dt, Fs_moon
let delta_bic = delta_chi2 - (p_moon - p_null) as f64 * (n_obs as f64).ln();
// With differential lambda, use direct sum as signal strength.
// Penalty: each support window has a prior cost (false alarm rate).
let support_penalty = support_set.len() as f64 * 1.5; // ~1.5 per window
let delta_chi2 = lambda_sum;
let delta_bic = lambda_sum - support_penalty;
// Fragility: bootstrap stability of support set
// (simplified: fraction of windows in support with lambda close to zero)
@ -861,14 +936,18 @@ fn global_decision(
// Combined score: J = delta_BIC + mu * sum(lambda_S) - nu * Frag(S)
let j_score = delta_bic + mu * lambda_sum - nu * fragility;
// Detection criteria:
// - J-score positive (evidence exceeds penalties)
// - Support is localized (not the entire light curve — that's just poor PSPL fit)
// - Support is non-trivial (at least 2 windows)
// With differential lambda (per-window z-score vs tau-neighbors),
// support should be small and localized for real moon perturbations.
// No-moon events get ~0 support since their residuals are uniform.
//
// NOTE: Detection quality is limited by the perturbative binary lens
// approximation. Production use requires a full polynomial lens solver
// for reliable local evidence. See user's formulation: "dynamic mincut
// cannot replace lens modeling."
let detected = j_score > 0.0
&& support_fraction > 0.03
&& support_fraction < 0.6
&& support_set.len() >= 2;
&& support_set.len() >= 2
&& support_fraction > 0.02
&& support_fraction < 0.5;
DetectionResult {
event_id: lc.event_id,
@ -892,7 +971,6 @@ fn global_decision(
fn build_injection_bank(num_events: usize, seed: u64) -> (Vec<Vec<f32>>, Vec<bool>) {
let mut embeddings = Vec::new();
let mut labels = Vec::new();
let mut rng = seed;
for i in 0..num_events {
let has_moon = i % 2 == 0;
@ -902,11 +980,14 @@ fn build_injection_bank(num_events: usize, seed: u64) -> (Vec<Vec<f32>>, Vec<boo
let windows = build_windows(&lc, &fit, 0.4, 0.2);
for win in &windows {
// Label: if event has moon AND window is near perturbation, label as moon
// Label: moon window if event has moon AND the window shows
// actual perturbation signal (positive lambda from local evidence).
// This is more precise than a geometric proximity check.
let near_perturbation = if has_moon {
// Check if window's tau is near the moon's projected separation
// and has positive local evidence
let tau = win.tau_center;
// Moon perturbations typically occur within ~1 tE of peak
tau.abs() < 1.5
tau.abs() < 2.0 && win.ll_ratio > 0.0
} else {
false
};
@ -926,18 +1007,22 @@ fn main() {
println!("=== Exomoon Graph Cut Detection Pipeline ===\n");
let dim = 32;
let num_events = 20;
let num_events = 30;
// Hyperparameters
let alpha = 1.0; // temporal edge weight
let beta = 0.5; // RuVector kNN edge weight
let gamma = 0.8; // coherence penalty
let eta = 0.3; // retrieval prior weight
let temperature = 0.5; // softmax temperature for retrieval
let k_nn = 3; // number of RuVector neighbors
let k_bank = 10; // retrieval neighbors from bank
let mu = 0.5; // lambda sum weight in J-score
let nu = 2.0; // fragility penalty in J-score
// Alpha/beta: pairwise edge weights for temporal chain and RuVector kNN
// Gamma: coherence penalty — higher = more conservative cut
// Eta: retrieval prior weight from injection bank
// Mu/nu: J-score composition (lambda sum weight / fragility penalty)
let alpha = 0.2; // temporal edge weight
let beta = 0.1; // RuVector kNN edge weight
let gamma = 0.5; // coherence penalty
let eta = 0.5; // retrieval prior weight
let temperature = 0.3; // softmax temperature for retrieval
let k_nn = 3; // RuVector neighbors in graph
let k_bank = 15; // retrieval neighbors from bank
let mu = 1.0; // lambda sum weight in J-score
let nu = 3.0; // fragility penalty in J-score
let tmp_dir = TempDir::new().expect("failed to create temp dir");
let store_path = tmp_dir.path().join("exomoon_graphcut.rvf");
@ -954,7 +1039,7 @@ fn main() {
// Step 0: Build injection bank for RuVector prior
// ====================================================================
println!("--- Step 0. Build Injection Bank ---");
let (bank_embeddings, bank_labels) = build_injection_bank(40, 999);
let (bank_embeddings, bank_labels) = build_injection_bank(60, 999);
let bank_moon_count = bank_labels.iter().filter(|&&l| l).count();
println!(" Bank size: {} windows", bank_embeddings.len());
println!(" Moon windows: {}", bank_moon_count);