diff --git a/Cargo.lock b/Cargo.lock index 98437f50..12982909 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8260,6 +8260,32 @@ dependencies = [ "wasm-bindgen-test", ] +[[package]] +name = "ruvector-domain-expansion" +version = "2.0.3" +dependencies = [ + "criterion 0.5.1", + "proptest", + "rand 0.8.5", + "serde", + "serde_json", + "thiserror 2.0.17", +] + +[[package]] +name = "ruvector-domain-expansion-wasm" +version = "0.1.0" +dependencies = [ + "js-sys", + "rand 0.8.5", + "ruvector-domain-expansion", + "serde", + "serde-wasm-bindgen", + "serde_json", + "wasm-bindgen", + "wasm-bindgen-test", +] + [[package]] name = "ruvector-economy-wasm" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 1b433989..14ad8df6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,8 @@ members = [ "crates/ruqu-algorithms", "crates/ruqu-wasm", "crates/ruqu-exotic", + "crates/ruvector-domain-expansion", + "crates/ruvector-domain-expansion-wasm", "examples/dna", "examples/OSpipe", "crates/rvf/rvf-adapters/rvlite", diff --git a/crates/ruvector-domain-expansion-wasm/Cargo.toml b/crates/ruvector-domain-expansion-wasm/Cargo.toml new file mode 100644 index 00000000..2abdcc27 --- /dev/null +++ b/crates/ruvector-domain-expansion-wasm/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "ruvector-domain-expansion-wasm" +version = "0.1.0" +edition = "2021" +description = "WASM bindings for the domain expansion cross-domain transfer learning engine" +license = "MIT OR Apache-2.0" +repository = "https://github.com/ruvnet/ruvector" +homepage = "https://ruv.io" +authors = ["rUv "] +keywords = ["wasm", "transfer-learning", "domain-expansion", "generalization"] +categories = ["wasm", "algorithms", "science"] +rust-version = "1.70" + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +ruvector-domain-expansion = { path = "../ruvector-domain-expansion" } +wasm-bindgen = "0.2" +js-sys = "0.3" +serde = { version = "1.0", features = ["derive"] } +serde-wasm-bindgen = "0.6" +serde_json = "1.0" +rand = "0.8" + +[dev-dependencies] +wasm-bindgen-test = "0.3" + +[profile.release] +opt-level = "z" +lto = true +codegen-units = 1 +panic = "abort" +strip = true + +[package.metadata.wasm-pack.profile.release] +wasm-opt = false diff --git a/crates/ruvector-domain-expansion-wasm/src/lib.rs b/crates/ruvector-domain-expansion-wasm/src/lib.rs new file mode 100644 index 00000000..302041c8 --- /dev/null +++ b/crates/ruvector-domain-expansion-wasm/src/lib.rs @@ -0,0 +1,374 @@ +//! WASM bindings for the Domain Expansion Engine. +//! +//! Provides JavaScript-accessible interfaces for cross-domain transfer learning, +//! Meta Thompson Sampling, PolicyKernel population search, and the acceleration +//! scoreboard. All domain operations run at native speed in the browser/edge. + +use ruvector_domain_expansion::{ + AccelerationScoreboard, ArmId, ContextBucket, CostCurve, + DomainExpansionEngine, DomainId, Evaluation, MetaThompsonEngine, + PopulationSearch, Solution, Task, +}; +use wasm_bindgen::prelude::*; + +// ─── Engine ────────────────────────────────────────────────────────────────── + +/// WASM-exported domain expansion engine. +#[wasm_bindgen] +pub struct WasmDomainExpansionEngine { + inner: DomainExpansionEngine, +} + +#[wasm_bindgen] +impl WasmDomainExpansionEngine { + /// Create a new domain expansion engine with 3 core domains. + #[wasm_bindgen(constructor)] + pub fn new() -> Self { + Self { + inner: DomainExpansionEngine::new(), + } + } + + /// Get registered domain IDs as JSON array. + #[wasm_bindgen(js_name = domainIds)] + pub fn domain_ids(&self) -> JsValue { + let ids: Vec = self.inner.domain_ids().into_iter().map(|d| d.0).collect(); + serde_wasm_bindgen::to_value(&ids).unwrap_or(JsValue::NULL) + } + + /// Generate tasks for a domain. Returns JSON array of tasks. + #[wasm_bindgen(js_name = generateTasks)] + pub fn generate_tasks(&self, domain_id: &str, count: usize, difficulty: f32) -> JsValue { + let id = DomainId(domain_id.to_string()); + let tasks = self.inner.generate_tasks(&id, count, difficulty); + serde_wasm_bindgen::to_value(&tasks).unwrap_or(JsValue::NULL) + } + + /// Generate holdout tasks for all domains. + #[wasm_bindgen(js_name = generateHoldouts)] + pub fn generate_holdouts(&mut self, tasks_per_domain: usize, difficulty: f32) { + self.inner.generate_holdouts(tasks_per_domain, difficulty); + } + + /// Evaluate a solution (JSON). Returns evaluation JSON. + #[wasm_bindgen(js_name = evaluateAndRecord)] + pub fn evaluate_and_record( + &mut self, + domain_id: &str, + task_json: &str, + solution_json: &str, + difficulty_tier: &str, + category: &str, + arm: &str, + ) -> JsValue { + let task: Task = match serde_json::from_str(task_json) { + Ok(t) => t, + Err(_) => return JsValue::NULL, + }; + let solution: Solution = match serde_json::from_str(solution_json) { + Ok(s) => s, + Err(_) => return JsValue::NULL, + }; + + let bucket = ContextBucket { + difficulty_tier: difficulty_tier.to_string(), + category: category.to_string(), + }; + + let eval = self.inner.evaluate_and_record( + &DomainId(domain_id.to_string()), + &task, + &solution, + bucket, + ArmId(arm.to_string()), + ); + + serde_wasm_bindgen::to_value(&eval).unwrap_or(JsValue::NULL) + } + + /// Select the best arm for a context using Thompson Sampling. + #[wasm_bindgen(js_name = selectArm)] + pub fn select_arm( + &self, + domain_id: &str, + difficulty_tier: &str, + category: &str, + ) -> Option { + let bucket = ContextBucket { + difficulty_tier: difficulty_tier.to_string(), + category: category.to_string(), + }; + self.inner + .select_arm(&DomainId(domain_id.to_string()), &bucket) + .map(|a| a.0) + } + + /// Check if speculation should be triggered. + #[wasm_bindgen(js_name = shouldSpeculate)] + pub fn should_speculate( + &self, + domain_id: &str, + difficulty_tier: &str, + category: &str, + ) -> bool { + let bucket = ContextBucket { + difficulty_tier: difficulty_tier.to_string(), + category: category.to_string(), + }; + self.inner + .should_speculate(&DomainId(domain_id.to_string()), &bucket) + } + + /// Initiate transfer from source to target domain. + #[wasm_bindgen(js_name = initiateTransfer)] + pub fn initiate_transfer(&mut self, source: &str, target: &str) { + self.inner.initiate_transfer( + &DomainId(source.to_string()), + &DomainId(target.to_string()), + ); + } + + /// Verify a transfer delta. Returns verification JSON. + #[wasm_bindgen(js_name = verifyTransfer)] + pub fn verify_transfer( + &self, + source: &str, + target: &str, + source_before: f32, + source_after: f32, + target_before: f32, + target_after: f32, + baseline_cycles: u64, + transfer_cycles: u64, + ) -> JsValue { + let v = self.inner.verify_transfer( + &DomainId(source.to_string()), + &DomainId(target.to_string()), + source_before, + source_after, + target_before, + target_after, + baseline_cycles, + transfer_cycles, + ); + serde_wasm_bindgen::to_value(&v).unwrap_or(JsValue::NULL) + } + + /// Evaluate all policy kernels on holdout tasks. + #[wasm_bindgen(js_name = evaluatePopulation)] + pub fn evaluate_population(&mut self) { + self.inner.evaluate_population(); + } + + /// Evolve the policy kernel population. + #[wasm_bindgen(js_name = evolvePopulation)] + pub fn evolve_population(&mut self) { + self.inner.evolve_population(); + } + + /// Get population statistics as JSON. + #[wasm_bindgen(js_name = populationStats)] + pub fn population_stats(&self) -> JsValue { + let stats = self.inner.population_stats(); + serde_wasm_bindgen::to_value(&stats).unwrap_or(JsValue::NULL) + } + + /// Get the scoreboard summary as JSON. + #[wasm_bindgen(js_name = scoreboardSummary)] + pub fn scoreboard_summary(&self) -> JsValue { + let summary = self.inner.scoreboard_summary(); + serde_wasm_bindgen::to_value(&summary).unwrap_or(JsValue::NULL) + } + + /// Get the best policy kernel as JSON. + #[wasm_bindgen(js_name = bestKernel)] + pub fn best_kernel(&self) -> JsValue { + match self.inner.best_kernel() { + Some(k) => serde_wasm_bindgen::to_value(k).unwrap_or(JsValue::NULL), + None => JsValue::NULL, + } + } + + /// Get counterexamples for a domain as JSON. + #[wasm_bindgen(js_name = counterexamples)] + pub fn counterexamples(&self, domain_id: &str) -> JsValue { + let examples = self + .inner + .counterexamples(&DomainId(domain_id.to_string())); + let serializable: Vec<(&Task, &Solution, &Evaluation)> = examples + .iter() + .map(|(t, s, e)| (t, s, e)) + .collect(); + serde_wasm_bindgen::to_value(&serializable).unwrap_or(JsValue::NULL) + } +} + +// ─── Standalone Thompson Sampling ──────────────────────────────────────────── + +/// WASM-exported standalone Thompson Sampling engine. +#[wasm_bindgen] +pub struct WasmThompsonEngine { + inner: MetaThompsonEngine, +} + +#[wasm_bindgen] +impl WasmThompsonEngine { + /// Create a Thompson engine with the given arms. + #[wasm_bindgen(constructor)] + pub fn new(arms_json: &str) -> Self { + let arms: Vec = serde_json::from_str(arms_json).unwrap_or_default(); + Self { + inner: MetaThompsonEngine::new(arms), + } + } + + /// Initialize a domain with uniform priors. + #[wasm_bindgen(js_name = initDomain)] + pub fn init_domain(&mut self, domain_id: &str) { + self.inner + .init_domain_uniform(DomainId(domain_id.to_string())); + } + + /// Record an outcome. + #[wasm_bindgen(js_name = recordOutcome)] + pub fn record_outcome( + &mut self, + domain_id: &str, + difficulty_tier: &str, + category: &str, + arm: &str, + reward: f32, + cost: f32, + ) { + let bucket = ContextBucket { + difficulty_tier: difficulty_tier.to_string(), + category: category.to_string(), + }; + self.inner.record_outcome( + &DomainId(domain_id.to_string()), + bucket, + ArmId(arm.to_string()), + reward, + cost, + ); + } + + /// Select the best arm. + #[wasm_bindgen(js_name = selectArm)] + pub fn select_arm( + &self, + domain_id: &str, + difficulty_tier: &str, + category: &str, + ) -> Option { + let bucket = ContextBucket { + difficulty_tier: difficulty_tier.to_string(), + category: category.to_string(), + }; + let mut rng = rand::thread_rng(); + self.inner + .select_arm(&DomainId(domain_id.to_string()), &bucket, &mut rng) + .map(|a| a.0) + } + + /// Extract transfer prior as JSON. + #[wasm_bindgen(js_name = extractPrior)] + pub fn extract_prior(&self, domain_id: &str) -> JsValue { + match self.inner.extract_prior(&DomainId(domain_id.to_string())) { + Some(prior) => serde_wasm_bindgen::to_value(&prior).unwrap_or(JsValue::NULL), + None => JsValue::NULL, + } + } +} + +// ─── Population Search ─────────────────────────────────────────────────────── + +/// WASM-exported population-based policy search. +#[wasm_bindgen] +pub struct WasmPopulationSearch { + inner: PopulationSearch, +} + +#[wasm_bindgen] +impl WasmPopulationSearch { + /// Create a new population search with given size. + #[wasm_bindgen(constructor)] + pub fn new(pop_size: usize) -> Self { + Self { + inner: PopulationSearch::new(pop_size), + } + } + + /// Get current population size. + #[wasm_bindgen(js_name = populationSize)] + pub fn population_size(&self) -> usize { + self.inner.population().len() + } + + /// Get current generation. + pub fn generation(&self) -> u32 { + self.inner.generation() + } + + /// Evolve to next generation. + pub fn evolve(&mut self) { + self.inner.evolve(); + } + + /// Get stats as JSON. + pub fn stats(&self) -> JsValue { + let stats = self.inner.stats(); + serde_wasm_bindgen::to_value(&stats).unwrap_or(JsValue::NULL) + } +} + +// ─── Acceleration Scoreboard ───────────────────────────────────────────────── + +/// WASM-exported acceleration scoreboard. +#[wasm_bindgen] +pub struct WasmScoreboard { + inner: AccelerationScoreboard, +} + +#[wasm_bindgen] +impl WasmScoreboard { + /// Create a new scoreboard. + #[wasm_bindgen(constructor)] + pub fn new() -> Self { + Self { + inner: AccelerationScoreboard::new(), + } + } + + /// Add a cost curve from JSON. + #[wasm_bindgen(js_name = addCurve)] + pub fn add_curve(&mut self, curve_json: &str) { + if let Ok(curve) = serde_json::from_str::(curve_json) { + self.inner.add_curve(curve); + } + } + + /// Compute acceleration between two domains. + #[wasm_bindgen(js_name = computeAcceleration)] + pub fn compute_acceleration(&mut self, baseline: &str, transfer: &str) -> JsValue { + match self.inner.compute_acceleration( + &DomainId(baseline.to_string()), + &DomainId(transfer.to_string()), + ) { + Some(entry) => serde_wasm_bindgen::to_value(&entry).unwrap_or(JsValue::NULL), + None => JsValue::NULL, + } + } + + /// Check if progressive acceleration holds. + #[wasm_bindgen(js_name = progressiveAcceleration)] + pub fn progressive_acceleration(&self) -> bool { + self.inner.progressive_acceleration() + } + + /// Get full summary as JSON. + pub fn summary(&self) -> JsValue { + let s = self.inner.summary(); + serde_wasm_bindgen::to_value(&s).unwrap_or(JsValue::NULL) + } +} diff --git a/crates/ruvector-domain-expansion/Cargo.toml b/crates/ruvector-domain-expansion/Cargo.toml new file mode 100644 index 00000000..c760dfed --- /dev/null +++ b/crates/ruvector-domain-expansion/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "ruvector-domain-expansion" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +authors.workspace = true +repository.workspace = true +description = "Cross-domain transfer learning engine: Rust synthesis, structured planning, tool orchestration" +keywords = ["transfer-learning", "domain-expansion", "generalization", "rust-synthesis", "planning"] +categories = ["algorithms", "science"] + +[dependencies] +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +rand = { workspace = true } + +[dev-dependencies] +proptest = { workspace = true } +criterion = { workspace = true } + +[[bench]] +name = "domain_expansion_bench" +harness = false + +[lib] +crate-type = ["rlib"] diff --git a/crates/ruvector-domain-expansion/benches/domain_expansion_bench.rs b/crates/ruvector-domain-expansion/benches/domain_expansion_bench.rs new file mode 100644 index 00000000..8bae602e --- /dev/null +++ b/crates/ruvector-domain-expansion/benches/domain_expansion_bench.rs @@ -0,0 +1,181 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use ruvector_domain_expansion::{ + ArmId, ContextBucket, CostCurve, CostCurvePoint, ConvergenceThresholds, + AccelerationScoreboard, DomainExpansionEngine, DomainId, MetaThompsonEngine, + PolicyKnobs, PopulationSearch, Solution, TransferPrior, +}; + +fn bench_task_generation(c: &mut Criterion) { + let engine = DomainExpansionEngine::new(); + let domains = engine.domain_ids(); + + let mut group = c.benchmark_group("task_generation"); + + for domain_id in &domains { + group.bench_function(format!("{}", domain_id), |b| { + b.iter(|| { + engine.generate_tasks(black_box(domain_id), black_box(10), black_box(0.5)) + }) + }); + } + group.finish(); +} + +fn bench_evaluation(c: &mut Criterion) { + let engine = DomainExpansionEngine::new(); + let rust_id = DomainId("rust_synthesis".into()); + let tasks = engine.generate_tasks(&rust_id, 10, 0.5); + + let solution = Solution { + task_id: tasks[0].id.clone(), + content: "fn sum_positives(values: &[i64]) -> i64 { values.iter().filter(|&&x| x > 0).sum() }".into(), + data: serde_json::Value::Null, + }; + + c.bench_function("evaluate_rust_solution", |b| { + b.iter(|| { + let mut eng = DomainExpansionEngine::new(); + eng.evaluate_and_record( + black_box(&rust_id), + black_box(&tasks[0]), + black_box(&solution), + ContextBucket { + difficulty_tier: "medium".into(), + category: "transform".into(), + }, + ArmId("greedy".into()), + ) + }) + }); +} + +fn bench_embedding(c: &mut Criterion) { + let engine = DomainExpansionEngine::new(); + let rust_id = DomainId("rust_synthesis".into()); + + let solution = Solution { + task_id: "bench".into(), + content: "fn foo() { for i in 0..10 { if i > 5 { let x = i.max(3); } } }".into(), + data: serde_json::Value::Null, + }; + + c.bench_function("embed_solution", |b| { + b.iter(|| engine.embed(black_box(&rust_id), black_box(&solution))) + }); +} + +fn bench_thompson_sampling(c: &mut Criterion) { + let mut engine = MetaThompsonEngine::new(vec![ + "greedy".into(), + "exploratory".into(), + "conservative".into(), + "speculative".into(), + ]); + + let domain = DomainId("bench".into()); + engine.init_domain_uniform(domain.clone()); + + let bucket = ContextBucket { + difficulty_tier: "medium".into(), + category: "algorithm".into(), + }; + + // Pre-populate with data + for i in 0..100 { + let arm = ArmId(format!( + "{}", + ["greedy", "exploratory", "conservative", "speculative"][i % 4] + )); + let reward = if i % 4 == 0 { 0.9 } else { 0.4 }; + engine.record_outcome(&domain, bucket.clone(), arm, reward, 1.0); + } + + c.bench_function("thompson_select_arm", |b| { + b.iter(|| { + let mut rng = rand::thread_rng(); + engine.select_arm(black_box(&domain), black_box(&bucket), &mut rng) + }) + }); +} + +fn bench_population_evolve(c: &mut Criterion) { + let mut search = PopulationSearch::new(16); + + // Pre-populate fitness + for i in 0..16 { + if let Some(kernel) = search.kernel_mut(i) { + kernel.record_score(DomainId("bench".into()), i as f32 / 16.0, 1.0); + } + } + + c.bench_function("population_evolve_16", |b| { + b.iter(|| { + let mut s = search.clone(); + s.evolve(); + }) + }); +} + +fn bench_knobs_mutate(c: &mut Criterion) { + let knobs = PolicyKnobs::default_knobs(); + c.bench_function("knobs_mutate", |b| { + b.iter(|| { + let mut rng = rand::thread_rng(); + black_box(knobs.mutate(&mut rng, 0.3)) + }) + }); +} + +fn bench_cost_curve_auc(c: &mut Criterion) { + let mut curve = CostCurve::new(DomainId("bench".into()), ConvergenceThresholds::default()); + for i in 0..1000 { + curve.record(CostCurvePoint { + cycle: i, + accuracy: (i as f32 / 1000.0).min(1.0), + cost_per_solve: 1.0 / (i as f32 + 1.0), + robustness: (i as f32 / 1000.0).min(1.0), + policy_violations: 0, + timestamp: i as f64, + }); + } + + c.bench_function("cost_curve_auc_1000pts", |b| { + b.iter(|| black_box(curve.auc_accuracy())) + }); +} + +fn bench_transfer_prior_extract(c: &mut Criterion) { + let domain = DomainId("bench".into()); + let mut prior = TransferPrior::uniform(domain); + + // Populate with 100 buckets x 4 arms + for b in 0..100 { + for a in 0..4 { + let bucket = ContextBucket { + difficulty_tier: format!("tier_{}", b % 3), + category: format!("cat_{}", b), + }; + let arm = ArmId(format!("arm_{}", a)); + for _ in 0..20 { + prior.update_posterior(bucket.clone(), arm.clone(), 0.7); + } + } + } + + c.bench_function("transfer_prior_extract_100buckets", |b| { + b.iter(|| black_box(prior.extract_summary())) + }); +} + +criterion_group!( + benches, + bench_task_generation, + bench_evaluation, + bench_embedding, + bench_thompson_sampling, + bench_population_evolve, + bench_knobs_mutate, + bench_cost_curve_auc, + bench_transfer_prior_extract, +); +criterion_main!(benches); diff --git a/crates/ruvector-domain-expansion/docs/README.md b/crates/ruvector-domain-expansion/docs/README.md new file mode 100644 index 00000000..72094219 --- /dev/null +++ b/crates/ruvector-domain-expansion/docs/README.md @@ -0,0 +1,241 @@ +# ruvector-domain-expansion + +Cross-domain transfer learning engine for general problem-solving capability. + +## Core Insight + +> True IQ growth appears when a kernel trained on Domain 1 improves Domain 2 faster than Domain 2 alone. That is generalization. + +If cost curves compress faster in each new domain, you are increasing general problem-solving capability. + +## Architecture + +### Two-Layer Learning + +``` +Policy Learning Layer (Meta Thompson Sampling) + | + | TransferPrior: compact Beta posteriors per bucket/arm + | NOT raw trajectories. Ship priors, not memories. + | + v +Operator Layer (Domain Kernels) + | + | Rust Synthesis | Planning | Tool Orchestration + | Generate tasks, evaluate solutions, produce embeddings + | + v +Shared Embedding Space (64-dim) + Cross-domain similarity via cosine distance +``` + +### Domains + +| Domain | Description | Task Types | +|--------|-------------|------------| +| **Rust Program Synthesis** | Synthesize Rust functions from specs | Transform, DataStructure, Algorithm, TypeLevel, Concurrency | +| **Structured Planning** | Multi-step plans with constraints | ResourceAllocation, DependencyScheduling, StateSpaceSearch, ConstraintSatisfaction | +| **Tool Orchestration** | Coordinate multiple tools/agents | PipelineConstruction, ErrorRecovery, ParallelCoordination, ResourceNegotiation | + +### Transfer Protocol + +1. Train on Domain 1, extract `TransferPrior` (posterior summaries) +2. Initialize Domain 2 with dampened priors from Domain 1 +3. Measure acceleration: cycles to convergence with vs without transfer +4. **Generalization rule**: A delta is promotable only if it improves Domain 2 without regressing Domain 1 + +### Population-Based Policy Search + +Run a population of `PolicyKernel` variants in parallel. Each variant tunes knobs: +- Skip mode policy +- Prepass mode +- Speculation trigger thresholds +- Budget allocation + +Selection: keep top performers on holdouts, mutate knobs, repeat. Only merge deltas that pass replay-verify. + +### Speculative Dual-Path + +When posterior variance is high (top two arms within delta), run both strategies with bounded budgets. Pick the first correct, log the loser as a counterexample. + +## Usage + +### Rust + +```rust +use ruvector_domain_expansion::{ + DomainExpansionEngine, DomainId, ArmId, ContextBucket, +}; + +// Create engine with 3 core domains +let mut engine = DomainExpansionEngine::new(); + +// Generate tasks +let tasks = engine.generate_tasks( + &DomainId("rust_synthesis".into()), + 10, // count + 0.5, // difficulty +); + +// Select arm via Thompson Sampling +let bucket = ContextBucket { + difficulty_tier: "medium".into(), + category: "algorithm".into(), +}; +let arm = engine.select_arm( + &DomainId("rust_synthesis".into()), + &bucket, +).unwrap(); + +// Evaluate and record +let eval = engine.evaluate_and_record( + &DomainId("rust_synthesis".into()), + &tasks[0], + &solution, + bucket, + arm, +); + +// Transfer learning +engine.initiate_transfer( + &DomainId("rust_synthesis".into()), + &DomainId("structured_planning".into()), +); + +// Verify generalization +let v = engine.verify_transfer( + &DomainId("rust_synthesis".into()), + &DomainId("structured_planning".into()), + 0.85, 0.84, // source before/after + 0.3, 0.7, // target before/after + 100, 40, // baseline/transfer cycles +); +assert!(v.promotable); // improved target without regressing source +assert!(v.acceleration_factor > 1.0); // 2.5x faster convergence +``` + +### WASM (JavaScript) + +```javascript +import { WasmDomainExpansionEngine } from 'ruvector-domain-expansion-wasm'; + +const engine = new WasmDomainExpansionEngine(); + +// List domains +console.log(engine.domainIds()); +// ["rust_synthesis", "structured_planning", "tool_orchestration"] + +// Generate tasks +const tasks = engine.generateTasks("rust_synthesis", 10, 0.5); + +// Select strategy via Thompson Sampling +const arm = engine.selectArm("rust_synthesis", "medium", "algorithm"); + +// Check if dual-path speculation needed +if (engine.shouldSpeculate("rust_synthesis", "medium", "algorithm")) { + // Run both strategies, pick winner +} + +// Transfer priors between domains +engine.initiateTransfer("rust_synthesis", "structured_planning"); + +// Evolve policy kernels +engine.generateHoldouts(10, 0.5); +engine.evaluatePopulation(); +engine.evolvePopulation(); +console.log(engine.populationStats()); + +// Acceleration scoreboard +console.log(engine.scoreboardSummary()); +``` + +## Acceptance Test + +Domain 2 must converge faster than Domain 1. Measure cycles to reach: +- 95% accuracy +- Target cost per solve +- Target robustness +- Zero policy violations + +```rust +use ruvector_domain_expansion::{AccelerationScoreboard, CostCurve, DomainId}; + +let mut board = AccelerationScoreboard::new(); + +// Add baseline and transfer curves +board.add_curve(baseline_curve); +board.add_curve(transfer_curve); + +// Compute acceleration +let entry = board.compute_acceleration( + &DomainId("baseline".into()), + &DomainId("transfer".into()), +).unwrap(); + +assert!(entry.acceleration > 1.0); // transfer helped +assert!(entry.generalization_passed); + +// Check progressive improvement across multiple domains +assert!(board.progressive_acceleration()); +``` + +## RVF Packaging + +Transfer artifacts are designed for RVF segment packaging: + +| Segment | Content | Purpose | +|---------|---------|---------| +| `TransferPrior` | Beta posteriors per bucket/arm | Seeds new domain initialization | +| `PolicyKernel` | Knob configuration + fitness history | Best policy for a domain | +| `CostCurve` | Convergence data points | Acceleration measurement | +| `WitnessChain` | Hash of derivation + holdout results | Audit trail | +| `Counterexamples` | Failed solutions per context | Negative signal for future decisions | + +## Benchmarks + +```bash +cargo bench -p ruvector-domain-expansion +``` + +Benchmarks cover: +- Task generation (per domain) +- Solution evaluation +- Embedding extraction +- Thompson Sampling arm selection +- Population evolution +- PolicyKnobs mutation +- Cost curve AUC computation +- TransferPrior extraction + +## Module Structure + +``` +src/ + lib.rs -- Orchestrator: DomainExpansionEngine + domain.rs -- Core Domain trait, Task, Solution, Evaluation, Embedding + rust_synthesis.rs -- Rust program synthesis domain + planning.rs -- Structured planning tasks domain + tool_orchestration.rs -- Tool orchestration problems domain + transfer.rs -- Meta Thompson Sampling, TransferPrior, verification + policy_kernel.rs -- PolicyKernel, PopulationSearch, PolicyKnobs + cost_curve.rs -- CostCurve, AccelerationScoreboard +``` + +## Tests + +49 unit tests covering all modules: + +```bash +cargo test -p ruvector-domain-expansion +``` + +| Module | Tests | +|--------|-------| +| `domain` | 5 tests: types, embedding cosine similarity, evaluation | +| `rust_synthesis` | 5 tests: generation, evaluation, embedding, difficulty | +| `planning` | 5 tests: generation, reference, evaluation, embedding, scaling | +| `tool_orchestration` | 5 tests: generation, reference, evaluation, embedding, errors | +| `transfer` | 6 tests: Beta params, Thompson engine, prior extraction, verification | +| `policy_kernel` | 5 tests: knobs, fitness, evolution, stats, crossover | +| `cost_curve` | 5 tests: convergence, compression, AUC, acceleration, scoreboard | +| `lib` (integration) | 8 tests: engine, tasks, arms, evaluation, embedding, transfer, population | diff --git a/crates/ruvector-domain-expansion/src/cost_curve.rs b/crates/ruvector-domain-expansion/src/cost_curve.rs new file mode 100644 index 00000000..adacb290 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/cost_curve.rs @@ -0,0 +1,476 @@ +//! Cost Curve Compression Tracker and Acceleration Scoreboard +//! +//! Measures whether cost curves compress faster in each new domain. +//! If they do, you are increasing general problem-solving capability. +//! +//! ## Acceptance Test +//! +//! Domain 2 must converge faster than Domain 1. +//! Measure cycles to reach: +//! - 95% accuracy +//! - Target cost per solve +//! - Target robustness +//! - Zero policy violations + +use crate::domain::DomainId; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// A single data point on the cost curve. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostCurvePoint { + /// Cycle number (training iteration). + pub cycle: u64, + /// Current accuracy [0.0, 1.0]. + pub accuracy: f32, + /// Cost per solve at this point. + pub cost_per_solve: f32, + /// Robustness score [0.0, 1.0]. + pub robustness: f32, + /// Number of policy violations in this cycle. + pub policy_violations: u32, + /// Wall-clock timestamp (seconds since epoch). + pub timestamp: f64, +} + +/// Convergence thresholds for the acceptance test. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConvergenceThresholds { + /// Target accuracy (default: 0.95). + pub target_accuracy: f32, + /// Target cost per solve. + pub target_cost: f32, + /// Target robustness (default: 0.90). + pub target_robustness: f32, + /// Maximum allowed policy violations (default: 0). + pub max_violations: u32, +} + +impl Default for ConvergenceThresholds { + fn default() -> Self { + Self { + target_accuracy: 0.95, + target_cost: 0.01, + target_robustness: 0.90, + max_violations: 0, + } + } +} + +/// Cost curve for a single domain, tracking convergence over cycles. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostCurve { + /// Domain this curve belongs to. + pub domain_id: DomainId, + /// Whether this was trained with transfer priors. + pub used_transfer: bool, + /// Source domain for transfer (if any). + pub transfer_source: Option, + /// Ordered data points. + pub points: Vec, + /// Convergence thresholds. + pub thresholds: ConvergenceThresholds, +} + +impl CostCurve { + /// Create a new cost curve for a domain. + pub fn new(domain_id: DomainId, thresholds: ConvergenceThresholds) -> Self { + Self { + domain_id, + used_transfer: false, + transfer_source: None, + points: Vec::new(), + thresholds, + } + } + + /// Create a cost curve with transfer metadata. + pub fn with_transfer( + domain_id: DomainId, + source: DomainId, + thresholds: ConvergenceThresholds, + ) -> Self { + Self { + domain_id, + used_transfer: true, + transfer_source: Some(source), + points: Vec::new(), + thresholds, + } + } + + /// Record a new data point. + pub fn record(&mut self, point: CostCurvePoint) { + self.points.push(point); + } + + /// Check if all convergence criteria are met at the latest point. + pub fn has_converged(&self) -> bool { + self.points.last().map_or(false, |p| { + p.accuracy >= self.thresholds.target_accuracy + && p.cost_per_solve <= self.thresholds.target_cost + && p.robustness >= self.thresholds.target_robustness + && p.policy_violations <= self.thresholds.max_violations + }) + } + + /// Cycles to reach target accuracy (None if not yet reached). + pub fn cycles_to_accuracy(&self) -> Option { + self.points + .iter() + .find(|p| p.accuracy >= self.thresholds.target_accuracy) + .map(|p| p.cycle) + } + + /// Cycles to reach target cost (None if not yet reached). + pub fn cycles_to_cost(&self) -> Option { + self.points + .iter() + .find(|p| p.cost_per_solve <= self.thresholds.target_cost) + .map(|p| p.cycle) + } + + /// Cycles to reach target robustness. + pub fn cycles_to_robustness(&self) -> Option { + self.points + .iter() + .find(|p| p.robustness >= self.thresholds.target_robustness) + .map(|p| p.cycle) + } + + /// Cycles to full convergence (all criteria met). + pub fn cycles_to_convergence(&self) -> Option { + self.points + .iter() + .find(|p| { + p.accuracy >= self.thresholds.target_accuracy + && p.cost_per_solve <= self.thresholds.target_cost + && p.robustness >= self.thresholds.target_robustness + && p.policy_violations <= self.thresholds.max_violations + }) + .map(|p| p.cycle) + } + + /// Area under the accuracy curve (higher = faster learning). + pub fn auc_accuracy(&self) -> f32 { + if self.points.len() < 2 { + return 0.0; + } + self.points + .windows(2) + .map(|w| { + let dx = (w[1].cycle - w[0].cycle) as f32; + let avg_y = (w[0].accuracy + w[1].accuracy) / 2.0; + dx * avg_y + }) + .sum() + } + + /// Compression ratio: how fast the cost curve drops. + /// Computed as initial_cost / final_cost (higher = more compression). + pub fn compression_ratio(&self) -> f32 { + if self.points.len() < 2 { + return 1.0; + } + let initial = self.points.first().unwrap().cost_per_solve; + let final_cost = self.points.last().unwrap().cost_per_solve; + if final_cost > 1e-10 { + initial / final_cost + } else { + initial / 1e-10 + } + } +} + +/// Acceleration scoreboard comparing domain learning curves. +/// Shows acceleration, not just improvement. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccelerationScoreboard { + /// Per-domain cost curves. + pub curves: HashMap, + /// Pairwise acceleration factors. + pub accelerations: Vec, +} + +/// An entry showing how transfer from source to target affected convergence. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccelerationEntry { + /// Source domain. + pub source: DomainId, + /// Target domain. + pub target: DomainId, + /// Cycles to convergence without transfer (baseline). + pub baseline_cycles: Option, + /// Cycles to convergence with transfer. + pub transfer_cycles: Option, + /// Acceleration factor: baseline / transfer (>1 = transfer helped). + pub acceleration: f32, + /// AUC comparison (higher = better learning curve). + pub auc_baseline: f32, + pub auc_transfer: f32, + /// Compression ratio comparison. + pub compression_baseline: f32, + pub compression_transfer: f32, + /// Whether generalization test passed. + pub generalization_passed: bool, +} + +impl AccelerationScoreboard { + pub fn new() -> Self { + Self { + curves: HashMap::new(), + accelerations: Vec::new(), + } + } + + /// Add a cost curve for a domain. + pub fn add_curve(&mut self, curve: CostCurve) { + self.curves.insert(curve.domain_id.clone(), curve); + } + + /// Compute acceleration between a baseline (no transfer) and transfer curve. + pub fn compute_acceleration( + &mut self, + baseline_domain: &DomainId, + transfer_domain: &DomainId, + ) -> Option { + let baseline = self.curves.get(baseline_domain)?; + let transfer = self.curves.get(transfer_domain)?; + + let baseline_cycles = baseline.cycles_to_convergence(); + let transfer_cycles = transfer.cycles_to_convergence(); + + let acceleration = match (baseline_cycles, transfer_cycles) { + (Some(b), Some(t)) if t > 0 => b as f32 / t as f32, + _ => 1.0, // No measurable acceleration + }; + + let entry = AccelerationEntry { + source: transfer + .transfer_source + .clone() + .unwrap_or_else(|| DomainId("none".into())), + target: transfer_domain.clone(), + baseline_cycles, + transfer_cycles, + acceleration, + auc_baseline: baseline.auc_accuracy(), + auc_transfer: transfer.auc_accuracy(), + compression_baseline: baseline.compression_ratio(), + compression_transfer: transfer.compression_ratio(), + generalization_passed: acceleration > 1.0, + }; + + self.accelerations.push(entry.clone()); + Some(entry) + } + + /// Check whether each successive domain converges faster (the IQ growth test). + pub fn progressive_acceleration(&self) -> bool { + if self.accelerations.len() < 2 { + return true; // Not enough data to judge + } + + self.accelerations + .windows(2) + .all(|w| w[1].acceleration >= w[0].acceleration) + } + + /// Summary report of all domains. + pub fn summary(&self) -> ScoreboardSummary { + let domain_summaries: Vec = self + .curves + .iter() + .map(|(id, curve)| DomainSummary { + domain_id: id.clone(), + total_cycles: curve.points.last().map(|p| p.cycle).unwrap_or(0), + final_accuracy: curve.points.last().map(|p| p.accuracy).unwrap_or(0.0), + final_cost: curve.points.last().map(|p| p.cost_per_solve).unwrap_or(f32::MAX), + converged: curve.has_converged(), + cycles_to_convergence: curve.cycles_to_convergence(), + compression_ratio: curve.compression_ratio(), + used_transfer: curve.used_transfer, + }) + .collect(); + + let overall_acceleration = if self.accelerations.is_empty() { + 1.0 + } else { + self.accelerations.iter().map(|a| a.acceleration).sum::() + / self.accelerations.len() as f32 + }; + + ScoreboardSummary { + domains: domain_summaries, + accelerations: self.accelerations.clone(), + overall_acceleration, + progressive_improvement: self.progressive_acceleration(), + } + } +} + +impl Default for AccelerationScoreboard { + fn default() -> Self { + Self::new() + } +} + +/// Summary of a single domain's learning. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DomainSummary { + pub domain_id: DomainId, + pub total_cycles: u64, + pub final_accuracy: f32, + pub final_cost: f32, + pub converged: bool, + pub cycles_to_convergence: Option, + pub compression_ratio: f32, + pub used_transfer: bool, +} + +/// Full scoreboard summary. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScoreboardSummary { + pub domains: Vec, + pub accelerations: Vec, + pub overall_acceleration: f32, + /// True if each new domain converges faster than the previous. + pub progressive_improvement: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_curve( + domain: &str, + transfer: bool, + accuracy_steps: &[(u64, f32, f32)], + ) -> CostCurve { + let mut curve = if transfer { + CostCurve::with_transfer( + DomainId(domain.into()), + DomainId("source".into()), + ConvergenceThresholds::default(), + ) + } else { + CostCurve::new(DomainId(domain.into()), ConvergenceThresholds::default()) + }; + + for &(cycle, accuracy, cost) in accuracy_steps { + curve.record(CostCurvePoint { + cycle, + accuracy, + cost_per_solve: cost, + robustness: accuracy * 0.95, + policy_violations: 0, + timestamp: cycle as f64, + }); + } + curve + } + + #[test] + fn test_cost_curve_convergence() { + let curve = make_curve( + "test", + false, + &[ + (0, 0.3, 0.1), + (10, 0.6, 0.05), + (20, 0.8, 0.02), + (30, 0.95, 0.008), + ], + ); + + assert!(curve.has_converged()); + assert_eq!(curve.cycles_to_accuracy(), Some(30)); + assert_eq!(curve.cycles_to_cost(), Some(30)); + } + + #[test] + fn test_cost_curve_not_converged() { + let curve = make_curve("test", false, &[(0, 0.3, 0.1), (10, 0.6, 0.05)]); + + assert!(!curve.has_converged()); + assert_eq!(curve.cycles_to_accuracy(), None); + } + + #[test] + fn test_compression_ratio() { + let curve = + make_curve("test", false, &[(0, 0.3, 1.0), (10, 0.6, 0.5), (20, 0.9, 0.1)]); + + let ratio = curve.compression_ratio(); + assert!((ratio - 10.0).abs() < 1e-4); // 1.0 / 0.1 = 10x + } + + #[test] + fn test_acceleration_scoreboard() { + let mut board = AccelerationScoreboard::new(); + + // Domain 1: baseline (slow convergence) + let baseline = make_curve( + "d1_baseline", + false, + &[ + (0, 0.2, 0.1), + (20, 0.5, 0.05), + (50, 0.8, 0.02), + (100, 0.95, 0.008), + ], + ); + + // Domain 2: with transfer (fast convergence) + let transfer = make_curve( + "d2_transfer", + true, + &[ + (0, 0.4, 0.08), + (10, 0.7, 0.03), + (20, 0.9, 0.01), + (40, 0.96, 0.007), + ], + ); + + board.add_curve(baseline); + board.add_curve(transfer); + + let entry = board + .compute_acceleration( + &DomainId("d1_baseline".into()), + &DomainId("d2_transfer".into()), + ) + .unwrap(); + + assert!(entry.acceleration > 1.0, "Transfer should accelerate"); + assert_eq!(entry.baseline_cycles, Some(100)); + assert_eq!(entry.transfer_cycles, Some(40)); + assert!((entry.acceleration - 2.5).abs() < 1e-4); + assert!(entry.generalization_passed); + } + + #[test] + fn test_scoreboard_summary() { + let mut board = AccelerationScoreboard::new(); + let curve = make_curve("d1", false, &[(0, 0.5, 0.1), (50, 0.96, 0.005)]); + board.add_curve(curve); + + let summary = board.summary(); + assert_eq!(summary.domains.len(), 1); + assert!(summary.domains[0].converged); + } + + #[test] + fn test_auc_accuracy() { + let curve = make_curve( + "test", + false, + &[(0, 0.0, 1.0), (10, 0.5, 0.5), (20, 1.0, 0.1)], + ); + + let auc = curve.auc_accuracy(); + // Trapezoid: (10*(0+0.5)/2) + (10*(0.5+1.0)/2) = 2.5 + 7.5 = 10.0 + assert!((auc - 10.0).abs() < 1e-4); + } +} diff --git a/crates/ruvector-domain-expansion/src/domain.rs b/crates/ruvector-domain-expansion/src/domain.rs new file mode 100644 index 00000000..61e7daef --- /dev/null +++ b/crates/ruvector-domain-expansion/src/domain.rs @@ -0,0 +1,212 @@ +//! Core domain trait and types for cross-domain transfer learning. +//! +//! A domain defines a problem space with: +//! - A task generator (produces training instances) +//! - An evaluator (scores solutions on [0.0, 1.0]) +//! - Embedding extraction (maps solutions into a shared representation space) +//! +//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2 +//! faster than Domain 2 alone. That is generalization. + +use serde::{Deserialize, Serialize}; +use std::fmt; + +/// Unique identifier for a domain. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct DomainId(pub String); + +impl fmt::Display for DomainId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +/// A single task instance within a domain. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + /// Unique task identifier. + pub id: String, + /// Domain this task belongs to. + pub domain_id: DomainId, + /// Difficulty level [0.0, 1.0]. + pub difficulty: f32, + /// Structured task specification (domain-specific JSON). + pub spec: serde_json::Value, + /// Optional constraints the solution must satisfy. + pub constraints: Vec, +} + +/// A candidate solution to a domain task. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Solution { + /// The task this solves. + pub task_id: String, + /// Raw solution content (e.g., Rust source, plan steps, tool calls). + pub content: String, + /// Structured solution data (domain-specific). + pub data: serde_json::Value, +} + +/// Evaluation result for a solution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Evaluation { + /// Overall score [0.0, 1.0] where 1.0 is perfect. + pub score: f32, + /// Correctness: does it produce the right answer? + pub correctness: f32, + /// Efficiency: resource usage relative to optimal. + pub efficiency: f32, + /// Elegance: structural quality, idiomatic patterns. + pub elegance: f32, + /// Per-constraint pass/fail results. + pub constraint_results: Vec, + /// Diagnostic notes from the evaluator. + pub notes: Vec, +} + +impl Evaluation { + /// Create a zero-score evaluation (failure). + pub fn zero(notes: Vec) -> Self { + Self { + score: 0.0, + correctness: 0.0, + efficiency: 0.0, + elegance: 0.0, + constraint_results: Vec::new(), + notes, + } + } + + /// Compute composite score from weighted sub-scores. + pub fn composite(correctness: f32, efficiency: f32, elegance: f32) -> Self { + let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance; + Self { + score: score.clamp(0.0, 1.0), + correctness, + efficiency, + elegance, + constraint_results: Vec::new(), + notes: Vec::new(), + } + } +} + +/// Embedding vector for cross-domain representation. +/// Solutions from different domains are projected into a shared space +/// so that transfer learning can identify structural similarities. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DomainEmbedding { + /// The embedding vector. + pub vector: Vec, + /// Which domain produced this embedding. + pub domain_id: DomainId, + /// Dimensionality. + pub dim: usize, +} + +impl DomainEmbedding { + /// Create a new embedding. + pub fn new(vector: Vec, domain_id: DomainId) -> Self { + let dim = vector.len(); + Self { + vector, + domain_id, + dim, + } + } + + /// Cosine similarity with another embedding. + pub fn cosine_similarity(&self, other: &DomainEmbedding) -> f32 { + assert_eq!(self.dim, other.dim, "Embedding dimensions must match"); + + let mut dot = 0.0f32; + let mut norm_a = 0.0f32; + let mut norm_b = 0.0f32; + + for i in 0..self.dim { + dot += self.vector[i] * other.vector[i]; + norm_a += self.vector[i] * self.vector[i]; + norm_b += other.vector[i] * other.vector[i]; + } + + let denom = (norm_a.sqrt() * norm_b.sqrt()).max(1e-10); + dot / denom + } +} + +/// Core trait that every domain must implement. +/// +/// Domains are problem spaces: Rust program synthesis, structured planning, +/// tool orchestration, etc. Each domain knows how to generate tasks, +/// evaluate solutions, and embed solutions into a shared representation space. +pub trait Domain: Send + Sync { + /// Unique identifier for this domain. + fn id(&self) -> &DomainId; + + /// Human-readable name. + fn name(&self) -> &str; + + /// Generate a batch of tasks at the given difficulty level. + /// + /// # Arguments + /// * `count` - Number of tasks to generate + /// * `difficulty` - Target difficulty [0.0, 1.0] + fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec; + + /// Evaluate a solution against its task. + fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation; + + /// Project a solution into the shared embedding space. + /// This enables cross-domain transfer by finding structural similarities + /// between solutions across different problem domains. + fn embed(&self, solution: &Solution) -> DomainEmbedding; + + /// Embedding dimensionality for this domain. + fn embedding_dim(&self) -> usize; + + /// Generate a reference (optimal or near-optimal) solution for a task. + /// Used for computing efficiency ratios and as training signal. + fn reference_solution(&self, task: &Task) -> Option; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_domain_id_display() { + let id = DomainId("rust_synthesis".to_string()); + assert_eq!(format!("{}", id), "rust_synthesis"); + } + + #[test] + fn test_evaluation_zero() { + let eval = Evaluation::zero(vec!["compile error".to_string()]); + assert_eq!(eval.score, 0.0); + assert_eq!(eval.notes.len(), 1); + } + + #[test] + fn test_evaluation_composite() { + let eval = Evaluation::composite(1.0, 0.8, 0.6); + // 0.6*1.0 + 0.25*0.8 + 0.15*0.6 = 0.6 + 0.2 + 0.09 = 0.89 + assert!((eval.score - 0.89).abs() < 1e-4); + } + + #[test] + fn test_embedding_cosine_similarity() { + let id = DomainId("test".to_string()); + let a = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone()); + let b = DomainEmbedding::new(vec![1.0, 0.0, 0.0], id.clone()); + assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-6); + + let c = DomainEmbedding::new(vec![0.0, 1.0, 0.0], id); + assert!(a.cosine_similarity(&c).abs() < 1e-6); + } + + #[test] + fn test_evaluation_clamp() { + let eval = Evaluation::composite(1.0, 1.0, 1.0); + assert!(eval.score <= 1.0); + } +} diff --git a/crates/ruvector-domain-expansion/src/error.rs b/crates/ruvector-domain-expansion/src/error.rs new file mode 100644 index 00000000..adfb1018 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/error.rs @@ -0,0 +1,39 @@ +//! Error types for domain expansion. + +use thiserror::Error; + +/// Errors that can occur during domain expansion operations. +#[derive(Error, Debug)] +pub enum DomainError { + /// Problem generation failed. + #[error("problem generation failed: {0}")] + Generation(String), + + /// Solution evaluation failed. + #[error("evaluation failed: {0}")] + Evaluation(String), + + /// Dimension mismatch between domains. + #[error("dimension mismatch: expected {expected}, got {got}")] + DimensionMismatch { expected: usize, got: usize }, + + /// Domain not found in the expansion engine. + #[error("domain not found: {0}")] + DomainNotFound(String), + + /// Transfer failed between domains. + #[error("transfer failed from {source} to {target}: {reason}")] + TransferFailed { + source: String, + target: String, + reason: String, + }, + + /// Kernel has not been trained on any domain yet. + #[error("kernel not initialized: {0}")] + KernelNotInitialized(String), + + /// Invalid configuration. + #[error("invalid config: {0}")] + InvalidConfig(String), +} diff --git a/crates/ruvector-domain-expansion/src/lib.rs b/crates/ruvector-domain-expansion/src/lib.rs new file mode 100644 index 00000000..0a391c0a --- /dev/null +++ b/crates/ruvector-domain-expansion/src/lib.rs @@ -0,0 +1,500 @@ +//! # Domain Expansion Engine +//! +//! Cross-domain transfer learning for general problem-solving capability. +//! +//! ## Core Insight +//! +//! True IQ growth appears when a kernel trained on Domain 1 improves Domain 2 +//! faster than Domain 2 alone. That is generalization. +//! +//! ## Two-Layer Architecture +//! +//! **Policy learning layer**: Meta Thompson Sampling with Beta priors across +//! context buckets. Chooses strategies via uncertainty-aware selection. +//! Transfer happens through compact priors — not raw trajectories. +//! +//! **Operator layer**: Deterministic domain kernels (Rust synthesis, planning, +//! tool orchestration) that generate tasks, evaluate solutions, and produce +//! embeddings into a shared representation space. +//! +//! ## Domains +//! +//! - **Rust Program Synthesis**: Generate Rust functions from specifications +//! - **Structured Planning**: Multi-step plans with dependencies and resources +//! - **Tool Orchestration**: Coordinate multiple tools/agents for complex goals +//! +//! ## Transfer Protocol +//! +//! 1. Train on Domain 1, extract `TransferPrior` (posterior summaries) +//! 2. Initialize Domain 2 with dampened priors from Domain 1 +//! 3. Measure acceleration: cycles to convergence with/without transfer +//! 4. A delta is promotable only if it improves target without regressing source +//! +//! ## Population-Based Policy Search +//! +//! Run a population of `PolicyKernel` variants in parallel. +//! Each variant tunes knobs (skip mode, prepass, speculation thresholds). +//! Keep top performers on holdouts, mutate, repeat. +//! +//! ## Acceptance Test +//! +//! Domain 2 must converge faster than Domain 1 to target accuracy, cost, +//! robustness, and zero policy violations. + +#![warn(missing_docs)] + +pub mod cost_curve; +pub mod domain; +pub mod planning; +pub mod policy_kernel; +pub mod rust_synthesis; +pub mod tool_orchestration; +pub mod transfer; + +// Re-export core types. +pub use cost_curve::{ + AccelerationEntry, AccelerationScoreboard, ConvergenceThresholds, CostCurve, CostCurvePoint, + ScoreboardSummary, +}; +pub use domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task}; +pub use planning::PlanningDomain; +pub use policy_kernel::{PolicyKernel, PolicyKnobs, PopulationSearch, PopulationStats}; +pub use rust_synthesis::RustSynthesisDomain; +pub use tool_orchestration::ToolOrchestrationDomain; +pub use transfer::{ + ArmId, BetaParams, ContextBucket, DualPathResult, MetaThompsonEngine, TransferPrior, + TransferVerification, +}; + +use std::collections::HashMap; + +/// The domain expansion orchestrator. +/// +/// Manages multiple domains, transfer learning between them, +/// population-based policy search, and the acceleration scoreboard. +pub struct DomainExpansionEngine { + /// Registered domains. + domains: HashMap>, + /// Meta Thompson Sampling engine for cross-domain transfer. + pub thompson: MetaThompsonEngine, + /// Population-based policy search. + pub population: PopulationSearch, + /// Acceleration scoreboard tracking convergence across domains. + pub scoreboard: AccelerationScoreboard, + /// Holdout tasks per domain for verification. + holdouts: HashMap>, + /// Counterexample set: failed solutions that inform future decisions. + counterexamples: HashMap>, +} + +impl DomainExpansionEngine { + /// Create a new domain expansion engine with default configuration. + /// + /// Initializes the three core domains and the transfer engine. + pub fn new() -> Self { + let arms = vec![ + "greedy".into(), + "exploratory".into(), + "conservative".into(), + "speculative".into(), + ]; + + let mut engine = Self { + domains: HashMap::new(), + thompson: MetaThompsonEngine::new(arms), + population: PopulationSearch::new(8), + scoreboard: AccelerationScoreboard::new(), + holdouts: HashMap::new(), + counterexamples: HashMap::new(), + }; + + // Register the three core domains. + engine.register_domain(Box::new(RustSynthesisDomain::new())); + engine.register_domain(Box::new(PlanningDomain::new())); + engine.register_domain(Box::new(ToolOrchestrationDomain::new())); + + engine + } + + /// Register a new domain. + pub fn register_domain(&mut self, domain: Box) { + let id = domain.id().clone(); + self.thompson.init_domain_uniform(id.clone()); + self.domains.insert(id, domain); + } + + /// Generate holdout tasks for verification. + pub fn generate_holdouts(&mut self, tasks_per_domain: usize, difficulty: f32) { + for (id, domain) in &self.domains { + let tasks = domain.generate_tasks(tasks_per_domain, difficulty); + self.holdouts.insert(id.clone(), tasks); + } + } + + /// Generate training tasks for a specific domain. + pub fn generate_tasks( + &self, + domain_id: &DomainId, + count: usize, + difficulty: f32, + ) -> Vec { + self.domains + .get(domain_id) + .map(|d| d.generate_tasks(count, difficulty)) + .unwrap_or_default() + } + + /// Evaluate a solution and record the outcome. + pub fn evaluate_and_record( + &mut self, + domain_id: &DomainId, + task: &Task, + solution: &Solution, + bucket: ContextBucket, + arm: ArmId, + ) -> Evaluation { + let eval = self + .domains + .get(domain_id) + .map(|d| d.evaluate(task, solution)) + .unwrap_or_else(|| Evaluation::zero(vec!["Domain not found".into()])); + + // Record outcome in Thompson engine. + self.thompson.record_outcome( + domain_id, + bucket, + arm, + eval.score, + 1.0, // unit cost for now + ); + + // Store counterexamples for poor solutions. + if eval.score < 0.3 { + self.counterexamples + .entry(domain_id.clone()) + .or_default() + .push((task.clone(), solution.clone(), eval.clone())); + } + + eval + } + + /// Embed a solution into the shared representation space. + pub fn embed(&self, domain_id: &DomainId, solution: &Solution) -> Option { + self.domains.get(domain_id).map(|d| d.embed(solution)) + } + + /// Initiate transfer from source domain to target domain. + /// Extracts priors from source and seeds target. + pub fn initiate_transfer(&mut self, source: &DomainId, target: &DomainId) { + if let Some(prior) = self.thompson.extract_prior(source) { + self.thompson + .init_domain_with_transfer(target.clone(), &prior); + } + } + + /// Verify a transfer delta: did it improve target without regressing source? + pub fn verify_transfer( + &self, + source: &DomainId, + target: &DomainId, + source_before: f32, + source_after: f32, + target_before: f32, + target_after: f32, + baseline_cycles: u64, + transfer_cycles: u64, + ) -> TransferVerification { + TransferVerification::verify( + source.clone(), + target.clone(), + source_before, + source_after, + target_before, + target_after, + baseline_cycles, + transfer_cycles, + ) + } + + /// Evaluate all policy kernels on holdout tasks. + pub fn evaluate_population(&mut self) { + let holdout_snapshot: HashMap> = self.holdouts.clone(); + let domain_ids: Vec = self.domains.keys().cloned().collect(); + + for i in 0..self.population.population().len() { + for domain_id in &domain_ids { + if let Some(holdout_tasks) = holdout_snapshot.get(domain_id) { + let mut total_score = 0.0f32; + let mut count = 0; + + for task in holdout_tasks { + if let Some(domain) = self.domains.get(domain_id) { + if let Some(ref_sol) = domain.reference_solution(task) { + let eval = domain.evaluate(task, &ref_sol); + total_score += eval.score; + count += 1; + } + } + } + + let avg_score = if count > 0 { + total_score / count as f32 + } else { + 0.0 + }; + + if let Some(kernel) = self.population.kernel_mut(i) { + kernel.record_score(domain_id.clone(), avg_score, 1.0); + } + } + } + } + } + + /// Evolve the policy kernel population. + pub fn evolve_population(&mut self) { + self.population.evolve(); + } + + /// Get the best policy kernel found so far. + pub fn best_kernel(&self) -> Option<&PolicyKernel> { + self.population.best() + } + + /// Get population statistics. + pub fn population_stats(&self) -> PopulationStats { + self.population.stats() + } + + /// Get the scoreboard summary. + pub fn scoreboard_summary(&self) -> ScoreboardSummary { + self.scoreboard.summary() + } + + /// Get registered domain IDs. + pub fn domain_ids(&self) -> Vec { + self.domains.keys().cloned().collect() + } + + /// Get counterexamples for a domain. + pub fn counterexamples( + &self, + domain_id: &DomainId, + ) -> &[(Task, Solution, Evaluation)] { + self.counterexamples + .get(domain_id) + .map(|v| v.as_slice()) + .unwrap_or(&[]) + } + + /// Select best arm for a context using Thompson Sampling. + pub fn select_arm( + &self, + domain_id: &DomainId, + bucket: &ContextBucket, + ) -> Option { + let mut rng = rand::thread_rng(); + self.thompson.select_arm(domain_id, bucket, &mut rng) + } + + /// Check if dual-path speculation should be triggered. + pub fn should_speculate( + &self, + domain_id: &DomainId, + bucket: &ContextBucket, + ) -> bool { + self.thompson.is_uncertain(domain_id, bucket, 0.15) + } +} + +impl Default for DomainExpansionEngine { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_engine_creation() { + let engine = DomainExpansionEngine::new(); + let ids = engine.domain_ids(); + assert_eq!(ids.len(), 3); + } + + #[test] + fn test_generate_tasks_all_domains() { + let engine = DomainExpansionEngine::new(); + for domain_id in engine.domain_ids() { + let tasks = engine.generate_tasks(&domain_id, 5, 0.5); + assert_eq!(tasks.len(), 5); + } + } + + #[test] + fn test_arm_selection() { + let engine = DomainExpansionEngine::new(); + let bucket = ContextBucket { + difficulty_tier: "medium".into(), + category: "general".into(), + }; + for domain_id in engine.domain_ids() { + let arm = engine.select_arm(&domain_id, &bucket); + assert!(arm.is_some()); + } + } + + #[test] + fn test_evaluate_and_record() { + let mut engine = DomainExpansionEngine::new(); + let domain_id = DomainId("rust_synthesis".into()); + let tasks = engine.generate_tasks(&domain_id, 1, 0.3); + let task = &tasks[0]; + + let solution = Solution { + task_id: task.id.clone(), + content: "fn double(values: &[i64]) -> Vec { values.iter().map(|&x| x * 2).collect() }".into(), + data: serde_json::Value::Null, + }; + + let bucket = ContextBucket { + difficulty_tier: "easy".into(), + category: "transform".into(), + }; + let arm = ArmId("greedy".into()); + + let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm); + assert!(eval.score >= 0.0 && eval.score <= 1.0); + } + + #[test] + fn test_cross_domain_embedding() { + let engine = DomainExpansionEngine::new(); + + let rust_sol = Solution { + task_id: "rust".into(), + content: "fn foo() { for i in 0..10 { if i > 5 { } } }".into(), + data: serde_json::Value::Null, + }; + + let plan_sol = Solution { + task_id: "plan".into(), + content: "allocate cpu then schedule parallel jobs".into(), + data: serde_json::json!({"steps": []}), + }; + + let rust_emb = engine + .embed(&DomainId("rust_synthesis".into()), &rust_sol) + .unwrap(); + let plan_emb = engine + .embed(&DomainId("structured_planning".into()), &plan_sol) + .unwrap(); + + // Embeddings should be same dimension. + assert_eq!(rust_emb.dim, plan_emb.dim); + + // Cross-domain similarity should be defined. + let sim = rust_emb.cosine_similarity(&plan_emb); + assert!(sim >= -1.0 && sim <= 1.0); + } + + #[test] + fn test_transfer_flow() { + let mut engine = DomainExpansionEngine::new(); + let source = DomainId("rust_synthesis".into()); + let target = DomainId("structured_planning".into()); + + // Record some outcomes in source domain. + let bucket = ContextBucket { + difficulty_tier: "medium".into(), + category: "algorithm".into(), + }; + + for _ in 0..30 { + engine.thompson.record_outcome( + &source, + bucket.clone(), + ArmId("greedy".into()), + 0.85, + 1.0, + ); + } + + // Initiate transfer. + engine.initiate_transfer(&source, &target); + + // Verify the transfer. + let verification = engine.verify_transfer( + &source, + &target, + 0.85, // source before + 0.845, // source after (within tolerance) + 0.3, // target before + 0.7, // target after + 100, // baseline cycles + 45, // transfer cycles + ); + + assert!(verification.promotable); + assert!(verification.acceleration_factor > 1.0); + } + + #[test] + fn test_population_evolution() { + let mut engine = DomainExpansionEngine::new(); + engine.generate_holdouts(3, 0.3); + engine.evaluate_population(); + + let stats_before = engine.population_stats(); + assert_eq!(stats_before.generation, 0); + + engine.evolve_population(); + let stats_after = engine.population_stats(); + assert_eq!(stats_after.generation, 1); + } + + #[test] + fn test_speculation_trigger() { + let engine = DomainExpansionEngine::new(); + let bucket = ContextBucket { + difficulty_tier: "hard".into(), + category: "unknown".into(), + }; + + // With uniform priors, should be uncertain. + assert!(engine.should_speculate( + &DomainId("rust_synthesis".into()), + &bucket, + )); + } + + #[test] + fn test_counterexample_tracking() { + let mut engine = DomainExpansionEngine::new(); + let domain_id = DomainId("rust_synthesis".into()); + let tasks = engine.generate_tasks(&domain_id, 1, 0.9); + let task = &tasks[0]; + + // Submit a terrible solution. + let solution = Solution { + task_id: task.id.clone(), + content: "".into(), // empty = bad + data: serde_json::Value::Null, + }; + + let bucket = ContextBucket { + difficulty_tier: "hard".into(), + category: "algorithm".into(), + }; + let arm = ArmId("speculative".into()); + + let eval = engine.evaluate_and_record(&domain_id, task, &solution, bucket, arm); + assert!(eval.score < 0.3); + + // Should be recorded as counterexample. + assert!(!engine.counterexamples(&domain_id).is_empty()); + } +} diff --git a/crates/ruvector-domain-expansion/src/planning.rs b/crates/ruvector-domain-expansion/src/planning.rs new file mode 100644 index 00000000..5700d114 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/planning.rs @@ -0,0 +1,646 @@ +//! Structured Planning Tasks Domain +//! +//! Generates tasks that require multi-step reasoning and plan construction. +//! Task types include: +//! +//! - **ResourceAllocation**: Assign limited resources to maximize objective +//! - **DependencyScheduling**: Order tasks respecting dependencies and deadlines +//! - **StateSpaceSearch**: Navigate from initial to goal state +//! - **ConstraintSatisfaction**: Find assignments satisfying all constraints +//! - **HierarchicalDecomposition**: Break complex goals into sub-goals +//! +//! Solutions are plans: ordered sequences of actions with preconditions and effects. +//! Cross-domain transfer from Rust synthesis helps because both require: +//! structured decomposition, constraint satisfaction, and efficient search. + +use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task}; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +const EMBEDDING_DIM: usize = 64; + +/// Categories of planning tasks. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PlanningCategory { + /// Assign limited resources to competing demands. + ResourceAllocation, + /// Schedule tasks with precedence constraints and deadlines. + DependencyScheduling, + /// Find a path from initial state to goal state. + StateSpaceSearch, + /// Assign values to variables satisfying all constraints. + ConstraintSatisfaction, + /// Decompose a high-level goal into achievable sub-tasks. + HierarchicalDecomposition, +} + +/// A resource in the planning world. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Resource { + pub name: String, + pub capacity: u32, +} + +/// An action in a plan. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlanAction { + pub name: String, + pub preconditions: Vec, + pub effects: Vec, + pub cost: f32, + pub duration: u32, +} + +/// A dependency edge: task A must complete before task B. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Dependency { + pub from: String, + pub to: String, +} + +/// Specification for a planning task. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlanningTaskSpec { + pub category: PlanningCategory, + pub description: String, + /// Available actions in the planning domain. + pub available_actions: Vec, + /// Resources with capacity limits. + pub resources: Vec, + /// Dependency constraints. + pub dependencies: Vec, + /// Initial state predicates. + pub initial_state: Vec, + /// Goal state predicates. + pub goal_state: Vec, + /// Maximum allowed plan cost. + pub max_cost: Option, + /// Maximum allowed plan steps. + pub max_steps: Option, +} + +/// A parsed plan from a solution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Plan { + pub steps: Vec, +} + +/// A single step in a plan. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PlanStep { + pub action: String, + pub args: Vec, + pub start_time: Option, +} + +/// Structured planning domain. +pub struct PlanningDomain { + id: DomainId, +} + +impl PlanningDomain { + pub fn new() -> Self { + Self { + id: DomainId("structured_planning".to_string()), + } + } + + fn gen_resource_allocation(&self, difficulty: f32) -> PlanningTaskSpec { + let num_tasks = if difficulty < 0.3 { + 3 + } else if difficulty < 0.7 { + 6 + } else { + 10 + }; + + let actions: Vec = (0..num_tasks) + .map(|i| PlanAction { + name: format!("task_{}", i), + preconditions: vec![format!("resource_available_{}", i % 3)], + effects: vec![format!("task_{}_complete", i)], + cost: (i as f32 + 1.0) * 10.0, + duration: (i as u32 % 5) + 1, + }) + .collect(); + + let resources = vec![ + Resource { + name: "cpu".into(), + capacity: if difficulty < 0.5 { 10 } else { 5 }, + }, + Resource { + name: "memory".into(), + capacity: if difficulty < 0.5 { 8 } else { 3 }, + }, + Resource { + name: "io".into(), + capacity: if difficulty < 0.5 { 6 } else { 2 }, + }, + ]; + + let goal_state: Vec = (0..num_tasks) + .map(|i| format!("task_{}_complete", i)) + .collect(); + + PlanningTaskSpec { + category: PlanningCategory::ResourceAllocation, + description: format!( + "Allocate {} resources to complete {} tasks within capacity.", + resources.len(), + num_tasks + ), + available_actions: actions, + resources, + dependencies: Vec::new(), + initial_state: vec![ + "resource_available_0".into(), + "resource_available_1".into(), + "resource_available_2".into(), + ], + goal_state, + max_cost: Some(num_tasks as f32 * 50.0), + max_steps: Some(num_tasks * 2), + } + } + + fn gen_dependency_scheduling(&self, difficulty: f32) -> PlanningTaskSpec { + let num_tasks = if difficulty < 0.3 { + 4 + } else if difficulty < 0.7 { + 7 + } else { + 12 + }; + + let actions: Vec = (0..num_tasks) + .map(|i| PlanAction { + name: format!("job_{}", i), + preconditions: if i > 0 { + vec![format!("job_{}_done", i - 1)] + } else { + Vec::new() + }, + effects: vec![format!("job_{}_done", i)], + cost: 1.0, + duration: (i as u32 % 3) + 1, + }) + .collect(); + + // Create dependency chain with some parallelism + let mut dependencies = Vec::new(); + for i in 1..num_tasks { + // Linear chain + dependencies.push(Dependency { + from: format!("job_{}", i - 1), + to: format!("job_{}", i), + }); + // Add cross-dependencies at higher difficulty + if difficulty > 0.5 && i >= 3 && i % 2 == 0 { + dependencies.push(Dependency { + from: format!("job_{}", i - 3), + to: format!("job_{}", i), + }); + } + } + + PlanningTaskSpec { + category: PlanningCategory::DependencyScheduling, + description: format!( + "Schedule {} jobs respecting {} dependencies, minimizing makespan.", + num_tasks, + dependencies.len() + ), + available_actions: actions, + resources: vec![Resource { + name: "worker".into(), + capacity: if difficulty < 0.5 { 3 } else { 2 }, + }], + dependencies, + initial_state: Vec::new(), + goal_state: (0..num_tasks) + .map(|i| format!("job_{}_done", i)) + .collect(), + max_cost: None, + max_steps: Some(num_tasks + 5), + } + } + + fn gen_state_space_search(&self, difficulty: f32) -> PlanningTaskSpec { + let grid_size = if difficulty < 0.3 { + 3 + } else if difficulty < 0.7 { + 5 + } else { + 8 + }; + + let actions = vec![ + PlanAction { + name: "move_up".into(), + preconditions: vec!["not_top_edge".into()], + effects: vec!["moved_up".into()], + cost: 1.0, + duration: 1, + }, + PlanAction { + name: "move_down".into(), + preconditions: vec!["not_bottom_edge".into()], + effects: vec!["moved_down".into()], + cost: 1.0, + duration: 1, + }, + PlanAction { + name: "move_left".into(), + preconditions: vec!["not_left_edge".into()], + effects: vec!["moved_left".into()], + cost: 1.0, + duration: 1, + }, + PlanAction { + name: "move_right".into(), + preconditions: vec!["not_right_edge".into()], + effects: vec!["moved_right".into()], + cost: 1.0, + duration: 1, + }, + ]; + + PlanningTaskSpec { + category: PlanningCategory::StateSpaceSearch, + description: format!( + "Navigate a {}x{} grid from (0,0) to ({},{}) avoiding obstacles.", + grid_size, + grid_size, + grid_size - 1, + grid_size - 1 + ), + available_actions: actions, + resources: Vec::new(), + dependencies: Vec::new(), + initial_state: vec!["at(0,0)".into()], + goal_state: vec![format!("at({},{})", grid_size - 1, grid_size - 1)], + max_cost: Some((grid_size as f32) * 4.0), + max_steps: Some(grid_size * grid_size), + } + } + + /// Extract structural features from a planning solution. + fn extract_features(&self, solution: &Solution) -> Vec { + let content = &solution.content; + let mut features = vec![0.0f32; EMBEDDING_DIM]; + + // Parse the plan + let plan: Plan = serde_json::from_str(&solution.data.to_string()) + .or_else(|_| serde_json::from_str(content)) + .unwrap_or(Plan { steps: Vec::new() }); + + // Feature 0-7: Plan structure + features[0] = plan.steps.len() as f32 / 20.0; + features[1] = { + let unique_actions: std::collections::HashSet<&str> = + plan.steps.iter().map(|s| s.action.as_str()).collect(); + unique_actions.len() as f32 / plan.steps.len().max(1) as f32 + }; + // Sequential vs parallel indicator + features[2] = plan + .steps + .windows(2) + .filter(|w| w[0].start_time == w[1].start_time) + .count() as f32 + / plan.steps.len().max(1) as f32; + // Average args per step + features[3] = plan.steps.iter().map(|s| s.args.len() as f32).sum::() + / plan.steps.len().max(1) as f32 + / 5.0; + + // Feature 8-15: Action type distribution + let action_counts: std::collections::HashMap<&str, usize> = + plan.steps.iter().fold(std::collections::HashMap::new(), |mut acc, s| { + *acc.entry(s.action.as_str()).or_insert(0) += 1; + acc + }); + let max_count = action_counts.values().max().copied().unwrap_or(0); + features[8] = action_counts.len() as f32 / 10.0; + features[9] = max_count as f32 / plan.steps.len().max(1) as f32; + + // Feature 16-23: Text-based features from content + features[16] = content.matches("allocate").count() as f32 / 5.0; + features[17] = content.matches("schedule").count() as f32 / 5.0; + features[18] = content.matches("move").count() as f32 / 10.0; + features[19] = content.matches("assign").count() as f32 / 5.0; + features[20] = content.matches("wait").count() as f32 / 5.0; + features[21] = content.matches("parallel").count() as f32 / 3.0; + features[22] = content.matches("constraint").count() as f32 / 5.0; + features[23] = content.matches("deadline").count() as f32 / 3.0; + + // Feature 32-39: Structural complexity indicators + features[32] = content.matches("->").count() as f32 / 10.0; + features[33] = content.matches("if ").count() as f32 / 5.0; + features[34] = content.matches("then ").count() as f32 / 5.0; + features[35] = content.matches("before").count() as f32 / 5.0; + features[36] = content.matches("after").count() as f32 / 5.0; + features[37] = content.matches("while").count() as f32 / 3.0; + features[38] = content.matches("until").count() as f32 / 3.0; + features[39] = content.matches("complete").count() as f32 / 5.0; + + // Feature 48-55: Resource usage indicators + features[48] = content.matches("cpu").count() as f32 / 3.0; + features[49] = content.matches("memory").count() as f32 / 3.0; + features[50] = content.matches("worker").count() as f32 / 3.0; + features[51] = content.matches("capacity").count() as f32 / 3.0; + features[52] = content.matches("cost").count() as f32 / 5.0; + features[53] = content.matches("time").count() as f32 / 5.0; + features[54] = content.matches("resource").count() as f32 / 5.0; + features[55] = content.matches("limit").count() as f32 / 3.0; + + // Normalize + let norm: f32 = features.iter().map(|x| x * x).sum::().sqrt(); + if norm > 1e-10 { + for f in &mut features { + *f /= norm; + } + } + + features + } + + /// Evaluate a planning solution. + fn score_plan(&self, spec: &PlanningTaskSpec, solution: &Solution) -> Evaluation { + let content = &solution.content; + let mut correctness = 0.0f32; + let mut efficiency = 0.5f32; + let mut elegance = 0.5f32; + let mut notes = Vec::new(); + + // Parse plan from solution + let plan: Option = serde_json::from_str(&solution.data.to_string()) + .ok() + .or_else(|| serde_json::from_str(content).ok()); + + let plan = match plan { + Some(p) => p, + None => { + // Fall back to text analysis + let has_steps = content.contains("step") || content.contains("action"); + if has_steps { + correctness = 0.2; + } + return Evaluation { + score: correctness * 0.6, + correctness, + efficiency: 0.0, + elegance: 0.0, + constraint_results: Vec::new(), + notes: vec!["Could not parse structured plan".into()], + }; + } + }; + + // Check plan is non-empty + if plan.steps.is_empty() { + return Evaluation::zero(vec!["Empty plan".into()]); + } + + // Check goal coverage: how many goal predicates are addressed + let goal_coverage = spec + .goal_state + .iter() + .filter(|goal| { + plan.steps.iter().any(|step| { + let action_name = &step.action; + // Check if any action's effects mention this goal + spec.available_actions + .iter() + .any(|a| a.name == *action_name && a.effects.iter().any(|e| e == *goal)) + }) + }) + .count() as f32 + / spec.goal_state.len().max(1) as f32; + + correctness = goal_coverage; + + // Check dependency ordering + let mut dep_violations = 0; + for dep in &spec.dependencies { + let from_pos = plan.steps.iter().position(|s| s.action == dep.from); + let to_pos = plan.steps.iter().position(|s| s.action == dep.to); + if let (Some(f), Some(t)) = (from_pos, to_pos) { + if f >= t { + dep_violations += 1; + notes.push(format!( + "Dependency violation: {} must come before {}", + dep.from, dep.to + )); + } + } + } + if !spec.dependencies.is_empty() { + let dep_score = + 1.0 - (dep_violations as f32 / spec.dependencies.len() as f32); + correctness = correctness * 0.5 + dep_score * 0.5; + } + + // Efficiency: compare to max allowed steps/cost + if let Some(max_steps) = spec.max_steps { + let step_ratio = plan.steps.len() as f32 / max_steps as f32; + efficiency = if step_ratio <= 1.0 { + 1.0 - (step_ratio * 0.5) // Fewer steps = better + } else { + 0.5 / step_ratio // Penalty for exceeding max + }; + } + + if let Some(max_cost) = spec.max_cost { + let total_cost: f32 = plan + .steps + .iter() + .filter_map(|step| { + spec.available_actions + .iter() + .find(|a| a.name == step.action) + .map(|a| a.cost) + }) + .sum(); + if total_cost > max_cost { + efficiency *= 0.5; + notes.push(format!( + "Plan cost {:.1} exceeds budget {:.1}", + total_cost, max_cost + )); + } + } + + // Elegance: minimal redundancy, good parallelism + let unique_actions: std::collections::HashSet<&str> = + plan.steps.iter().map(|s| s.action.as_str()).collect(); + let redundancy = 1.0 - (unique_actions.len() as f32 / plan.steps.len().max(1) as f32); + elegance = 1.0 - redundancy * 0.5; + + // Bonus for parallel scheduling + if plan.steps.windows(2).any(|w| w[0].start_time == w[1].start_time) { + elegance += 0.1; + } + elegance = elegance.clamp(0.0, 1.0); + + let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance; + Evaluation { + score: score.clamp(0.0, 1.0), + correctness, + efficiency, + elegance, + constraint_results: Vec::new(), + notes, + } + } +} + +impl Default for PlanningDomain { + fn default() -> Self { + Self::new() + } +} + +impl Domain for PlanningDomain { + fn id(&self) -> &DomainId { + &self.id + } + + fn name(&self) -> &str { + "Structured Planning" + } + + fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec { + let mut rng = rand::thread_rng(); + let difficulty = difficulty.clamp(0.0, 1.0); + + (0..count) + .map(|i| { + let category_roll: f32 = rng.gen(); + let spec = if category_roll < 0.35 { + self.gen_resource_allocation(difficulty) + } else if category_roll < 0.7 { + self.gen_dependency_scheduling(difficulty) + } else { + self.gen_state_space_search(difficulty) + }; + + Task { + id: format!("planning_{}_d{:.0}", i, difficulty * 100.0), + domain_id: self.id.clone(), + difficulty, + spec: serde_json::to_value(&spec).unwrap_or_default(), + constraints: Vec::new(), + } + }) + .collect() + } + + fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation { + let spec: PlanningTaskSpec = match serde_json::from_value(task.spec.clone()) { + Ok(s) => s, + Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]), + }; + self.score_plan(&spec, solution) + } + + fn embed(&self, solution: &Solution) -> DomainEmbedding { + let features = self.extract_features(solution); + DomainEmbedding::new(features, self.id.clone()) + } + + fn embedding_dim(&self) -> usize { + EMBEDDING_DIM + } + + fn reference_solution(&self, task: &Task) -> Option { + let spec: PlanningTaskSpec = serde_json::from_value(task.spec.clone()).ok()?; + + // Generate a naive sequential plan that executes all actions in order + let steps: Vec = spec + .available_actions + .iter() + .enumerate() + .map(|(i, a)| PlanStep { + action: a.name.clone(), + args: Vec::new(), + start_time: Some(i as u32), + }) + .collect(); + + let plan = Plan { steps }; + let content = serde_json::to_string_pretty(&plan).ok()?; + + Some(Solution { + task_id: task.id.clone(), + content, + data: serde_json::to_value(&plan).ok()?, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_planning_tasks() { + let domain = PlanningDomain::new(); + let tasks = domain.generate_tasks(5, 0.5); + assert_eq!(tasks.len(), 5); + for task in &tasks { + assert_eq!(task.domain_id, domain.id); + } + } + + #[test] + fn test_reference_solution_exists() { + let domain = PlanningDomain::new(); + let tasks = domain.generate_tasks(3, 0.3); + for task in &tasks { + let ref_sol = domain.reference_solution(task); + assert!(ref_sol.is_some(), "Should produce reference solution"); + } + } + + #[test] + fn test_evaluate_reference() { + let domain = PlanningDomain::new(); + let tasks = domain.generate_tasks(3, 0.3); + for task in &tasks { + if let Some(solution) = domain.reference_solution(task) { + let eval = domain.evaluate(task, &solution); + assert!(eval.score >= 0.0 && eval.score <= 1.0); + } + } + } + + #[test] + fn test_embed_planning() { + let domain = PlanningDomain::new(); + let solution = Solution { + task_id: "test".into(), + content: "allocate cpu to task_0, schedule job_1 after job_0".into(), + data: serde_json::json!({ "steps": [] }), + }; + let embedding = domain.embed(&solution); + assert_eq!(embedding.dim, EMBEDDING_DIM); + } + + #[test] + fn test_difficulty_scaling() { + let domain = PlanningDomain::new(); + let easy = domain.generate_tasks(1, 0.1); + let hard = domain.generate_tasks(1, 0.9); + + let easy_spec: PlanningTaskSpec = + serde_json::from_value(easy[0].spec.clone()).unwrap(); + let hard_spec: PlanningTaskSpec = + serde_json::from_value(hard[0].spec.clone()).unwrap(); + + assert!( + hard_spec.available_actions.len() >= easy_spec.available_actions.len(), + "Harder tasks should have more actions" + ); + } +} diff --git a/crates/ruvector-domain-expansion/src/policy_kernel.rs b/crates/ruvector-domain-expansion/src/policy_kernel.rs new file mode 100644 index 00000000..0307dbe0 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/policy_kernel.rs @@ -0,0 +1,463 @@ +//! PolicyKernel: Population-Based Policy Search +//! +//! Run a small population of policy variants in parallel. +//! Each variant changes a small set of knobs: +//! - skip mode policy +//! - prepass mode +//! - speculation trigger thresholds +//! - budget allocation +//! +//! Selection: keep top performers on holdouts, mutate knobs, repeat. +//! Only merge deltas that pass replay-verify. + +use crate::domain::DomainId; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Configuration knobs that a PolicyKernel can tune. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyKnobs { + /// Whether to skip low-value operations. + pub skip_mode: bool, + /// Run a cheaper prepass before full execution. + pub prepass_enabled: bool, + /// Threshold for triggering speculative dual-path [0.0, 1.0]. + pub speculation_threshold: f32, + /// Budget fraction allocated to exploration vs exploitation [0.0, 1.0]. + pub exploration_budget: f32, + /// Maximum retries on failure. + pub max_retries: u32, + /// Batch size for parallel evaluation. + pub batch_size: usize, + /// Cost decay factor for EMA. + pub cost_decay: f32, + /// Minimum confidence to skip uncertainty check. + pub confidence_floor: f32, +} + +impl PolicyKnobs { + /// Sensible defaults. + pub fn default_knobs() -> Self { + Self { + skip_mode: false, + prepass_enabled: true, + speculation_threshold: 0.15, + exploration_budget: 0.2, + max_retries: 2, + batch_size: 8, + cost_decay: 0.9, + confidence_floor: 0.7, + } + } + + /// Mutate knobs with small random perturbations. + pub fn mutate(&self, rng: &mut impl Rng, mutation_rate: f32) -> Self { + let mut knobs = self.clone(); + + if rng.gen::() < mutation_rate { + knobs.skip_mode = !knobs.skip_mode; + } + if rng.gen::() < mutation_rate { + knobs.prepass_enabled = !knobs.prepass_enabled; + } + if rng.gen::() < mutation_rate { + let delta: f32 = rng.gen_range(-0.1..0.1); + knobs.speculation_threshold = (knobs.speculation_threshold + delta).clamp(0.01, 0.5); + } + if rng.gen::() < mutation_rate { + let delta: f32 = rng.gen_range(-0.1..0.1); + knobs.exploration_budget = (knobs.exploration_budget + delta).clamp(0.01, 0.5); + } + if rng.gen::() < mutation_rate { + knobs.max_retries = rng.gen_range(0..5); + } + if rng.gen::() < mutation_rate { + knobs.batch_size = rng.gen_range(1..32); + } + if rng.gen::() < mutation_rate { + let delta: f32 = rng.gen_range(-0.05..0.05); + knobs.cost_decay = (knobs.cost_decay + delta).clamp(0.5, 0.99); + } + if rng.gen::() < mutation_rate { + let delta: f32 = rng.gen_range(-0.1..0.1); + knobs.confidence_floor = (knobs.confidence_floor + delta).clamp(0.3, 0.95); + } + + knobs + } + + /// Crossover two parent knobs to produce a child. + pub fn crossover(&self, other: &PolicyKnobs, rng: &mut impl Rng) -> Self { + Self { + skip_mode: if rng.gen() { self.skip_mode } else { other.skip_mode }, + prepass_enabled: if rng.gen() { + self.prepass_enabled + } else { + other.prepass_enabled + }, + speculation_threshold: if rng.gen() { + self.speculation_threshold + } else { + other.speculation_threshold + }, + exploration_budget: if rng.gen() { + self.exploration_budget + } else { + other.exploration_budget + }, + max_retries: if rng.gen() { + self.max_retries + } else { + other.max_retries + }, + batch_size: if rng.gen() { + self.batch_size + } else { + other.batch_size + }, + cost_decay: if rng.gen() { + self.cost_decay + } else { + other.cost_decay + }, + confidence_floor: if rng.gen() { + self.confidence_floor + } else { + other.confidence_floor + }, + } + } +} + +/// A PolicyKernel is a versioned policy configuration with performance history. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyKernel { + /// Unique identifier. + pub id: String, + /// Configuration knobs. + pub knobs: PolicyKnobs, + /// Performance on holdout tasks (domain_id -> score). + pub holdout_scores: HashMap, + /// Total cost incurred. + pub total_cost: f32, + /// Number of evaluation cycles. + pub cycles: u64, + /// Generation (0 = initial, increments on mutation). + pub generation: u32, + /// Parent kernel ID (for lineage tracking). + pub parent_id: Option, + /// Whether this kernel has been verified via replay. + pub replay_verified: bool, +} + +impl PolicyKernel { + /// Create a new kernel with default knobs. + pub fn new(id: String) -> Self { + Self { + id, + knobs: PolicyKnobs::default_knobs(), + holdout_scores: HashMap::new(), + total_cost: 0.0, + cycles: 0, + generation: 0, + parent_id: None, + replay_verified: false, + } + } + + /// Create a mutated child kernel. + pub fn mutate(&self, child_id: String, rng: &mut impl Rng) -> Self { + Self { + id: child_id, + knobs: self.knobs.mutate(rng, 0.3), + holdout_scores: HashMap::new(), + total_cost: 0.0, + cycles: 0, + generation: self.generation + 1, + parent_id: Some(self.id.clone()), + replay_verified: false, + } + } + + /// Record a holdout score for a domain. + pub fn record_score(&mut self, domain_id: DomainId, score: f32, cost: f32) { + self.holdout_scores.insert(domain_id, score); + self.total_cost += cost; + self.cycles += 1; + } + + /// Fitness: average holdout score across all evaluated domains. + pub fn fitness(&self) -> f32 { + if self.holdout_scores.is_empty() { + return 0.0; + } + let total: f32 = self.holdout_scores.values().sum(); + total / self.holdout_scores.len() as f32 + } + + /// Cost-adjusted fitness: penalizes expensive kernels. + pub fn cost_adjusted_fitness(&self) -> f32 { + let raw = self.fitness(); + let cost_penalty = (self.total_cost / self.cycles.max(1) as f32).min(1.0); + raw * (1.0 - cost_penalty * 0.3) // 30% weight on cost + } +} + +/// Population-based policy search engine. +#[derive(Clone)] +pub struct PopulationSearch { + /// Current population of kernels. + population: Vec, + /// Population size. + pop_size: usize, + /// Best kernel seen so far. + best_kernel: Option, + /// Generation counter. + generation: u32, +} + +impl PopulationSearch { + /// Create a new population search with initial random population. + pub fn new(pop_size: usize) -> Self { + let mut rng = rand::thread_rng(); + let population: Vec = (0..pop_size) + .map(|i| { + let mut kernel = PolicyKernel::new(format!("kernel_g0_{}", i)); + // Random initial knobs + kernel.knobs = PolicyKnobs::default_knobs().mutate(&mut rng, 0.8); + kernel + }) + .collect(); + + Self { + population, + pop_size, + best_kernel: None, + generation: 0, + } + } + + /// Get current population for evaluation. + pub fn population(&self) -> &[PolicyKernel] { + &self.population + } + + /// Get mutable reference to a kernel by index. + pub fn kernel_mut(&mut self, index: usize) -> Option<&mut PolicyKernel> { + self.population.get_mut(index) + } + + /// Evolve to next generation: select top performers, mutate, fill population. + pub fn evolve(&mut self) { + let mut rng = rand::thread_rng(); + self.generation += 1; + + // Sort by cost-adjusted fitness (descending) + self.population + .sort_by(|a, b| { + b.cost_adjusted_fitness() + .partial_cmp(&a.cost_adjusted_fitness()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // Track best + if let Some(best) = self.population.first() { + if self + .best_kernel + .as_ref() + .map_or(true, |b| best.fitness() > b.fitness()) + { + self.best_kernel = Some(best.clone()); + } + } + + // Elite selection: keep top 25% + let elite_count = (self.pop_size / 4).max(1); + let elites: Vec = self.population[..elite_count].to_vec(); + + // Build next generation + let mut next_gen = Vec::with_capacity(self.pop_size); + + // Keep elites + for elite in &elites { + let mut kept = elite.clone(); + kept.id = format!("kernel_g{}_{}", self.generation, next_gen.len()); + kept.holdout_scores.clear(); + kept.total_cost = 0.0; + kept.cycles = 0; + next_gen.push(kept); + } + + // Fill rest with mutations and crossovers + while next_gen.len() < self.pop_size { + let parent_idx = rng.gen_range(0..elites.len()); + let child_id = format!("kernel_g{}_{}", self.generation, next_gen.len()); + + let child = if rng.gen::() < 0.3 && elites.len() > 1 { + // Crossover + let other_idx = (parent_idx + 1 + rng.gen_range(0..elites.len() - 1)) % elites.len(); + let mut child = PolicyKernel::new(child_id); + child.knobs = elites[parent_idx] + .knobs + .crossover(&elites[other_idx].knobs, &mut rng); + child.generation = self.generation; + child.parent_id = Some(elites[parent_idx].id.clone()); + child + } else { + // Mutation + elites[parent_idx].mutate(child_id, &mut rng) + }; + + next_gen.push(child); + } + + self.population = next_gen; + } + + /// Get the best kernel found so far. + pub fn best(&self) -> Option<&PolicyKernel> { + self.best_kernel.as_ref() + } + + /// Current generation number. + pub fn generation(&self) -> u32 { + self.generation + } + + /// Get fitness statistics for the current population. + pub fn stats(&self) -> PopulationStats { + let fitnesses: Vec = self.population.iter().map(|k| k.fitness()).collect(); + let mean = fitnesses.iter().sum::() / fitnesses.len().max(1) as f32; + let max = fitnesses + .iter() + .cloned() + .fold(f32::NEG_INFINITY, f32::max); + let min = fitnesses.iter().cloned().fold(f32::INFINITY, f32::min); + let variance = fitnesses.iter().map(|f| (f - mean).powi(2)).sum::() + / fitnesses.len().max(1) as f32; + + PopulationStats { + generation: self.generation, + pop_size: self.population.len(), + mean_fitness: mean, + max_fitness: max, + min_fitness: min, + fitness_variance: variance, + best_ever_fitness: self.best_kernel.as_ref().map(|k| k.fitness()).unwrap_or(0.0), + } + } +} + +/// Statistics about the current population. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PopulationStats { + pub generation: u32, + pub pop_size: usize, + pub mean_fitness: f32, + pub max_fitness: f32, + pub min_fitness: f32, + pub fitness_variance: f32, + pub best_ever_fitness: f32, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_policy_knobs_default() { + let knobs = PolicyKnobs::default_knobs(); + assert!(!knobs.skip_mode); + assert!(knobs.prepass_enabled); + assert!(knobs.speculation_threshold > 0.0); + } + + #[test] + fn test_policy_knobs_mutate() { + let knobs = PolicyKnobs::default_knobs(); + let mut rng = rand::thread_rng(); + let mutated = knobs.mutate(&mut rng, 1.0); // high mutation rate + // At least something should differ (probabilistically) + // Can't guarantee due to randomness, but bounds should hold + assert!(mutated.speculation_threshold >= 0.01 && mutated.speculation_threshold <= 0.5); + assert!(mutated.exploration_budget >= 0.01 && mutated.exploration_budget <= 0.5); + } + + #[test] + fn test_policy_kernel_fitness() { + let mut kernel = PolicyKernel::new("test".into()); + assert_eq!(kernel.fitness(), 0.0); + + kernel.record_score(DomainId("d1".into()), 0.8, 1.0); + kernel.record_score(DomainId("d2".into()), 0.6, 1.0); + assert!((kernel.fitness() - 0.7).abs() < 1e-6); + } + + #[test] + fn test_population_search_evolve() { + let mut search = PopulationSearch::new(8); + assert_eq!(search.population().len(), 8); + + // Simulate evaluation + for i in 0..8 { + if let Some(kernel) = search.kernel_mut(i) { + let score = 0.3 + (i as f32) * 0.08; + kernel.record_score(DomainId("test".into()), score, 1.0); + } + } + + search.evolve(); + assert_eq!(search.population().len(), 8); + assert_eq!(search.generation(), 1); + assert!(search.best().is_some()); + } + + #[test] + fn test_population_stats() { + let mut search = PopulationSearch::new(4); + + for i in 0..4 { + if let Some(kernel) = search.kernel_mut(i) { + kernel.record_score(DomainId("test".into()), (i as f32) * 0.25, 1.0); + } + } + + let stats = search.stats(); + assert_eq!(stats.pop_size, 4); + assert!(stats.max_fitness >= stats.min_fitness); + assert!(stats.mean_fitness >= stats.min_fitness); + assert!(stats.mean_fitness <= stats.max_fitness); + } + + #[test] + fn test_crossover() { + let a = PolicyKnobs { + skip_mode: true, + prepass_enabled: false, + speculation_threshold: 0.1, + exploration_budget: 0.1, + max_retries: 1, + batch_size: 4, + cost_decay: 0.8, + confidence_floor: 0.5, + }; + let b = PolicyKnobs { + skip_mode: false, + prepass_enabled: true, + speculation_threshold: 0.4, + exploration_budget: 0.4, + max_retries: 4, + batch_size: 16, + cost_decay: 0.95, + confidence_floor: 0.9, + }; + + let mut rng = rand::thread_rng(); + let child = a.crossover(&b, &mut rng); + + // Child values should come from one parent or the other + assert!(child.max_retries == 1 || child.max_retries == 4); + assert!(child.batch_size == 4 || child.batch_size == 16); + } +} diff --git a/crates/ruvector-domain-expansion/src/rust_synthesis.rs b/crates/ruvector-domain-expansion/src/rust_synthesis.rs new file mode 100644 index 00000000..6ddb74a5 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/rust_synthesis.rs @@ -0,0 +1,601 @@ +//! Rust Program Synthesis Domain +//! +//! Generates tasks that require synthesizing Rust programs from specifications. +//! Task types include: +//! +//! - **Transform**: Apply a function to data (map, filter, fold) +//! - **DataStructure**: Implement a data structure with specific operations +//! - **Algorithm**: Implement a named algorithm (sorting, searching, graph) +//! - **TypeLevel**: Express constraints via Rust's type system +//! - **Concurrency**: Safe concurrent data access patterns +//! +//! Solutions are evaluated on correctness (do test cases pass?), +//! efficiency (complexity class), and elegance (idiomatic Rust patterns). + +use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task}; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +/// Embedding dimension for Rust synthesis domain. +const EMBEDDING_DIM: usize = 64; + +/// Categories of Rust synthesis tasks. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum RustTaskCategory { + /// Transform data: map, filter, fold, scan. + Transform, + /// Implement a data structure with trait impls. + DataStructure, + /// Implement a named algorithm. + Algorithm, + /// Type-level programming: generics, trait bounds, associated types. + TypeLevel, + /// Concurrent programming: Arc, Mutex, channels, atomics. + Concurrency, +} + +/// Specification for a Rust synthesis task. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RustTaskSpec { + /// Task category. + pub category: RustTaskCategory, + /// Function signature that must be implemented. + pub signature: String, + /// Natural language description of the required behavior. + pub description: String, + /// Test cases as (input_json, expected_output_json) pairs. + pub test_cases: Vec<(String, String)>, + /// Required traits the solution must implement. + pub required_traits: Vec, + /// Banned patterns (e.g., "unsafe", "unwrap"). + pub banned_patterns: Vec, + /// Expected complexity class (e.g., "O(n log n)"). + pub expected_complexity: Option, +} + +/// Rust program synthesis domain. +pub struct RustSynthesisDomain { + id: DomainId, +} + +impl RustSynthesisDomain { + /// Create a new Rust synthesis domain. + pub fn new() -> Self { + Self { + id: DomainId("rust_synthesis".to_string()), + } + } + + /// Generate a transform task at the given difficulty. + fn gen_transform(&self, difficulty: f32, rng: &mut impl Rng) -> RustTaskSpec { + let (signature, description, tests, complexity) = if difficulty < 0.3 { + // Easy: simple map + let ops = ["double", "negate", "abs", "square"]; + let op = ops[rng.gen_range(0..ops.len())]; + ( + format!("fn {}(values: &[i64]) -> Vec", op), + format!("Apply {} to each element in the slice.", op), + match op { + "double" => vec![ + ("[1, 2, 3]".into(), "[2, 4, 6]".into()), + ("[-1, 0, 5]".into(), "[-2, 0, 10]".into()), + ], + "negate" => vec![ + ("[1, -2, 3]".into(), "[-1, 2, -3]".into()), + ("[0]".into(), "[0]".into()), + ], + "abs" => vec![ + ("[-1, 2, -3]".into(), "[1, 2, 3]".into()), + ("[0, -0]".into(), "[0, 0]".into()), + ], + _ => vec![ + ("[2, 3, 4]".into(), "[4, 9, 16]".into()), + ("[0, -1]".into(), "[0, 1]".into()), + ], + }, + "O(n)", + ) + } else if difficulty < 0.7 { + // Medium: filter + fold combos + ( + "fn sum_positives(values: &[i64]) -> i64".into(), + "Sum all positive values in the slice.".into(), + vec![ + ("[1, -2, 3, -4, 5]".into(), "9".into()), + ("[-1, -2, -3]".into(), "0".into()), + ("[]".into(), "0".into()), + ], + "O(n)", + ) + } else { + // Hard: sliding window / scan + ( + "fn max_subarray_sum(values: &[i64]) -> i64".into(), + "Find the maximum sum contiguous subarray (Kadane's algorithm).".into(), + vec![ + ("[-2, 1, -3, 4, -1, 2, 1, -5, 4]".into(), "6".into()), + ("[-1, -2, -3]".into(), "-1".into()), + ("[5]".into(), "5".into()), + ], + "O(n)", + ) + }; + + RustTaskSpec { + category: RustTaskCategory::Transform, + signature, + description, + test_cases: tests, + required_traits: Vec::new(), + banned_patterns: vec!["unsafe".into()], + expected_complexity: Some(complexity.into()), + } + } + + /// Generate a data structure task. + fn gen_data_structure(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec { + if difficulty < 0.4 { + RustTaskSpec { + category: RustTaskCategory::DataStructure, + signature: "struct Stack".into(), + description: "Implement a generic stack with push, pop, peek, is_empty, len." + .into(), + test_cases: vec![ + ("push(1); push(2); pop()".into(), "Some(2)".into()), + ("is_empty()".into(), "true".into()), + ("push(1); len()".into(), "1".into()), + ], + required_traits: vec!["Default".into()], + banned_patterns: vec!["unsafe".into()], + expected_complexity: Some("O(1) per operation".into()), + } + } else if difficulty < 0.7 { + RustTaskSpec { + category: RustTaskCategory::DataStructure, + signature: "struct MinHeap".into(), + description: "Implement a binary min-heap with insert, extract_min, peek_min." + .into(), + test_cases: vec![ + ( + "insert(3); insert(1); insert(2); extract_min()".into(), + "Some(1)".into(), + ), + ("peek_min() on empty".into(), "None".into()), + ], + required_traits: vec!["Default".into()], + banned_patterns: vec!["unsafe".into(), "BinaryHeap".into()], + expected_complexity: Some("O(log n) insert/extract".into()), + } + } else { + RustTaskSpec { + category: RustTaskCategory::DataStructure, + signature: "struct LRUCache".into(), + description: + "Implement an LRU cache with get, put, and capacity eviction.".into(), + test_cases: vec![ + ( + "cap=2; put(1,'a'); put(2,'b'); get(1); put(3,'c'); get(2)".into(), + "None".into(), + ), + ("cap=1; put(1,'a'); put(2,'b'); get(1)".into(), "None".into()), + ], + required_traits: Vec::new(), + banned_patterns: vec!["unsafe".into()], + expected_complexity: Some("O(1) get/put".into()), + } + } + } + + /// Generate an algorithm task. + fn gen_algorithm(&self, difficulty: f32, _rng: &mut impl Rng) -> RustTaskSpec { + if difficulty < 0.4 { + RustTaskSpec { + category: RustTaskCategory::Algorithm, + signature: "fn binary_search(sorted: &[i64], target: i64) -> Option".into(), + description: "Implement binary search on a sorted slice.".into(), + test_cases: vec![ + ("[1,3,5,7,9], 5".into(), "Some(2)".into()), + ("[1,3,5,7,9], 4".into(), "None".into()), + ("[], 1".into(), "None".into()), + ], + required_traits: Vec::new(), + banned_patterns: vec!["unsafe".into()], + expected_complexity: Some("O(log n)".into()), + } + } else if difficulty < 0.7 { + RustTaskSpec { + category: RustTaskCategory::Algorithm, + signature: "fn merge_sort(values: &mut [i64])".into(), + description: "Implement stable merge sort in-place.".into(), + test_cases: vec![ + ("[3,1,4,1,5,9,2,6]".into(), "[1,1,2,3,4,5,6,9]".into()), + ("[1]".into(), "[1]".into()), + ("[]".into(), "[]".into()), + ], + required_traits: Vec::new(), + banned_patterns: vec!["unsafe".into(), ".sort".into()], + expected_complexity: Some("O(n log n)".into()), + } + } else { + RustTaskSpec { + category: RustTaskCategory::Algorithm, + signature: "fn shortest_path(adj: &[Vec<(usize, u64)>], src: usize, dst: usize) -> Option".into(), + description: "Implement Dijkstra's shortest path on a weighted directed graph.".into(), + test_cases: vec![ + ("3 nodes, 0->1:2, 1->2:3, 0->2:10; src=0, dst=2".into(), "Some(5)".into()), + ("2 nodes, no edges; src=0, dst=1".into(), "None".into()), + ], + required_traits: Vec::new(), + banned_patterns: vec!["unsafe".into()], + expected_complexity: Some("O((V + E) log V)".into()), + } + } + } + + /// Extract structural features from a Rust solution for embedding. + fn extract_features(&self, solution: &Solution) -> Vec { + let code = &solution.content; + let mut features = vec![0.0f32; EMBEDDING_DIM]; + + // Feature 0-7: Control flow complexity + features[0] = code.matches("if ").count() as f32 / 10.0; + features[1] = code.matches("for ").count() as f32 / 5.0; + features[2] = code.matches("while ").count() as f32 / 5.0; + features[3] = code.matches("match ").count() as f32 / 5.0; + features[4] = code.matches("loop ").count() as f32 / 3.0; + features[5] = code.matches("return ").count() as f32 / 5.0; + features[6] = code.matches("break").count() as f32 / 3.0; + features[7] = code.matches("continue").count() as f32 / 3.0; + + // Feature 8-15: Type system usage + features[8] = code.matches("impl ").count() as f32 / 5.0; + features[9] = code.matches("trait ").count() as f32 / 3.0; + features[10] = code.matches("struct ").count() as f32 / 3.0; + features[11] = code.matches("enum ").count() as f32 / 3.0; + features[12] = code.matches("where ").count() as f32 / 3.0; + features[13] = code.matches("dyn ").count() as f32 / 3.0; + features[14] = code.matches("Box<").count() as f32 / 3.0; + features[15] = code.matches("Rc<").count() as f32 / 3.0; + + // Feature 16-23: Functional patterns + features[16] = code.matches(".map(").count() as f32 / 5.0; + features[17] = code.matches(".filter(").count() as f32 / 5.0; + features[18] = code.matches(".fold(").count() as f32 / 3.0; + features[19] = code.matches(".collect()").count() as f32 / 3.0; + features[20] = code.matches(".iter()").count() as f32 / 5.0; + features[21] = code.matches("|").count() as f32 / 10.0; // closures + features[22] = code.matches("Some(").count() as f32 / 5.0; + features[23] = code.matches("None").count() as f32 / 5.0; + + // Feature 24-31: Memory/ownership patterns + features[24] = code.matches("&mut ").count() as f32 / 5.0; + features[25] = code.matches("&self").count() as f32 / 5.0; + features[26] = code.matches("mut ").count() as f32 / 10.0; + features[27] = code.matches(".clone()").count() as f32 / 5.0; + features[28] = code.matches("Vec<").count() as f32 / 5.0; + features[29] = code.matches("HashMap").count() as f32 / 3.0; + features[30] = code.matches("String").count() as f32 / 5.0; + features[31] = code.matches("Result<").count() as f32 / 3.0; + + // Feature 32-39: Concurrency patterns + features[32] = code.matches("Arc<").count() as f32 / 3.0; + features[33] = code.matches("Mutex<").count() as f32 / 3.0; + features[34] = code.matches("RwLock").count() as f32 / 3.0; + features[35] = code.matches("async ").count() as f32 / 3.0; + features[36] = code.matches("await").count() as f32 / 5.0; + features[37] = code.matches("spawn").count() as f32 / 3.0; + features[38] = code.matches("channel").count() as f32 / 3.0; + features[39] = code.matches("Atomic").count() as f32 / 3.0; + + // Feature 40-47: Code structure metrics + let lines: Vec<&str> = code.lines().collect(); + features[40] = (lines.len() as f32) / 100.0; + features[41] = lines.iter().filter(|l| l.trim().is_empty()).count() as f32 + / (lines.len().max(1) as f32); + features[42] = code.matches("fn ").count() as f32 / 10.0; + features[43] = code.matches("pub ").count() as f32 / 10.0; + features[44] = code.matches("mod ").count() as f32 / 5.0; + features[45] = code.matches("use ").count() as f32 / 10.0; + features[46] = code.matches("#[").count() as f32 / 5.0; // attributes + features[47] = code.matches("///").count() as f32 / 10.0; // doc comments + + // Feature 48-55: Error handling patterns + features[48] = code.matches("unwrap()").count() as f32 / 5.0; + features[49] = code.matches("expect(").count() as f32 / 5.0; + features[50] = code.matches("?;").count() as f32 / 5.0; // error propagation + features[51] = code.matches("Err(").count() as f32 / 5.0; + features[52] = code.matches("Ok(").count() as f32 / 5.0; + features[53] = code.matches("panic!").count() as f32 / 3.0; + features[54] = code.matches("assert").count() as f32 / 5.0; + features[55] = code.matches("debug_assert").count() as f32 / 3.0; + + // Feature 56-63: Algorithm indicators + features[56] = code.matches("sort").count() as f32 / 3.0; + features[57] = code.matches("binary_search").count() as f32 / 2.0; + features[58] = code.matches("push").count() as f32 / 5.0; + features[59] = code.matches("pop").count() as f32 / 5.0; + features[60] = code.matches("swap").count() as f32 / 5.0; + features[61] = code.matches("len()").count() as f32 / 5.0; + features[62] = code.matches("is_empty").count() as f32 / 3.0; + features[63] = code.matches("contains").count() as f32 / 3.0; + + // Normalize to unit length + let norm: f32 = features.iter().map(|x| x * x).sum::().sqrt(); + if norm > 1e-10 { + for f in &mut features { + *f /= norm; + } + } + + features + } + + /// Score a Rust solution based on pattern matching heuristics. + fn score_solution(&self, spec: &RustTaskSpec, solution: &Solution) -> Evaluation { + let code = &solution.content; + let mut correctness = 0.0f32; + let mut efficiency = 0.5f32; + let mut elegance = 0.5f32; + let mut notes = Vec::new(); + + // Check for banned patterns + let mut banned_found = false; + for pattern in &spec.banned_patterns { + if code.contains(pattern.as_str()) { + notes.push(format!("Banned pattern found: {}", pattern)); + banned_found = true; + } + } + + if banned_found { + elegance *= 0.5; + } + + // Check that the solution contains the expected signature + let sig_name = spec + .signature + .split('(') + .next() + .unwrap_or("") + .split_whitespace() + .last() + .unwrap_or(""); + + if code.contains(sig_name) { + correctness += 0.3; + } else { + notes.push(format!("Missing expected identifier: {}", sig_name)); + } + + // Check for fn definition + if code.contains("fn ") { + correctness += 0.2; + } + + // Check for test case coverage hints + let test_coverage = spec + .test_cases + .iter() + .filter(|(input, _)| { + // Heuristic: solution likely handles the input pattern + let key_tokens: Vec<&str> = input.split(|c: char| !c.is_alphanumeric()).collect(); + key_tokens.iter().any(|t| !t.is_empty() && code.contains(t)) + }) + .count() as f32 + / spec.test_cases.len().max(1) as f32; + correctness += test_coverage * 0.5; + correctness = correctness.clamp(0.0, 1.0); + + // Efficiency: penalize obviously quadratic patterns + let nested_loops = code.matches("for ").count() > 1 && code.matches("for ").count() > 2; + if nested_loops { + if let Some(ref expected) = spec.expected_complexity { + if expected.contains("O(n)") || expected.contains("O(log") { + efficiency *= 0.5; + notes.push("Possible O(n^2) when O(n) or O(log n) expected".into()); + } + } + } + + // Elegance: favor idiomatic Rust + let iterator_usage = code.matches(".iter()").count() + + code.matches(".map(").count() + + code.matches(".filter(").count() + + code.matches(".fold(").count(); + if iterator_usage > 0 { + elegance += 0.2; + } + + // Penalize excessive unwrap + let unwrap_count = code.matches("unwrap()").count(); + if unwrap_count > 3 { + elegance -= 0.2; + notes.push("Excessive unwrap() usage".into()); + } + + // Proper error handling bonus + if code.contains("Result<") || code.contains("?;") { + elegance += 0.1; + } + + elegance = elegance.clamp(0.0, 1.0); + + // Constraint results + let constraint_results = spec + .banned_patterns + .iter() + .map(|p| !code.contains(p.as_str())) + .collect(); + + let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance; + + Evaluation { + score: score.clamp(0.0, 1.0), + correctness, + efficiency, + elegance, + constraint_results, + notes, + } + } +} + +impl Default for RustSynthesisDomain { + fn default() -> Self { + Self::new() + } +} + +impl Domain for RustSynthesisDomain { + fn id(&self) -> &DomainId { + &self.id + } + + fn name(&self) -> &str { + "Rust Program Synthesis" + } + + fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec { + let mut rng = rand::thread_rng(); + let difficulty = difficulty.clamp(0.0, 1.0); + + (0..count) + .map(|i| { + let category_roll: f32 = rng.gen(); + let spec = if category_roll < 0.4 { + self.gen_transform(difficulty, &mut rng) + } else if category_roll < 0.7 { + self.gen_data_structure(difficulty, &mut rng) + } else { + self.gen_algorithm(difficulty, &mut rng) + }; + + Task { + id: format!("rust_synth_{}_d{:.0}", i, difficulty * 100.0), + domain_id: self.id.clone(), + difficulty, + spec: serde_json::to_value(&spec).unwrap_or_default(), + constraints: spec.banned_patterns.clone(), + } + }) + .collect() + } + + fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation { + let spec: RustTaskSpec = match serde_json::from_value(task.spec.clone()) { + Ok(s) => s, + Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]), + }; + self.score_solution(&spec, solution) + } + + fn embed(&self, solution: &Solution) -> DomainEmbedding { + let features = self.extract_features(solution); + DomainEmbedding::new(features, self.id.clone()) + } + + fn embedding_dim(&self) -> usize { + EMBEDDING_DIM + } + + fn reference_solution(&self, task: &Task) -> Option { + let spec: RustTaskSpec = serde_json::from_value(task.spec.clone()).ok()?; + + let content = match spec.category { + RustTaskCategory::Transform => { + if spec.signature.contains("sum_positives") { + "fn sum_positives(values: &[i64]) -> i64 {\n values.iter().filter(|&&x| x > 0).sum()\n}".to_string() + } else if spec.signature.contains("max_subarray_sum") { + "fn max_subarray_sum(values: &[i64]) -> i64 {\n let mut max_so_far = values[0];\n let mut max_ending = values[0];\n for &v in &values[1..] {\n max_ending = v.max(max_ending + v);\n max_so_far = max_so_far.max(max_ending);\n }\n max_so_far\n}".to_string() + } else { + format!( + "{} {{\n values.iter().map(|&x| x /* TODO */).collect()\n}}", + spec.signature + ) + } + } + _ => return None, + }; + + Some(Solution { + task_id: task.id.clone(), + content, + data: serde_json::Value::Null, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_tasks() { + let domain = RustSynthesisDomain::new(); + let tasks = domain.generate_tasks(5, 0.5); + assert_eq!(tasks.len(), 5); + for task in &tasks { + assert_eq!(task.domain_id, domain.id); + assert!((task.difficulty - 0.5).abs() < 1e-6); + } + } + + #[test] + fn test_evaluate_good_solution() { + let domain = RustSynthesisDomain::new(); + let tasks = domain.generate_tasks(1, 0.0); + let task = &tasks[0]; + + let solution = Solution { + task_id: task.id.clone(), + content: "fn double(values: &[i64]) -> Vec {\n values.iter().map(|&x| x * 2).collect()\n}".to_string(), + data: serde_json::Value::Null, + }; + + let eval = domain.evaluate(task, &solution); + assert!(eval.score > 0.0); + } + + #[test] + fn test_embed_produces_correct_dim() { + let domain = RustSynthesisDomain::new(); + let solution = Solution { + task_id: "test".into(), + content: "fn foo() { let x = 1; }".into(), + data: serde_json::Value::Null, + }; + let embedding = domain.embed(&solution); + assert_eq!(embedding.dim, EMBEDDING_DIM); + assert_eq!(embedding.vector.len(), EMBEDDING_DIM); + } + + #[test] + fn test_embedding_normalized() { + let domain = RustSynthesisDomain::new(); + let solution = Solution { + task_id: "test".into(), + content: "fn foo() { for i in 0..10 { if i > 5 { println!(\"{}\", i); } } }".into(), + data: serde_json::Value::Null, + }; + let embedding = domain.embed(&solution); + let norm: f32 = embedding.vector.iter().map(|x| x * x).sum::().sqrt(); + assert!((norm - 1.0).abs() < 1e-4); + } + + #[test] + fn test_difficulty_range() { + let domain = RustSynthesisDomain::new(); + // Easy tasks + let easy = domain.generate_tasks(3, 0.1); + for t in &easy { + let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap(); + assert!(!spec.signature.is_empty()); + } + // Hard tasks + let hard = domain.generate_tasks(3, 0.9); + for t in &hard { + let spec: RustTaskSpec = serde_json::from_value(t.spec.clone()).unwrap(); + assert!(!spec.signature.is_empty()); + } + } +} diff --git a/crates/ruvector-domain-expansion/src/tool_orchestration.rs b/crates/ruvector-domain-expansion/src/tool_orchestration.rs new file mode 100644 index 00000000..8064d3d9 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/tool_orchestration.rs @@ -0,0 +1,711 @@ +//! Tool Orchestration Problems Domain +//! +//! Generates tasks requiring coordinating multiple tools/agents to achieve goals. +//! Task types include: +//! +//! - **PipelineConstruction**: Build a data processing pipeline from available tools +//! - **ErrorRecovery**: Handle failures in multi-step tool chains +//! - **ParallelCoordination**: Execute independent tool calls concurrently +//! - **ResourceNegotiation**: Manage shared resources across tool invocations +//! - **AdaptiveRouting**: Select tools dynamically based on intermediate results +//! +//! Cross-domain transfer is strongest here: planning decomposes goals, +//! Rust synthesis provides execution patterns, and orchestration combines them. + +use crate::domain::{Domain, DomainEmbedding, DomainId, Evaluation, Solution, Task}; +use rand::Rng; +use serde::{Deserialize, Serialize}; + +const EMBEDDING_DIM: usize = 64; + +/// Categories of tool orchestration tasks. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum OrchestrationCategory { + /// Build a pipeline: chain tools to transform input to desired output. + PipelineConstruction, + /// Handle failure: detect errors and apply fallback strategies. + ErrorRecovery, + /// Coordinate parallel: dispatch independent calls and merge results. + ParallelCoordination, + /// Negotiate resources: manage rate limits, quotas, shared state. + ResourceNegotiation, + /// Adaptive routing: choose tool based on intermediate result properties. + AdaptiveRouting, +} + +/// A tool available in the orchestration environment. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolSpec { + pub name: String, + pub description: String, + /// Input type signature (e.g., "text", "json", "binary"). + pub input_type: String, + /// Output type signature. + pub output_type: String, + /// Average latency in milliseconds. + pub latency_ms: u32, + /// Failure rate [0.0, 1.0]. + pub failure_rate: f32, + /// Cost per invocation. + pub cost: f32, + /// Rate limit (max calls per minute), 0 = unlimited. + pub rate_limit: u32, +} + +/// An orchestration task specification. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OrchestrationTaskSpec { + pub category: OrchestrationCategory, + pub description: String, + /// Available tools in the environment. + pub available_tools: Vec, + /// Input to the pipeline. + pub input: serde_json::Value, + /// Expected output type/shape. + pub expected_output_type: String, + /// Maximum total latency budget (ms). + pub latency_budget_ms: u32, + /// Maximum total cost budget. + pub cost_budget: f32, + /// Required reliability (min success rate). + pub min_reliability: f32, + /// Error scenarios that must be handled. + pub error_scenarios: Vec, +} + +/// A tool call in an orchestration solution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub tool_name: String, + /// Input to this tool call (ref to previous output or literal). + pub input_ref: String, + /// Whether this can run in parallel with other calls. + pub parallel_group: Option, + /// Fallback tool if this one fails. + pub fallback: Option, + /// Retry count on failure. + pub retries: u32, +} + +/// A parsed orchestration plan. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OrchestrationPlan { + pub calls: Vec, + /// Error handling strategy description. + pub error_strategy: String, +} + +/// Tool orchestration domain. +pub struct ToolOrchestrationDomain { + id: DomainId, +} + +impl ToolOrchestrationDomain { + pub fn new() -> Self { + Self { + id: DomainId("tool_orchestration".to_string()), + } + } + + fn base_tools() -> Vec { + vec![ + ToolSpec { + name: "text_extract".into(), + description: "Extract text from documents".into(), + input_type: "binary".into(), + output_type: "text".into(), + latency_ms: 50, + failure_rate: 0.02, + cost: 0.001, + rate_limit: 100, + }, + ToolSpec { + name: "text_embed".into(), + description: "Generate embeddings from text".into(), + input_type: "text".into(), + output_type: "vector".into(), + latency_ms: 30, + failure_rate: 0.01, + cost: 0.002, + rate_limit: 200, + }, + ToolSpec { + name: "vector_search".into(), + description: "Search vector index for similar items".into(), + input_type: "vector".into(), + output_type: "json".into(), + latency_ms: 10, + failure_rate: 0.005, + cost: 0.0005, + rate_limit: 500, + }, + ToolSpec { + name: "llm_generate".into(), + description: "Generate text using language model".into(), + input_type: "text".into(), + output_type: "text".into(), + latency_ms: 2000, + failure_rate: 0.05, + cost: 0.01, + rate_limit: 30, + }, + ToolSpec { + name: "json_transform".into(), + description: "Apply JQ-like transformations to JSON".into(), + input_type: "json".into(), + output_type: "json".into(), + latency_ms: 5, + failure_rate: 0.001, + cost: 0.0001, + rate_limit: 0, + }, + ToolSpec { + name: "code_execute".into(), + description: "Execute code in sandboxed environment".into(), + input_type: "text".into(), + output_type: "json".into(), + latency_ms: 500, + failure_rate: 0.1, + cost: 0.005, + rate_limit: 20, + }, + ToolSpec { + name: "http_fetch".into(), + description: "Fetch data from external HTTP endpoint".into(), + input_type: "text".into(), + output_type: "json".into(), + latency_ms: 300, + failure_rate: 0.15, + cost: 0.0, + rate_limit: 60, + }, + ToolSpec { + name: "cache_lookup".into(), + description: "Check local cache for previously computed results".into(), + input_type: "text".into(), + output_type: "json".into(), + latency_ms: 1, + failure_rate: 0.0, + cost: 0.0, + rate_limit: 0, + }, + ToolSpec { + name: "validator".into(), + description: "Validate output against schema".into(), + input_type: "json".into(), + output_type: "json".into(), + latency_ms: 2, + failure_rate: 0.0, + cost: 0.0, + rate_limit: 0, + }, + ToolSpec { + name: "aggregator".into(), + description: "Merge multiple results into one".into(), + input_type: "json".into(), + output_type: "json".into(), + latency_ms: 5, + failure_rate: 0.0, + cost: 0.0001, + rate_limit: 0, + }, + ] + } + + fn gen_pipeline(&self, difficulty: f32) -> OrchestrationTaskSpec { + let tools = Self::base_tools(); + let num_tools = if difficulty < 0.3 { + 3 + } else if difficulty < 0.7 { + 6 + } else { + 10 + }; + + OrchestrationTaskSpec { + category: OrchestrationCategory::PipelineConstruction, + description: format!( + "Build a RAG pipeline using {} tools: extract, embed, search, generate.", + num_tools + ), + available_tools: tools[..num_tools.min(tools.len())].to_vec(), + input: serde_json::json!({"type": "binary", "format": "pdf"}), + expected_output_type: "text".into(), + latency_budget_ms: if difficulty < 0.5 { 5000 } else { 2000 }, + cost_budget: if difficulty < 0.5 { 0.1 } else { 0.02 }, + min_reliability: if difficulty < 0.5 { 0.9 } else { 0.99 }, + error_scenarios: Vec::new(), + } + } + + fn gen_error_recovery(&self, difficulty: f32) -> OrchestrationTaskSpec { + let tools = Self::base_tools(); + let error_scenarios = if difficulty < 0.3 { + vec!["timeout on llm_generate".into()] + } else if difficulty < 0.7 { + vec![ + "timeout on llm_generate".into(), + "http_fetch returns 429".into(), + "code_execute sandbox OOM".into(), + ] + } else { + vec![ + "timeout on llm_generate".into(), + "http_fetch returns 429".into(), + "code_execute sandbox OOM".into(), + "vector_search index corruption".into(), + "cascading failure: embed + search both down".into(), + ] + }; + + OrchestrationTaskSpec { + category: OrchestrationCategory::ErrorRecovery, + description: format!( + "Handle {} error scenarios in a multi-tool pipeline with graceful degradation.", + error_scenarios.len() + ), + available_tools: tools, + input: serde_json::json!({"type": "text", "content": "query"}), + expected_output_type: "json".into(), + latency_budget_ms: 10000, + cost_budget: 0.1, + min_reliability: 0.95, + error_scenarios, + } + } + + fn gen_parallel_coordination(&self, difficulty: f32) -> OrchestrationTaskSpec { + let tools = Self::base_tools(); + let parallelism = if difficulty < 0.3 { 2 } else if difficulty < 0.7 { 4 } else { 8 }; + + OrchestrationTaskSpec { + category: OrchestrationCategory::ParallelCoordination, + description: format!( + "Execute {} independent tool chains in parallel, merge results within latency budget.", + parallelism + ), + available_tools: tools, + input: serde_json::json!({"queries": (0..parallelism).map(|i| format!("query_{}", i)).collect::>()}), + expected_output_type: "json".into(), + latency_budget_ms: if difficulty < 0.5 { 3000 } else { 1000 }, + cost_budget: 0.05 * parallelism as f32, + min_reliability: 0.95, + error_scenarios: Vec::new(), + } + } + + fn extract_features(&self, solution: &Solution) -> Vec { + let content = &solution.content; + let mut features = vec![0.0f32; EMBEDDING_DIM]; + + let plan: OrchestrationPlan = serde_json::from_str(&solution.data.to_string()) + .or_else(|_| serde_json::from_str(content)) + .unwrap_or(OrchestrationPlan { + calls: Vec::new(), + error_strategy: String::new(), + }); + + // Feature 0-7: Plan structure + features[0] = plan.calls.len() as f32 / 20.0; + let unique_tools: std::collections::HashSet<&str> = + plan.calls.iter().map(|c| c.tool_name.as_str()).collect(); + features[1] = unique_tools.len() as f32 / 10.0; + // Parallelism ratio + let parallel_calls = plan.calls.iter().filter(|c| c.parallel_group.is_some()).count(); + features[2] = parallel_calls as f32 / plan.calls.len().max(1) as f32; + // Fallback coverage + let fallback_calls = plan.calls.iter().filter(|c| c.fallback.is_some()).count(); + features[3] = fallback_calls as f32 / plan.calls.len().max(1) as f32; + // Average retries + let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum(); + features[4] = total_retries as f32 / plan.calls.len().max(1) as f32 / 5.0; + + // Feature 8-15: Tool type usage + let tool_names = [ + "extract", "embed", "search", "generate", "transform", + "execute", "fetch", "cache", + ]; + for (i, name) in tool_names.iter().enumerate() { + features[8 + i] = plan + .calls + .iter() + .filter(|c| c.tool_name.contains(name)) + .count() as f32 + / plan.calls.len().max(1) as f32; + } + + // Feature 16-23: Text pattern features + features[16] = content.matches("pipeline").count() as f32 / 3.0; + features[17] = content.matches("parallel").count() as f32 / 5.0; + features[18] = content.matches("fallback").count() as f32 / 5.0; + features[19] = content.matches("retry").count() as f32 / 5.0; + features[20] = content.matches("cache").count() as f32 / 5.0; + features[21] = content.matches("timeout").count() as f32 / 3.0; + features[22] = content.matches("merge").count() as f32 / 3.0; + features[23] = content.matches("validate").count() as f32 / 3.0; + + // Feature 32-39: Error handling patterns + features[32] = content.matches("error").count() as f32 / 5.0; + features[33] = content.matches("recover").count() as f32 / 3.0; + features[34] = content.matches("degrade").count() as f32 / 3.0; + features[35] = content.matches("circuit_break").count() as f32 / 2.0; + features[36] = content.matches("rate_limit").count() as f32 / 3.0; + features[37] = content.matches("backoff").count() as f32 / 3.0; + features[38] = content.matches("health_check").count() as f32 / 2.0; + features[39] = content.matches("monitor").count() as f32 / 3.0; + + // Feature 48-55: Coordination patterns + features[48] = content.matches("scatter").count() as f32 / 2.0; + features[49] = content.matches("gather").count() as f32 / 2.0; + features[50] = content.matches("fan_out").count() as f32 / 2.0; + features[51] = content.matches("aggregate").count() as f32 / 3.0; + features[52] = content.matches("route").count() as f32 / 3.0; + features[53] = content.matches("dispatch").count() as f32 / 3.0; + features[54] = content.matches("await").count() as f32 / 5.0; + features[55] = content.matches("join").count() as f32 / 3.0; + + // Normalize + let norm: f32 = features.iter().map(|x| x * x).sum::().sqrt(); + if norm > 1e-10 { + for f in &mut features { + *f /= norm; + } + } + + features + } + + fn score_orchestration( + &self, + spec: &OrchestrationTaskSpec, + solution: &Solution, + ) -> Evaluation { + let content = &solution.content; + let mut correctness = 0.0f32; + let mut efficiency = 0.5f32; + let mut elegance = 0.5f32; + let mut notes = Vec::new(); + + let plan: Option = serde_json::from_str(&solution.data.to_string()) + .ok() + .or_else(|| serde_json::from_str(content).ok()); + + let plan = match plan { + Some(p) => p, + None => { + let has_tools = spec + .available_tools + .iter() + .any(|t| content.contains(&t.name)); + if has_tools { + correctness = 0.2; + } + return Evaluation { + score: correctness * 0.6, + correctness, + efficiency: 0.0, + elegance: 0.0, + constraint_results: Vec::new(), + notes: vec!["Could not parse orchestration plan".into()], + }; + } + }; + + if plan.calls.is_empty() { + return Evaluation::zero(vec!["Empty orchestration plan".into()]); + } + + // Correctness: type chain validity + let mut type_errors = 0; + for window in plan.calls.windows(2) { + let output_tool = spec + .available_tools + .iter() + .find(|t| t.name == window[0].tool_name); + let input_tool = spec + .available_tools + .iter() + .find(|t| t.name == window[1].tool_name); + + if let (Some(out_t), Some(in_t)) = (output_tool, input_tool) { + if window[1].parallel_group.is_none() && out_t.output_type != in_t.input_type { + type_errors += 1; + notes.push(format!( + "Type mismatch: {} outputs {} but {} expects {}", + out_t.name, out_t.output_type, in_t.name, in_t.input_type + )); + } + } + } + let chain_len = (plan.calls.len() - 1).max(1); + correctness = 1.0 - (type_errors as f32 / chain_len as f32); + + // Tool coverage: do we use tools that produce the expected output? + let produces_output = plan.calls.iter().any(|c| { + spec.available_tools + .iter() + .any(|t| t.name == c.tool_name && t.output_type == spec.expected_output_type) + }); + if !produces_output { + correctness *= 0.5; + notes.push("No tool produces the expected output type".into()); + } + + // Error handling coverage + if !spec.error_scenarios.is_empty() { + let handled = spec + .error_scenarios + .iter() + .filter(|scenario| { + plan.calls.iter().any(|c| c.fallback.is_some() || c.retries > 0) + || plan.error_strategy.contains(&scenario.as_str()[..scenario.len().min(10)]) + }) + .count() as f32 + / spec.error_scenarios.len() as f32; + correctness = correctness * 0.7 + handled * 0.3; + } + + // Efficiency: estimated latency and cost + let est_latency: u32 = { + let mut groups: std::collections::HashMap = std::collections::HashMap::new(); + let mut sequential_latency = 0u32; + for call in &plan.calls { + let tool_latency = spec + .available_tools + .iter() + .find(|t| t.name == call.tool_name) + .map(|t| t.latency_ms) + .unwrap_or(100); + + if let Some(group) = call.parallel_group { + let entry = groups.entry(group).or_insert(0); + *entry = (*entry).max(tool_latency); + } else { + sequential_latency += tool_latency; + } + } + sequential_latency + groups.values().sum::() + }; + + if est_latency <= spec.latency_budget_ms { + efficiency = 1.0 - (est_latency as f32 / spec.latency_budget_ms as f32 * 0.5); + } else { + efficiency = spec.latency_budget_ms as f32 / est_latency as f32 * 0.5; + notes.push(format!( + "Estimated latency {}ms exceeds budget {}ms", + est_latency, spec.latency_budget_ms + )); + } + + let est_cost: f32 = plan + .calls + .iter() + .filter_map(|c| { + spec.available_tools + .iter() + .find(|t| t.name == c.tool_name) + .map(|t| t.cost * (1.0 + c.retries as f32)) + }) + .sum(); + + if est_cost > spec.cost_budget { + efficiency *= 0.7; + notes.push(format!( + "Cost {:.4} exceeds budget {:.4}", + est_cost, spec.cost_budget + )); + } + + // Elegance: parallelism, caching, minimal redundancy + let parallelism_used = plan.calls.iter().any(|c| c.parallel_group.is_some()); + if parallelism_used { + elegance += 0.15; + } + + let cache_used = plan.calls.iter().any(|c| c.tool_name.contains("cache")); + if cache_used { + elegance += 0.1; + } + + let validation_used = plan + .calls + .iter() + .any(|c| c.tool_name.contains("validat")); + if validation_used { + elegance += 0.1; + } + + // Penalize excessive retries + let total_retries: u32 = plan.calls.iter().map(|c| c.retries).sum(); + if total_retries > plan.calls.len() as u32 * 2 { + elegance -= 0.2; + notes.push("Excessive retry configuration".into()); + } + + elegance = elegance.clamp(0.0, 1.0); + + let score = 0.6 * correctness + 0.25 * efficiency + 0.15 * elegance; + Evaluation { + score: score.clamp(0.0, 1.0), + correctness, + efficiency, + elegance, + constraint_results: Vec::new(), + notes, + } + } +} + +impl Default for ToolOrchestrationDomain { + fn default() -> Self { + Self::new() + } +} + +impl Domain for ToolOrchestrationDomain { + fn id(&self) -> &DomainId { + &self.id + } + + fn name(&self) -> &str { + "Tool Orchestration" + } + + fn generate_tasks(&self, count: usize, difficulty: f32) -> Vec { + let mut rng = rand::thread_rng(); + let difficulty = difficulty.clamp(0.0, 1.0); + + (0..count) + .map(|i| { + let roll: f32 = rng.gen(); + let spec = if roll < 0.4 { + self.gen_pipeline(difficulty) + } else if roll < 0.7 { + self.gen_error_recovery(difficulty) + } else { + self.gen_parallel_coordination(difficulty) + }; + + Task { + id: format!("orch_{}_d{:.0}", i, difficulty * 100.0), + domain_id: self.id.clone(), + difficulty, + spec: serde_json::to_value(&spec).unwrap_or_default(), + constraints: Vec::new(), + } + }) + .collect() + } + + fn evaluate(&self, task: &Task, solution: &Solution) -> Evaluation { + let spec: OrchestrationTaskSpec = match serde_json::from_value(task.spec.clone()) { + Ok(s) => s, + Err(e) => return Evaluation::zero(vec![format!("Invalid task spec: {}", e)]), + }; + self.score_orchestration(&spec, solution) + } + + fn embed(&self, solution: &Solution) -> DomainEmbedding { + let features = self.extract_features(solution); + DomainEmbedding::new(features, self.id.clone()) + } + + fn embedding_dim(&self) -> usize { + EMBEDDING_DIM + } + + fn reference_solution(&self, task: &Task) -> Option { + let spec: OrchestrationTaskSpec = serde_json::from_value(task.spec.clone()).ok()?; + + // Build a sequential pipeline through available tools + let calls: Vec = spec + .available_tools + .iter() + .map(|t| ToolCall { + tool_name: t.name.clone(), + input_ref: "previous".into(), + parallel_group: None, + fallback: None, + retries: if t.failure_rate > 0.05 { 2 } else { 0 }, + }) + .collect(); + + let plan = OrchestrationPlan { + calls, + error_strategy: "retry with exponential backoff".into(), + }; + + let content = serde_json::to_string_pretty(&plan).ok()?; + Some(Solution { + task_id: task.id.clone(), + content, + data: serde_json::to_value(&plan).ok()?, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_orchestration_tasks() { + let domain = ToolOrchestrationDomain::new(); + let tasks = domain.generate_tasks(5, 0.5); + assert_eq!(tasks.len(), 5); + for task in &tasks { + assert_eq!(task.domain_id, domain.id); + } + } + + #[test] + fn test_reference_solution() { + let domain = ToolOrchestrationDomain::new(); + let tasks = domain.generate_tasks(3, 0.3); + for task in &tasks { + let ref_sol = domain.reference_solution(task); + assert!(ref_sol.is_some()); + } + } + + #[test] + fn test_evaluate_reference() { + let domain = ToolOrchestrationDomain::new(); + let tasks = domain.generate_tasks(3, 0.3); + for task in &tasks { + if let Some(solution) = domain.reference_solution(task) { + let eval = domain.evaluate(task, &solution); + assert!(eval.score >= 0.0 && eval.score <= 1.0); + } + } + } + + #[test] + fn test_embed_orchestration() { + let domain = ToolOrchestrationDomain::new(); + let solution = Solution { + task_id: "test".into(), + content: "pipeline: extract -> embed -> search with fallback and retry".into(), + data: serde_json::json!({ + "calls": [ + {"tool_name": "text_extract", "input_ref": "input", "retries": 1} + ], + "error_strategy": "retry" + }), + }; + let embedding = domain.embed(&solution); + assert_eq!(embedding.dim, EMBEDDING_DIM); + } + + #[test] + fn test_difficulty_affects_error_scenarios() { + let domain = ToolOrchestrationDomain::new(); + // Generate many tasks at high difficulty to get error recovery tasks + let hard = domain.generate_tasks(20, 0.9); + let has_error_tasks = hard.iter().any(|t| { + let spec: OrchestrationTaskSpec = serde_json::from_value(t.spec.clone()).unwrap(); + !spec.error_scenarios.is_empty() + }); + assert!(has_error_tasks, "High difficulty should produce error scenarios"); + } +} diff --git a/crates/ruvector-domain-expansion/src/transfer.rs b/crates/ruvector-domain-expansion/src/transfer.rs new file mode 100644 index 00000000..a7ab30e3 --- /dev/null +++ b/crates/ruvector-domain-expansion/src/transfer.rs @@ -0,0 +1,599 @@ +//! Cross-Domain Transfer Engine with Meta Thompson Sampling +//! +//! Transfer happens through priors, not raw memories. +//! Ship compact priors and verified kernels between domains. +//! +//! ## Two-Layer Learning Architecture +//! +//! **Policy learning layer**: Chooses strategies, budgets, and tool paths +//! using uncertainty-aware selection (Thompson Sampling with Beta priors). +//! +//! **Operator layer**: Executes deterministic kernels and graders, +//! logs witnesses, and commits state through gates. +//! +//! ## Meta Thompson Sampling +//! +//! After each cycle, compute posterior summary per bucket and arm. +//! Store as TransferPrior. When a new domain starts, initialize its +//! buckets with these priors instead of uniform, enabling faster adaptation. +//! +//! ## Cross-Domain Transfer Protocol +//! +//! A delta is promotable only if it improves Domain 2 without regressing +//! Domain 1, or improves Domain 1 without regressing Domain 2. +//! That is generalization. + +use crate::domain::DomainId; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Beta distribution parameters for Thompson Sampling. +/// Represents uncertainty about an arm's reward probability. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BetaParams { + /// Success count + prior (alpha). + pub alpha: f32, + /// Failure count + prior (beta). + pub beta: f32, +} + +impl BetaParams { + /// Uniform (uninformative) prior: Beta(1, 1). + pub fn uniform() -> Self { + Self { + alpha: 1.0, + beta: 1.0, + } + } + + /// Create from observed successes and failures. + pub fn from_observations(successes: f32, failures: f32) -> Self { + Self { + alpha: successes + 1.0, + beta: failures + 1.0, + } + } + + /// Mean of the Beta distribution: E[X] = alpha / (alpha + beta). + pub fn mean(&self) -> f32 { + self.alpha / (self.alpha + self.beta) + } + + /// Variance: measures uncertainty. Lower = more confident. + pub fn variance(&self) -> f32 { + let total = self.alpha + self.beta; + (self.alpha * self.beta) / (total * total * (total + 1.0)) + } + + /// Sample from the Beta distribution using the Kumaraswamy approximation. + /// Fast, no special functions needed, good enough for Thompson Sampling. + pub fn sample(&self, rng: &mut impl Rng) -> f32 { + // Use inverse CDF of Beta via simple approximation + let u: f32 = rng.gen_range(0.001..0.999); + // Kumaraswamy approximation: x = (1 - (1 - u^(1/b))^(1/a)) + // Better approximation using ratio of gammas via the normal approach + let x = Self::beta_inv_approx(u, self.alpha, self.beta); + x.clamp(0.0, 1.0) + } + + /// Approximate inverse CDF of Beta distribution. + fn beta_inv_approx(p: f32, a: f32, b: f32) -> f32 { + // Use normal approximation for Beta when a,b are not too small + if a > 1.0 && b > 1.0 { + let mean = a / (a + b); + let var = (a * b) / ((a + b) * (a + b) * (a + b + 1.0)); + let std = var.sqrt(); + // Inverse normal approximation (Abramowitz & Stegun) + let t = if p < 0.5 { + (-2.0 * (p).ln()).sqrt() + } else { + (-2.0 * (1.0 - p).ln()).sqrt() + }; + let x = if p < 0.5 { + mean - std * t + } else { + mean + std * t + }; + x.clamp(0.001, 0.999) + } else { + // Fallback: simple power approximation + p.powf(1.0 / a) * (1.0 - (1.0 - p).powf(1.0 / b)) + + p.powf(1.0 / a) * 0.5 + } + } + + /// Update with an observation (Bayesian posterior update). + pub fn update(&mut self, reward: f32) { + self.alpha += reward; + self.beta += 1.0 - reward; + } + + /// Merge two Beta distributions (approximate: sum parameters). + pub fn merge(&self, other: &BetaParams) -> BetaParams { + BetaParams { + alpha: self.alpha + other.alpha - 1.0, // subtract uniform prior + beta: self.beta + other.beta - 1.0, + } + } +} + +/// A context bucket groups similar problem instances for targeted learning. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct ContextBucket { + /// Difficulty tier: "easy", "medium", "hard". + pub difficulty_tier: String, + /// Problem category within the domain. + pub category: String, +} + +/// An arm in the multi-armed bandit: a strategy choice. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct ArmId(pub String); + +/// Transfer prior: compact posterior summary from a source domain. +/// This is what gets shipped between domains — not raw trajectories. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TransferPrior { + /// Source domain that generated this prior. + pub source_domain: DomainId, + /// Per-bucket, per-arm Beta parameters (posterior summaries). + pub bucket_priors: HashMap>, + /// Cost EMA (exponential moving average) priors per bucket. + pub cost_ema_priors: HashMap, + /// Number of cycles this prior was trained on. + pub training_cycles: u64, + /// Witness hash: proof of how this prior was derived. + pub witness_hash: String, +} + +impl TransferPrior { + /// Create an empty (uniform) prior for a domain. + pub fn uniform(source_domain: DomainId) -> Self { + Self { + source_domain, + bucket_priors: HashMap::new(), + cost_ema_priors: HashMap::new(), + training_cycles: 0, + witness_hash: String::new(), + } + } + + /// Get the prior for a specific bucket and arm, defaulting to uniform. + pub fn get_prior(&self, bucket: &ContextBucket, arm: &ArmId) -> BetaParams { + self.bucket_priors + .get(bucket) + .and_then(|arms| arms.get(arm)) + .cloned() + .unwrap_or_else(BetaParams::uniform) + } + + /// Update the posterior for a bucket/arm with a new observation. + pub fn update_posterior( + &mut self, + bucket: ContextBucket, + arm: ArmId, + reward: f32, + ) { + let arms = self.bucket_priors.entry(bucket.clone()).or_default(); + let params = arms.entry(arm).or_insert_with(BetaParams::uniform); + params.update(reward); + self.training_cycles += 1; + } + + /// Update cost EMA for a bucket. + pub fn update_cost_ema(&mut self, bucket: ContextBucket, cost: f32, decay: f32) { + let entry = self.cost_ema_priors.entry(bucket).or_insert(cost); + *entry = decay * (*entry) + (1.0 - decay) * cost; + } + + /// Extract a compact summary suitable for shipping to another domain. + pub fn extract_summary(&self) -> TransferPrior { + // Only ship buckets with sufficient evidence (>10 observations) + let filtered: HashMap> = self + .bucket_priors + .iter() + .filter_map(|(bucket, arms)| { + let significant_arms: HashMap = arms + .iter() + .filter(|(_, params)| (params.alpha + params.beta) > 12.0) + .map(|(arm, params)| (arm.clone(), params.clone())) + .collect(); + if significant_arms.is_empty() { + None + } else { + Some((bucket.clone(), significant_arms)) + } + }) + .collect(); + + TransferPrior { + source_domain: self.source_domain.clone(), + bucket_priors: filtered, + cost_ema_priors: self.cost_ema_priors.clone(), + training_cycles: self.training_cycles, + witness_hash: self.witness_hash.clone(), + } + } +} + +/// Meta Thompson Sampling engine that manages priors across domains. +pub struct MetaThompsonEngine { + /// Active priors per domain. + domain_priors: HashMap, + /// Available arms (strategies) shared across domains. + arms: Vec, + /// Difficulty tiers for bucketing. + difficulty_tiers: Vec, +} + +impl MetaThompsonEngine { + /// Create a new engine with the given strategy arms. + pub fn new(arms: Vec) -> Self { + Self { + domain_priors: HashMap::new(), + arms: arms.into_iter().map(ArmId).collect(), + difficulty_tiers: vec!["easy".into(), "medium".into(), "hard".into()], + } + } + + /// Initialize a domain with uniform priors. + pub fn init_domain_uniform(&mut self, domain_id: DomainId) { + self.domain_priors + .insert(domain_id.clone(), TransferPrior::uniform(domain_id)); + } + + /// Initialize a domain using transfer priors from a source domain. + /// This is the key mechanism: Meta-TS seeds new domains with learned priors. + pub fn init_domain_with_transfer( + &mut self, + target_domain: DomainId, + source_prior: &TransferPrior, + ) { + let mut prior = TransferPrior::uniform(target_domain.clone()); + + // Copy bucket priors from source, scaling by confidence + for (bucket, arms) in &source_prior.bucket_priors { + for (arm, params) in arms { + // Dampen the prior: don't fully trust cross-domain evidence. + // Use sqrt scaling: reduces confidence while preserving mean. + let dampened = BetaParams { + alpha: 1.0 + (params.alpha - 1.0).sqrt(), + beta: 1.0 + (params.beta - 1.0).sqrt(), + }; + prior + .bucket_priors + .entry(bucket.clone()) + .or_default() + .insert(arm.clone(), dampened); + } + } + + // Transfer cost EMAs with dampening + for (bucket, &cost) in &source_prior.cost_ema_priors { + prior.cost_ema_priors.insert(bucket.clone(), cost * 1.5); // pessimistic transfer + } + + prior.witness_hash = format!("transfer_from_{}", source_prior.source_domain); + self.domain_priors.insert(target_domain, prior); + } + + /// Select an arm for a given domain and context using Thompson Sampling. + pub fn select_arm( + &self, + domain_id: &DomainId, + bucket: &ContextBucket, + rng: &mut impl Rng, + ) -> Option { + let prior = self.domain_priors.get(domain_id)?; + + let mut best_arm = None; + let mut best_sample = f32::NEG_INFINITY; + + for arm in &self.arms { + let params = prior.get_prior(bucket, arm); + let sample = params.sample(rng); + if sample > best_sample { + best_sample = sample; + best_arm = Some(arm.clone()); + } + } + + best_arm + } + + /// Record the outcome of using an arm in a domain. + pub fn record_outcome( + &mut self, + domain_id: &DomainId, + bucket: ContextBucket, + arm: ArmId, + reward: f32, + cost: f32, + ) { + if let Some(prior) = self.domain_priors.get_mut(domain_id) { + prior.update_posterior(bucket.clone(), arm, reward); + prior.update_cost_ema(bucket, cost, 0.9); + } + } + + /// Extract transfer prior from a domain (for shipping to another domain). + pub fn extract_prior(&self, domain_id: &DomainId) -> Option { + self.domain_priors.get(domain_id).map(|p| p.extract_summary()) + } + + /// Get all domain IDs currently tracked. + pub fn domain_ids(&self) -> Vec<&DomainId> { + self.domain_priors.keys().collect() + } + + /// Check if posterior variance is high (triggers speculative dual-path). + pub fn is_uncertain( + &self, + domain_id: &DomainId, + bucket: &ContextBucket, + threshold: f32, + ) -> bool { + let prior = match self.domain_priors.get(domain_id) { + Some(p) => p, + None => return true, // No data = maximum uncertainty + }; + + // Check if top two arms are within delta of each other + let mut samples: Vec<(f32, &ArmId)> = self + .arms + .iter() + .map(|arm| { + let params = prior.get_prior(bucket, arm); + (params.mean(), arm) + }) + .collect(); + samples.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + + if samples.len() < 2 { + return true; + } + + let gap = samples[0].0 - samples[1].0; + gap < threshold + } +} + +/// Speculative dual-path execution for high-uncertainty decisions. +/// When the top two arms are within delta, run both and pick the winner. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DualPathResult { + /// Primary arm and its outcome. + pub primary: (ArmId, f32), + /// Secondary arm and its outcome. + pub secondary: (ArmId, f32), + /// Which arm won. + pub winner: ArmId, + /// The loser becomes a counterexample for that context. + pub counterexample: ArmId, +} + +/// Cross-domain transfer verification result. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TransferVerification { + /// Source domain. + pub source: DomainId, + /// Target domain. + pub target: DomainId, + /// Did transfer improve the target domain? + pub improved_target: bool, + /// Did transfer regress the source domain? + pub regressed_source: bool, + /// Is this delta promotable? (improved target AND not regressed source). + pub promotable: bool, + /// Acceleration factor: ratio of convergence speeds. + pub acceleration_factor: f32, + /// Source score before/after. + pub source_scores: (f32, f32), + /// Target score before/after. + pub target_scores: (f32, f32), +} + +impl TransferVerification { + /// Verify a transfer delta against the generalization rule: + /// promotable iff it improves Domain 2 without regressing Domain 1. + pub fn verify( + source: DomainId, + target: DomainId, + source_before: f32, + source_after: f32, + target_before: f32, + target_after: f32, + target_baseline_cycles: u64, + target_transfer_cycles: u64, + ) -> Self { + let improved_target = target_after > target_before; + let regressed_source = source_after < source_before - 0.01; // small tolerance + + let promotable = improved_target && !regressed_source; + + // Acceleration = baseline_cycles / transfer_cycles (higher = better transfer) + let acceleration_factor = if target_transfer_cycles > 0 { + target_baseline_cycles as f32 / target_transfer_cycles as f32 + } else { + 1.0 + }; + + Self { + source, + target, + improved_target, + regressed_source, + promotable, + acceleration_factor, + source_scores: (source_before, source_after), + target_scores: (target_before, target_after), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_beta_params_uniform() { + let p = BetaParams::uniform(); + assert_eq!(p.alpha, 1.0); + assert_eq!(p.beta, 1.0); + assert!((p.mean() - 0.5).abs() < 1e-6); + } + + #[test] + fn test_beta_params_update() { + let mut p = BetaParams::uniform(); + p.update(1.0); // success + assert_eq!(p.alpha, 2.0); + assert_eq!(p.beta, 1.0); + assert!(p.mean() > 0.5); + } + + #[test] + fn test_beta_params_sample_in_range() { + let p = BetaParams::from_observations(10.0, 5.0); + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let s = p.sample(&mut rng); + assert!(s >= 0.0 && s <= 1.0, "Sample {} out of [0,1]", s); + } + } + + #[test] + fn test_transfer_prior_round_trip() { + let domain = DomainId("test".into()); + let mut prior = TransferPrior::uniform(domain); + + let bucket = ContextBucket { + difficulty_tier: "easy".into(), + category: "transform".into(), + }; + let arm = ArmId("strategy_a".into()); + + for _ in 0..20 { + prior.update_posterior(bucket.clone(), arm.clone(), 0.8); + } + + let summary = prior.extract_summary(); + assert!(!summary.bucket_priors.is_empty()); + + let retrieved = summary.get_prior(&bucket, &arm); + assert!(retrieved.mean() > 0.5); + } + + #[test] + fn test_meta_thompson_engine() { + let mut engine = MetaThompsonEngine::new(vec![ + "strategy_a".into(), + "strategy_b".into(), + "strategy_c".into(), + ]); + + let domain1 = DomainId("rust_synthesis".into()); + engine.init_domain_uniform(domain1.clone()); + + let bucket = ContextBucket { + difficulty_tier: "medium".into(), + category: "algorithm".into(), + }; + + let mut rng = rand::thread_rng(); + + // Record some outcomes + for _ in 0..50 { + let arm = engine.select_arm(&domain1, &bucket, &mut rng).unwrap(); + let reward = if arm.0 == "strategy_a" { 0.9 } else { 0.3 }; + engine.record_outcome(&domain1, bucket.clone(), arm, reward, 1.0); + } + + // Extract prior and transfer to domain2 + let prior = engine.extract_prior(&domain1).unwrap(); + let domain2 = DomainId("planning".into()); + engine.init_domain_with_transfer(domain2.clone(), &prior); + + // Domain2 should now have informative priors + let d2_prior = engine.domain_priors.get(&domain2).unwrap(); + let a_params = d2_prior.get_prior(&bucket, &ArmId("strategy_a".into())); + assert!(a_params.mean() > 0.5, "Transferred prior should favor strategy_a"); + } + + #[test] + fn test_transfer_verification() { + let v = TransferVerification::verify( + DomainId("d1".into()), + DomainId("d2".into()), + 0.8, // source before + 0.79, // source after (slight decrease, within tolerance) + 0.3, // target before + 0.7, // target after (big improvement) + 100, // baseline cycles + 40, // transfer cycles + ); + + assert!(v.improved_target); + assert!(!v.regressed_source); // within tolerance + assert!(v.promotable); + assert!((v.acceleration_factor - 2.5).abs() < 1e-4); + } + + #[test] + fn test_transfer_not_promotable_on_regression() { + let v = TransferVerification::verify( + DomainId("d1".into()), + DomainId("d2".into()), + 0.8, // source before + 0.5, // source after (regression!) + 0.3, // target before + 0.7, // target after + 100, + 40, + ); + + assert!(v.improved_target); + assert!(v.regressed_source); + assert!(!v.promotable); + } + + #[test] + fn test_uncertainty_detection() { + let mut engine = MetaThompsonEngine::new(vec![ + "a".into(), + "b".into(), + ]); + + let domain = DomainId("test".into()); + engine.init_domain_uniform(domain.clone()); + + let bucket = ContextBucket { + difficulty_tier: "easy".into(), + category: "test".into(), + }; + + // With uniform priors, should be uncertain + assert!(engine.is_uncertain(&domain, &bucket, 0.1)); + + // After many observations favoring one arm, should be certain + for _ in 0..100 { + engine.record_outcome( + &domain, + bucket.clone(), + ArmId("a".into()), + 0.95, + 1.0, + ); + engine.record_outcome( + &domain, + bucket.clone(), + ArmId("b".into()), + 0.1, + 1.0, + ); + } + + assert!(!engine.is_uncertain(&domain, &bucket, 0.1)); + } +}