diff --git a/Cargo.lock b/Cargo.lock index 938720c6..0460cdeb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6672,11 +6672,7 @@ dependencies = [ [[package]] name = "ruvector-postgres" -<<<<<<< HEAD version = "0.2.6" -======= -version = "0.2.5" ->>>>>>> origin/main dependencies = [ "approx", "bincode 1.3.3", diff --git a/crates/ruvector-gnn/examples/loss_demo.rs b/crates/ruvector-gnn/examples/loss_demo.rs new file mode 100644 index 00000000..1efe2877 --- /dev/null +++ b/crates/ruvector-gnn/examples/loss_demo.rs @@ -0,0 +1,121 @@ +//! Manual test/demo for Loss functions +//! +//! Run with: cargo run -p ruvector-gnn --example loss_demo + +use ndarray::Array2; +use ruvector_gnn::training::{Loss, LossType, Optimizer, OptimizerType}; + +fn main() { + println!("=== RuVector GNN Loss Functions Demo ===\n"); + + // 1. Basic MSE Loss + println!("1. MSE Loss Demo"); + println!(" -----------------"); + let predictions = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.7, 0.8, 0.1, 0.1]).unwrap(); + let targets = Array2::from_shape_vec((2, 3), vec![0.0, 0.0, 1.0, 1.0, 0.0, 0.0]).unwrap(); + + let mse_loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap(); + let mse_grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap(); + + println!(" Predictions: {:?}", predictions.as_slice().unwrap()); + println!(" Targets: {:?}", targets.as_slice().unwrap()); + println!(" MSE Loss: {:.6}", mse_loss); + println!(" Gradient: {:?}\n", mse_grad.as_slice().unwrap()); + + // 2. Binary Cross Entropy Loss + println!("2. Binary Cross Entropy Demo"); + println!(" --------------------------"); + let pred_bce = Array2::from_shape_vec((1, 4), vec![0.9, 0.1, 0.8, 0.3]).unwrap(); + let target_bce = Array2::from_shape_vec((1, 4), vec![1.0, 0.0, 1.0, 0.0]).unwrap(); + + let bce_loss = Loss::compute(LossType::BinaryCrossEntropy, &pred_bce, &target_bce).unwrap(); + let bce_grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred_bce, &target_bce).unwrap(); + + println!(" Predictions: {:?}", pred_bce.as_slice().unwrap()); + println!(" Targets: {:?}", target_bce.as_slice().unwrap()); + println!(" BCE Loss: {:.6}", bce_loss); + println!(" Gradient: {:?}\n", bce_grad.as_slice().unwrap()); + + // 3. Cross Entropy Loss (multi-class) + println!("3. Cross Entropy Demo (multi-class)"); + println!(" ----------------------------------"); + // Softmax-like predictions (each row sums to ~1) + let pred_ce = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.1, 0.8]).unwrap(); + let target_ce = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0]).unwrap(); + + let ce_loss = Loss::compute(LossType::CrossEntropy, &pred_ce, &target_ce).unwrap(); + let ce_grad = Loss::gradient(LossType::CrossEntropy, &pred_ce, &target_ce).unwrap(); + + println!(" Predictions (row 1): {:?}", &pred_ce.as_slice().unwrap()[0..3]); + println!(" Predictions (row 2): {:?}", &pred_ce.as_slice().unwrap()[3..6]); + println!(" Targets (one-hot): [1,0,0] and [0,0,1]"); + println!(" CE Loss: {:.6}", ce_loss); + println!(" Gradient: {:?}\n", ce_grad.as_slice().unwrap()); + + // 4. Training loop demo - minimize MSE + println!("4. Training Loop Demo (minimizing MSE)"); + println!(" ------------------------------------"); + + let target = Array2::from_shape_vec((1, 4), vec![1.0, 0.0, 1.0, 0.0]).unwrap(); + let mut pred = Array2::from_shape_vec((1, 4), vec![0.5, 0.5, 0.5, 0.5]).unwrap(); + + let mut optimizer = Optimizer::new(OptimizerType::Adam { + learning_rate: 0.1, + beta1: 0.9, + beta2: 0.999, + epsilon: 1e-8, + }); + + println!(" Target: {:?}", target.as_slice().unwrap()); + println!(" Initial: {:?}", pred.as_slice().unwrap()); + + let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + println!(" Initial loss: {:.6}\n", initial_loss); + + for epoch in 0..20 { + let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); + optimizer.step(&mut pred, &grad).unwrap(); + + if epoch % 5 == 0 || epoch == 19 { + println!( + " Epoch {:2}: loss={:.6}, pred={:?}", + epoch, + loss, + pred.as_slice() + .unwrap() + .iter() + .map(|x| format!("{:.3}", x)) + .collect::>() + ); + } + } + + let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + println!("\n Final loss: {:.6}", final_loss); + println!( + " Improvement: {:.1}%", + (1.0 - final_loss / initial_loss) * 100.0 + ); + + // 5. Numerical stability test + println!("\n5. Numerical Stability Test"); + println!(" -------------------------"); + + // Test with extreme values + let extreme_pred = Array2::from_shape_vec((1, 2), vec![1e-10, 1.0 - 1e-10]).unwrap(); + let extreme_target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); + + let bce_extreme = Loss::compute(LossType::BinaryCrossEntropy, &extreme_pred, &extreme_target); + let ce_extreme = Loss::compute(LossType::CrossEntropy, &extreme_pred, &extreme_target); + + println!(" Extreme predictions: [{:.2e}, {:.2e}]", 1e-10, 1.0 - 1e-10); + println!(" BCE result: {:?}", bce_extreme); + println!(" CE result: {:?}", ce_extreme); + + // Test gradient stability + let grad_extreme = Loss::gradient(LossType::BinaryCrossEntropy, &extreme_pred, &extreme_target); + println!(" BCE gradient: {:?}", grad_extreme); + + println!("\n=== Demo Complete ==="); +} diff --git a/crates/ruvector-gnn/src/lib.rs b/crates/ruvector-gnn/src/lib.rs index 74323087..e100ffbc 100644 --- a/crates/ruvector-gnn/src/lib.rs +++ b/crates/ruvector-gnn/src/lib.rs @@ -70,8 +70,8 @@ pub use replay::{DistributionStats, ReplayBuffer, ReplayEntry}; pub use scheduler::{LearningRateScheduler, SchedulerType}; pub use search::{cosine_similarity, differentiable_search, hierarchical_forward}; pub use training::{ - info_nce_loss, local_contrastive_loss, sgd_step, OnlineConfig, Optimizer, OptimizerType, - TrainConfig, + info_nce_loss, local_contrastive_loss, sgd_step, Loss, LossType, OnlineConfig, Optimizer, + OptimizerType, TrainConfig, }; #[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))] diff --git a/crates/ruvector-gnn/src/training.rs b/crates/ruvector-gnn/src/training.rs index 4b415171..38827e04 100644 --- a/crates/ruvector-gnn/src/training.rs +++ b/crates/ruvector-gnn/src/training.rs @@ -227,20 +227,45 @@ pub enum LossType { BinaryCrossEntropy, } -/// TODO: Implement loss functions +/// Loss function implementations for neural network training. +/// +/// Provides forward (loss computation) and backward (gradient computation) passes +/// for common loss functions used in GNN training. +/// +/// # Numerical Stability +/// +/// All loss functions use epsilon clamping and gradient clipping to prevent +/// numerical instability with extreme prediction values (near 0 or 1). pub struct Loss; impl Loss { - /// Compute loss value + /// Small epsilon value for numerical stability in logarithms and divisions. + const EPS: f32 = 1e-7; + + /// Maximum absolute gradient value to prevent explosion. + const MAX_GRAD: f32 = 1e6; + + /// Compute the loss value between predictions and targets. /// /// # Arguments - /// * `loss_type` - Type of loss function to use - /// * `predictions` - Model predictions (must match targets shape) - /// * `targets` - Ground truth targets + /// * `loss_type` - The type of loss function to use + /// * `predictions` - Model predictions as a 2D array + /// * `targets` - Ground truth targets as a 2D array (same shape as predictions) /// /// # Returns - /// * `Ok(f32)` - Computed loss value - /// * `Err(GnnError)` - If shapes don't match + /// * `Ok(f32)` - The computed scalar loss value + /// * `Err(GnnError)` - If shapes don't match or computation fails + /// + /// # Example + /// ``` + /// use ndarray::Array2; + /// use ruvector_gnn::training::{Loss, LossType}; + /// + /// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap(); + /// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap(); + /// let loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap(); + /// assert!(loss >= 0.0); + /// ``` pub fn compute( loss_type: LossType, predictions: &Array2, @@ -254,47 +279,38 @@ impl Loss { )); } - const EPSILON: f32 = 1e-7; + if predictions.is_empty() { + return Err(GnnError::invalid_input("Cannot compute loss on empty arrays")); + } - let loss = match loss_type { - LossType::Mse => { - // Mean Squared Error: mean((pred - target)^2) - let diff = predictions - targets; - let squared_diff = diff.mapv(|x| x * x); - squared_diff.mean().unwrap_or(0.0) - } - LossType::CrossEntropy => { - // Cross Entropy: -sum(target * log(pred + epsilon)) - let log_preds = predictions.mapv(|x| (x + EPSILON).ln()); - let product = targets * log_preds; - -product.sum() - } - LossType::BinaryCrossEntropy => { - // Binary Cross Entropy: -mean(target * log(pred) + (1-target) * log(1-pred)) - let term1 = targets.iter().zip(predictions.iter()).map(|(&t, &p)| { - t * (p + EPSILON).ln() - }); - let term2 = targets.iter().zip(predictions.iter()).map(|(&t, &p)| { - (1.0 - t) * (1.0 - p + EPSILON).ln() - }); - let sum: f32 = term1.zip(term2).map(|(t1, t2)| t1 + t2).sum(); - -sum / (predictions.len() as f32) - } - }; - - Ok(loss) + match loss_type { + LossType::Mse => Self::mse_forward(predictions, targets), + LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets), + LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets), + } } - /// Compute loss gradient + /// Compute the gradient of the loss with respect to predictions. /// /// # Arguments - /// * `loss_type` - Type of loss function to use - /// * `predictions` - Model predictions (must match targets shape) - /// * `targets` - Ground truth targets + /// * `loss_type` - The type of loss function to use + /// * `predictions` - Model predictions as a 2D array + /// * `targets` - Ground truth targets as a 2D array (same shape as predictions) /// /// # Returns - /// * `Ok(Array2)` - Gradient of loss with respect to predictions - /// * `Err(GnnError)` - If shapes don't match + /// * `Ok(Array2)` - Gradient array with same shape as predictions + /// * `Err(GnnError)` - If shapes don't match or computation fails + /// + /// # Example + /// ``` + /// use ndarray::Array2; + /// use ruvector_gnn::training::{Loss, LossType}; + /// + /// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap(); + /// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap(); + /// let grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap(); + /// assert_eq!(grad.shape(), predictions.shape()); + /// ``` pub fn gradient( loss_type: LossType, predictions: &Array2, @@ -308,37 +324,95 @@ impl Loss { )); } - const EPSILON: f32 = 1e-7; + if predictions.is_empty() { + return Err(GnnError::invalid_input( + "Cannot compute gradient on empty arrays", + )); + } + + match loss_type { + LossType::Mse => Self::mse_backward(predictions, targets), + LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets), + LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets), + } + } + + /// Mean Squared Error: MSE = mean((predictions - targets)^2) + fn mse_forward(predictions: &Array2, targets: &Array2) -> Result { + let diff = predictions - targets; + let squared = diff.mapv(|x| x * x); + Ok(squared.mean().unwrap_or(0.0)) + } + + /// MSE gradient: d(MSE)/d(pred) = 2 * (predictions - targets) / n + fn mse_backward(predictions: &Array2, targets: &Array2) -> Result> { let n = predictions.len() as f32; + let diff = predictions - targets; + Ok(diff.mapv(|x| 2.0 * x / n)) + } - let gradient = match loss_type { - LossType::Mse => { - // MSE gradient: 2 * (pred - target) / n - let diff = predictions - targets; - diff.mapv(|x| 2.0 * x / n) - } - LossType::CrossEntropy => { - // Cross Entropy gradient: -target / (pred + epsilon) - let mut grad = Array2::zeros(predictions.dim()); - for (i, (&t, &p)) in targets.iter().zip(predictions.iter()).enumerate() { - let (row, col) = (i / predictions.ncols(), i % predictions.ncols()); - grad[[row, col]] = -t / (p + EPSILON); - } - grad - } - LossType::BinaryCrossEntropy => { - // Binary Cross Entropy gradient: (pred - target) / (pred * (1 - pred) + epsilon) - let mut grad = Array2::zeros(predictions.dim()); - for (i, (&t, &p)) in targets.iter().zip(predictions.iter()).enumerate() { - let (row, col) = (i / predictions.ncols(), i % predictions.ncols()); - let denom = p * (1.0 - p) + EPSILON; - grad[[row, col]] = (p - t) / denom; - } - grad - } - }; + /// Cross Entropy: CE = -mean(sum(targets * log(predictions), axis=1)) + /// + /// Used for multi-class classification where targets are one-hot encoded + /// and predictions are softmax probabilities. + fn cross_entropy_forward(predictions: &Array2, targets: &Array2) -> Result { + let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln()); + let elementwise = targets * &log_pred; + let loss = -elementwise.sum() / predictions.nrows() as f32; + Ok(loss) + } - Ok(gradient) + /// Cross Entropy gradient: d(CE)/d(pred) = -targets / predictions / n + /// + /// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion. + fn cross_entropy_backward( + predictions: &Array2, + targets: &Array2, + ) -> Result> { + let n = predictions.nrows() as f32; + // Clamp predictions to avoid division by zero + let safe_pred = predictions.mapv(|x| x.max(Self::EPS)); + let grad = targets / &safe_pred; + // Apply gradient clipping + Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD))) + } + + /// Binary Cross Entropy: BCE = -mean(targets * log(pred) + (1 - targets) * log(1 - pred)) + /// + /// Used for binary classification or multi-label classification. + fn bce_forward(predictions: &Array2, targets: &Array2) -> Result { + let n = predictions.len() as f32; + let loss: f32 = predictions + .iter() + .zip(targets.iter()) + .map(|(&p, &t)| { + // Clamp predictions to (eps, 1-eps) for numerical stability + let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS); + -(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln()) + }) + .sum(); + Ok(loss / n) + } + + /// BCE gradient: d(BCE)/d(pred) = (-targets/pred + (1-targets)/(1-pred)) / n + /// + /// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion. + fn bce_backward(predictions: &Array2, targets: &Array2) -> Result> { + let n = predictions.len() as f32; + let grad_vec: Vec = predictions + .iter() + .zip(targets.iter()) + .map(|(&p, &t)| { + // Clamp predictions for numerical stability + let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS); + let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n; + // Clip gradient to prevent explosion + grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD) + }) + .collect(); + + Array2::from_shape_vec(predictions.dim(), grad_vec) + .map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e))) } } @@ -1027,4 +1101,206 @@ mod tests { assert!(params[[0, 0]].abs() < 0.5); assert!(params[[0, 1]].abs() < 0.5); } + + // ==================== Loss Function Tests ==================== + + #[test] + fn test_mse_loss_zero_when_equal() { + let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); + let target = pred.clone(); + let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + assert!((loss - 0.0).abs() < 1e-6, "MSE should be 0 when pred == target"); + } + + #[test] + fn test_mse_loss_positive() { + let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); + let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap(); + let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + // Each element differs by 1, so squared diff = 1, mean = 1 + assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss); + } + + #[test] + fn test_mse_loss_varying_diffs() { + let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap(); + let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); + let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + // Squared diffs: 1, 4, 9, 16. Mean = 30/4 = 7.5 + assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss); + } + + #[test] + fn test_mse_gradient_shape() { + let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap(); + let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); + let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); + assert_eq!(grad.shape(), pred.shape()); + } + + #[test] + fn test_mse_gradient_direction() { + let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap(); + let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap(); + let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); + // grad = 2*(pred - target)/n = 2*(-1, 1)/2 = (-1, 1) + assert!(grad[[0, 0]] < 0.0, "Gradient should be negative when pred < target"); + assert!(grad[[0, 1]] > 0.0, "Gradient should be positive when pred > target"); + } + + #[test] + fn test_mse_gradient_zero_when_equal() { + let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap(); + let target = pred.clone(); + let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); + assert!(grad.iter().all(|&x| x.abs() < 1e-6), "Gradient should be zero when pred == target"); + } + + #[test] + fn test_bce_loss_perfect_predictions() { + let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap(); + let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); + let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); + // Near-perfect predictions should have low loss + assert!(loss < 0.1, "BCE should be low for good predictions, got {}", loss); + } + + #[test] + fn test_bce_loss_bad_predictions() { + let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap(); + let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); + let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); + // Bad predictions should have high loss + assert!(loss > 1.0, "BCE should be high for bad predictions, got {}", loss); + } + + #[test] + fn test_bce_loss_numerical_stability() { + // Test with extreme values that could cause numerical issues + let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap(); + let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap(); + let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); + assert!(loss.is_finite(), "BCE should be finite even with extreme values"); + } + + #[test] + fn test_bce_gradient_shape() { + let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap(); + let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap(); + let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); + assert_eq!(grad.shape(), pred.shape()); + } + + #[test] + fn test_bce_gradient_direction() { + let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap(); + let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); + let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap(); + // When target=1 and pred<1, gradient should push pred up (negative gradient) + assert!(grad[[0, 0]] < 0.0, "Gradient should be negative to increase pred towards 1"); + // When target=0 and pred>0, gradient should push pred down (positive gradient) + assert!(grad[[0, 1]] > 0.0, "Gradient should be positive to decrease pred towards 0"); + } + + #[test] + fn test_cross_entropy_one_hot() { + // Softmax-like predictions (sum to 1 per row) + let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap(); + let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap(); + let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap(); + // Good predictions should have reasonable loss + assert!(loss > 0.0 && loss < 1.0, "CE should be reasonable for good predictions, got {}", loss); + } + + #[test] + fn test_cross_entropy_wrong_class() { + let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap(); + let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap(); + let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap(); + // Predicting wrong class should have high loss + assert!(loss > 1.0, "CE should be high for wrong predictions, got {}", loss); + } + + #[test] + fn test_cross_entropy_gradient_shape() { + let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap(); + let target = Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap(); + let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap(); + assert_eq!(grad.shape(), pred.shape()); + } + + #[test] + fn test_loss_dimension_mismatch_error() { + let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap(); + let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); + + let result = Loss::compute(LossType::Mse, &pred, &target); + assert!(result.is_err(), "Should error on dimension mismatch"); + + let result = Loss::gradient(LossType::Mse, &pred, &target); + assert!(result.is_err(), "Gradient should error on dimension mismatch"); + } + + #[test] + fn test_loss_empty_array_error() { + let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap(); + let target = Array2::from_shape_vec((0, 2), vec![]).unwrap(); + + let result = Loss::compute(LossType::Mse, &pred, &target); + assert!(result.is_err(), "Should error on empty arrays"); + + let result = Loss::gradient(LossType::Mse, &pred, &target); + assert!(result.is_err(), "Gradient should error on empty arrays"); + } + + #[test] + fn test_loss_gradient_numerical_check() { + // Numerical gradient check for MSE + let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap(); + let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); + + let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); + + // Compute numerical gradient + let eps = 1e-5; + for i in 0..2 { + let mut pred_plus = pred.clone(); + let mut pred_minus = pred.clone(); + pred_plus[[0, i]] += eps; + pred_minus[[0, i]] -= eps; + + let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap(); + let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap(); + + let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps); + let error = (analytical_grad[[0, i]] - numerical_grad).abs(); + + assert!(error < 1e-3, "Numerical gradient check failed: analytical={}, numerical={}", + analytical_grad[[0, i]], numerical_grad); + } + } + + #[test] + fn test_training_loop_integration() { + // Integration test: use Loss with Optimizer + let mut optimizer = Optimizer::new(OptimizerType::Sgd { + learning_rate: 0.1, + momentum: 0.0, + }); + + let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap(); + let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap(); + + let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + + // Perform a few optimization steps + for _ in 0..10 { + let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap(); + optimizer.step(&mut pred, &grad).unwrap(); + } + + let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap(); + + assert!(final_loss < initial_loss, "Loss should decrease during training"); + } }