Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.6.0] - 2025-09-04

### Added
- **Early Stopping**: Complete early stopping implementation for all trainers
- **Configurable Patience**: Stop training after N epochs without improvement
- **Multiple Metrics**: Monitor validation loss or training loss
- **Best Weight Restoration**: Automatically restore best weights when stopping
- **Flexible Configuration**: Customizable min_delta threshold and monitoring options
- **Universal Support**: Works with LSTMTrainer, ScheduledLSTMTrainer, and LSTMBatchTrainer

- **Enhanced Training Features**:
- **Visual Best Epoch Indicators**: "*" markers in training logs for best epochs
- **Automatic Overfitting Prevention**: Stop training before performance degrades
- **Comprehensive Logging**: Detailed early stopping trigger information
- **Weight Management**: Optional best weight restoration for optimal model recovery

- **Examples and Documentation**:
- **Complete Early Stopping Example**: Demonstration of all early stopping configurations
- **Multiple Scenarios**: Validation loss, training loss, and custom patience examples
- **Integration Examples**: Shows usage with different trainer types
- **Comprehensive Testing**: Full test suite for early stopping functionality

### Enhanced
- **All Trainers**: Added early stopping support to LSTMTrainer, ScheduledLSTMTrainer, and LSTMBatchTrainer
- **Training Configuration**: Extended TrainingConfig with optional early stopping settings
- **Training Loops**: Enhanced all training loops with early stopping logic and best epoch tracking

## [0.5.0] - 2025-08-14

### Added
Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ path = "examples/multi_layer_lstm.rs"
[[example]]
name = "batch_processing_example"
path = "examples/batch_processing_example.rs"

[[example]]
name = "early_stopping_example"
path = "examples/early_stopping_example.rs"
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ graph TD
- **Complete Training System** with backpropagation through time (BPTT)
- **Multiple Optimizers**: SGD, Adam, RMSprop with comprehensive learning rate scheduling
- **Advanced Learning Rate Scheduling**: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial
- **Early Stopping**: Prevent overfitting with configurable patience and metric monitoring
- **Loss Functions**: MSE, MAE, Cross-entropy with softmax
- **Advanced Dropout**: Input, recurrent, output dropout, variational dropout, and zoneout
- **Batch Processing**: 4-5x training speedup with efficient batch operations
- **Schedule Visualization**: ASCII visualization of learning rate schedules
- **Model Persistence**: Save/load models in JSON or binary format
- **Peephole LSTM variant** for enhanced performance
Expand Down Expand Up @@ -98,6 +100,39 @@ fn main() {
}
```

### Early Stopping

```rust
use rust_lstm::{
LSTMNetwork, create_basic_trainer, TrainingConfig,
EarlyStoppingConfig, EarlyStoppingMetric
};

fn main() {
let network = LSTMNetwork::new(1, 10, 2);

// Configure early stopping
let early_stopping = EarlyStoppingConfig {
patience: 10, // Stop after 10 epochs with no improvement
min_delta: 1e-4, // Minimum improvement threshold
restore_best_weights: true, // Restore best weights when stopping
monitor: EarlyStoppingMetric::ValidationLoss, // Monitor validation loss
};

let config = TrainingConfig {
epochs: 1000, // Will likely stop early
early_stopping: Some(early_stopping),
..Default::default()
};

let mut trainer = create_basic_trainer(network, 0.001)
.with_config(config);

// Training will stop early if validation loss stops improving
trainer.train(&train_data, Some(&validation_data));
}
```

### Bidirectional LSTM

```rust
Expand Down Expand Up @@ -258,6 +293,10 @@ cargo run --example dropout_example # Dropout demo
# Learning and scheduling
cargo run --example learning_rate_scheduling # Basic schedulers
cargo run --example advanced_lr_scheduling # Advanced schedulers with visualization
cargo run --example early_stopping_example # Early stopping demonstration

# Performance and batch processing
cargo run --example batch_processing_example # Batch processing with performance benchmarks

# Real-world applications
cargo run --example stock_prediction
Expand Down
6 changes: 6 additions & 0 deletions examples/advanced_lr_scheduling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ fn polynomial_decay_example(train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
};

let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
Expand Down Expand Up @@ -83,6 +84,7 @@ fn cyclical_lr_examples(train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: false, // Too frequent for cyclical
early_stopping: None,
};

let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
Expand All @@ -108,6 +110,7 @@ fn cyclical_lr_examples(train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};

let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
Expand All @@ -134,6 +137,7 @@ fn cyclical_lr_examples(train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};

let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
Expand Down Expand Up @@ -172,6 +176,7 @@ fn warmup_scheduler_example(train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
print_every: 3,
clip_gradient: Some(1.0),
log_lr_changes: true,
early_stopping: None,
};

let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
Expand Down Expand Up @@ -236,6 +241,7 @@ fn advanced_training_example(train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
print_every: 5,
clip_gradient: Some(1.0), // Gradient clipping
log_lr_changes: false, // Too frequent for cyclical
early_stopping: None,
};

let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
Expand Down
1 change: 1 addition & 0 deletions examples/dropout_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ fn demonstrate_training_with_dropout() {
print_every: 5,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: None,
};
trainer = trainer.with_config(config);

Expand Down
227 changes: 227 additions & 0 deletions examples/early_stopping_example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
use ndarray::{Array2, arr2};
use rust_lstm::{
LSTMNetwork, create_basic_trainer, TrainingConfig, EarlyStoppingConfig, EarlyStoppingMetric,
MSELoss, Adam
};

fn main() {
println!("Early Stopping Demonstration");
println!("================================\n");

// Generate synthetic data that will overfit quickly
let (train_data, val_data) = generate_overfitting_data();

println!("Generated {} training sequences and {} validation sequences",
train_data.len(), val_data.len());

// Demonstrate different early stopping configurations
demonstrate_validation_early_stopping(&train_data, &val_data);
demonstrate_train_loss_early_stopping(&train_data, &val_data);
demonstrate_no_weight_restoration(&train_data, &val_data);
demonstrate_custom_patience(&train_data, &val_data);
}

/// Demonstrate early stopping based on validation loss (most common)
fn demonstrate_validation_early_stopping(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("1. VALIDATION LOSS EARLY STOPPING");
println!("==================================");

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping with default settings (validation loss monitoring)
let early_stopping_config = EarlyStoppingConfig {
patience: 5,
min_delta: 1e-4,
restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};

let training_config = TrainingConfig {
epochs: 100, // Will likely stop early
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};

let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);

println!("Training with validation loss monitoring (patience=5)...");
trainer.train(train_data, Some(val_data));

// Show final metrics
if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
}
}

/// Demonstrate early stopping based on training loss
fn demonstrate_train_loss_early_stopping(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("2. TRAINING LOSS EARLY STOPPING");
println!("===============================");

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping to monitor training loss
let early_stopping_config = EarlyStoppingConfig {
patience: 8,
min_delta: 1e-5,
restore_best_weights: true,
monitor: EarlyStoppingMetric::TrainLoss,
};

let training_config = TrainingConfig {
epochs: 100,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};

let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);

println!("Training with training loss monitoring (patience=8)...");
trainer.train(train_data, Some(val_data));

if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
}
}

/// Demonstrate early stopping without weight restoration
fn demonstrate_no_weight_restoration(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("3. EARLY STOPPING WITHOUT WEIGHT RESTORATION");
println!("=============================================");

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping without restoring best weights
let early_stopping_config = EarlyStoppingConfig {
patience: 5,
min_delta: 1e-4,
restore_best_weights: false, // Don't restore best weights
monitor: EarlyStoppingMetric::ValidationLoss,
};

let training_config = TrainingConfig {
epochs: 100,
print_every: 1,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};

let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);

println!("Training without weight restoration...");
trainer.train(train_data, Some(val_data));

if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
println!("Note: Weights are from the last epoch, not the best epoch\n");
}
}

/// Demonstrate early stopping with custom patience
fn demonstrate_custom_patience(
train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
val_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]
) {
println!("4. EARLY STOPPING WITH HIGH PATIENCE");
println!("====================================");

let network = LSTMNetwork::new(1, 8, 1);

// Configure early stopping with higher patience
let early_stopping_config = EarlyStoppingConfig {
patience: 15, // More patient
min_delta: 1e-6, // Smaller improvement threshold
restore_best_weights: true,
monitor: EarlyStoppingMetric::ValidationLoss,
};

let training_config = TrainingConfig {
epochs: 100,
print_every: 2,
clip_gradient: Some(1.0),
log_lr_changes: false,
early_stopping: Some(early_stopping_config),
};

let mut trainer = create_basic_trainer(network, 0.01)
.with_config(training_config);

println!("Training with high patience (patience=15)...");
trainer.train(train_data, Some(val_data));

if let Some(final_metrics) = trainer.get_latest_metrics() {
println!("Final epoch: {}, Train loss: {:.6}, Val loss: {:.6}\n",
final_metrics.epoch,
final_metrics.train_loss,
final_metrics.validation_loss.unwrap_or(0.0));
}
}

/// Generate synthetic data that will cause overfitting
/// This creates a simple pattern that's easy to memorize but doesn't generalize well
fn generate_overfitting_data() -> (Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)>, Vec<(Vec<Array2<f64>>, Vec<Array2<f64>>)>) {
let mut train_data = Vec::new();
let mut val_data = Vec::new();

// Create training data - simple sine wave with noise
for i in 0..20 {
let mut inputs = Vec::new();
let mut targets = Vec::new();

let phase = i as f64 * 0.1;
for t in 0..10 {
let x = (t as f64 * 0.3 + phase).sin();
let y = ((t + 1) as f64 * 0.3 + phase).sin(); // Next value

inputs.push(arr2(&[[x]]));
targets.push(arr2(&[[y]]));
}

train_data.push((inputs, targets));
}

// Create validation data - different phase to test generalization
for i in 0..5 {
let mut inputs = Vec::new();
let mut targets = Vec::new();

let phase = (i as f64 + 100.0) * 0.1; // Different phase
for t in 0..10 {
let x = (t as f64 * 0.3 + phase).sin();
let y = ((t + 1) as f64 * 0.3 + phase).sin();

inputs.push(arr2(&[[x]]));
targets.push(arr2(&[[y]]));
}

val_data.push((inputs, targets));
}

(train_data, val_data)
}
Loading