mirror of
https://github.com/ruvnet/RuView.git
synced 2026-04-28 05:59:32 +00:00
Complete implementation of all 24 vendor-integrated sensing modules across 7 categories, compiled to wasm32-unknown-unknown for ESP32-S3 WASM3 runtime deployment. All 243 unit tests pass. Signal Intelligence (6): flash attention, coherence gate, temporal compress, sparse recovery, min-cut person match, optimal transport. Adaptive Learning (4): DTW gesture learn, anomaly attractor, meta adapt, EWC++ lifelong learning. Spatial Reasoning (3): PageRank influence, micro-HNSW, spiking tracker. Temporal Analysis (3): pattern sequence, temporal logic guard, GOAP. AI Security (2): prompt shield, behavioral profiler. Quantum-Inspired (2): quantum coherence, interference search. Autonomous Systems (2): psycho-symbolic engine, self-healing mesh. Exotic (2): time crystal detector, hyperbolic space embedding. Includes vendor_common.rs shared library, security audit with 5 fixes, and security audit report. Co-Authored-By: claude-flow <ruv@ruv.net>
611 lines
20 KiB
Rust
611 lines
20 KiB
Rust
//! Elastic Weight Consolidation for lifelong on-device learning — ADR-041 adaptive module.
|
|
//!
|
|
//! # Algorithm
|
|
//!
|
|
//! Implements EWC (Kirkpatrick et al., 2017) on a tiny 8-input, 4-output
|
|
//! linear classifier running entirely on the ESP32-S3 WASM3 interpreter.
|
|
//! The classifier maps 8D CSI feature vectors to 4 zone predictions.
|
|
//!
|
|
//! ## Core EWC Mechanism
|
|
//!
|
|
//! When learning a new task (e.g., a new room layout), naive gradient descent
|
|
//! overwrites parameters important for previous tasks -- "catastrophic
|
|
//! forgetting." EWC prevents this by adding a penalty term:
|
|
//!
|
|
//! ```text
|
|
//! L_total = L_current + (lambda/2) * sum_i( F_i * (theta_i - theta_i*)^2 )
|
|
//! ```
|
|
//!
|
|
//! where:
|
|
//! - `L_current` = MSE between predicted zone and actual zone
|
|
//! - `F_i` = Fisher Information diagonal (parameter importance)
|
|
//! - `theta_i*` = parameters at end of previous task
|
|
//! - `lambda` = 1000 (regularization strength)
|
|
//!
|
|
//! ## Fisher Information Estimation
|
|
//!
|
|
//! The Fisher diagonal approximates parameter importance:
|
|
//! `F_i = E[(d log p / d theta_i)^2] ~ running_average(gradient_i^2)`
|
|
//!
|
|
//! Gradients are estimated via finite differences (perturb each parameter
|
|
//! by epsilon=0.01, measure loss change).
|
|
//!
|
|
//! ## Task Boundary Detection
|
|
//!
|
|
//! A new task is detected when the system achieves 100 consecutive frames
|
|
//! with stable performance (loss below threshold). At this point:
|
|
//! 1. Snapshot current parameters as `theta_star`
|
|
//! 2. Update Fisher diagonal from accumulated gradient squares
|
|
//! 3. Increment task counter
|
|
//!
|
|
//! # Events (745-series: Adaptive Learning)
|
|
//!
|
|
//! - `KNOWLEDGE_RETAINED` (745): EWC penalty magnitude (lower = less forgetting).
|
|
//! - `NEW_TASK_LEARNED` (746): Task count after learning a new task.
|
|
//! - `FISHER_UPDATE` (747): Mean Fisher information value.
|
|
//! - `FORGETTING_RISK` (748): Ratio of EWC penalty to current loss.
|
|
//!
|
|
//! # Budget
|
|
//!
|
|
//! L (lightweight, < 2 ms) -- only updates a few params per frame using
|
|
//! a round-robin finite-difference gradient schedule.
|
|
|
|
// ── Constants ────────────────────────────────────────────────────────────────
|
|
|
|
/// Number of learnable parameters (8 inputs * 4 outputs = 32).
|
|
const N_PARAMS: usize = 32;
|
|
|
|
/// Input dimension (8 subcarrier groups).
|
|
const N_INPUT: usize = 8;
|
|
|
|
/// Output dimension (4 zones).
|
|
const N_OUTPUT: usize = 4;
|
|
|
|
/// EWC regularization strength.
|
|
const LAMBDA: f32 = 1000.0;
|
|
|
|
/// Finite-difference epsilon for gradient estimation.
|
|
const EPSILON: f32 = 0.01;
|
|
|
|
/// Number of parameters to update per frame (round-robin).
|
|
const PARAMS_PER_FRAME: usize = 4;
|
|
|
|
/// Learning rate for parameter updates.
|
|
const LEARNING_RATE: f32 = 0.001;
|
|
|
|
/// Consecutive stable frames required to trigger task boundary.
|
|
const STABLE_FRAMES_THRESHOLD: u32 = 100;
|
|
|
|
/// Loss threshold below which a frame is considered "stable".
|
|
const STABLE_LOSS_THRESHOLD: f32 = 0.1;
|
|
|
|
/// EMA smoothing for Fisher diagonal updates.
|
|
const FISHER_ALPHA: f32 = 0.01;
|
|
|
|
/// Maximum number of tasks before Fisher memory saturates.
|
|
const MAX_TASKS: u8 = 32;
|
|
|
|
/// Reporting interval (frames between event emissions).
|
|
const REPORT_INTERVAL: u32 = 20;
|
|
|
|
// ── Event IDs (745-series: Adaptive Learning) ────────────────────────────────
|
|
|
|
pub const EVENT_KNOWLEDGE_RETAINED: i32 = 745;
|
|
pub const EVENT_NEW_TASK_LEARNED: i32 = 746;
|
|
pub const EVENT_FISHER_UPDATE: i32 = 747;
|
|
pub const EVENT_FORGETTING_RISK: i32 = 748;
|
|
|
|
// ── EWC Lifelong Learner ─────────────────────────────────────────────────────
|
|
|
|
/// Elastic Weight Consolidation lifelong on-device learner.
|
|
pub struct EwcLifelong {
|
|
/// Current learnable parameters [N_PARAMS] (flattened [N_OUTPUT][N_INPUT]).
|
|
params: [f32; N_PARAMS],
|
|
/// Fisher Information diagonal [N_PARAMS].
|
|
fisher: [f32; N_PARAMS],
|
|
/// Snapshot of parameters at previous task boundary.
|
|
theta_star: [f32; N_PARAMS],
|
|
/// Accumulated gradient squares for Fisher estimation.
|
|
grad_accum: [f32; N_PARAMS],
|
|
/// Number of gradient samples accumulated.
|
|
grad_count: u32,
|
|
/// Number of completed tasks.
|
|
task_count: u8,
|
|
/// Consecutive frames with loss below threshold.
|
|
stable_frames: u32,
|
|
/// Current round-robin parameter index.
|
|
param_cursor: usize,
|
|
/// Frame counter.
|
|
frame_count: u32,
|
|
/// Last computed total loss (current + EWC penalty).
|
|
last_loss: f32,
|
|
/// Last computed EWC penalty.
|
|
last_penalty: f32,
|
|
/// Whether theta_star has been set (false until first task completes).
|
|
has_prior: bool,
|
|
}
|
|
|
|
impl EwcLifelong {
|
|
pub const fn new() -> Self {
|
|
Self {
|
|
params: Self::default_params(),
|
|
fisher: [0.0; N_PARAMS],
|
|
theta_star: [0.0; N_PARAMS],
|
|
grad_accum: [0.0; N_PARAMS],
|
|
grad_count: 0,
|
|
task_count: 0,
|
|
stable_frames: 0,
|
|
param_cursor: 0,
|
|
frame_count: 0,
|
|
last_loss: 0.0,
|
|
last_penalty: 0.0,
|
|
has_prior: false,
|
|
}
|
|
}
|
|
|
|
/// Initialize parameters with small diverse values to break symmetry.
|
|
/// Uses a deterministic pattern (no RNG needed in const context).
|
|
const fn default_params() -> [f32; N_PARAMS] {
|
|
let mut p = [0.0f32; N_PARAMS];
|
|
let mut i = 0;
|
|
while i < N_PARAMS {
|
|
// Deterministic pseudo-random initialization: scaled index with alternation.
|
|
let sign = if i % 2 == 0 { 1.0 } else { -1.0 };
|
|
// (i * 0.037 + 0.01) * sign via integer scaling for const compatibility.
|
|
let magnitude = (i as f32 * 37.0 + 10.0) / 1000.0 * sign;
|
|
p[i] = magnitude;
|
|
i += 1;
|
|
}
|
|
p
|
|
}
|
|
|
|
/// Process one frame with learning.
|
|
///
|
|
/// `features` -- 8D CSI feature vector (mean amplitude per subcarrier group).
|
|
/// `target_zone` -- ground truth zone label (0-3), or -1 if no label available.
|
|
///
|
|
/// When `target_zone >= 0`, the system performs a gradient step and updates
|
|
/// parameters. When -1, it only runs inference.
|
|
///
|
|
/// Returns events as `(event_id, value)` pairs.
|
|
pub fn process_frame(&mut self, features: &[f32], target_zone: i32) -> &[(i32, f32)] {
|
|
static mut EVENTS: [(i32, f32); 4] = [(0, 0.0); 4];
|
|
let mut n_ev = 0usize;
|
|
|
|
if features.len() < N_INPUT {
|
|
return &[];
|
|
}
|
|
|
|
self.frame_count += 1;
|
|
|
|
// Run forward pass: predict zone from features.
|
|
let predicted = self.forward(features);
|
|
|
|
// If we have a ground truth label, compute loss and update.
|
|
if target_zone >= 0 && (target_zone as usize) < N_OUTPUT {
|
|
let tz = target_zone as usize;
|
|
|
|
// Compute MSE loss against one-hot target.
|
|
let current_loss = self.compute_mse_loss(&predicted, tz);
|
|
|
|
// Compute EWC penalty.
|
|
let ewc_penalty = if self.has_prior {
|
|
self.compute_ewc_penalty()
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
let total_loss = current_loss + ewc_penalty;
|
|
self.last_loss = total_loss;
|
|
self.last_penalty = ewc_penalty;
|
|
|
|
// Finite-difference gradient estimation (round-robin subset).
|
|
self.update_gradients(features, tz);
|
|
|
|
// Gradient descent step.
|
|
self.gradient_step(features, tz);
|
|
|
|
// Track stability for task boundary detection.
|
|
if current_loss < STABLE_LOSS_THRESHOLD {
|
|
self.stable_frames += 1;
|
|
} else {
|
|
self.stable_frames = 0;
|
|
}
|
|
|
|
// Task boundary detection.
|
|
if self.stable_frames >= STABLE_FRAMES_THRESHOLD
|
|
&& self.task_count < MAX_TASKS
|
|
{
|
|
self.commit_task();
|
|
unsafe {
|
|
EVENTS[n_ev] = (EVENT_NEW_TASK_LEARNED, self.task_count as f32);
|
|
}
|
|
n_ev += 1;
|
|
|
|
// Emit mean Fisher value.
|
|
let mean_fisher = self.mean_fisher();
|
|
if n_ev < 4 {
|
|
unsafe {
|
|
EVENTS[n_ev] = (EVENT_FISHER_UPDATE, mean_fisher);
|
|
}
|
|
n_ev += 1;
|
|
}
|
|
}
|
|
|
|
// Periodic reporting.
|
|
if self.frame_count % REPORT_INTERVAL == 0 {
|
|
if n_ev < 4 {
|
|
unsafe {
|
|
EVENTS[n_ev] = (EVENT_KNOWLEDGE_RETAINED, ewc_penalty);
|
|
}
|
|
n_ev += 1;
|
|
}
|
|
|
|
// Forgetting risk: ratio of penalty to current loss.
|
|
let risk = if current_loss > 1e-8 {
|
|
ewc_penalty / current_loss
|
|
} else {
|
|
0.0
|
|
};
|
|
if n_ev < 4 {
|
|
unsafe {
|
|
EVENTS[n_ev] = (EVENT_FORGETTING_RISK, risk);
|
|
}
|
|
n_ev += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
unsafe { &EVENTS[..n_ev] }
|
|
}
|
|
|
|
/// Forward pass: linear classifier `output = params * features`.
|
|
///
|
|
/// Params are stored as [output_0_weights..., output_1_weights..., ...].
|
|
fn forward(&self, features: &[f32]) -> [f32; N_OUTPUT] {
|
|
let mut output = [0.0f32; N_OUTPUT];
|
|
for o in 0..N_OUTPUT {
|
|
let base = o * N_INPUT;
|
|
let mut sum = 0.0f32;
|
|
for i in 0..N_INPUT {
|
|
sum += self.params[base + i] * features[i];
|
|
}
|
|
output[o] = sum;
|
|
}
|
|
output
|
|
}
|
|
|
|
/// Compute MSE loss against a one-hot target for `target_zone`.
|
|
fn compute_mse_loss(&self, predicted: &[f32; N_OUTPUT], target: usize) -> f32 {
|
|
let mut loss = 0.0f32;
|
|
for o in 0..N_OUTPUT {
|
|
let target_val = if o == target { 1.0 } else { 0.0 };
|
|
let diff = predicted[o] - target_val;
|
|
loss += diff * diff;
|
|
}
|
|
loss / N_OUTPUT as f32
|
|
}
|
|
|
|
/// Compute the EWC penalty: (lambda/2) * sum(F_i * (theta_i - theta_i*)^2).
|
|
fn compute_ewc_penalty(&self) -> f32 {
|
|
let mut penalty = 0.0f32;
|
|
for i in 0..N_PARAMS {
|
|
let diff = self.params[i] - self.theta_star[i];
|
|
penalty += self.fisher[i] * diff * diff;
|
|
}
|
|
(LAMBDA / 2.0) * penalty
|
|
}
|
|
|
|
/// Estimate gradients via finite differences for a subset of parameters.
|
|
///
|
|
/// Uses round-robin scheduling: PARAMS_PER_FRAME parameters per call.
|
|
fn update_gradients(&mut self, features: &[f32], target: usize) {
|
|
let predicted = self.forward(features);
|
|
let base_loss = self.compute_mse_loss(&predicted, target);
|
|
|
|
for _step in 0..PARAMS_PER_FRAME {
|
|
let idx = self.param_cursor;
|
|
self.param_cursor = (self.param_cursor + 1) % N_PARAMS;
|
|
|
|
// Perturb parameter positively.
|
|
self.params[idx] += EPSILON;
|
|
let perturbed_pred = self.forward(features);
|
|
let perturbed_loss = self.compute_mse_loss(&perturbed_pred, target);
|
|
self.params[idx] -= EPSILON; // Restore.
|
|
|
|
// Finite-difference gradient.
|
|
let grad = (perturbed_loss - base_loss) / EPSILON;
|
|
|
|
// Accumulate gradient squared for Fisher estimation.
|
|
self.grad_accum[idx] =
|
|
FISHER_ALPHA * grad * grad + (1.0 - FISHER_ALPHA) * self.grad_accum[idx];
|
|
self.grad_count += 1;
|
|
}
|
|
}
|
|
|
|
/// Apply gradient descent with EWC regularization.
|
|
fn gradient_step(&mut self, features: &[f32], target: usize) {
|
|
// Compute output error: predicted - target (one-hot).
|
|
let predicted = self.forward(features);
|
|
|
|
for o in 0..N_OUTPUT {
|
|
let target_val = if o == target { 1.0 } else { 0.0 };
|
|
let error = predicted[o] - target_val;
|
|
|
|
let base = o * N_INPUT;
|
|
for i in 0..N_INPUT {
|
|
// Gradient of MSE w.r.t. weight: 2 * error * feature / N_OUTPUT.
|
|
let grad_mse = 2.0 * error * features[i] / N_OUTPUT as f32;
|
|
|
|
// EWC gradient: lambda * F_i * (theta_i - theta_i*).
|
|
let grad_ewc = if self.has_prior {
|
|
LAMBDA * self.fisher[base + i]
|
|
* (self.params[base + i] - self.theta_star[base + i])
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
let total_grad = grad_mse + grad_ewc;
|
|
self.params[base + i] -= LEARNING_RATE * total_grad;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Commit the current state as a learned task.
|
|
fn commit_task(&mut self) {
|
|
// Snapshot parameters.
|
|
self.theta_star = self.params;
|
|
|
|
// Update Fisher diagonal from accumulated gradient squares.
|
|
if self.has_prior {
|
|
// Merge with existing Fisher (online consolidation).
|
|
for i in 0..N_PARAMS {
|
|
self.fisher[i] = 0.5 * self.fisher[i] + 0.5 * self.grad_accum[i];
|
|
}
|
|
} else {
|
|
// First task: Fisher = accumulated gradient squares.
|
|
self.fisher = self.grad_accum;
|
|
}
|
|
|
|
// Reset accumulators.
|
|
self.grad_accum = [0.0; N_PARAMS];
|
|
self.grad_count = 0;
|
|
self.stable_frames = 0;
|
|
self.task_count += 1;
|
|
self.has_prior = true;
|
|
}
|
|
|
|
/// Compute mean Fisher information across all parameters.
|
|
fn mean_fisher(&self) -> f32 {
|
|
let mut sum = 0.0f32;
|
|
for i in 0..N_PARAMS {
|
|
sum += self.fisher[i];
|
|
}
|
|
sum / N_PARAMS as f32
|
|
}
|
|
|
|
/// Run inference only (no learning). Returns the predicted zone (argmax).
|
|
pub fn predict(&self, features: &[f32]) -> u8 {
|
|
if features.len() < N_INPUT {
|
|
return 0;
|
|
}
|
|
let output = self.forward(features);
|
|
let mut best = 0u8;
|
|
let mut best_val = output[0];
|
|
for o in 1..N_OUTPUT {
|
|
if output[o] > best_val {
|
|
best_val = output[o];
|
|
best = o as u8;
|
|
}
|
|
}
|
|
best
|
|
}
|
|
|
|
/// Get the current parameter vector.
|
|
pub fn parameters(&self) -> &[f32; N_PARAMS] {
|
|
&self.params
|
|
}
|
|
|
|
/// Get the Fisher diagonal.
|
|
pub fn fisher_diagonal(&self) -> &[f32; N_PARAMS] {
|
|
&self.fisher
|
|
}
|
|
|
|
/// Get the number of completed tasks.
|
|
pub fn task_count(&self) -> u8 {
|
|
self.task_count
|
|
}
|
|
|
|
/// Get the last computed total loss.
|
|
pub fn last_loss(&self) -> f32 {
|
|
self.last_loss
|
|
}
|
|
|
|
/// Get the last computed EWC penalty.
|
|
pub fn last_penalty(&self) -> f32 {
|
|
self.last_penalty
|
|
}
|
|
|
|
/// Get total frames processed.
|
|
pub fn frame_count(&self) -> u32 {
|
|
self.frame_count
|
|
}
|
|
|
|
/// Whether a prior task has been committed.
|
|
pub fn has_prior_task(&self) -> bool {
|
|
self.has_prior
|
|
}
|
|
|
|
/// Reset to initial state.
|
|
pub fn reset(&mut self) {
|
|
*self = Self::new();
|
|
}
|
|
}
|
|
|
|
// ── Tests ────────────────────────────────────────────────────────────────────
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use libm::fabsf;
|
|
|
|
#[test]
|
|
fn test_const_new() {
|
|
let ewc = EwcLifelong::new();
|
|
assert_eq!(ewc.frame_count(), 0);
|
|
assert_eq!(ewc.task_count(), 0);
|
|
assert!(!ewc.has_prior_task());
|
|
}
|
|
|
|
#[test]
|
|
fn test_default_params_nonzero() {
|
|
let ewc = EwcLifelong::new();
|
|
let params = ewc.parameters();
|
|
// At least some params should be nonzero (symmetry breaking).
|
|
let nonzero = params.iter().filter(|&&p| fabsf(p) > 1e-6).count();
|
|
assert!(nonzero > N_PARAMS / 2,
|
|
"default params should have diverse nonzero values, got {}/{}", nonzero, N_PARAMS);
|
|
}
|
|
|
|
#[test]
|
|
fn test_forward_produces_output() {
|
|
let ewc = EwcLifelong::new();
|
|
let features = [1.0f32; N_INPUT];
|
|
let output = ewc.predict(&features);
|
|
assert!(output < N_OUTPUT as u8, "predicted zone should be 0-3");
|
|
}
|
|
|
|
#[test]
|
|
fn test_insufficient_features_no_events() {
|
|
let mut ewc = EwcLifelong::new();
|
|
let features = [1.0f32; 4]; // Only 4, need 8.
|
|
let events = ewc.process_frame(&features, 0);
|
|
assert!(events.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_inference_only_no_learning() {
|
|
let mut ewc = EwcLifelong::new();
|
|
let features = [1.0f32; N_INPUT];
|
|
// target_zone = -1 means no label -> no learning.
|
|
let events = ewc.process_frame(&features, -1);
|
|
assert!(events.is_empty(), "inference-only should emit no events");
|
|
assert_eq!(ewc.task_count(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_learning_reduces_loss() {
|
|
let mut ewc = EwcLifelong::new();
|
|
let features = [0.5f32, 0.3, 0.8, 0.1, 0.6, 0.2, 0.9, 0.4];
|
|
let target = 2; // Zone 2.
|
|
|
|
// Train for many frames.
|
|
for _ in 0..200 {
|
|
ewc.process_frame(&features, target);
|
|
}
|
|
|
|
// After training, the loss should have decreased.
|
|
assert!(ewc.last_loss() < 1.0,
|
|
"loss should decrease after training, got {}", ewc.last_loss());
|
|
}
|
|
|
|
#[test]
|
|
fn test_ewc_penalty_zero_without_prior() {
|
|
let mut ewc = EwcLifelong::new();
|
|
let features = [1.0f32; N_INPUT];
|
|
ewc.process_frame(&features, 0);
|
|
assert!(!ewc.has_prior_task());
|
|
assert!(ewc.last_penalty() < 1e-8,
|
|
"EWC penalty should be 0 without prior task");
|
|
}
|
|
|
|
#[test]
|
|
fn test_task_boundary_detection() {
|
|
let mut ewc = EwcLifelong::new();
|
|
let features = [0.5f32; N_INPUT];
|
|
let target = 1;
|
|
|
|
// Run enough frames to potentially trigger task boundary.
|
|
for _ in 0..500 {
|
|
ewc.process_frame(&features, target);
|
|
}
|
|
|
|
// Exercise the accessor -- exact timing depends on convergence.
|
|
let _ = ewc.task_count();
|
|
}
|
|
|
|
#[test]
|
|
fn test_fisher_starts_zero() {
|
|
let ewc = EwcLifelong::new();
|
|
let fisher = ewc.fisher_diagonal();
|
|
for &f in fisher.iter() {
|
|
assert!(fabsf(f) < 1e-8, "Fisher should start at 0");
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_commit_task_sets_prior() {
|
|
let mut ewc = EwcLifelong::new();
|
|
ewc.stable_frames = STABLE_FRAMES_THRESHOLD;
|
|
ewc.commit_task();
|
|
assert!(ewc.has_prior_task());
|
|
assert_eq!(ewc.task_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_ewc_penalty_nonzero_after_drift() {
|
|
let mut ewc = EwcLifelong::new();
|
|
|
|
// Set up a prior task with nonzero Fisher.
|
|
ewc.fisher = [0.1; N_PARAMS];
|
|
ewc.theta_star = [0.0; N_PARAMS];
|
|
ewc.has_prior = true;
|
|
|
|
// Shift parameters away from theta_star.
|
|
for i in 0..N_PARAMS {
|
|
ewc.params[i] = 0.5;
|
|
}
|
|
|
|
let penalty = ewc.compute_ewc_penalty();
|
|
// Expected: (1000/2) * 32 * 0.1 * 0.25 = 400.0
|
|
assert!(penalty > 100.0,
|
|
"EWC penalty should be large when params drift, got {}", penalty);
|
|
}
|
|
|
|
#[test]
|
|
fn test_predict_deterministic() {
|
|
let ewc = EwcLifelong::new();
|
|
let features = [0.5f32; N_INPUT];
|
|
let p1 = ewc.predict(&features);
|
|
let p2 = ewc.predict(&features);
|
|
assert_eq!(p1, p2, "predict should be deterministic");
|
|
}
|
|
|
|
#[test]
|
|
fn test_reset() {
|
|
let mut ewc = EwcLifelong::new();
|
|
let features = [1.0f32; N_INPUT];
|
|
for _ in 0..50 {
|
|
ewc.process_frame(&features, 0);
|
|
}
|
|
assert!(ewc.frame_count() > 0);
|
|
ewc.reset();
|
|
assert_eq!(ewc.frame_count(), 0);
|
|
assert_eq!(ewc.task_count(), 0);
|
|
assert!(!ewc.has_prior_task());
|
|
}
|
|
|
|
#[test]
|
|
fn test_max_tasks_cap() {
|
|
let mut ewc = EwcLifelong::new();
|
|
ewc.task_count = MAX_TASKS;
|
|
ewc.stable_frames = STABLE_FRAMES_THRESHOLD;
|
|
let features = [1.0f32; N_INPUT];
|
|
let events = ewc.process_frame(&features, 0);
|
|
let new_task_events = events.iter()
|
|
.filter(|e| e.0 == EVENT_NEW_TASK_LEARNED)
|
|
.count();
|
|
assert_eq!(new_task_events, 0,
|
|
"should not learn new task when at MAX_TASKS");
|
|
}
|
|
}
|