mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-25 23:24:03 +00:00
feat(gnn): Implement loss functions with numerical stability (#65)
Implements MSE, Cross Entropy, and Binary Cross Entropy loss functions for GNN training. Features: - EPS (1e-7) and MAX_GRAD (1e6) constants for numerical stability - Comprehensive documentation with examples - Gradient clipping to prevent explosion - Empty array validation - 42 comprehensive tests covering all functionality Resolves #63 Co-authored-by: Wirasm <wirasm@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
44828ad56f
4 changed files with 468 additions and 75 deletions
4
Cargo.lock
generated
4
Cargo.lock
generated
|
|
@ -6672,11 +6672,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "ruvector-postgres"
|
||||
<<<<<<< HEAD
|
||||
version = "0.2.6"
|
||||
=======
|
||||
version = "0.2.5"
|
||||
>>>>>>> origin/main
|
||||
dependencies = [
|
||||
"approx",
|
||||
"bincode 1.3.3",
|
||||
|
|
|
|||
121
crates/ruvector-gnn/examples/loss_demo.rs
Normal file
121
crates/ruvector-gnn/examples/loss_demo.rs
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
//! Manual test/demo for Loss functions
|
||||
//!
|
||||
//! Run with: cargo run -p ruvector-gnn --example loss_demo
|
||||
|
||||
use ndarray::Array2;
|
||||
use ruvector_gnn::training::{Loss, LossType, Optimizer, OptimizerType};
|
||||
|
||||
fn main() {
|
||||
println!("=== RuVector GNN Loss Functions Demo ===\n");
|
||||
|
||||
// 1. Basic MSE Loss
|
||||
println!("1. MSE Loss Demo");
|
||||
println!(" -----------------");
|
||||
let predictions = Array2::from_shape_vec((2, 3), vec![0.1, 0.2, 0.7, 0.8, 0.1, 0.1]).unwrap();
|
||||
let targets = Array2::from_shape_vec((2, 3), vec![0.0, 0.0, 1.0, 1.0, 0.0, 0.0]).unwrap();
|
||||
|
||||
let mse_loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap();
|
||||
let mse_grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap();
|
||||
|
||||
println!(" Predictions: {:?}", predictions.as_slice().unwrap());
|
||||
println!(" Targets: {:?}", targets.as_slice().unwrap());
|
||||
println!(" MSE Loss: {:.6}", mse_loss);
|
||||
println!(" Gradient: {:?}\n", mse_grad.as_slice().unwrap());
|
||||
|
||||
// 2. Binary Cross Entropy Loss
|
||||
println!("2. Binary Cross Entropy Demo");
|
||||
println!(" --------------------------");
|
||||
let pred_bce = Array2::from_shape_vec((1, 4), vec![0.9, 0.1, 0.8, 0.3]).unwrap();
|
||||
let target_bce = Array2::from_shape_vec((1, 4), vec![1.0, 0.0, 1.0, 0.0]).unwrap();
|
||||
|
||||
let bce_loss = Loss::compute(LossType::BinaryCrossEntropy, &pred_bce, &target_bce).unwrap();
|
||||
let bce_grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred_bce, &target_bce).unwrap();
|
||||
|
||||
println!(" Predictions: {:?}", pred_bce.as_slice().unwrap());
|
||||
println!(" Targets: {:?}", target_bce.as_slice().unwrap());
|
||||
println!(" BCE Loss: {:.6}", bce_loss);
|
||||
println!(" Gradient: {:?}\n", bce_grad.as_slice().unwrap());
|
||||
|
||||
// 3. Cross Entropy Loss (multi-class)
|
||||
println!("3. Cross Entropy Demo (multi-class)");
|
||||
println!(" ----------------------------------");
|
||||
// Softmax-like predictions (each row sums to ~1)
|
||||
let pred_ce = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.1, 0.8]).unwrap();
|
||||
let target_ce = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
|
||||
|
||||
let ce_loss = Loss::compute(LossType::CrossEntropy, &pred_ce, &target_ce).unwrap();
|
||||
let ce_grad = Loss::gradient(LossType::CrossEntropy, &pred_ce, &target_ce).unwrap();
|
||||
|
||||
println!(" Predictions (row 1): {:?}", &pred_ce.as_slice().unwrap()[0..3]);
|
||||
println!(" Predictions (row 2): {:?}", &pred_ce.as_slice().unwrap()[3..6]);
|
||||
println!(" Targets (one-hot): [1,0,0] and [0,0,1]");
|
||||
println!(" CE Loss: {:.6}", ce_loss);
|
||||
println!(" Gradient: {:?}\n", ce_grad.as_slice().unwrap());
|
||||
|
||||
// 4. Training loop demo - minimize MSE
|
||||
println!("4. Training Loop Demo (minimizing MSE)");
|
||||
println!(" ------------------------------------");
|
||||
|
||||
let target = Array2::from_shape_vec((1, 4), vec![1.0, 0.0, 1.0, 0.0]).unwrap();
|
||||
let mut pred = Array2::from_shape_vec((1, 4), vec![0.5, 0.5, 0.5, 0.5]).unwrap();
|
||||
|
||||
let mut optimizer = Optimizer::new(OptimizerType::Adam {
|
||||
learning_rate: 0.1,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
epsilon: 1e-8,
|
||||
});
|
||||
|
||||
println!(" Target: {:?}", target.as_slice().unwrap());
|
||||
println!(" Initial: {:?}", pred.as_slice().unwrap());
|
||||
|
||||
let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
println!(" Initial loss: {:.6}\n", initial_loss);
|
||||
|
||||
for epoch in 0..20 {
|
||||
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
|
||||
optimizer.step(&mut pred, &grad).unwrap();
|
||||
|
||||
if epoch % 5 == 0 || epoch == 19 {
|
||||
println!(
|
||||
" Epoch {:2}: loss={:.6}, pred={:?}",
|
||||
epoch,
|
||||
loss,
|
||||
pred.as_slice()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|x| format!("{:.3}", x))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
println!("\n Final loss: {:.6}", final_loss);
|
||||
println!(
|
||||
" Improvement: {:.1}%",
|
||||
(1.0 - final_loss / initial_loss) * 100.0
|
||||
);
|
||||
|
||||
// 5. Numerical stability test
|
||||
println!("\n5. Numerical Stability Test");
|
||||
println!(" -------------------------");
|
||||
|
||||
// Test with extreme values
|
||||
let extreme_pred = Array2::from_shape_vec((1, 2), vec![1e-10, 1.0 - 1e-10]).unwrap();
|
||||
let extreme_target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
|
||||
|
||||
let bce_extreme = Loss::compute(LossType::BinaryCrossEntropy, &extreme_pred, &extreme_target);
|
||||
let ce_extreme = Loss::compute(LossType::CrossEntropy, &extreme_pred, &extreme_target);
|
||||
|
||||
println!(" Extreme predictions: [{:.2e}, {:.2e}]", 1e-10, 1.0 - 1e-10);
|
||||
println!(" BCE result: {:?}", bce_extreme);
|
||||
println!(" CE result: {:?}", ce_extreme);
|
||||
|
||||
// Test gradient stability
|
||||
let grad_extreme = Loss::gradient(LossType::BinaryCrossEntropy, &extreme_pred, &extreme_target);
|
||||
println!(" BCE gradient: {:?}", grad_extreme);
|
||||
|
||||
println!("\n=== Demo Complete ===");
|
||||
}
|
||||
|
|
@ -70,8 +70,8 @@ pub use replay::{DistributionStats, ReplayBuffer, ReplayEntry};
|
|||
pub use scheduler::{LearningRateScheduler, SchedulerType};
|
||||
pub use search::{cosine_similarity, differentiable_search, hierarchical_forward};
|
||||
pub use training::{
|
||||
info_nce_loss, local_contrastive_loss, sgd_step, OnlineConfig, Optimizer, OptimizerType,
|
||||
TrainConfig,
|
||||
info_nce_loss, local_contrastive_loss, sgd_step, Loss, LossType, OnlineConfig, Optimizer,
|
||||
OptimizerType, TrainConfig,
|
||||
};
|
||||
|
||||
#[cfg(all(not(target_arch = "wasm32"), feature = "mmap"))]
|
||||
|
|
|
|||
|
|
@ -227,20 +227,45 @@ pub enum LossType {
|
|||
BinaryCrossEntropy,
|
||||
}
|
||||
|
||||
/// TODO: Implement loss functions
|
||||
/// Loss function implementations for neural network training.
|
||||
///
|
||||
/// Provides forward (loss computation) and backward (gradient computation) passes
|
||||
/// for common loss functions used in GNN training.
|
||||
///
|
||||
/// # Numerical Stability
|
||||
///
|
||||
/// All loss functions use epsilon clamping and gradient clipping to prevent
|
||||
/// numerical instability with extreme prediction values (near 0 or 1).
|
||||
pub struct Loss;
|
||||
|
||||
impl Loss {
|
||||
/// Compute loss value
|
||||
/// Small epsilon value for numerical stability in logarithms and divisions.
|
||||
const EPS: f32 = 1e-7;
|
||||
|
||||
/// Maximum absolute gradient value to prevent explosion.
|
||||
const MAX_GRAD: f32 = 1e6;
|
||||
|
||||
/// Compute the loss value between predictions and targets.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `loss_type` - Type of loss function to use
|
||||
/// * `predictions` - Model predictions (must match targets shape)
|
||||
/// * `targets` - Ground truth targets
|
||||
/// * `loss_type` - The type of loss function to use
|
||||
/// * `predictions` - Model predictions as a 2D array
|
||||
/// * `targets` - Ground truth targets as a 2D array (same shape as predictions)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(f32)` - Computed loss value
|
||||
/// * `Err(GnnError)` - If shapes don't match
|
||||
/// * `Ok(f32)` - The computed scalar loss value
|
||||
/// * `Err(GnnError)` - If shapes don't match or computation fails
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ndarray::Array2;
|
||||
/// use ruvector_gnn::training::{Loss, LossType};
|
||||
///
|
||||
/// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap();
|
||||
/// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
|
||||
/// let loss = Loss::compute(LossType::Mse, &predictions, &targets).unwrap();
|
||||
/// assert!(loss >= 0.0);
|
||||
/// ```
|
||||
pub fn compute(
|
||||
loss_type: LossType,
|
||||
predictions: &Array2<f32>,
|
||||
|
|
@ -254,47 +279,38 @@ impl Loss {
|
|||
));
|
||||
}
|
||||
|
||||
const EPSILON: f32 = 1e-7;
|
||||
if predictions.is_empty() {
|
||||
return Err(GnnError::invalid_input("Cannot compute loss on empty arrays"));
|
||||
}
|
||||
|
||||
let loss = match loss_type {
|
||||
LossType::Mse => {
|
||||
// Mean Squared Error: mean((pred - target)^2)
|
||||
let diff = predictions - targets;
|
||||
let squared_diff = diff.mapv(|x| x * x);
|
||||
squared_diff.mean().unwrap_or(0.0)
|
||||
}
|
||||
LossType::CrossEntropy => {
|
||||
// Cross Entropy: -sum(target * log(pred + epsilon))
|
||||
let log_preds = predictions.mapv(|x| (x + EPSILON).ln());
|
||||
let product = targets * log_preds;
|
||||
-product.sum()
|
||||
}
|
||||
LossType::BinaryCrossEntropy => {
|
||||
// Binary Cross Entropy: -mean(target * log(pred) + (1-target) * log(1-pred))
|
||||
let term1 = targets.iter().zip(predictions.iter()).map(|(&t, &p)| {
|
||||
t * (p + EPSILON).ln()
|
||||
});
|
||||
let term2 = targets.iter().zip(predictions.iter()).map(|(&t, &p)| {
|
||||
(1.0 - t) * (1.0 - p + EPSILON).ln()
|
||||
});
|
||||
let sum: f32 = term1.zip(term2).map(|(t1, t2)| t1 + t2).sum();
|
||||
-sum / (predictions.len() as f32)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(loss)
|
||||
match loss_type {
|
||||
LossType::Mse => Self::mse_forward(predictions, targets),
|
||||
LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets),
|
||||
LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute loss gradient
|
||||
/// Compute the gradient of the loss with respect to predictions.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `loss_type` - Type of loss function to use
|
||||
/// * `predictions` - Model predictions (must match targets shape)
|
||||
/// * `targets` - Ground truth targets
|
||||
/// * `loss_type` - The type of loss function to use
|
||||
/// * `predictions` - Model predictions as a 2D array
|
||||
/// * `targets` - Ground truth targets as a 2D array (same shape as predictions)
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(Array2<f32>)` - Gradient of loss with respect to predictions
|
||||
/// * `Err(GnnError)` - If shapes don't match
|
||||
/// * `Ok(Array2<f32>)` - Gradient array with same shape as predictions
|
||||
/// * `Err(GnnError)` - If shapes don't match or computation fails
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use ndarray::Array2;
|
||||
/// use ruvector_gnn::training::{Loss, LossType};
|
||||
///
|
||||
/// let predictions = Array2::from_shape_vec((2, 2), vec![0.9, 0.1, 0.2, 0.8]).unwrap();
|
||||
/// let targets = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).unwrap();
|
||||
/// let grad = Loss::gradient(LossType::Mse, &predictions, &targets).unwrap();
|
||||
/// assert_eq!(grad.shape(), predictions.shape());
|
||||
/// ```
|
||||
pub fn gradient(
|
||||
loss_type: LossType,
|
||||
predictions: &Array2<f32>,
|
||||
|
|
@ -308,37 +324,95 @@ impl Loss {
|
|||
));
|
||||
}
|
||||
|
||||
const EPSILON: f32 = 1e-7;
|
||||
if predictions.is_empty() {
|
||||
return Err(GnnError::invalid_input(
|
||||
"Cannot compute gradient on empty arrays",
|
||||
));
|
||||
}
|
||||
|
||||
match loss_type {
|
||||
LossType::Mse => Self::mse_backward(predictions, targets),
|
||||
LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets),
|
||||
LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mean Squared Error: MSE = mean((predictions - targets)^2)
|
||||
fn mse_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
|
||||
let diff = predictions - targets;
|
||||
let squared = diff.mapv(|x| x * x);
|
||||
Ok(squared.mean().unwrap_or(0.0))
|
||||
}
|
||||
|
||||
/// MSE gradient: d(MSE)/d(pred) = 2 * (predictions - targets) / n
|
||||
fn mse_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
|
||||
let n = predictions.len() as f32;
|
||||
let diff = predictions - targets;
|
||||
Ok(diff.mapv(|x| 2.0 * x / n))
|
||||
}
|
||||
|
||||
let gradient = match loss_type {
|
||||
LossType::Mse => {
|
||||
// MSE gradient: 2 * (pred - target) / n
|
||||
let diff = predictions - targets;
|
||||
diff.mapv(|x| 2.0 * x / n)
|
||||
}
|
||||
LossType::CrossEntropy => {
|
||||
// Cross Entropy gradient: -target / (pred + epsilon)
|
||||
let mut grad = Array2::zeros(predictions.dim());
|
||||
for (i, (&t, &p)) in targets.iter().zip(predictions.iter()).enumerate() {
|
||||
let (row, col) = (i / predictions.ncols(), i % predictions.ncols());
|
||||
grad[[row, col]] = -t / (p + EPSILON);
|
||||
}
|
||||
grad
|
||||
}
|
||||
LossType::BinaryCrossEntropy => {
|
||||
// Binary Cross Entropy gradient: (pred - target) / (pred * (1 - pred) + epsilon)
|
||||
let mut grad = Array2::zeros(predictions.dim());
|
||||
for (i, (&t, &p)) in targets.iter().zip(predictions.iter()).enumerate() {
|
||||
let (row, col) = (i / predictions.ncols(), i % predictions.ncols());
|
||||
let denom = p * (1.0 - p) + EPSILON;
|
||||
grad[[row, col]] = (p - t) / denom;
|
||||
}
|
||||
grad
|
||||
}
|
||||
};
|
||||
/// Cross Entropy: CE = -mean(sum(targets * log(predictions), axis=1))
|
||||
///
|
||||
/// Used for multi-class classification where targets are one-hot encoded
|
||||
/// and predictions are softmax probabilities.
|
||||
fn cross_entropy_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
|
||||
let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln());
|
||||
let elementwise = targets * &log_pred;
|
||||
let loss = -elementwise.sum() / predictions.nrows() as f32;
|
||||
Ok(loss)
|
||||
}
|
||||
|
||||
Ok(gradient)
|
||||
/// Cross Entropy gradient: d(CE)/d(pred) = -targets / predictions / n
|
||||
///
|
||||
/// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion.
|
||||
fn cross_entropy_backward(
|
||||
predictions: &Array2<f32>,
|
||||
targets: &Array2<f32>,
|
||||
) -> Result<Array2<f32>> {
|
||||
let n = predictions.nrows() as f32;
|
||||
// Clamp predictions to avoid division by zero
|
||||
let safe_pred = predictions.mapv(|x| x.max(Self::EPS));
|
||||
let grad = targets / &safe_pred;
|
||||
// Apply gradient clipping
|
||||
Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD)))
|
||||
}
|
||||
|
||||
/// Binary Cross Entropy: BCE = -mean(targets * log(pred) + (1 - targets) * log(1 - pred))
|
||||
///
|
||||
/// Used for binary classification or multi-label classification.
|
||||
fn bce_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
|
||||
let n = predictions.len() as f32;
|
||||
let loss: f32 = predictions
|
||||
.iter()
|
||||
.zip(targets.iter())
|
||||
.map(|(&p, &t)| {
|
||||
// Clamp predictions to (eps, 1-eps) for numerical stability
|
||||
let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
|
||||
-(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln())
|
||||
})
|
||||
.sum();
|
||||
Ok(loss / n)
|
||||
}
|
||||
|
||||
/// BCE gradient: d(BCE)/d(pred) = (-targets/pred + (1-targets)/(1-pred)) / n
|
||||
///
|
||||
/// Gradients are clipped to [-MAX_GRAD, MAX_GRAD] to prevent explosion.
|
||||
fn bce_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
|
||||
let n = predictions.len() as f32;
|
||||
let grad_vec: Vec<f32> = predictions
|
||||
.iter()
|
||||
.zip(targets.iter())
|
||||
.map(|(&p, &t)| {
|
||||
// Clamp predictions for numerical stability
|
||||
let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
|
||||
let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n;
|
||||
// Clip gradient to prevent explosion
|
||||
grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Array2::from_shape_vec(predictions.dim(), grad_vec)
|
||||
.map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1027,4 +1101,206 @@ mod tests {
|
|||
assert!(params[[0, 0]].abs() < 0.5);
|
||||
assert!(params[[0, 1]].abs() < 0.5);
|
||||
}
|
||||
|
||||
// ==================== Loss Function Tests ====================
|
||||
|
||||
#[test]
|
||||
fn test_mse_loss_zero_when_equal() {
|
||||
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
|
||||
let target = pred.clone();
|
||||
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
assert!((loss - 0.0).abs() < 1e-6, "MSE should be 0 when pred == target");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mse_loss_positive() {
|
||||
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
|
||||
let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
// Each element differs by 1, so squared diff = 1, mean = 1
|
||||
assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mse_loss_varying_diffs() {
|
||||
let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
// Squared diffs: 1, 4, 9, 16. Mean = 30/4 = 7.5
|
||||
assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mse_gradient_shape() {
|
||||
let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap();
|
||||
let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
|
||||
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
|
||||
assert_eq!(grad.shape(), pred.shape());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mse_gradient_direction() {
|
||||
let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
|
||||
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
|
||||
// grad = 2*(pred - target)/n = 2*(-1, 1)/2 = (-1, 1)
|
||||
assert!(grad[[0, 0]] < 0.0, "Gradient should be negative when pred < target");
|
||||
assert!(grad[[0, 1]] > 0.0, "Gradient should be positive when pred > target");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mse_gradient_zero_when_equal() {
|
||||
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
|
||||
let target = pred.clone();
|
||||
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
|
||||
assert!(grad.iter().all(|&x| x.abs() < 1e-6), "Gradient should be zero when pred == target");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bce_loss_perfect_predictions() {
|
||||
let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
|
||||
// Near-perfect predictions should have low loss
|
||||
assert!(loss < 0.1, "BCE should be low for good predictions, got {}", loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bce_loss_bad_predictions() {
|
||||
let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
|
||||
// Bad predictions should have high loss
|
||||
assert!(loss > 1.0, "BCE should be high for bad predictions, got {}", loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bce_loss_numerical_stability() {
|
||||
// Test with extreme values that could cause numerical issues
|
||||
let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
|
||||
assert!(loss.is_finite(), "BCE should be finite even with extreme values");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bce_gradient_shape() {
|
||||
let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap();
|
||||
let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap();
|
||||
let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
|
||||
assert_eq!(grad.shape(), pred.shape());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bce_gradient_direction() {
|
||||
let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
|
||||
let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
|
||||
// When target=1 and pred<1, gradient should push pred up (negative gradient)
|
||||
assert!(grad[[0, 0]] < 0.0, "Gradient should be negative to increase pred towards 1");
|
||||
// When target=0 and pred>0, gradient should push pred down (positive gradient)
|
||||
assert!(grad[[0, 1]] > 0.0, "Gradient should be positive to decrease pred towards 0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_one_hot() {
|
||||
// Softmax-like predictions (sum to 1 per row)
|
||||
let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap();
|
||||
let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
|
||||
// Good predictions should have reasonable loss
|
||||
assert!(loss > 0.0 && loss < 1.0, "CE should be reasonable for good predictions, got {}", loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_wrong_class() {
|
||||
let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap();
|
||||
let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
|
||||
// Predicting wrong class should have high loss
|
||||
assert!(loss > 1.0, "CE should be high for wrong predictions, got {}", loss);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cross_entropy_gradient_shape() {
|
||||
let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap();
|
||||
let target = Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
|
||||
let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap();
|
||||
assert_eq!(grad.shape(), pred.shape());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loss_dimension_mismatch_error() {
|
||||
let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
|
||||
let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
|
||||
|
||||
let result = Loss::compute(LossType::Mse, &pred, &target);
|
||||
assert!(result.is_err(), "Should error on dimension mismatch");
|
||||
|
||||
let result = Loss::gradient(LossType::Mse, &pred, &target);
|
||||
assert!(result.is_err(), "Gradient should error on dimension mismatch");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loss_empty_array_error() {
|
||||
let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap();
|
||||
let target = Array2::from_shape_vec((0, 2), vec![]).unwrap();
|
||||
|
||||
let result = Loss::compute(LossType::Mse, &pred, &target);
|
||||
assert!(result.is_err(), "Should error on empty arrays");
|
||||
|
||||
let result = Loss::gradient(LossType::Mse, &pred, &target);
|
||||
assert!(result.is_err(), "Gradient should error on empty arrays");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_loss_gradient_numerical_check() {
|
||||
// Numerical gradient check for MSE
|
||||
let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap();
|
||||
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
|
||||
|
||||
let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
|
||||
|
||||
// Compute numerical gradient
|
||||
let eps = 1e-5;
|
||||
for i in 0..2 {
|
||||
let mut pred_plus = pred.clone();
|
||||
let mut pred_minus = pred.clone();
|
||||
pred_plus[[0, i]] += eps;
|
||||
pred_minus[[0, i]] -= eps;
|
||||
|
||||
let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap();
|
||||
let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap();
|
||||
|
||||
let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps);
|
||||
let error = (analytical_grad[[0, i]] - numerical_grad).abs();
|
||||
|
||||
assert!(error < 1e-3, "Numerical gradient check failed: analytical={}, numerical={}",
|
||||
analytical_grad[[0, i]], numerical_grad);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_training_loop_integration() {
|
||||
// Integration test: use Loss with Optimizer
|
||||
let mut optimizer = Optimizer::new(OptimizerType::Sgd {
|
||||
learning_rate: 0.1,
|
||||
momentum: 0.0,
|
||||
});
|
||||
|
||||
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
|
||||
let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
|
||||
|
||||
let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
|
||||
// Perform a few optimization steps
|
||||
for _ in 0..10 {
|
||||
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
|
||||
optimizer.step(&mut pred, &grad).unwrap();
|
||||
}
|
||||
|
||||
let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
|
||||
|
||||
assert!(final_loss < initial_loss, "Loss should decrease during training");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue