refactor(rvdna): consolidate SNP arrays, cache metadata, optimize streaming

Structural improvements from deep code review:

- Consolidate 5 parallel arrays (SNP_WEIGHTS, HOM_REF, HOM_ALT, HET,
  ALLELE_FREQS) into single SnpDef struct array — eliminates entire class
  of parallel-array misalignment bugs
- Cache category_meta() with LazyLock — avoids per-call Vec allocation
  (critical in generate_synthetic_population hot path)
- Hoist Normal::new out of inner loop in generate_readings — pre-compute
  distributions per biomarker instead of per-step*per-biomarker
- Add clinically meaningful lower bounds: LDL normal_low 0→50 mg/dL
  (critical_low 25), Triglycerides normal_low 0→35 mg/dL (critical_low 20)
- Optimize RingBuffer::clear from O(capacity) to O(1) — head/len reset
  is sufficient since push overwrites before read
- Use NUM_SNPS const for vector encoding bounds instead of magic number 51

All 172 tests pass, zero clippy warnings for rvdna.

https://claude.ai/code/session_014FpaYVohmyLH5dcBZTgmSY
This commit is contained in:
Claude 2026-02-22 06:31:44 +00:00
parent b4c230f4b5
commit 8b85624352
No known key found for this signature in database
2 changed files with 96 additions and 107 deletions

View file

@ -33,9 +33,9 @@ pub enum BiomarkerClassification {
static REFERENCES: &[BiomarkerReference] = &[
BiomarkerReference { name: "Total Cholesterol", unit: "mg/dL", normal_low: 125.0, normal_high: 200.0, critical_low: Some(100.0), critical_high: Some(300.0), category: "Lipid" },
BiomarkerReference { name: "LDL", unit: "mg/dL", normal_low: 0.0, normal_high: 100.0, critical_low: None, critical_high: Some(190.0), category: "Lipid" },
BiomarkerReference { name: "LDL", unit: "mg/dL", normal_low: 50.0, normal_high: 100.0, critical_low: Some(25.0), critical_high: Some(190.0), category: "Lipid" },
BiomarkerReference { name: "HDL", unit: "mg/dL", normal_low: 40.0, normal_high: 90.0, critical_low: Some(20.0), critical_high: None, category: "Lipid" },
BiomarkerReference { name: "Triglycerides", unit: "mg/dL", normal_low: 0.0, normal_high: 150.0, critical_low: None, critical_high: Some(500.0), category: "Lipid" },
BiomarkerReference { name: "Triglycerides", unit: "mg/dL", normal_low: 35.0, normal_high: 150.0, critical_low: Some(20.0), critical_high: Some(500.0), category: "Lipid" },
BiomarkerReference { name: "Fasting Glucose", unit: "mg/dL", normal_low: 70.0, normal_high: 100.0, critical_low: Some(50.0), critical_high: Some(250.0), category: "Metabolic" },
BiomarkerReference { name: "HbA1c", unit: "%", normal_low: 4.0, normal_high: 5.7, critical_low: None, critical_high: Some(9.0), category: "Metabolic" },
BiomarkerReference { name: "Homocysteine", unit: "umol/L", normal_low: 5.0, normal_high: 15.0, critical_low: None, critical_high: Some(30.0), category: "Metabolic" },
@ -100,40 +100,49 @@ pub struct BiomarkerProfile {
pub biomarker_values: HashMap<String, f64>,
}
// SNP Risk Weight Matrix -- (rsid, category, hom_ref, het, hom_alt)
static SNP_WEIGHTS: &[(&str, &str, f64, f64, f64)] = &[
("rs429358", "Neurological", 0.0, 0.4, 0.9),
("rs7412", "Neurological", 0.0, -0.15, -0.3),
("rs1042522", "Cancer Risk", 0.0, 0.25, 0.5), // Pro72Arg: CC/ProPro not risk-associated (SOTA)
("rs80357906","Cancer Risk", 0.0, 0.7, 0.95),
("rs28897696","Cancer Risk", 0.0, 0.3, 0.6),
("rs11571833","Cancer Risk", 0.0, 0.20, 0.5), // K3326X: OR 1.28 breast (Meeks 2016, iCOGS)
("rs1801133", "Metabolism", 0.0, 0.35, 0.7), // C677T: het=40% enzyme decrease (geneticlifehacks)
("rs1801131", "Metabolism", 0.0, 0.10, 0.25), // A1298C: hom_alt=~20% decrease (geneticlifehacks)
("rs4680", "Neurological", 0.0, 0.2, 0.45),
("rs1799971", "Neurological", 0.0, 0.2, 0.4),
("rs762551", "Metabolism", 0.0, 0.15, 0.35),
("rs4988235", "Metabolism", 0.0, 0.05, 0.15),
("rs53576", "Neurological", 0.0, 0.1, 0.25),
("rs6311", "Neurological", 0.0, 0.15, 0.3),
("rs1800497", "Neurological", 0.0, 0.25, 0.5),
("rs4363657", "Cardiovascular", 0.0, 0.35, 0.7),
("rs1800566", "Cancer Risk", 0.0, 0.15, 0.30), // Pro187Ser: OR 1.18 TT (Lajin 2013 meta-analysis)
];
// Genotype encoding: 0 = hom_ref, 1 = het, 2 = hom_alt
static HOM_REF: &[&str] = &[
"TT", "CC", "CC", "DD", "GG", "AA", "GG", "TT",
"GG", "AA", "AA", "AA", "GG", "CC", "GG", "TT", "CC",
];
fn genotype_code(i: usize, gt: &str) -> u8 {
if gt == HOM_REF[i] { 0 } else if gt.len() == 2 && gt.as_bytes()[0] != gt.as_bytes()[1] { 1 } else { 2 }
/// Unified SNP descriptor — eliminates parallel-array fragility.
struct SnpDef {
rsid: &'static str,
category: &'static str,
w_ref: f64,
w_het: f64,
w_alt: f64,
hom_ref: &'static str,
het: &'static str,
hom_alt: &'static str,
maf: f64, // minor allele frequency
}
fn snp_weight(i: usize, code: u8) -> f64 {
let (_, _, w0, w1, w2) = SNP_WEIGHTS[i];
match code { 0 => w0, 1 => w1, _ => w2 }
static SNPS: &[SnpDef] = &[
SnpDef { rsid: "rs429358", category: "Neurological", w_ref: 0.0, w_het: 0.4, w_alt: 0.9, hom_ref: "TT", het: "CT", hom_alt: "CC", maf: 0.14 },
SnpDef { rsid: "rs7412", category: "Neurological", w_ref: 0.0, w_het: -0.15, w_alt: -0.3, hom_ref: "CC", het: "CT", hom_alt: "TT", maf: 0.08 },
SnpDef { rsid: "rs1042522", category: "Cancer Risk", w_ref: 0.0, w_het: 0.25, w_alt: 0.5, hom_ref: "CC", het: "CG", hom_alt: "GG", maf: 0.40 },
SnpDef { rsid: "rs80357906", category: "Cancer Risk", w_ref: 0.0, w_het: 0.7, w_alt: 0.95, hom_ref: "DD", het: "DI", hom_alt: "II", maf: 0.003 },
SnpDef { rsid: "rs28897696", category: "Cancer Risk", w_ref: 0.0, w_het: 0.3, w_alt: 0.6, hom_ref: "GG", het: "AG", hom_alt: "AA", maf: 0.005 },
SnpDef { rsid: "rs11571833", category: "Cancer Risk", w_ref: 0.0, w_het: 0.20, w_alt: 0.5, hom_ref: "AA", het: "AT", hom_alt: "TT", maf: 0.01 },
SnpDef { rsid: "rs1801133", category: "Metabolism", w_ref: 0.0, w_het: 0.35, w_alt: 0.7, hom_ref: "GG", het: "AG", hom_alt: "AA", maf: 0.32 },
SnpDef { rsid: "rs1801131", category: "Metabolism", w_ref: 0.0, w_het: 0.10, w_alt: 0.25, hom_ref: "TT", het: "GT", hom_alt: "GG", maf: 0.30 },
SnpDef { rsid: "rs4680", category: "Neurological", w_ref: 0.0, w_het: 0.2, w_alt: 0.45, hom_ref: "GG", het: "AG", hom_alt: "AA", maf: 0.50 },
SnpDef { rsid: "rs1799971", category: "Neurological", w_ref: 0.0, w_het: 0.2, w_alt: 0.4, hom_ref: "AA", het: "AG", hom_alt: "GG", maf: 0.15 },
SnpDef { rsid: "rs762551", category: "Metabolism", w_ref: 0.0, w_het: 0.15, w_alt: 0.35, hom_ref: "AA", het: "AC", hom_alt: "CC", maf: 0.37 },
SnpDef { rsid: "rs4988235", category: "Metabolism", w_ref: 0.0, w_het: 0.05, w_alt: 0.15, hom_ref: "AA", het: "AG", hom_alt: "GG", maf: 0.24 },
SnpDef { rsid: "rs53576", category: "Neurological", w_ref: 0.0, w_het: 0.1, w_alt: 0.25, hom_ref: "GG", het: "AG", hom_alt: "AA", maf: 0.35 },
SnpDef { rsid: "rs6311", category: "Neurological", w_ref: 0.0, w_het: 0.15, w_alt: 0.3, hom_ref: "CC", het: "CT", hom_alt: "TT", maf: 0.45 },
SnpDef { rsid: "rs1800497", category: "Neurological", w_ref: 0.0, w_het: 0.25, w_alt: 0.5, hom_ref: "GG", het: "AG", hom_alt: "AA", maf: 0.20 },
SnpDef { rsid: "rs4363657", category: "Cardiovascular", w_ref: 0.0, w_het: 0.35, w_alt: 0.7, hom_ref: "TT", het: "CT", hom_alt: "CC", maf: 0.15 },
SnpDef { rsid: "rs1800566", category: "Cancer Risk", w_ref: 0.0, w_het: 0.15, w_alt: 0.30, hom_ref: "CC", het: "CT", hom_alt: "TT", maf: 0.22 },
];
const NUM_SNPS: usize = 17;
fn genotype_code(snp: &SnpDef, gt: &str) -> u8 {
if gt == snp.hom_ref { 0 }
else if gt.len() == 2 && gt.as_bytes()[0] != gt.as_bytes()[1] { 1 }
else { 2 }
}
fn snp_weight(snp: &SnpDef, code: u8) -> f64 {
match code { 0 => snp.w_ref, 1 => snp.w_het, _ => snp.w_alt }
}
struct Interaction {
@ -153,12 +162,12 @@ static INTERACTIONS: &[Interaction] = &[
];
fn snp_idx(rsid: &str) -> Option<usize> {
SNP_WEIGHTS.iter().position(|(r, _, _, _, _)| *r == rsid)
SNPS.iter().position(|s| s.rsid == rsid)
}
fn is_non_ref(gts: &HashMap<String, String>, rsid: &str) -> bool {
match (gts.get(rsid), snp_idx(rsid)) {
(Some(g), Some(idx)) => g != HOM_REF[idx],
(Some(g), Some(idx)) => g != SNPS[idx].hom_ref,
_ => false,
}
}
@ -175,12 +184,16 @@ struct CategoryMeta { name: &'static str, max_possible: f64, expected_count: usi
static CAT_ORDER: &[&str] = &["Cancer Risk", "Cardiovascular", "Neurological", "Metabolism"];
fn category_meta() -> Vec<CategoryMeta> {
CAT_ORDER.iter().map(|&cat| {
let (mp, ec) = SNP_WEIGHTS.iter().filter(|(_, c, _, _, _)| *c == cat)
.fold((0.0, 0usize), |(s, n), (_, _, _, _, w2)| (s + w2.max(0.0), n + 1));
CategoryMeta { name: cat, max_possible: mp.max(1.0), expected_count: ec }
}).collect()
fn category_meta() -> &'static [CategoryMeta] {
use std::sync::LazyLock;
static META: LazyLock<Vec<CategoryMeta>> = LazyLock::new(|| {
CAT_ORDER.iter().map(|&cat| {
let (mp, ec) = SNPS.iter().filter(|s| s.category == cat)
.fold((0.0, 0usize), |(s, n), snp| (s + snp.w_alt.max(0.0), n + 1));
CategoryMeta { name: cat, max_possible: mp.max(1.0), expected_count: ec }
}).collect()
});
&META
}
/// Compute composite risk scores from genotype data.
@ -188,15 +201,15 @@ pub fn compute_risk_scores(genotypes: &HashMap<String, String>) -> BiomarkerProf
let meta = category_meta();
let mut cat_scores: HashMap<&str, (f64, Vec<String>, usize)> = HashMap::with_capacity(4);
for (i, (rsid, cat, _, _, _)) in SNP_WEIGHTS.iter().enumerate() {
if let Some(gt) = genotypes.get(*rsid) {
let code = genotype_code(i, gt);
let w = snp_weight(i, code);
let entry = cat_scores.entry(cat).or_insert_with(|| (0.0, Vec::new(), 0));
for snp in SNPS {
if let Some(gt) = genotypes.get(snp.rsid) {
let code = genotype_code(snp, gt);
let w = snp_weight(snp, code);
let entry = cat_scores.entry(snp.category).or_insert_with(|| (0.0, Vec::new(), 0));
entry.0 += w;
entry.2 += 1;
if code > 0 {
entry.1.push(rsid.to_string());
entry.1.push(snp.rsid.to_string());
}
}
}
@ -211,7 +224,7 @@ pub fn compute_risk_scores(genotypes: &HashMap<String, String>) -> BiomarkerProf
}
let mut category_scores = HashMap::with_capacity(meta.len());
for cm in &meta {
for cm in meta {
let (raw, variants, count) = cat_scores.remove(cm.name).unwrap_or((0.0, Vec::new(), 0));
let score = (raw / cm.max_possible).clamp(0.0, 1.0);
let confidence = if count > 0 { (count as f64 / cm.expected_count.max(1) as f64).min(1.0) } else { 0.0 };
@ -238,74 +251,51 @@ pub fn encode_profile_vector(profile: &BiomarkerProfile) -> Vec<f32> {
fn encode_profile_vector_with_genotypes(profile: &BiomarkerProfile, genotypes: &HashMap<String, String>) -> Vec<f32> {
let mut v = vec![0.0f32; 64];
for (i, (rsid, _, _, _, _)) in SNP_WEIGHTS.iter().enumerate() {
let code = genotypes.get(*rsid).map(|gt| genotype_code(i, gt)).unwrap_or(0);
// Dims 0..50: one-hot genotype encoding (17 SNPs x 3 = 51 dims)
for (i, snp) in SNPS.iter().enumerate() {
let code = genotypes.get(snp.rsid).map(|gt| genotype_code(snp, gt)).unwrap_or(0);
let base = i * 3;
if base + 2 < 51 {
if base + 2 < NUM_SNPS * 3 {
v[base + code as usize] = 1.0;
}
}
// Dims 51..54: category scores
for (j, cat) in CAT_ORDER.iter().enumerate() {
v[51 + j] = profile.category_scores.get(*cat).map(|c| c.score as f32).unwrap_or(0.0);
}
v[55] = profile.global_risk_score as f32;
// Encode first 4 interactions in dims 56-59; additional interactions
// affect category scores (dims 51-54) but don't need dedicated dims.
// Dims 56..59: first 4 interaction modifiers
for (j, inter) in INTERACTIONS.iter().take(4).enumerate() {
let m = interaction_mod(genotypes, inter);
v[56 + j] = if m > 1.0 { (m - 1.0) as f32 } else { 0.0 };
}
// Dims 60..63: derived clinical scores
v[60] = analyze_mthfr(genotypes).score as f32 / 4.0;
v[61] = analyze_pain(genotypes).map(|p| p.score as f32 / 4.0).unwrap_or(0.0);
v[62] = genotypes.get("rs429358").map(|g| genotype_code(0, g) as f32 / 2.0).unwrap_or(0.0);
v[63] = genotypes.get("rs1800566").map(|g| genotype_code(16, g) as f32 / 2.0).unwrap_or(0.0);
v[62] = genotypes.get("rs429358").map(|g| genotype_code(&SNPS[0], g) as f32 / 2.0).unwrap_or(0.0);
v[63] = genotypes.get("rs1800566").map(|g| genotype_code(&SNPS[NUM_SNPS - 1], g) as f32 / 2.0).unwrap_or(0.0);
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
if norm > 0.0 { v.iter_mut().for_each(|x| *x /= norm); }
v
}
// Population allele frequencies (minor allele freq) per SNP
static ALLELE_FREQS: &[f64] = &[
0.14, 0.08, 0.40, 0.003, 0.005, 0.01,
0.32, 0.30, 0.50, 0.15, 0.37, 0.24,
0.35, 0.45, 0.20, 0.15, 0.22,
];
static HOM_ALT: &[&str] = &[
"CC", "TT", "GG", "II", "AA", "TT", "AA", "GG",
"AA", "GG", "CC", "GG", "AA", "TT", "AA", "CC", "TT",
];
static HET: &[&str] = &[
"CT", "CT", "CG", "DI", "AG", "AT", "AG", "GT",
"AG", "AG", "AC", "AG", "AG", "CT", "AG", "CT", "CT",
];
fn random_genotype(rng: &mut StdRng, idx: usize) -> String {
let p = ALLELE_FREQS[idx];
let r: f64 = rng.gen();
fn random_genotype(rng: &mut StdRng, snp: &SnpDef) -> String {
let p = snp.maf;
let q = 1.0 - p;
if r < q * q { HOM_REF[idx] } else if r < q * q + 2.0 * p * q { HET[idx] } else { HOM_ALT[idx] }.to_string()
let r: f64 = rng.gen();
if r < q * q { snp.hom_ref } else if r < q * q + 2.0 * p * q { snp.het } else { snp.hom_alt }.to_string()
}
/// Generate a deterministic synthetic population of biomarker profiles.
pub fn generate_synthetic_population(count: usize, seed: u64) -> Vec<BiomarkerProfile> {
let mut rng = StdRng::seed_from_u64(seed);
let mut pop = Vec::with_capacity(count);
let num_snps = SNP_WEIGHTS.len();
for i in 0..count {
let mut genotypes = HashMap::with_capacity(num_snps);
for (idx, (rsid, _, _, _, _)) in SNP_WEIGHTS.iter().enumerate() {
genotypes.insert(rsid.to_string(), random_genotype(&mut rng, idx));
let mut genotypes = HashMap::with_capacity(NUM_SNPS);
for snp in SNPS {
genotypes.insert(snp.rsid.to_string(), random_genotype(&mut rng, snp));
}
let mut profile = compute_risk_scores(&genotypes);
@ -313,8 +303,8 @@ pub fn generate_synthetic_population(count: usize, seed: u64) -> Vec<BiomarkerPr
profile.timestamp = 1700000000 + i as i64;
let mthfr_score = analyze_mthfr(&genotypes).score;
let apoe_code = genotypes.get("rs429358").map(|g| genotype_code(0, g)).unwrap_or(0);
let nqo1_code = genotypes.get("rs1800566").map(|g| genotype_code(16, g)).unwrap_or(0);
let apoe_code = genotypes.get("rs429358").map(|g| genotype_code(&SNPS[0], g)).unwrap_or(0);
let nqo1_code = genotypes.get("rs1800566").map(|g| genotype_code(&SNPS[NUM_SNPS - 1], g)).unwrap_or(0);
profile.biomarker_values.reserve(REFERENCES.len());
for bref in REFERENCES {
@ -345,9 +335,7 @@ mod tests {
use super::*;
fn full_hom_ref() -> HashMap<String, String> {
SNP_WEIGHTS.iter().enumerate().map(|(i, (rsid, _, _, _, _))| {
(rsid.to_string(), HOM_REF[i].to_string())
}).collect()
SNPS.iter().map(|s| (s.rsid.to_string(), s.hom_ref.to_string())).collect()
}
#[test]

View file

@ -84,7 +84,6 @@ impl<T: Clone + Default> RingBuffer<T> {
pub fn is_full(&self) -> bool { self.len == self.capacity }
pub fn clear(&mut self) {
self.buffer.iter_mut().for_each(|s| *s = T::default());
self.head = 0;
self.len = 0;
}
@ -113,18 +112,23 @@ pub fn generate_readings(
let mut rng = StdRng::seed_from_u64(seed);
let active = &BIOMARKER_DEFS[..config.num_biomarkers.min(BIOMARKER_DEFS.len())];
let mut readings = Vec::with_capacity(count * active.len());
// Pre-compute distributions per biomarker (avoids Normal::new in inner loop)
let dists: Vec<_> = active.iter().map(|def| {
let range = def.high - def.low;
let mid = (def.low + def.high) / 2.0;
let sigma = (config.noise_amplitude * range).max(1e-12);
let normal = Normal::new(0.0, sigma).unwrap();
let spike = Normal::new(0.0, sigma * config.anomaly_magnitude).unwrap();
(mid, range, normal, spike)
}).collect();
let mut ts: u64 = 0;
for step in 0..count {
for def in active {
let range = def.high - def.low;
let mid = (def.low + def.high) / 2.0;
let sigma = config.noise_amplitude * range;
let normal = Normal::new(0.0, sigma.max(1e-12)).unwrap();
for (j, def) in active.iter().enumerate() {
let (mid, range, ref normal, ref spike) = dists[j];
let drift = config.drift_rate * range * step as f64;
let is_anom = rng.gen::<f64>() < config.anomaly_probability;
let value = if is_anom {
let spike = Normal::new(0.0, sigma * config.anomaly_magnitude).unwrap();
(mid + rng.sample::<f64, _>(spike) + drift).max(0.0)
} else {
(mid + rng.sample::<f64, _>(normal) + drift).max(0.0)
@ -153,11 +157,8 @@ pub struct StreamStats {
pub anomaly_rate: f64,
pub trend_slope: f64,
pub ema: f64,
/// CUSUM statistic for changepoint detection (positive direction).
pub cusum_pos: f64,
/// CUSUM statistic for changepoint detection (negative direction).
pub cusum_neg: f64,
/// Whether a changepoint has been detected in the current window.
pub cusum_pos: f64, // CUSUM positive direction
pub cusum_neg: f64, // CUSUM negative direction
pub changepoint_detected: bool,
}