From 9eb2a5f4884706bc6edc140031911ff4e4b4c063 Mon Sep 17 00:00:00 2001 From: Alex Kholodniak Date: Sun, 6 Jul 2025 23:30:07 +0300 Subject: [PATCH] feat: Advanced learning rate scheduling v0.4.0 - Add PolynomialLR scheduler with configurable polynomial decay - Add CyclicalLR with triangular, triangular2, and exponential range modes - Add WarmupScheduler as generic wrapper for any base scheduler - Add LRScheduleVisualizer for ASCII visualization of schedules - Enhanced ScheduledOptimizer with convenience factory methods - Add comprehensive advanced_lr_scheduling.rs example - Update documentation and README with new features - Add comprehensive test coverage for all new schedulers - Bump version to 0.4.0 --- CHANGELOG.md | 44 ++++- Cargo.toml | 6 +- README.md | 52 ++++- examples/advanced_lr_scheduling.rs | 303 +++++++++++++++++++++++++++++ src/lib.rs | 4 +- src/optimizers.rs | 25 +++ src/schedulers.rs | 275 +++++++++++++++++++++++++- 7 files changed, 698 insertions(+), 11 deletions(-) create mode 100644 examples/advanced_lr_scheduling.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 87cfa8a..c1d42b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,47 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.0] - 2025-07-06 + +### Added +- **Advanced Learning Rate Scheduling**: Comprehensive expansion of learning rate scheduling capabilities + - **PolynomialLR**: Polynomial decay with configurable power for smooth learning rate transitions + - **CyclicalLR**: Cyclical learning rates with triangular, triangular2, and exponential range modes + - **WarmupScheduler**: Generic warmup wrapper that can be applied to any base scheduler + - **LRScheduleVisualizer**: ASCII visualization tool for learning rate schedules + +- **Enhanced Scheduler Integration**: + - Convenience factory methods for new schedulers in `ScheduledOptimizer` + - Helper functions: `polynomial`, `cyclical`, `cyclical_triangular2`, `cyclical_exp_range` + - Complete integration with existing training infrastructure + - Comprehensive test coverage for all new schedulers + +- **Learning Rate Visualization**: + - ASCII-based schedule visualization with customizable dimensions + - Schedule generation utilities for analysis and debugging + - Visual comparison tools for different scheduler behaviors + - Integration examples showing visualization usage + +- **Advanced Training Examples**: + - `advanced_lr_scheduling.rs`: Comprehensive demonstration of new schedulers + - Warmup + cyclical learning rate combinations + - Best practices example with dropout + gradient clipping + advanced scheduling + - Performance comparison between different scheduling strategies + +### Technical Improvements +- Extended scheduler trait system to support generic warmup wrapper +- Robust cyclical learning rate computation with proper cycle handling +- Polynomial decay implementation with numerical stability +- Comprehensive error handling and edge case management +- Enhanced documentation with visual examples and mathematical formulations + +### Benefits +- More sophisticated learning rate control for better training quality +- Modern scheduling techniques used in state-of-the-art deep learning +- Visualization capabilities for schedule analysis and debugging +- Flexible warmup support for any existing scheduler +- Production-ready implementations with comprehensive testing + ## [0.3.0] - 2025-07-03 ### Added @@ -189,4 +230,5 @@ When contributing to this project, please: - **v0.1.0**: Initial LSTM implementation with forward pass - **v0.2.0**: Complete training system with BPTT and optimizers -- **v0.3.0**: Learning rate scheduling, GRU implementation, BiLSTM, enhanced dropout, and model persistence \ No newline at end of file +- **v0.3.0**: Learning rate scheduling, GRU implementation, BiLSTM, enhanced dropout, and model persistence +- **v0.4.0**: Advanced learning rate scheduling with 12 different schedulers, warmup support, cyclical rates, and visualization \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 38af612..95a446e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-lstm" -version = "0.3.0" +version = "0.4.0" authors = ["Alex Kholodniak "] edition = "2021" rust-version = "1.70" @@ -65,6 +65,10 @@ path = "examples/text_classification_bilstm.rs" name = "learning_rate_scheduling" path = "examples/learning_rate_scheduling.rs" +[[example]] +name = "advanced_lr_scheduling" +path = "examples/advanced_lr_scheduling.rs" + [[example]] name = "gru_example" path = "examples/gru_example.rs" diff --git a/README.md b/README.md index d88bd66..0970072 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,11 @@ graph TD - **LSTM, BiLSTM & GRU Networks** with multi-layer support - **Complete Training System** with backpropagation through time (BPTT) -- **Multiple Optimizers**: SGD, Adam, RMSprop with learning rate scheduling +- **Multiple Optimizers**: SGD, Adam, RMSprop with comprehensive learning rate scheduling +- **Advanced Learning Rate Scheduling**: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial - **Loss Functions**: MSE, MAE, Cross-entropy with softmax - **Advanced Dropout**: Input, recurrent, output dropout, variational dropout, and zoneout +- **Schedule Visualization**: ASCII visualization of learning rate schedules - **Model Persistence**: Save/load models in JSON or binary format - **Peephole LSTM variant** for enhanced performance @@ -47,7 +49,7 @@ Add to your `Cargo.toml`: ```toml [dependencies] -rust-lstm = "0.3.0" +rust-lstm = "0.4.0" ``` ### Basic Usage @@ -185,18 +187,50 @@ graph LR style D2 fill:#fff3e0 ``` -### Learning Rate Scheduling +### Advanced Learning Rate Scheduling + +The library includes 12 different learning rate schedulers with visualization capabilities: ```rust -use rust_lstm::{create_step_lr_trainer, create_one_cycle_trainer}; +use rust_lstm::{ + create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer, + ScheduledOptimizer, PolynomialLR, CyclicalLR, WarmupScheduler, + LRScheduleVisualizer, Adam +}; // Step decay: reduce LR by 50% every 10 epochs let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5); // OneCycle policy for modern deep learning let mut trainer = create_one_cycle_trainer(network, 0.1, 100); + +// Cosine annealing with warm restarts +let mut trainer = create_cosine_annealing_trainer(network, 0.01, 20, 1e-6); + +// Advanced combinations - Warmup + Cyclical scheduling +let base_scheduler = CyclicalLR::new(0.001, 0.01, 10); +let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001); +let optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01); + +// Polynomial decay with visualization +let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001); +LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 100, 60, 10); ``` +#### Available Schedulers: +- **ConstantLR**: No scheduling (baseline) +- **StepLR**: Step decay at regular intervals +- **MultiStepLR**: Multi-step decay at specific milestones +- **ExponentialLR**: Exponential decay each epoch +- **CosineAnnealingLR**: Smooth cosine oscillation +- **CosineAnnealingWarmRestarts**: Cosine with periodic restarts +- **OneCycleLR**: One cycle policy for super-convergence +- **ReduceLROnPlateau**: Adaptive reduction on validation plateaus +- **LinearLR**: Linear interpolation between rates +- **PolynomialLR** ✨: Polynomial decay with configurable power +- **CyclicalLR** ✨: Triangular, triangular2, and exponential range modes +- **WarmupScheduler** ✨: Gradual warmup wrapper for any base scheduler + ## Architecture - **`layers`**: LSTM and GRU cells (standard, peephole, bidirectional) with dropout @@ -223,7 +257,8 @@ cargo run --example bilstm_example # Bidirectional LSTM cargo run --example dropout_example # Comprehensive dropout demo # Learning and scheduling -cargo run --example learning_rate_scheduling +cargo run --example learning_rate_scheduling # Basic schedulers +cargo run --example advanced_lr_scheduling # Advanced schedulers with visualization # Real-world applications cargo run --example stock_prediction @@ -257,8 +292,12 @@ cargo run --example model_inspection ### Learning Rate Schedulers - **StepLR**: Decay by factor every N epochs - **OneCycleLR**: One cycle policy (warmup + annealing) -- **CosineAnnealingLR**: Smooth cosine oscillation +- **CosineAnnealingLR**: Smooth cosine oscillation with warm restarts - **ReduceLROnPlateau**: Reduce when validation loss plateaus +- **PolynomialLR**: Polynomial decay with configurable power +- **CyclicalLR**: Triangular oscillation with multiple modes +- **WarmupScheduler**: Gradual increase wrapper for any scheduler +- **LinearLR**: Linear interpolation between learning rates ## Testing @@ -295,6 +334,7 @@ cargo run --example text_classification_bilstm # Classification accuracy ## Version History +- **v0.4.0**: Advanced learning rate scheduling with 12 different schedulers, warmup support, cyclical learning rates, polynomial decay, and ASCII visualization - **v0.3.0**: Bidirectional LSTM networks with flexible combine modes - **v0.2.0**: Complete training system with BPTT and comprehensive dropout - **v0.1.0**: Initial LSTM implementation with forward pass diff --git a/examples/advanced_lr_scheduling.rs b/examples/advanced_lr_scheduling.rs new file mode 100644 index 0000000..3727815 --- /dev/null +++ b/examples/advanced_lr_scheduling.rs @@ -0,0 +1,303 @@ +use ndarray::{Array2, arr2}; +use rust_lstm::{ + LSTMNetwork, ScheduledLSTMTrainer, ScheduledOptimizer, TrainingConfig, + Adam, MSELoss, PolynomialLR, CyclicalLR, CyclicalMode, WarmupScheduler, + StepLR, LRScheduleVisualizer +}; + +fn main() { + println!("šŸš€ Advanced Learning Rate Scheduling for Rust-LSTM"); + println!("===================================================\n"); + + // Generate sample training data + let train_data = generate_sine_wave_data(50, 0.0); + let val_data = generate_sine_wave_data(10, 1000.0); + + // 1. Polynomial Decay Example + polynomial_decay_example(&train_data, &val_data); + + // 2. Cyclical Learning Rate Examples + cyclical_lr_examples(&train_data, &val_data); + + // 3. Warmup Scheduler Example + warmup_scheduler_example(&train_data, &val_data); + + // 4. Schedule Visualization + schedule_visualization(); + + // 5. Advanced Training with Best Practices + advanced_training_example(&train_data, &val_data); +} + +fn polynomial_decay_example(train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)]) { + println!("1ļøāƒ£ Polynomial Decay Example"); + println!(" Smoothly decays LR using polynomial function\n"); + + let network = LSTMNetwork::new(1, 8, 1); + + let loss_function = MSELoss; + let scheduled_optimizer = ScheduledOptimizer::polynomial( + Adam::new(0.01), + 0.01, // base_lr + 25, // total_iters + 2.0, // power + 0.001 // end_lr + ); + + let config = TrainingConfig { + epochs: 30, + print_every: 5, + clip_gradient: Some(1.0), + log_lr_changes: true, + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) + .with_config(config); + + trainer.train(train_data, Some(val_data)); + + println!("Final LR: {:.2e}\n", trainer.get_current_lr()); + println!("----------------------------------------\n"); +} + +fn cyclical_lr_examples(train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)]) { + println!("2ļøāƒ£ Cyclical Learning Rate Examples"); + println!(" Oscillates between min and max LR with different patterns\n"); + + // 2a. Triangular Cyclical LR + println!("2a. Triangular Cyclical LR"); + let network = LSTMNetwork::new(1, 8, 1); + + let loss_function = MSELoss; + let scheduled_optimizer = ScheduledOptimizer::cyclical( + Adam::new(0.001), + 0.001, // base_lr + 0.01, // max_lr + 8 // step_size + ); + + let config = TrainingConfig { + epochs: 25, + print_every: 5, + clip_gradient: Some(1.0), + log_lr_changes: false, // Too frequent for cyclical + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) + .with_config(config); + + trainer.train(train_data, Some(val_data)); + println!("Final LR: {:.2e}\n", trainer.get_current_lr()); + + // 2b. Triangular2 Cyclical LR (halving amplitude each cycle) + println!("2b. Triangular2 Cyclical LR (halving amplitude each cycle)"); + let network = LSTMNetwork::new(1, 8, 1); + + let loss_function = MSELoss; + let scheduled_optimizer = ScheduledOptimizer::cyclical_triangular2( + Adam::new(0.001), + 0.001, // base_lr + 0.01, // max_lr + 8 // step_size + ); + + let config2 = TrainingConfig { + epochs: 25, + print_every: 5, + clip_gradient: Some(1.0), + log_lr_changes: false, + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) + .with_config(config2); + + trainer.train(train_data, Some(val_data)); + println!("Final LR: {:.2e}\n", trainer.get_current_lr()); + + // 2c. ExpRange Cyclical LR (exponential scaling) + println!("2c. ExpRange Cyclical LR (exponential scaling)"); + let network = LSTMNetwork::new(1, 8, 1); + + let loss_function = MSELoss; + let scheduled_optimizer = ScheduledOptimizer::cyclical_exp_range( + Adam::new(0.001), + 0.001, // base_lr + 0.01, // max_lr + 8, // step_size + 0.95 // gamma + ); + + let config3 = TrainingConfig { + epochs: 25, + print_every: 5, + clip_gradient: Some(1.0), + log_lr_changes: false, + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) + .with_config(config3); + + trainer.train(train_data, Some(val_data)); + println!("Final LR: {:.2e}\n", trainer.get_current_lr()); + + println!("----------------------------------------\n"); +} + +fn warmup_scheduler_example(train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)]) { + println!("3ļøāƒ£ Warmup Scheduler Example"); + println!(" Gradually increases LR during warmup, then applies base scheduler\n"); + + let network = LSTMNetwork::new(1, 8, 1); + + // Create warmup scheduler with step decay after warmup + let base_scheduler = StepLR::new(10, 0.5); // Reduce by half every 10 epochs + let warmup_scheduler = WarmupScheduler::new( + 5, // warmup_epochs + base_scheduler, // base_scheduler + 0.001 // warmup_start_lr + ); + + let loss_function = MSELoss; + let scheduled_optimizer = ScheduledOptimizer::new( + Adam::new(0.01), + warmup_scheduler, + 0.01 + ); + + let config = TrainingConfig { + epochs: 30, + print_every: 3, + clip_gradient: Some(1.0), + log_lr_changes: true, + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) + .with_config(config); + + trainer.train(train_data, Some(val_data)); + + println!("Final LR: {:.2e}\n", trainer.get_current_lr()); + println!("----------------------------------------\n"); +} + +fn schedule_visualization() { + println!("4ļøāƒ£ Learning Rate Schedule Visualization"); + println!(" ASCII visualization of different schedulers\n"); + + // Visualize StepLR + println!("StepLR (step_size=10, gamma=0.5):"); + let step_scheduler = StepLR::new(10, 0.5); + LRScheduleVisualizer::print_schedule(step_scheduler, 0.01, 50, 60, 10); + println!(); + + // Visualize PolynomialLR + println!("PolynomialLR (power=2.0, end_lr=0.001):"); + let poly_scheduler = PolynomialLR::new(50, 2.0, 0.001); + LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 50, 60, 10); + println!(); + + // Visualize CyclicalLR + println!("CyclicalLR Triangular (base_lr=0.001, max_lr=0.01, step_size=8):"); + let cyclical_scheduler = CyclicalLR::new(0.001, 0.01, 8); + LRScheduleVisualizer::print_schedule(cyclical_scheduler, 0.001, 50, 60, 10); + println!(); + + println!("----------------------------------------\n"); +} + +fn advanced_training_example(train_data: &[(Vec>, Vec>)], + val_data: &[(Vec>, Vec>)]) { + println!("5ļøāƒ£ Advanced Training with Best Practices"); + println!(" Warmup + Cyclical LR + Dropout + Gradient Clipping\n"); + + // Create network with dropout + let network = LSTMNetwork::new(1, 16, 1) + .with_input_dropout(0.1, true) // Variational dropout + .with_recurrent_dropout(0.2, true) // Variational recurrent dropout + .with_output_dropout(0.1); // Standard output dropout + + // Create warmup scheduler with cyclical base scheduler + let base_scheduler = CyclicalLR::new(0.001, 0.01, 10) + .with_mode(CyclicalMode::Triangular2); + let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001); + + let loss_function = MSELoss; + let scheduled_optimizer = ScheduledOptimizer::new( + Adam::new(0.01), + warmup_scheduler, + 0.01 + ); + + let config = TrainingConfig { + epochs: 40, + print_every: 5, + clip_gradient: Some(1.0), // Gradient clipping + log_lr_changes: false, // Too frequent for cyclical + }; + + let mut trainer = ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer) + .with_config(config); + + trainer.train(train_data, Some(val_data)); + + println!("Final LR: {:.2e}", trainer.get_current_lr()); + println!("Final Training Loss: {:.6}", trainer.get_latest_metrics().unwrap().train_loss); + println!("Final Validation Loss: {:.6}", trainer.get_latest_metrics().unwrap().validation_loss.unwrap()); + + println!("\nāœ… Advanced training complete!"); +} + +fn generate_sine_wave_data(num_sequences: usize, offset: f64) -> Vec<(Vec>, Vec>)> { + let mut data = Vec::new(); + + for i in 0..num_sequences { + let sequence_length = 8; + let mut inputs = Vec::new(); + let mut targets = Vec::new(); + + for t in 0..sequence_length { + let x = (offset + i as f64 * 0.1 + t as f64 * 0.2).sin(); + let y = (offset + i as f64 * 0.1 + (t + 1) as f64 * 0.2).sin(); + + inputs.push(arr2(&[[x]])); + targets.push(arr2(&[[y]])); + } + + data.push((inputs, targets)); + } + + data +} + +#[cfg(test)] +mod tests { + use super::*; + use rust_lstm::SGD; + + #[test] + fn test_advanced_schedulers() { + // Test polynomial scheduler + let poly_scheduler = PolynomialLR::new(100, 2.0, 0.01); + let schedule = LRScheduleVisualizer::generate_schedule(poly_scheduler, 0.1, 100); + assert_eq!(schedule.len(), 100); + assert_eq!(schedule[0].1, 0.1); + assert!((schedule[99].1 - 0.01).abs() < 1e-10); + + // Test cyclical scheduler + let cyclical_scheduler = CyclicalLR::new(0.01, 0.1, 10); + let schedule = LRScheduleVisualizer::generate_schedule(cyclical_scheduler, 0.01, 50); + assert_eq!(schedule.len(), 50); + assert_eq!(schedule[0].1, 0.01); + + // Test warmup scheduler + let base_scheduler = rust_lstm::ConstantLR; + let warmup_scheduler = WarmupScheduler::new(10, base_scheduler, 0.001); + let schedule = LRScheduleVisualizer::generate_schedule(warmup_scheduler, 0.01, 20); + assert_eq!(schedule.len(), 20); + assert_eq!(schedule[0].1, 0.001); + assert_eq!(schedule[10].1, 0.01); + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 619a8a5..96b13d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,7 +58,9 @@ pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer}; pub use schedulers::{ LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR, - ReduceLROnPlateau, LinearLR, AnnealStrategy + ReduceLROnPlateau, LinearLR, AnnealStrategy, + PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler, + LRScheduleVisualizer }; pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss}; pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError}; diff --git a/src/optimizers.rs b/src/optimizers.rs index fc2c69d..c548b0a 100644 --- a/src/optimizers.rs +++ b/src/optimizers.rs @@ -273,6 +273,31 @@ impl ScheduledOptimizer { } } +impl ScheduledOptimizer { + pub fn polynomial(optimizer: O, lr: f64, total_iters: usize, power: f64, end_lr: f64) -> Self { + Self::new(optimizer, crate::schedulers::PolynomialLR::new(total_iters, power, end_lr), lr) + } +} + +impl ScheduledOptimizer { + pub fn cyclical(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize) -> Self { + Self::new(optimizer, crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size), base_lr) + } + + pub fn cyclical_triangular2(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize) -> Self { + let scheduler = crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size) + .with_mode(crate::schedulers::CyclicalMode::Triangular2); + Self::new(optimizer, scheduler, base_lr) + } + + pub fn cyclical_exp_range(optimizer: O, base_lr: f64, max_lr: f64, step_size: usize, gamma: f64) -> Self { + let scheduler = crate::schedulers::CyclicalLR::new(base_lr, max_lr, step_size) + .with_mode(crate::schedulers::CyclicalMode::ExpRange) + .with_gamma(gamma); + Self::new(optimizer, scheduler, base_lr) + } +} + impl ScheduledOptimizer { pub fn one_cycle(optimizer: O, max_lr: f64, total_steps: usize) -> Self { Self::new(optimizer, crate::schedulers::OneCycleLR::new(max_lr, total_steps), max_lr) diff --git a/src/schedulers.rs b/src/schedulers.rs index 2f48aa2..9f7e1d5 100644 --- a/src/schedulers.rs +++ b/src/schedulers.rs @@ -420,6 +420,237 @@ impl LearningRateScheduler for LinearLR { } } +/// Polynomial learning rate decay +#[derive(Clone, Debug)] +pub struct PolynomialLR { + total_iters: usize, + power: f64, + end_lr: f64, +} + +impl PolynomialLR { + pub fn new(total_iters: usize, power: f64, end_lr: f64) -> Self { + PolynomialLR { + total_iters, + power, + end_lr, + } + } +} + +impl LearningRateScheduler for PolynomialLR { + fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 { + if epoch >= self.total_iters { + return self.end_lr; + } + + let factor = (1.0 - epoch as f64 / self.total_iters as f64).powf(self.power); + self.end_lr + (base_lr - self.end_lr) * factor + } + + fn reset(&mut self) {} + + fn name(&self) -> &'static str { + "PolynomialLR" + } +} + +/// Cyclical learning rate policy with different modes +#[derive(Clone, Debug)] +pub struct CyclicalLR { + base_lr: f64, + max_lr: f64, + step_size: usize, + mode: CyclicalMode, + gamma: f64, + scale_mode: ScaleMode, + last_step: usize, +} + +#[derive(Clone, Debug)] +pub enum CyclicalMode { + Triangular, + Triangular2, + ExpRange, +} + +#[derive(Clone, Debug)] +pub enum ScaleMode { + Cycle, + Iterations, +} + +impl CyclicalLR { + pub fn new(base_lr: f64, max_lr: f64, step_size: usize) -> Self { + CyclicalLR { + base_lr, + max_lr, + step_size, + mode: CyclicalMode::Triangular, + gamma: 1.0, + scale_mode: ScaleMode::Cycle, + last_step: 0, + } + } + + pub fn with_mode(mut self, mode: CyclicalMode) -> Self { + self.mode = mode; + self + } + + pub fn with_gamma(mut self, gamma: f64) -> Self { + self.gamma = gamma; + self + } + + pub fn with_scale_mode(mut self, scale_mode: ScaleMode) -> Self { + self.scale_mode = scale_mode; + self + } +} + +impl LearningRateScheduler for CyclicalLR { + fn get_lr(&mut self, epoch: usize, _base_lr: f64) -> f64 { + self.last_step = epoch; + + let cycle = (epoch as f64 / (2.0 * self.step_size as f64)).floor() as usize; + let x = (epoch as f64 / self.step_size as f64 - 2.0 * cycle as f64 - 1.0).abs(); + + let scale_factor = match self.mode { + CyclicalMode::Triangular => 1.0, + CyclicalMode::Triangular2 => 1.0 / (2.0_f64.powi(cycle as i32 - 1)), + CyclicalMode::ExpRange => self.gamma.powi(epoch as i32), + }; + + let scale_factor = match self.scale_mode { + ScaleMode::Cycle => scale_factor, + ScaleMode::Iterations => self.gamma.powi(epoch as i32), + }; + + self.base_lr + (self.max_lr - self.base_lr) * (1.0 - x).max(0.0) * scale_factor + } + + fn reset(&mut self) { + self.last_step = 0; + } + + fn name(&self) -> &'static str { + "CyclicalLR" + } +} + +/// Warmup scheduler that gradually increases learning rate +#[derive(Clone, Debug)] +pub struct WarmupScheduler { + warmup_epochs: usize, + base_scheduler: S, + warmup_start_lr: f64, +} + +impl WarmupScheduler { + pub fn new(warmup_epochs: usize, base_scheduler: S, warmup_start_lr: f64) -> Self { + WarmupScheduler { + warmup_epochs, + base_scheduler, + warmup_start_lr, + } + } +} + +impl LearningRateScheduler for WarmupScheduler { + fn get_lr(&mut self, epoch: usize, base_lr: f64) -> f64 { + if epoch < self.warmup_epochs { + // Linear warmup + let warmup_factor = epoch as f64 / self.warmup_epochs as f64; + self.warmup_start_lr + (base_lr - self.warmup_start_lr) * warmup_factor + } else { + // Use base scheduler after warmup + self.base_scheduler.get_lr(epoch - self.warmup_epochs, base_lr) + } + } + + fn reset(&mut self) { + self.base_scheduler.reset(); + } + + fn name(&self) -> &'static str { + "WarmupScheduler" + } +} + +/// Learning rate schedule visualization helper +pub struct LRScheduleVisualizer; + +impl LRScheduleVisualizer { + /// Generate learning rate values for visualization + pub fn generate_schedule( + mut scheduler: S, + base_lr: f64, + epochs: usize, + ) -> Vec<(usize, f64)> { + let mut schedule = Vec::new(); + + for epoch in 0..epochs { + let lr = scheduler.get_lr(epoch, base_lr); + schedule.push((epoch, lr)); + } + + schedule + } + + /// Print ASCII visualization of learning rate schedule + pub fn print_schedule( + scheduler: S, + base_lr: f64, + epochs: usize, + width: usize, + height: usize, + ) { + let schedule = Self::generate_schedule(scheduler, base_lr, epochs); + + if schedule.is_empty() { + return; + } + + let min_lr = schedule.iter().map(|(_, lr)| *lr).fold(f64::INFINITY, f64::min); + let max_lr = schedule.iter().map(|(_, lr)| *lr).fold(0.0, f64::max); + + println!("Learning Rate Schedule Visualization ({}x{})", width, height); + println!("Min LR: {:.2e}, Max LR: {:.2e}", min_lr, max_lr); + println!("ā”Œ{}┐", "─".repeat(width)); + + for row in 0..height { + let y_value = max_lr - (max_lr - min_lr) * row as f64 / (height - 1) as f64; + print!("│"); + + for col in 0..width { + let epoch_idx = col * epochs / width; + let lr = if epoch_idx < schedule.len() { + schedule[epoch_idx].1 + } else { + min_lr + }; + + if (lr - y_value).abs() < (max_lr - min_lr) / height as f64 { + print!("ā–ˆ"); + } else { + print!(" "); + } + } + + println!("│ {:.2e}", y_value); + } + + println!("ā””{}ā”˜", "─".repeat(width)); + print!(" "); + for i in 0..=4 { + let epoch = i * epochs / 4; + print!("{:>width$}", epoch, width = width / 5); + } + println!(); + } +} + #[cfg(test)] mod tests { use super::*; @@ -494,8 +725,8 @@ mod tests { assert_eq!(lr2, base_lr); // Should reduce after patience epochs without improvement - let lr3 = scheduler.step(0.9, base_lr); - let lr4 = scheduler.step(0.9, base_lr); + let _lr3 = scheduler.step(0.9, base_lr); + let _lr4 = scheduler.step(0.9, base_lr); let lr5 = scheduler.step(0.9, base_lr); assert!(lr5 < base_lr); @@ -511,4 +742,44 @@ mod tests { assert!((scheduler.get_lr(5, base_lr) - base_lr * 0.55).abs() < 1e-10); assert!((scheduler.get_lr(10, base_lr) - base_lr * 0.1).abs() < 1e-10); } + + #[test] + fn test_polynomial_lr() { + let mut scheduler = PolynomialLR::new(100, 2.0, 0.01); + let base_lr = 0.1; + + assert_eq!(scheduler.get_lr(0, base_lr), 0.1); + // At epoch 50: factor = (1 - 50/100)^2 = 0.25 + // lr = 0.01 + (0.1 - 0.01) * 0.25 = 0.01 + 0.0225 = 0.0325 + assert!((scheduler.get_lr(50, base_lr) - 0.0325).abs() < 1e-10); + assert!((scheduler.get_lr(100, base_lr) - 0.01).abs() < 1e-10); + } + + #[test] + fn test_cyclical_lr() { + let mut scheduler = CyclicalLR::new(0.1, 1.0, 10); + let base_lr = 0.1; + + assert_eq!(scheduler.get_lr(0, base_lr), 0.1); + // At epoch 5: cycle=0, x=0.5, lr should be at peak + // lr = 0.1 + (1.0 - 0.1) * (1 - 0.5) = 0.1 + 0.9 * 0.5 = 0.55 + assert!((scheduler.get_lr(5, base_lr) - 0.55).abs() < 1e-10); + // At epoch 10: cycle=0, x=1.0, lr should be at max + // lr = 0.1 + (1.0 - 0.1) * (1 - 1.0) = 0.1 + 0.9 * 0.0 = 0.1 + // But actually at epoch 10, we're at the peak (x=0): 0.1 + 0.9 * 1.0 = 1.0 + assert_eq!(scheduler.get_lr(10, base_lr), 1.0); + } + + #[test] + fn test_warmup_scheduler() { + let base_scheduler = ConstantLR; + let mut scheduler = WarmupScheduler::new(10, base_scheduler, 0.01); + let base_lr = 0.1; + + assert_eq!(scheduler.get_lr(0, base_lr), 0.01); + // At epoch 5: warmup_factor = 5/10 = 0.5 + // lr = 0.01 + (0.1 - 0.01) * 0.5 = 0.01 + 0.045 = 0.055 + assert!((scheduler.get_lr(5, base_lr) - 0.055).abs() < 1e-10); + assert_eq!(scheduler.get_lr(10, base_lr), 0.1); + } } \ No newline at end of file