feat(gnn): Implement loss functions with numerical stability (#65)

Implements MSE, Cross Entropy, and Binary Cross Entropy loss functions for GNN training.

Features:
- EPS (1e-7) and MAX_GRAD (1e6) constants for numerical stability
- Comprehensive documentation with examples
- Gradient clipping to prevent explosion
- Empty array validation
- 42 comprehensive tests covering all functionality

Resolves #63

Co-authored-by: Wirasm <wirasm@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
rUv 2025-12-09 16:50:27 +00:00
commit 44828ad56f
4 changed files with 468 additions and 75 deletions

4
Cargo.lock generated
View file

@ -6672,11 +6672,7 @@ dependencies = [
[[package]]
name = "ruvector-postgres"
<<<<<<< HEAD
version = "0.2.6"
=======
version = "0.2.5"
>>>>>>> origin/main
dependencies = [
"approx",
"bincode 1.3.3",

View file

@ -0,0 +1,121 @@
//! Manual test/demo for Loss functions
//!
//! Run with: cargo run -p ruvector-gnn --example loss_demo
use ndarray::Array2;
use ruvector_gnn::training::{Loss, LossType, Optimizer, OptimizerType};
fn main() {
println!("=== RuVector GNN Loss Functions Demo ===\n");
// 1. Basic MSE Loss
println!("1. MSE Loss Demo");
println!(" -----------------");
let predictions = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.7, 0.8, 0.1, 0.1]).unwrap();
let targets = Array2::from_shape_vec((2, 3), vec![0.0, 0.0, 1.0, 1.0, 0.0, 0.0]).unwrap();
let mse_loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap();
let mse_grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap();
println!(" Predictions: {:?}", predictions.as_slice().unwrap());
println!(" Targets: {:?}", targets.as_slice().unwrap());
println!(" MSE Loss: {:.6}", mse_loss);
println!(" Gradient: {:?}\n", mse_grad.as_slice().unwrap());
// 2. Binary Cross Entropy Loss
println!("2. Binary Cross Entropy Demo");
println!(" --------------------------");
let pred_bce = Array2::from_shape_vec((1, 4), vec![0.9, 0.1, 0.8, 0.3]).unwrap();
let target_bce = Array2::from_shape_vec((1, 4), vec![1.0, 0.0, 1.0, 0.0]).unwrap();
let bce_loss = Loss::compute(LossType::BinaryCrossEntropy, &pred_bce, &target_bce).unwrap();
let bce_grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred_bce, &target_bce).unwrap();
println!(" Predictions: {:?}", pred_bce.as_slice().unwrap());
println!(" Targets: {:?}", target_bce.as_slice().unwrap());
println!(" BCE Loss: {:.6}", bce_loss);
println!(" Gradient: {:?}\n", bce_grad.as_slice().unwrap());
// 3. Cross Entropy Loss (multi-class)
println!("3. Cross Entropy Demo (multi-class)");
println!(" ----------------------------------");
// Softmax-like predictions (each row sums to ~1)
let pred_ce = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.1, 0.8]).unwrap();
let target_ce = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
let ce_loss = Loss::compute(LossType::CrossEntropy, &pred_ce, &target_ce).unwrap();
let ce_grad = Loss::gradient(LossType::CrossEntropy, &pred_ce, &target_ce).unwrap();
println!(" Predictions (row 1): {:?}", &pred_ce.as_slice().unwrap()[0..3]);
println!(" Predictions (row 2): {:?}", &pred_ce.as_slice().unwrap()[3..6]);
println!(" Targets (one-hot): [1,0,0] and [0,0,1]");
println!(" CE Loss: {:.6}", ce_loss);
println!(" Gradient: {:?}\n", ce_grad.as_slice().unwrap());
// 4. Training loop demo - minimize MSE
println!("4. Training Loop Demo (minimizing MSE)");
println!(" ------------------------------------");
let target = Array2::from_shape_vec((1, 4), vec![1.0, 0.0, 1.0, 0.0]).unwrap();
let mut pred = Array2::from_shape_vec((1, 4), vec![0.5, 0.5, 0.5, 0.5]).unwrap();
let mut optimizer = Optimizer::new(OptimizerType::Adam {
learning_rate: 0.1,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
});
println!(" Target: {:?}", target.as_slice().unwrap());
println!(" Initial: {:?}", pred.as_slice().unwrap());
let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
println!(" Initial loss: {:.6}\n", initial_loss);
for epoch in 0..20 {
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
optimizer.step(&mut pred, &grad).unwrap();
if epoch % 5 == 0 || epoch == 19 {
println!(
" Epoch {:2}: loss={:.6}, pred={:?}",
epoch,
loss,
pred.as_slice()
.unwrap()
.iter()
.map(|x| format!("{:.3}", x))
.collect::<Vec<_>>()
);
}
}
let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
println!("\n Final loss: {:.6}", final_loss);
println!(
" Improvement: {:.1}%",
(1.0 - final_loss / initial_loss) * 100.0
);
// 5. Numerical stability test
println!("\n5. Numerical Stability Test");
println!(" -------------------------");
// Test with extreme values
let extreme_pred = Array2::from_shape_vec((1, 2), vec![1e-10, 1.0 - 1e-10]).unwrap();
let extreme_target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let bce_extreme = Loss::compute(LossType::BinaryCrossEntropy, &extreme_pred, &extreme_target);
let ce_extreme = Loss::compute(LossType::CrossEntropy, &extreme_pred, &extreme_target);
println!(" Extreme predictions: [{:.2e}, {:.2e}]", 1e-10, 1.0 - 1e-10);
println!(" BCE result: {:?}", bce_extreme);
println!(" CE result: {:?}", ce_extreme);
// Test gradient stability
let grad_extreme = Loss::gradient(LossType::BinaryCrossEntropy, &extreme_pred, &extreme_target);
println!(" BCE gradient: {:?}", grad_extreme);
println!("\n=== Demo Complete ===");
}

View file

@ -70,8 +70,8 @@ pub use replay::{DistributionStats, ReplayBuffer, ReplayEntry};
pub use scheduler::{LearningRateScheduler, SchedulerType};
pub use search::{cosine_similarity, differentiable_search, hierarchical_forward};
pub use training::{
info_nce_loss, local_contrastive_loss, sgd_step, OnlineConfig, Optimizer, OptimizerType,
TrainConfig,
info_nce_loss, local_contrastive_loss, sgd_step, Loss, LossType, OnlineConfig, Optimizer,
OptimizerType, TrainConfig,
};
#[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]

View file

@ -227,20 +227,45 @@ pub enum LossType {
BinaryCrossEntropy,
}
/// TODO: Implement loss functions
/// Loss function implementations for neural network training.
///
/// Provides forward (loss computation) and backward (gradient computation) passes
/// for common loss functions used in GNN training.
///
/// # Numerical Stability
///
/// All loss functions use epsilon clamping and gradient clipping to prevent
/// numerical instability with extreme prediction values (near 0 or 1).
pub struct Loss;
impl Loss {
/// Compute loss value
/// Small epsilon value for numerical stability in logarithms and divisions.
const EPS: f32 = 1e-7;
/// Maximum absolute gradient value to prevent explosion.
const MAX_GRAD: f32 = 1e6;
/// Compute the loss value between predictions and targets.
///
/// # Arguments
/// * `loss_type` - Type of loss function to use
/// * `predictions` - Model predictions (must match targets shape)
/// * `targets` - Ground truth targets
/// * `loss_type` - The type of loss function to use
/// * `predictions` - Model predictions as a 2D array
/// * `targets` - Ground truth targets as a 2D array (same shape as predictions)
///
/// # Returns
/// * `Ok(f32)` - Computed loss value
/// * `Err(GnnError)` - If shapes don't match
/// * `Ok(f32)` - The computed scalar loss value
/// * `Err(GnnError)` - If shapes don't match or computation fails
///
/// # Example
/// ```
/// use ndarray::Array2;
/// use ruvector_gnn::training::{Loss, LossType};
///
/// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap();
/// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
/// let loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap();
/// assert!(loss >= 0.0);
/// ```
pub fn compute(
loss_type: LossType,
predictions: &Array2<f32>,
@ -254,47 +279,38 @@ impl Loss {
));
}
const EPSILON: f32 = 1e-7;
if predictions.is_empty() {
return Err(GnnError::invalid_input("Cannot compute loss on empty arrays"));
}
let loss = match loss_type {
LossType::Mse => {
// Mean Squared Error: mean((pred - target)^2)
let diff = predictions - targets;
let squared_diff = diff.mapv(|x| x * x);
squared_diff.mean().unwrap_or(0.0)
}
LossType::CrossEntropy => {
// Cross Entropy: -sum(target * log(pred + epsilon))
let log_preds = predictions.mapv(|x| (x + EPSILON).ln());
let product = targets * log_preds;
-product.sum()
}
LossType::BinaryCrossEntropy => {
// Binary Cross Entropy: -mean(target * log(pred) + (1-target) * log(1-pred))
let term1 = targets.iter().zip(predictions.iter()).map(|(&t, &p)| {
t * (p + EPSILON).ln()
});
let term2 = targets.iter().zip(predictions.iter()).map(|(&t, &p)| {
(1.0 - t) * (1.0 - p + EPSILON).ln()
});
let sum: f32 = term1.zip(term2).map(|(t1, t2)| t1 + t2).sum();
-sum / (predictions.len() as f32)
}
};
Ok(loss)
match loss_type {
LossType::Mse => Self::mse_forward(predictions, targets),
LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets),
LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets),
}
}
/// Compute loss gradient
/// Compute the gradient of the loss with respect to predictions.
///
/// # Arguments
/// * `loss_type` - Type of loss function to use
/// * `predictions` - Model predictions (must match targets shape)
/// * `targets` - Ground truth targets
/// * `loss_type` - The type of loss function to use
/// * `predictions` - Model predictions as a 2D array
/// * `targets` - Ground truth targets as a 2D array (same shape as predictions)
///
/// # Returns
/// * `Ok(Array2<f32>)` - Gradient of loss with respect to predictions
/// * `Err(GnnError)` - If shapes don't match
/// * `Ok(Array2<f32>)` - Gradient array with same shape as predictions
/// * `Err(GnnError)` - If shapes don't match or computation fails
///
/// # Example
/// ```
/// use ndarray::Array2;
/// use ruvector_gnn::training::{Loss, LossType};
///
/// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap();
/// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
/// let grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap();
/// assert_eq!(grad.shape(), predictions.shape());
/// ```
pub fn gradient(
loss_type: LossType,
predictions: &Array2<f32>,
@ -308,37 +324,95 @@ impl Loss {
));
}
const EPSILON: f32 = 1e-7;
if predictions.is_empty() {
return Err(GnnError::invalid_input(
"Cannot compute gradient on empty arrays",
));
}
match loss_type {
LossType::Mse => Self::mse_backward(predictions, targets),
LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets),
LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets),
}
}
/// Mean Squared Error: MSE = mean((predictions - targets)^2)
fn mse_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
let diff = predictions - targets;
let squared = diff.mapv(|x| x * x);
Ok(squared.mean().unwrap_or(0.0))
}
/// MSE gradient: d(MSE)/d(pred) = 2 * (predictions - targets) / n
fn mse_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
let n = predictions.len() as f32;
let diff = predictions - targets;
Ok(diff.mapv(|x| 2.0 * x / n))
}
let gradient = match loss_type {
LossType::Mse => {
// MSE gradient: 2 * (pred - target) / n
let diff = predictions - targets;
diff.mapv(|x| 2.0 * x / n)
}
LossType::CrossEntropy => {
// Cross Entropy gradient: -target / (pred + epsilon)
let mut grad = Array2::zeros(predictions.dim());
for (i, (&t, &p)) in targets.iter().zip(predictions.iter()).enumerate() {
let (row, col) = (i / predictions.ncols(), i % predictions.ncols());
grad[[row, col]] = -t / (p + EPSILON);
}
grad
}
LossType::BinaryCrossEntropy => {
// Binary Cross Entropy gradient: (pred - target) / (pred * (1 - pred) + epsilon)
let mut grad = Array2::zeros(predictions.dim());
for (i, (&t, &p)) in targets.iter().zip(predictions.iter()).enumerate() {
let (row, col) = (i / predictions.ncols(), i % predictions.ncols());
let denom = p * (1.0 - p) + EPSILON;
grad[[row, col]] = (p - t) / denom;
}
grad
}
};
/// Cross Entropy: CE = -mean(sum(targets * log(predictions), axis=1))
///
/// Used for multi-class classification where targets are one-hot encoded
/// and predictions are softmax probabilities.
fn cross_entropy_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln());
let elementwise = targets * &log_pred;
let loss = -elementwise.sum() / predictions.nrows() as f32;
Ok(loss)
}
Ok(gradient)
/// Cross Entropy gradient: d(CE)/d(pred) = -targets / predictions / n
///
/// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion.
fn cross_entropy_backward(
predictions: &Array2<f32>,
targets: &Array2<f32>,
) -> Result<Array2<f32>> {
let n = predictions.nrows() as f32;
// Clamp predictions to avoid division by zero
let safe_pred = predictions.mapv(|x| x.max(Self::EPS));
let grad = targets / &safe_pred;
// Apply gradient clipping
Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD)))
}
/// Binary Cross Entropy: BCE = -mean(targets * log(pred) + (1 - targets) * log(1 - pred))
///
/// Used for binary classification or multi-label classification.
fn bce_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
let n = predictions.len() as f32;
let loss: f32 = predictions
.iter()
.zip(targets.iter())
.map(|(&p, &t)| {
// Clamp predictions to (eps, 1-eps) for numerical stability
let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
-(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln())
})
.sum();
Ok(loss / n)
}
/// BCE gradient: d(BCE)/d(pred) = (-targets/pred + (1-targets)/(1-pred)) / n
///
/// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion.
fn bce_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
let n = predictions.len() as f32;
let grad_vec: Vec<f32> = predictions
.iter()
.zip(targets.iter())
.map(|(&p, &t)| {
// Clamp predictions for numerical stability
let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n;
// Clip gradient to prevent explosion
grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD)
})
.collect();
Array2::from_shape_vec(predictions.dim(), grad_vec)
.map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e)))
}
}
@ -1027,4 +1101,206 @@ mod tests {
assert!(params[[0, 0]].abs() < 0.5);
assert!(params[[0, 1]].abs() < 0.5);
}
// ==================== Loss Function Tests ====================
#[test]
fn test_mse_loss_zero_when_equal() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let target = pred.clone();
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
assert!((loss - 0.0).abs() < 1e-6, "MSE should be 0 when pred == target");
}
#[test]
fn test_mse_loss_positive() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap();
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
// Each element differs by 1, so squared diff = 1, mean = 1
assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss);
}
#[test]
fn test_mse_loss_varying_diffs() {
let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap();
let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
// Squared diffs: 1, 4, 9, 16. Mean = 30/4 = 7.5
assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss);
}
#[test]
fn test_mse_gradient_shape() {
let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap();
let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_mse_gradient_direction() {
let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
// grad = 2*(pred - target)/n = 2*(-1, 1)/2 = (-1, 1)
assert!(grad[[0, 0]] < 0.0, "Gradient should be negative when pred < target");
assert!(grad[[0, 1]] > 0.0, "Gradient should be positive when pred > target");
}
#[test]
fn test_mse_gradient_zero_when_equal() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let target = pred.clone();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
assert!(grad.iter().all(|&x| x.abs() < 1e-6), "Gradient should be zero when pred == target");
}
#[test]
fn test_bce_loss_perfect_predictions() {
let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
// Near-perfect predictions should have low loss
assert!(loss < 0.1, "BCE should be low for good predictions, got {}", loss);
}
#[test]
fn test_bce_loss_bad_predictions() {
let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
// Bad predictions should have high loss
assert!(loss > 1.0, "BCE should be high for bad predictions, got {}", loss);
}
#[test]
fn test_bce_loss_numerical_stability() {
// Test with extreme values that could cause numerical issues
let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert!(loss.is_finite(), "BCE should be finite even with extreme values");
}
#[test]
fn test_bce_gradient_shape() {
let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap();
let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap();
let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_bce_gradient_direction() {
let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
// When target=1 and pred<1, gradient should push pred up (negative gradient)
assert!(grad[[0, 0]] < 0.0, "Gradient should be negative to increase pred towards 1");
// When target=0 and pred>0, gradient should push pred down (positive gradient)
assert!(grad[[0, 1]] > 0.0, "Gradient should be positive to decrease pred towards 0");
}
#[test]
fn test_cross_entropy_one_hot() {
// Softmax-like predictions (sum to 1 per row)
let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap();
let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
// Good predictions should have reasonable loss
assert!(loss > 0.0 && loss < 1.0, "CE should be reasonable for good predictions, got {}", loss);
}
#[test]
fn test_cross_entropy_wrong_class() {
let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap();
let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
// Predicting wrong class should have high loss
assert!(loss > 1.0, "CE should be high for wrong predictions, got {}", loss);
}
#[test]
fn test_cross_entropy_gradient_shape() {
let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap();
let target = Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_loss_dimension_mismatch_error() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let result = Loss::compute(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Should error on dimension mismatch");
let result = Loss::gradient(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Gradient should error on dimension mismatch");
}
#[test]
fn test_loss_empty_array_error() {
let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap();
let target = Array2::from_shape_vec((0, 2), vec![]).unwrap();
let result = Loss::compute(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Should error on empty arrays");
let result = Loss::gradient(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Gradient should error on empty arrays");
}
#[test]
fn test_loss_gradient_numerical_check() {
// Numerical gradient check for MSE
let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
// Compute numerical gradient
let eps = 1e-5;
for i in 0..2 {
let mut pred_plus = pred.clone();
let mut pred_minus = pred.clone();
pred_plus[[0, i]] += eps;
pred_minus[[0, i]] -= eps;
let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap();
let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap();
let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps);
let error = (analytical_grad[[0, i]] - numerical_grad).abs();
assert!(error < 1e-3, "Numerical gradient check failed: analytical={}, numerical={}",
analytical_grad[[0, i]], numerical_grad);
}
}
#[test]
fn test_training_loop_integration() {
// Integration test: use Loss with Optimizer
let mut optimizer = Optimizer::new(OptimizerType::Sgd {
learning_rate: 0.1,
momentum: 0.0,
});
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
// Perform a few optimization steps
for _ in 0..10 {
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
optimizer.step(&mut pred, &grad).unwrap();
}
let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
assert!(final_loss < initial_loss, "Loss should decrease during training");
}
}