diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b60c12..31d139a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.0] - 2025-09-04 + +### Added +- **Early Stopping**: Complete early stopping implementation for all trainers + - **Configurable Patience**: Stop training after N epochs without improvement + - **Multiple Metrics**: Monitor validation loss or training loss + - **Best Weight Restoration**: Automatically restore best weights when stopping + - **Flexible Configuration**: Customizable min_delta threshold and monitoring options + - **Universal Support**: Works with LSTMTrainer, ScheduledLSTMTrainer, and LSTMBatchTrainer + +- **Enhanced Training Features**: + - **Visual Best Epoch Indicators**: "*" markers in training logs for best epochs + - **Automatic Overfitting Prevention**: Stop training before performance degrades + - **Comprehensive Logging**: Detailed early stopping trigger information + - **Weight Management**: Optional best weight restoration for optimal model recovery + +- **Examples and Documentation**: + - **Complete Early Stopping Example**: Demonstration of all early stopping configurations + - **Multiple Scenarios**: Validation loss, training loss, and custom patience examples + - **Integration Examples**: Shows usage with different trainer types + - **Comprehensive Testing**: Full test suite for early stopping functionality + +### Enhanced +- **All Trainers**: Added early stopping support to LSTMTrainer, ScheduledLSTMTrainer, and LSTMBatchTrainer +- **Training Configuration**: Extended TrainingConfig with optional early stopping settings +- **Training Loops**: Enhanced all training loops with early stopping logic and best epoch tracking + ## [0.5.0] - 2025-08-14 ### Added diff --git a/Cargo.toml b/Cargo.toml index b12b214..570a1d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,3 +88,7 @@ path = "examples/multi_layer_lstm.rs" [[example]] name = "batch_processing_example" path = "examples/batch_processing_example.rs" + +[[example]] +name = "early_stopping_example" +path = "examples/early_stopping_example.rs" diff --git a/README.md b/README.md index 8380e78..dc26443 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,10 @@ graph TD - **Complete Training System** with backpropagation through time (BPTT) - **Multiple Optimizers**: SGD, Adam, RMSprop with comprehensive learning rate scheduling - **Advanced Learning Rate Scheduling**: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial +- **Early Stopping**: Prevent overfitting with configurable patience and metric monitoring - **Loss Functions**: MSE, MAE, Cross-entropy with softmax - **Advanced Dropout**: Input, recurrent, output dropout, variational dropout, and zoneout +- **Batch Processing**: 4-5x training speedup with efficient batch operations - **Schedule Visualization**: ASCII visualization of learning rate schedules - **Model Persistence**: Save/load models in JSON or binary format - **Peephole LSTM variant** for enhanced performance @@ -98,6 +100,39 @@ fn main() { } ``` +### Early Stopping + +```rust +use rust_lstm::{ + LSTMNetwork, create_basic_trainer, TrainingConfig, + EarlyStoppingConfig, EarlyStoppingMetric +}; + +fn main() { + let network = LSTMNetwork::new(1, 10, 2); + + // Configure early stopping + let early_stopping = EarlyStoppingConfig { + patience: 10, // Stop after 10 epochs with no improvement + min_delta: 1e-4, // Minimum improvement threshold + restore_best_weights: true, // Restore best weights when stopping + monitor: EarlyStoppingMetric::ValidationLoss, // Monitor validation loss + }; + + let config = TrainingConfig { + epochs: 1000, // Will likely stop early + early_stopping: Some(early_stopping), + ..Default::default() + }; + + let mut trainer = create_basic_trainer(network, 0.001) + .with_config(config); + + // Training will stop early if validation loss stops improving + trainer.train(&train_data, Some(&validation_data)); +} +``` + ### Bidirectional LSTM ```rust @@ -258,6 +293,10 @@ cargo run --example dropout_example # Dropout demo # Learning and scheduling cargo run --example learning_rate_scheduling # Basic schedulers cargo run --example advanced_lr_scheduling # Advanced schedulers with visualization +cargo run --example early_stopping_example # Early stopping demonstration + +# Performance and batch processing +cargo run --example batch_processing_example # Batch processing with performance benchmarks # Real-world applications cargo run --example stock_prediction diff --git a/examples/advanced_lr_scheduling.rs b/examples/advanced_lr_scheduling.rs index 3727815..df3841f 100644 --- a/examples/advanced_lr_scheduling.rs +++ b/examples/advanced_lr_scheduling.rs @@ -50,6 +50,7 @@ fn polynomial_decay_example(train_data: &[(Vec>, Vec>)], print_every: 5, clip_gradient: Some(1.0), log_lr_changes: true, + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) @@ -83,6 +84,7 @@ fn cyclical_lr_examples(train_data: &[(Vec>, Vec>)], print_every: 5, clip_gradient: Some(1.0), log_lr_changes: false, // Too frequent for cyclical + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) @@ -108,6 +110,7 @@ fn cyclical_lr_examples(train_data: &[(Vec>, Vec>)], print_every: 5, clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) @@ -134,6 +137,7 @@ fn cyclical_lr_examples(train_data: &[(Vec>, Vec>)], print_every: 5, clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) @@ -172,6 +176,7 @@ fn warmup_scheduler_example(train_data: &[(Vec>, Vec>)], print_every: 3, clip_gradient: Some(1.0), log_lr_changes: true, + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) @@ -236,6 +241,7 @@ fn advanced_training_example(train_data: &[(Vec>, Vec>)] print_every: 5, clip_gradient: Some(1.0), // Gradient clipping log_lr_changes: false, // Too frequent for cyclical + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) diff --git a/examples/dropout_example.rs b/examples/dropout_example.rs index 7edce26..1c9d498 100644 --- a/examples/dropout_example.rs +++ b/examples/dropout_example.rs @@ -194,6 +194,7 @@ fn demonstrate_training_with_dropout() { print_every: 5, clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; trainer = trainer.with_config(config); diff --git a/examples/early_stopping_example.rs b/examples/early_stopping_example.rs new file mode 100644 index 0000000..781d900 --- /dev/null +++ b/examples/early_stopping_example.rs @@ -0,0 +1,227 @@ +use ndarray::{Array2, arr2}; +use rust_lstm::{ + LSTMNetwork, create_basic_trainer, TrainingConfig, EarlyStoppingConfig, EarlyStoppingMetric, + MSELoss, Adam +}; + +fn main() { + println!("Early Stopping Demonstration"); + println!("================================\n"); + + // Generate synthetic data that will overfit quickly + let (train_data, val_data) = generate_overfitting_data(); + + println!("Generated {} training sequences and {} validation sequences", + train_data.len(), val_data.len()); + + // Demonstrate different early stopping configurations + demonstrate_validation_early_stopping(&train_data, &val_data); + demonstrate_train_loss_early_stopping(&train_data, &val_data); + demonstrate_no_weight_restoration(&train_data, &val_data); + demonstrate_custom_patience(&train_data, &val_data); +} + +/// Demonstrate early stopping based on validation loss (most common) +fn demonstrate_validation_early_stopping( + train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)] +) { + println!("1. VALIDATION LOSS EARLY STOPPING"); + println!("=================================="); + + let network = LSTMNetwork::new(1, 8, 1); + + // Configure early stopping with default settings (validation loss monitoring) + let early_stopping_config = EarlyStoppingConfig { + patience: 5, + min_delta: 1e-4, + restore_best_weights: true, + monitor: EarlyStoppingMetric::ValidationLoss, + }; + + let training_config = TrainingConfig { + epochs: 100, // Will likely stop early + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + println!("Training with validation loss monitoring (patience=5)..."); + trainer.train(train_data, Some(val_data)); + + // Show final metrics + if let Some(final_metrics) = trainer.get_latest_metrics() { + println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n", + final_metrics.epoch, + final_metrics.train_loss, + final_metrics.validation_loss.unwrap_or(0.0)); + } +} + +/// Demonstrate early stopping based on training loss +fn demonstrate_train_loss_early_stopping( + train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)] +) { + println!("2. TRAINING LOSS EARLY STOPPING"); + println!("==============================="); + + let network = LSTMNetwork::new(1, 8, 1); + + // Configure early stopping to monitor training loss + let early_stopping_config = EarlyStoppingConfig { + patience: 8, + min_delta: 1e-5, + restore_best_weights: true, + monitor: EarlyStoppingMetric::TrainLoss, + }; + + let training_config = TrainingConfig { + epochs: 100, + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + println!("Training with training loss monitoring (patience=8)..."); + trainer.train(train_data, Some(val_data)); + + if let Some(final_metrics) = trainer.get_latest_metrics() { + println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n", + final_metrics.epoch, + final_metrics.train_loss, + final_metrics.validation_loss.unwrap_or(0.0)); + } +} + +/// Demonstrate early stopping without weight restoration +fn demonstrate_no_weight_restoration( + train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)] +) { + println!("3. EARLY STOPPING WITHOUT WEIGHT RESTORATION"); + println!("============================================="); + + let network = LSTMNetwork::new(1, 8, 1); + + // Configure early stopping without restoring best weights + let early_stopping_config = EarlyStoppingConfig { + patience: 5, + min_delta: 1e-4, + restore_best_weights: false, // Don't restore best weights + monitor: EarlyStoppingMetric::ValidationLoss, + }; + + let training_config = TrainingConfig { + epochs: 100, + print_every: 1, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + println!("Training without weight restoration..."); + trainer.train(train_data, Some(val_data)); + + if let Some(final_metrics) = trainer.get_latest_metrics() { + println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}", + final_metrics.epoch, + final_metrics.train_loss, + final_metrics.validation_loss.unwrap_or(0.0)); + println!("Note: Weights are from the last epoch, not the best epoch\n"); + } +} + +/// Demonstrate early stopping with custom patience +fn demonstrate_custom_patience( + train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)] +) { + println!("4. EARLY STOPPING WITH HIGH PATIENCE"); + println!("===================================="); + + let network = LSTMNetwork::new(1, 8, 1); + + // Configure early stopping with higher patience + let early_stopping_config = EarlyStoppingConfig { + patience: 15, // More patient + min_delta: 1e-6, // Smaller improvement threshold + restore_best_weights: true, + monitor: EarlyStoppingMetric::ValidationLoss, + }; + + let training_config = TrainingConfig { + epochs: 100, + print_every: 2, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + println!("Training with high patience (patience=15)..."); + trainer.train(train_data, Some(val_data)); + + if let Some(final_metrics) = trainer.get_latest_metrics() { + println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n", + final_metrics.epoch, + final_metrics.train_loss, + final_metrics.validation_loss.unwrap_or(0.0)); + } +} + +/// Generate synthetic data that will cause overfitting +/// This creates a simple pattern that's easy to memorize but doesn't generalize well +fn generate_overfitting_data() -> (Vec<(Vec>, Vec>)>, Vec<(Vec>, Vec>)>) { + let mut train_data = Vec::new(); + let mut val_data = Vec::new(); + + // Create training data - simple sine wave with noise + for i in 0..20 { + let mut inputs = Vec::new(); + let mut targets = Vec::new(); + + let phase = i as f64 * 0.1; + for t in 0..10 { + let x = (t as f64 * 0.3 + phase).sin(); + let y = ((t + 1) as f64 * 0.3 + phase).sin(); // Next value + + inputs.push(arr2(&[[x]])); + targets.push(arr2(&[[y]])); + } + + train_data.push((inputs, targets)); + } + + // Create validation data - different phase to test generalization + for i in 0..5 { + let mut inputs = Vec::new(); + let mut targets = Vec::new(); + + let phase = (i as f64 + 100.0) * 0.1; // Different phase + for t in 0..10 { + let x = (t as f64 * 0.3 + phase).sin(); + let y = ((t + 1) as f64 * 0.3 + phase).sin(); + + inputs.push(arr2(&[[x]])); + targets.push(arr2(&[[y]])); + } + + val_data.push((inputs, targets)); + } + + (train_data, val_data) +} diff --git a/examples/learning_rate_scheduling.rs b/examples/learning_rate_scheduling.rs index 515ab2c..bb02337 100644 --- a/examples/learning_rate_scheduling.rs +++ b/examples/learning_rate_scheduling.rs @@ -7,7 +7,7 @@ use rust_lstm::{ }; fn main() { - println!("🚀 Learning Rate Scheduling Examples for Rust-LSTM"); + println!("Learning Rate Scheduling Examples for Rust-LSTM"); println!("==================================================\n"); // Generate sample training data (sine wave prediction) @@ -35,8 +35,8 @@ fn main() { fn step_lr_example(train_data: &[(Vec>, Vec>)], val_data: &[(Vec>, Vec>)]) { - println!("1️⃣ Step Learning Rate Decay Example"); - println!(" Reduces LR by factor of 0.5 every 10 epochs\n"); + println!("Step Learning Rate Decay Example"); + println!("Reduces LR by factor of 0.5 every 10 epochs\n"); let network = LSTMNetwork::new(1, 10, 2) .with_input_dropout(0.1, false) @@ -47,6 +47,7 @@ fn step_lr_example(train_data: &[(Vec>, Vec>)], print_every: 5, clip_gradient: Some(1.0), log_lr_changes: true, + early_stopping: None, }; let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5) @@ -60,8 +61,8 @@ fn step_lr_example(train_data: &[(Vec>, Vec>)], fn one_cycle_example(train_data: &[(Vec>, Vec>)], val_data: &[(Vec>, Vec>)]) { - println!("2️⃣ OneCycle Learning Rate Policy Example"); - println!(" Starts low, ramps up to max, then anneals down\n"); + println!("OneCycle Learning Rate Policy Example"); + println!("Starts low, ramps up to max, then anneals down\n"); let network = LSTMNetwork::new(1, 10, 2); @@ -70,6 +71,7 @@ fn one_cycle_example(train_data: &[(Vec>, Vec>)], print_every: 10, clip_gradient: Some(1.0), log_lr_changes: false, // Too many changes for OneCycle + early_stopping: None, }; let mut trainer = create_one_cycle_trainer(network, 0.1, 50) @@ -83,8 +85,8 @@ fn one_cycle_example(train_data: &[(Vec>, Vec>)], fn cosine_annealing_example(train_data: &[(Vec>, Vec>)], val_data: &[(Vec>, Vec>)]) { - println!("3️⃣ Cosine Annealing Example"); - println!(" Smoothly oscillates LR following cosine curve\n"); + println!("Cosine Annealing Example"); + println!("Smoothly oscillates LR following cosine curve\n"); let network = LSTMNetwork::new(1, 10, 2); @@ -93,6 +95,7 @@ fn cosine_annealing_example(train_data: &[(Vec>, Vec>)], print_every: 8, clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; let mut trainer = create_cosine_annealing_trainer(network, 0.01, 20, 1e-6) @@ -106,8 +109,8 @@ fn cosine_annealing_example(train_data: &[(Vec>, Vec>)], fn exponential_decay_example(train_data: &[(Vec>, Vec>)], val_data: &[(Vec>, Vec>)]) { - println!("4️⃣ Exponential Decay Example"); - println!(" Continuously decays LR by factor of 0.95 each epoch\n"); + println!("Exponential Decay Example"); + println!("Continuously decays LR by factor of 0.95 each epoch\n"); let network = LSTMNetwork::new(1, 10, 2); @@ -123,6 +126,7 @@ fn exponential_decay_example(train_data: &[(Vec>, Vec>)] print_every: 6, clip_gradient: Some(1.0), log_lr_changes: true, + early_stopping: None, }; let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) @@ -136,8 +140,8 @@ fn exponential_decay_example(train_data: &[(Vec>, Vec>)] fn reduce_on_plateau_example(train_data: &[(Vec>, Vec>)], val_data: &[(Vec>, Vec>)]) { - println!("5️⃣ ReduceLROnPlateau Example"); - println!(" Reduces LR when validation loss stops improving\n"); + println!("ReduceLROnPlateau Example"); + println!("Reduces LR when validation loss stops improving\n"); let network = LSTMNetwork::new(1, 10, 2); @@ -151,6 +155,7 @@ fn reduce_on_plateau_example(train_data: &[(Vec>, Vec>)] print_every: 5, clip_gradient: Some(1.0), log_lr_changes: true, + early_stopping: None, }; println!("Training with manual ReduceLROnPlateau stepping..."); @@ -179,8 +184,8 @@ fn reduce_on_plateau_example(train_data: &[(Vec>, Vec>)] fn scheduler_comparison(train_data: &[(Vec>, Vec>)], val_data: &[(Vec>, Vec>)]) { - println!("6️⃣ Scheduler Comparison"); - println!(" Training the same network with different schedulers\n"); + println!("Scheduler Comparison"); + println!("Training the same network with different schedulers\n"); let schedulers = vec![ ("Constant", "constant"), @@ -190,7 +195,7 @@ fn scheduler_comparison(train_data: &[(Vec>, Vec>)], ]; for (name, scheduler_type) in schedulers { - println!("🔄 Testing {} scheduler:", name); + println!("Testing {} scheduler:", name); let network = LSTMNetwork::new(1, 8, 1); // Smaller network for faster comparison @@ -199,6 +204,7 @@ fn scheduler_comparison(train_data: &[(Vec>, Vec>)], print_every: 20, // Only print final result clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; let final_loss = match scheduler_type { @@ -236,7 +242,7 @@ fn scheduler_comparison(train_data: &[(Vec>, Vec>)], println!(" Final validation loss: {:.6}\n", final_loss); } - println!("✅ Comparison complete! Check which scheduler performed best."); + println!("Comparison complete! Check which scheduler performed best."); } fn generate_sine_wave_data(num_sequences: usize, offset: f64) -> Vec<(Vec>, Vec>)> { diff --git a/src/lib.rs b/src/lib.rs index 3795150..b8dbeed 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache} pub use layers::dropout::{Dropout, Zoneout}; pub use training::{ LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics, + EarlyStoppingConfig, EarlyStoppingMetric, EarlyStopper, create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer, create_basic_batch_trainer, create_adam_batch_trainer }; diff --git a/src/persistence.rs b/src/persistence.rs index 51a020e..22a00f5 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -8,7 +8,7 @@ use crate::models::lstm_network::LSTMNetwork; use crate::layers::lstm_cell::LSTMCell; /// Serializable version of Array2 for persistence -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] struct SerializableArray2 { data: Vec, shape: (usize, usize), @@ -31,7 +31,7 @@ impl Into> for SerializableArray2 { } /// Serializable LSTM cell parameters -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct SerializableLSTMCell { w_ih: SerializableArray2, w_hh: SerializableArray2, @@ -70,7 +70,7 @@ impl Into for SerializableLSTMCell { } /// Serializable LSTM network -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct SerializableLSTMNetwork { cells: Vec, input_size: usize, diff --git a/src/training.rs b/src/training.rs index a42fe9f..9d0ffae 100644 --- a/src/training.rs +++ b/src/training.rs @@ -3,6 +3,7 @@ use crate::models::lstm_network::LSTMNetwork; use crate::loss::{LossFunction, MSELoss}; use crate::optimizers::{Optimizer, SGD, ScheduledOptimizer}; use crate::schedulers::LearningRateScheduler; +use crate::persistence::SerializableLSTMNetwork; use std::time::Instant; /// Configuration for training hyperparameters @@ -11,6 +12,38 @@ pub struct TrainingConfig { pub print_every: usize, pub clip_gradient: Option, pub log_lr_changes: bool, + pub early_stopping: Option, +} + +/// Configuration for early stopping +#[derive(Debug, Clone)] +pub struct EarlyStoppingConfig { + /// Number of epochs with no improvement after which training will be stopped + pub patience: usize, + /// Minimum change in the monitored quantity to qualify as an improvement + pub min_delta: f64, + /// Whether to restore the best weights when early stopping triggers + pub restore_best_weights: bool, + /// Metric to monitor for early stopping ('val_loss' or 'train_loss') + pub monitor: EarlyStoppingMetric, +} + +/// Metric to monitor for early stopping +#[derive(Debug, Clone, PartialEq)] +pub enum EarlyStoppingMetric { + ValidationLoss, + TrainLoss, +} + +impl Default for EarlyStoppingConfig { + fn default() -> Self { + EarlyStoppingConfig { + patience: 10, + min_delta: 1e-4, + restore_best_weights: true, + monitor: EarlyStoppingMetric::ValidationLoss, + } + } } impl Default for TrainingConfig { @@ -20,6 +53,7 @@ impl Default for TrainingConfig { print_every: 10, clip_gradient: Some(5.0), log_lr_changes: true, + early_stopping: None, } } } @@ -34,6 +68,88 @@ pub struct TrainingMetrics { pub learning_rate: f64, } +/// Early stopping state tracker +#[derive(Debug, Clone)] +pub struct EarlyStopper { + config: EarlyStoppingConfig, + best_score: f64, + wait_count: usize, + stopped_epoch: Option, + best_weights: Option, // Serialized network weights +} + +impl EarlyStopper { + pub fn new(config: EarlyStoppingConfig) -> Self { + EarlyStopper { + config, + best_score: f64::INFINITY, + wait_count: 0, + stopped_epoch: None, + best_weights: None, + } + } + + /// Check if training should stop based on current metrics + /// Returns (should_stop, is_best_score) + pub fn should_stop(&mut self, current_metrics: &TrainingMetrics, network: &LSTMNetwork) -> (bool, bool) { + let current_score = match self.config.monitor { + EarlyStoppingMetric::ValidationLoss => { + match current_metrics.validation_loss { + Some(val_loss) => val_loss, + None => { + // If validation loss is not available, fall back to train loss + current_metrics.train_loss + } + } + } + EarlyStoppingMetric::TrainLoss => current_metrics.train_loss, + }; + + let is_improvement = current_score < self.best_score - self.config.min_delta; + + if is_improvement { + self.best_score = current_score; + self.wait_count = 0; + + // Save best weights if restore_best_weights is enabled + if self.config.restore_best_weights { + self.best_weights = Some(network.into()); + } + + (false, true) + } else { + self.wait_count += 1; + + if self.wait_count >= self.config.patience { + self.stopped_epoch = Some(current_metrics.epoch); + (true, false) + } else { + (false, false) + } + } + } + + /// Get the epoch where training was stopped + pub fn stopped_epoch(&self) -> Option { + self.stopped_epoch + } + + /// Get the best score achieved + pub fn best_score(&self) -> f64 { + self.best_score + } + + /// Restore the best weights to the network if available + pub fn restore_best_weights(&self, network: &mut LSTMNetwork) -> Result<(), String> { + if let Some(ref weights) = self.best_weights { + *network = weights.clone().into(); + Ok(()) + } else { + Err("No best weights available to restore".to_string()) + } + } +} + /// Main trainer for LSTM networks with configurable loss and optimizer pub struct LSTMTrainer { pub network: LSTMNetwork, @@ -41,6 +157,7 @@ pub struct LSTMTrainer { pub optimizer: O, pub config: TrainingConfig, pub metrics_history: Vec, + early_stopper: Option, } impl LSTMTrainer { @@ -51,10 +168,15 @@ impl LSTMTrainer { optimizer, config: TrainingConfig::default(), metrics_history: Vec::new(), + early_stopper: None, } } pub fn with_config(mut self, config: TrainingConfig) -> Self { + // Initialize early stopper if early stopping is configured + self.early_stopper = config.early_stopping.as_ref().map(|es_config| { + EarlyStopper::new(es_config.clone()) + }); self.config = config; self } @@ -136,14 +258,40 @@ impl LSTMTrainer { self.metrics_history.push(metrics.clone()); + // Check early stopping + let mut should_stop = false; + let mut is_best = false; + if let Some(ref mut early_stopper) = self.early_stopper { + let (stop, best) = early_stopper.should_stop(&metrics, &self.network); + should_stop = stop; + is_best = best; + } + if epoch % self.config.print_every == 0 { + let best_indicator = if is_best { " *" } else { "" }; if let Some(val_loss) = validation_loss { - println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", - epoch, epoch_loss, val_loss, current_lr, time_elapsed); + println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", + epoch, epoch_loss, val_loss, current_lr, time_elapsed, best_indicator); } else { - println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", - epoch, epoch_loss, current_lr, time_elapsed); + println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", + epoch, epoch_loss, current_lr, time_elapsed, best_indicator); + } + } + + if should_stop { + let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap(); + let best_score = self.early_stopper.as_ref().unwrap().best_score(); + println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score); + + // Restore best weights if configured + if let Some(ref early_stopper) = self.early_stopper { + if let Err(e) = early_stopper.restore_best_weights(&mut self.network) { + println!("Warning: Could not restore best weights: {}", e); + } else { + println!("Restored best weights from epoch with score {:.6}", best_score); + } } + break; } } @@ -229,6 +377,7 @@ pub struct ScheduledLSTMTrainer, pub config: TrainingConfig, pub metrics_history: Vec, + early_stopper: Option, } impl ScheduledLSTMTrainer { @@ -239,10 +388,15 @@ impl ScheduledLSTMTrain optimizer, config: TrainingConfig::default(), metrics_history: Vec::new(), + early_stopper: None, } } pub fn with_config(mut self, config: TrainingConfig) -> Self { + // Initialize early stopper if early stopping is configured + self.early_stopper = config.early_stopping.as_ref().map(|es_config| { + EarlyStopper::new(es_config.clone()) + }); self.config = config; self } @@ -338,14 +492,40 @@ impl ScheduledLSTMTrain self.metrics_history.push(metrics.clone()); + // Check early stopping + let mut should_stop = false; + let mut is_best = false; + if let Some(ref mut early_stopper) = self.early_stopper { + let (stop, best) = early_stopper.should_stop(&metrics, &self.network); + should_stop = stop; + is_best = best; + } + if epoch % self.config.print_every == 0 { + let best_indicator = if is_best { " *" } else { "" }; if let Some(val_loss) = validation_loss { - println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", - epoch, epoch_loss, val_loss, new_lr, time_elapsed); + println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", + epoch, epoch_loss, val_loss, new_lr, time_elapsed, best_indicator); } else { - println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", - epoch, epoch_loss, new_lr, time_elapsed); + println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", + epoch, epoch_loss, new_lr, time_elapsed, best_indicator); + } + } + + if should_stop { + let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap(); + let best_score = self.early_stopper.as_ref().unwrap().best_score(); + println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score); + + // Restore best weights if configured + if let Some(ref early_stopper) = self.early_stopper { + if let Err(e) = early_stopper.restore_best_weights(&mut self.network) { + println!("Warning: Could not restore best weights: {}", e); + } else { + println!("Restored best weights from epoch with score {:.6}", best_score); + } } + break; } } @@ -447,6 +627,7 @@ pub struct LSTMBatchTrainer { pub optimizer: O, pub config: TrainingConfig, pub metrics_history: Vec, + early_stopper: Option, } impl LSTMBatchTrainer { @@ -457,10 +638,15 @@ impl LSTMBatchTrainer { optimizer, config: TrainingConfig::default(), metrics_history: Vec::new(), + early_stopper: None, } } pub fn with_config(mut self, config: TrainingConfig) -> Self { + // Initialize early stopper if early stopping is configured + self.early_stopper = config.early_stopping.as_ref().map(|es_config| { + EarlyStopper::new(es_config.clone()) + }); self.config = config; self } @@ -645,15 +831,41 @@ impl LSTMBatchTrainer { self.metrics_history.push(metrics.clone()); + // Check early stopping + let mut should_stop = false; + let mut is_best = false; + if let Some(ref mut early_stopper) = self.early_stopper { + let (stop, best) = early_stopper.should_stop(&metrics, &self.network); + should_stop = stop; + is_best = best; + } + if epoch % self.config.print_every == 0 { + let best_indicator = if is_best { " *" } else { "" }; if let Some(val_loss) = validation_loss { - println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}", - epoch, epoch_loss, val_loss, current_lr, time_elapsed, num_batches); + println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}{}", + epoch, epoch_loss, val_loss, current_lr, time_elapsed, num_batches, best_indicator); } else { - println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}", - epoch, epoch_loss, current_lr, time_elapsed, num_batches); + println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}{}", + epoch, epoch_loss, current_lr, time_elapsed, num_batches, best_indicator); } } + + if should_stop { + let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap(); + let best_score = self.early_stopper.as_ref().unwrap().best_score(); + println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score); + + // Restore best weights if configured + if let Some(ref early_stopper) = self.early_stopper { + if let Err(e) = early_stopper.restore_best_weights(&mut self.network) { + println!("Warning: Could not restore best weights: {}", e); + } else { + println!("Restored best weights from epoch with score {:.6}", best_score); + } + } + break; + } } println!("Batch training completed!"); diff --git a/tests/early_stopping_test.rs b/tests/early_stopping_test.rs new file mode 100644 index 0000000..0f275b9 --- /dev/null +++ b/tests/early_stopping_test.rs @@ -0,0 +1,258 @@ +use rust_lstm::*; +use ndarray::arr2; + +/// Test basic early stopping functionality +#[test] +fn test_early_stopping_basic() { + let network = LSTMNetwork::new(1, 4, 1); + + // Create a simple dataset that will converge quickly + let train_data = vec![ + (vec![arr2(&[[1.0]])], vec![arr2(&[[0.5]])]), + (vec![arr2(&[[0.5]])], vec![arr2(&[[0.25]])]), + ]; + + let val_data = vec![ + (vec![arr2(&[[0.8]])], vec![arr2(&[[0.4]])]), + ]; + + // Configure early stopping with very low patience for quick test + let early_stopping_config = EarlyStoppingConfig { + patience: 3, + min_delta: 1e-2, // Higher threshold to make early stopping more likely + restore_best_weights: true, + monitor: EarlyStoppingMetric::ValidationLoss, + }; + + let training_config = TrainingConfig { + epochs: 50, // Should stop early + print_every: 10, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + trainer.train(&train_data, Some(&val_data)); + + // Early stopping should have been configured (this test just verifies the configuration works) + let final_metrics = trainer.get_latest_metrics().unwrap(); + assert!(final_metrics.epoch >= 0, "Training should have run at least one epoch"); +} + +/// Test early stopping with training loss monitoring +#[test] +fn test_early_stopping_train_loss() { + let network = LSTMNetwork::new(1, 4, 1); + + let train_data = vec![ + (vec![arr2(&[[1.0]])], vec![arr2(&[[0.5]])]), + (vec![arr2(&[[0.5]])], vec![arr2(&[[0.25]])]), + ]; + + // Configure early stopping to monitor training loss + let early_stopping_config = EarlyStoppingConfig { + patience: 4, + min_delta: 1e-2, // Higher threshold to make early stopping more likely + restore_best_weights: false, + monitor: EarlyStoppingMetric::TrainLoss, + }; + + let training_config = TrainingConfig { + epochs: 50, + print_every: 10, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + trainer.train(&train_data, None); // No validation data + + let final_metrics = trainer.get_latest_metrics().unwrap(); + assert!(final_metrics.epoch >= 0, "Training should have run with train loss monitoring"); +} + +/// Test that training without early stopping runs full epochs +#[test] +fn test_no_early_stopping() { + let network = LSTMNetwork::new(1, 4, 1); + + let train_data = vec![ + (vec![arr2(&[[1.0]])], vec![arr2(&[[0.5]])]), + ]; + + let training_config = TrainingConfig { + epochs: 10, + print_every: 5, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: None, // No early stopping + }; + + let mut trainer = create_basic_trainer(network, 0.01) + .with_config(training_config); + + trainer.train(&train_data, None); + + let final_metrics = trainer.get_latest_metrics().unwrap(); + assert_eq!(final_metrics.epoch, 9, "Should run all 10 epochs (0-indexed)"); +} + +/// Test early stopper configuration +#[test] +fn test_early_stopper_config() { + let config = EarlyStoppingConfig { + patience: 5, + min_delta: 1e-3, + restore_best_weights: true, + monitor: EarlyStoppingMetric::ValidationLoss, + }; + + let mut stopper = EarlyStopper::new(config.clone()); + + // Test initial state + assert_eq!(stopper.best_score(), f64::INFINITY); + assert_eq!(stopper.stopped_epoch(), None); + + // Create dummy network and metrics for testing + let network = LSTMNetwork::new(1, 2, 1); + let metrics = TrainingMetrics { + epoch: 0, + train_loss: 1.0, + validation_loss: Some(0.5), + time_elapsed: 1.0, + learning_rate: 0.01, + }; + + // First call should not stop and should be best + let (should_stop, is_best) = stopper.should_stop(&metrics, &network); + assert!(!should_stop); + assert!(is_best); + assert_eq!(stopper.best_score(), 0.5); +} + +/// Test early stopping with different min_delta values +#[test] +fn test_early_stopping_min_delta() { + let mut stopper = EarlyStopper::new(EarlyStoppingConfig { + patience: 2, + min_delta: 0.1, // Require significant improvement + restore_best_weights: false, + monitor: EarlyStoppingMetric::ValidationLoss, + }); + + let network = LSTMNetwork::new(1, 2, 1); + + // First metric - should be best + let metrics1 = TrainingMetrics { + epoch: 0, + train_loss: 1.0, + validation_loss: Some(1.0), + time_elapsed: 1.0, + learning_rate: 0.01, + }; + let (should_stop, is_best) = stopper.should_stop(&metrics1, &network); + assert!(!should_stop); + assert!(is_best); + + // Small improvement (less than min_delta) - should not be considered improvement + let metrics2 = TrainingMetrics { + epoch: 1, + train_loss: 0.95, + validation_loss: Some(0.95), // Only 0.05 improvement, less than 0.1 min_delta + time_elapsed: 1.0, + learning_rate: 0.01, + }; + let (should_stop, is_best) = stopper.should_stop(&metrics2, &network); + assert!(!should_stop); + assert!(!is_best); // Should not be considered best due to min_delta + + // Another small improvement - should trigger early stopping due to patience + let metrics3 = TrainingMetrics { + epoch: 2, + train_loss: 0.9, + validation_loss: Some(0.9), + time_elapsed: 1.0, + learning_rate: 0.01, + }; + let (should_stop, is_best) = stopper.should_stop(&metrics3, &network); + assert!(should_stop); // Should stop due to patience exhausted + assert!(!is_best); +} + +/// Test early stopping with scheduled trainer +#[test] +fn test_early_stopping_with_scheduled_trainer() { + use rust_lstm::{ScheduledOptimizer, StepLR, Adam}; + + let network = LSTMNetwork::new(1, 4, 1); + let optimizer = ScheduledOptimizer::new(Adam::new(0.01), StepLR::new(5, 0.5), 0.01); + + let train_data = vec![ + (vec![arr2(&[[1.0]])], vec![arr2(&[[0.5]])]), + ]; + + let early_stopping_config = EarlyStoppingConfig { + patience: 3, + min_delta: 1e-4, + restore_best_weights: true, + monitor: EarlyStoppingMetric::TrainLoss, + }; + + let training_config = TrainingConfig { + epochs: 30, + print_every: 10, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, MSELoss, optimizer) + .with_config(training_config); + + trainer.train(&train_data, None); + + // Should complete successfully with early stopping + let final_metrics = trainer.get_latest_metrics().unwrap(); + assert!(final_metrics.epoch >= 0, "Scheduled trainer should support early stopping"); +} + +/// Test early stopping with batch trainer +#[test] +fn test_early_stopping_with_batch_trainer() { + let network = LSTMNetwork::new(1, 4, 1); + + let train_data = vec![ + (vec![arr2(&[[1.0]])], vec![arr2(&[[0.5]])]), + (vec![arr2(&[[0.5]])], vec![arr2(&[[0.25]])]), + ]; + + let early_stopping_config = EarlyStoppingConfig { + patience: 3, + min_delta: 1e-4, + restore_best_weights: true, + monitor: EarlyStoppingMetric::TrainLoss, + }; + + let training_config = TrainingConfig { + epochs: 30, + print_every: 10, + clip_gradient: Some(1.0), + log_lr_changes: false, + early_stopping: Some(early_stopping_config), + }; + + let mut trainer = create_adam_batch_trainer(network, 0.01) + .with_config(training_config); + + trainer.train(&train_data, None, 2); // Batch size 2 + + // Should complete successfully with early stopping + let final_metrics = trainer.get_latest_metrics().unwrap(); + assert!(final_metrics.epoch >= 0, "Batch trainer should support early stopping"); +} diff --git a/tests/persistence_test.rs b/tests/persistence_test.rs index 9500237..10863f6 100644 --- a/tests/persistence_test.rs +++ b/tests/persistence_test.rs @@ -179,6 +179,7 @@ fn test_persistence_with_trained_model() { print_every: 1, clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; trainer = trainer.with_config(config); trainer.train(&train_data, None); diff --git a/tests/readme_examples_test.rs b/tests/readme_examples_test.rs index 5b7c2a3..590caef 100644 --- a/tests/readme_examples_test.rs +++ b/tests/readme_examples_test.rs @@ -95,6 +95,7 @@ fn test_training_example() { print_every: 1, clip_gradient: Some(1.0), log_lr_changes: false, + early_stopping: None, }; trainer = trainer.with_config(config);