Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ path = "examples/batch_processing_example.rs"
[[example]]
name = "early_stopping_example"
path = "examples/early_stopping_example.rs"

[[example]]
name = "linear_layer_example"
path = "examples/linear_layer_example.rs"
75 changes: 35 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ graph TD
## Features

- **LSTM, BiLSTM & GRU Networks** with multi-layer support
- **Linear (Dense) Layer** for classification and output projection
- **Complete Training System** with backpropagation through time (BPTT)
- **Multiple Optimizers**: SGD, Adam, RMSprop with comprehensive learning rate scheduling
- **Advanced Learning Rate Scheduling**: 12 different schedulers including OneCycle, Warmup, Cyclical, and Polynomial
- **Early Stopping**: Prevent overfitting with configurable patience and metric monitoring
- **Multiple Optimizers**: SGD, Adam, RMSprop with learning rate scheduling
- **Learning Rate Scheduling**: 12 schedulers including OneCycle, Warmup, Cyclical, Polynomial
- **Early Stopping**: Configurable patience and metric monitoring
- **Loss Functions**: MSE, MAE, Cross-entropy with softmax
- **Advanced Dropout**: Input, recurrent, output dropout, variational dropout, and zoneout
- **Batch Processing**: 4-5x training speedup with efficient batch operations
- **Schedule Visualization**: ASCII visualization of learning rate schedules
- **Advanced Dropout**: Input, recurrent, output, variational dropout, and zoneout
- **Batch Processing**: Efficient batch operations
- **Model Persistence**: Save/load models in JSON or binary format
- **Peephole LSTM variant** for enhanced performance

Expand All @@ -51,7 +51,7 @@ Add to your `Cargo.toml`:

```toml
[dependencies]
rust-lstm = "0.5.0"
rust-lstm = "0.6"
```

### Basic Usage
Expand Down Expand Up @@ -181,6 +181,24 @@ let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
let (output, _) = gru.forward(&input, &hidden_state);
```

### Linear Layer

```rust
use rust_lstm::layers::linear::LinearLayer;
use rust_lstm::optimizers::Adam;

// Create linear layer for classification: hidden_size -> num_classes
let mut classifier = LinearLayer::new(hidden_size, num_classes);
let mut optimizer = Adam::new(0.001);

// Forward pass
let logits = classifier.forward(&lstm_output);

// Backward pass
let (gradients, input_grad) = classifier.backward(&grad_output);
classifier.update_parameters(&gradients, &mut optimizer, "classifier");
```

#### LSTM vs GRU Cell Comparison

```mermaid
Expand Down Expand Up @@ -261,13 +279,13 @@ LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 100, 60, 10);
- **OneCycleLR**: One cycle policy for super-convergence
- **ReduceLROnPlateau**: Adaptive reduction on validation plateaus
- **LinearLR**: Linear interpolation between rates
- **PolynomialLR**: Polynomial decay with configurable power
- **CyclicalLR**: Triangular, triangular2, and exponential range modes
- **WarmupScheduler**: Gradual warmup wrapper for any base scheduler
- **PolynomialLR**: Polynomial decay with configurable power
- **CyclicalLR**: Triangular, triangular2, and exponential range modes
- **WarmupScheduler**: Gradual warmup wrapper for any base scheduler

## Architecture

- **`layers`**: LSTM and GRU cells (standard, peephole, bidirectional) with dropout
- **`layers`**: LSTM cells, GRU cells, Linear (dense) layer, dropout, peephole LSTM, bidirectional LSTM
- **`models`**: High-level network architectures (LSTM, BiLSTM, GRU)
- **`training`**: Training utilities with automatic train/eval mode switching
- **`optimizers`**: SGD, Adam, RMSprop with scheduling
Expand All @@ -288,7 +306,8 @@ cargo run --example time_series_prediction
# Advanced architectures
cargo run --example gru_example # GRU vs LSTM comparison
cargo run --example bilstm_example # Bidirectional LSTM
cargo run --example dropout_example # Dropout demo
cargo run --example dropout_example # Dropout regularization
cargo run --example linear_layer_example # Linear layer for classification

# Learning and scheduling
cargo run --example learning_rate_scheduling # Basic schedulers
Expand Down Expand Up @@ -343,36 +362,12 @@ cargo run --example model_inspection
cargo test
```

## Performance Examples

The library includes comprehensive examples that demonstrate its capabilities:

### Training with Different Schedulers
Run the learning rate scheduling examples to see different scheduler behaviors:
```bash
cargo run --example learning_rate_scheduling # Compare basic schedulers
cargo run --example advanced_lr_scheduling # Advanced schedulers with ASCII visualization
```

### Architecture Comparison
Compare LSTM vs GRU performance:
```bash
cargo run --example gru_example
```

### Real-world Applications
Test the library with practical examples:
```bash
cargo run --example stock_prediction # Stock price predictions
cargo run --example weather_prediction # Weather forecasting
cargo run --example text_classification_bilstm # Classification accuracy
```

The examples output training metrics, loss values, and predictions that you can analyze or plot with external tools.

## Version History

- **v0.4.0**: Advanced learning rate scheduling with 12 different schedulers, warmup support, cyclical learning rates, polynomial decay, and ASCII visualization
- **v0.6.1**: Fixed text generation in advanced example
- **v0.6.0**: Early stopping support with configurable patience and metric monitoring
- **v0.5.0**: Model persistence (JSON/binary), batch processing
- **v0.4.0**: Advanced learning rate scheduling (12 schedulers), warmup, cyclical LR, visualization
- **v0.3.0**: Bidirectional LSTM networks with flexible combine modes
- **v0.2.0**: Complete training system with BPTT and comprehensive dropout
- **v0.1.0**: Initial LSTM implementation with forward pass
Expand Down
252 changes: 252 additions & 0 deletions examples/linear_layer_example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
use ndarray::arr2;
use rust_lstm::layers::linear::LinearLayer;
use rust_lstm::optimizers::{SGD, Adam};
use rust_lstm::models::lstm_network::LSTMNetwork;

/// Example 1: Basic LinearLayer usage for classification
fn basic_classification_example() {
println!("=== Basic Classification Example ===");

// Create a linear layer: 4 input features -> 3 classes
let mut linear = LinearLayer::new(4, 3);
let mut optimizer = SGD::new(0.1);

// Sample input: batch of 2 samples, each with 4 features
let input = arr2(&[
[1.0, 0.5], // feature 1
[0.8, -0.2], // feature 2
[1.2, 0.9], // feature 3
[-0.1, 0.3] // feature 4
]); // Shape: (4, 2)

// Target classes (one-hot encoded)
let targets = arr2(&[
[1.0, 0.0], // class 1 for sample 1, class 2 for sample 2
[0.0, 1.0], //
[0.0, 0.0] //
]); // Shape: (3, 2)

println!("Input shape: {:?}", input.shape());
println!("Target shape: {:?}", targets.shape());

// Training loop
for epoch in 0..10 {
// Forward pass
let output = linear.forward(&input);

// Simple loss: mean squared error
let loss = (&output - &targets).map(|x| x * x).sum() / (output.len() as f64);

// Backward pass
let grad_output = 2.0 * (&output - &targets) / (output.len() as f64);
let (gradients, _input_grad) = linear.backward(&grad_output);

// Update parameters
linear.update_parameters(&gradients, &mut optimizer, "classifier");

if epoch % 2 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss);
}
}

// Final prediction
let final_output = linear.forward(&input);
println!("Final output:\n{:.3}", final_output);
println!("Target:\n{:.3}", targets);
println!();
}

/// Example 2: LSTM + LinearLayer for sequence classification
fn lstm_with_linear_example() {
println!("=== LSTM + LinearLayer Example ===");

// Create LSTM network: 5 input features -> 8 hidden units -> 3 classes
let mut lstm = LSTMNetwork::new(5, 8, 1);
let mut classifier = LinearLayer::new(8, 3);
let mut optimizer = Adam::new(0.001);

// Sample sequence data: 4 time steps, 5 features, batch size 1
let sequence = vec![
arr2(&[[1.0], [0.5], [0.2], [0.8], [0.1]]), // t=0
arr2(&[[0.9], [0.6], [0.3], [0.7], [0.2]]), // t=1
arr2(&[[0.8], [0.7], [0.4], [0.6], [0.3]]), // t=2
arr2(&[[0.7], [0.8], [0.5], [0.5], [0.4]]), // t=3
];

// Target: classify the entire sequence (shape: 3 classes, 1 sample)
let target = arr2(&[[0.0], [1.0], [0.0]]); // Class 2

println!("Sequence length: {}", sequence.len());
println!("Input features: {}", sequence[0].nrows());
println!("LSTM hidden size: {}", 8);
println!("Output classes: {}", target.nrows());

// Training loop
for epoch in 0..20 {
// LSTM forward pass
let (lstm_outputs, _) = lstm.forward_sequence_with_cache(&sequence);

// Use the last LSTM output for classification
let last_hidden = &lstm_outputs.last().unwrap().0;

// Linear layer forward pass
let class_logits = classifier.forward(last_hidden);

// Loss calculation
let loss = (&class_logits - &target).map(|x| x * x).sum() / (class_logits.len() as f64);

// Backward pass through linear layer
let grad_output = 2.0 * (&class_logits - &target) / (class_logits.len() as f64);
let (linear_grads, _lstm_grad) = classifier.backward(&grad_output);

// Update linear layer
classifier.update_parameters(&linear_grads, &mut optimizer, "classifier");

// Note: In a complete implementation, you would also backpropagate through LSTM
// This example focuses on demonstrating LinearLayer usage

if epoch % 5 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss);
}
}

// Final prediction
let (final_lstm_outputs, _) = lstm.forward_sequence_with_cache(&sequence);
let final_hidden = &final_lstm_outputs.last().unwrap().0;
let final_prediction = classifier.forward(final_hidden);

println!("Final prediction: [{:.3}, {:.3}, {:.3}]",
final_prediction[[0, 0]], final_prediction[[1, 0]], final_prediction[[2, 0]]);
println!("Target: [{:.3}, {:.3}, {:.3}]",
target[[0, 0]], target[[1, 0]], target[[2, 0]]);
println!();
}

/// Example 3: Multi-layer perceptron using multiple LinearLayers
fn multilayer_perceptron_example() {
println!("=== Multi-Layer Perceptron Example ===");

// Create a 3-layer MLP: 2 -> 4 -> 4 -> 1
let mut layer1 = LinearLayer::new(2, 4);
let mut layer2 = LinearLayer::new(4, 4);
let mut layer3 = LinearLayer::new(4, 1);
let mut optimizer = Adam::new(0.01);

// XOR problem dataset
let inputs = arr2(&[
[0.0, 1.0, 0.0, 1.0], // input 1
[0.0, 0.0, 1.0, 1.0] // input 2
]); // Shape: (2, 4)

let targets = arr2(&[[0.0, 1.0, 1.0, 0.0]]); // XOR outputs

println!("Training MLP on XOR problem...");
println!("Input shape: {:?}", inputs.shape());
println!("Target shape: {:?}", targets.shape());

// Training loop
for epoch in 0..100 {
// Forward pass
let h1 = layer1.forward(&inputs);
let h1_relu = h1.map(|&x| if x > 0.0 { x } else { 0.0 }); // ReLU activation

let h2 = layer2.forward(&h1_relu);
let h2_relu = h2.map(|&x| if x > 0.0 { x } else { 0.0 }); // ReLU activation

let output = layer3.forward(&h2_relu);

// Loss calculation
let loss = (&output - &targets).map(|x| x * x).sum() / (output.len() as f64);

// Backward pass
let grad_output = 2.0 * (&output - &targets) / (output.len() as f64);

// Layer 3 backward
let (grad3, grad_h2) = layer3.backward(&grad_output);

// ReLU backward for h2
let grad_h2_relu = &grad_h2 * &h2.map(|&x| if x > 0.0 { 1.0 } else { 0.0 });

// Layer 2 backward
let (grad2, grad_h1) = layer2.backward(&grad_h2_relu);

// ReLU backward for h1
let grad_h1_relu = &grad_h1 * &h1.map(|&x| if x > 0.0 { 1.0 } else { 0.0 });

// Layer 1 backward
let (grad1, _) = layer1.backward(&grad_h1_relu);

// Update all layers
layer1.update_parameters(&grad1, &mut optimizer, "layer1");
layer2.update_parameters(&grad2, &mut optimizer, "layer2");
layer3.update_parameters(&grad3, &mut optimizer, "layer3");

if epoch % 20 == 0 {
println!("Epoch {}: Loss = {:.4}", epoch, loss);
}
}

// Final predictions
let h1 = layer1.forward(&inputs);
let h1_relu = h1.map(|&x| if x > 0.0 { x } else { 0.0 });
let h2 = layer2.forward(&h1_relu);
let h2_relu = h2.map(|&x| if x > 0.0 { x } else { 0.0 });
let final_output = layer3.forward(&h2_relu);

println!("Final predictions:");
for i in 0..4 {
let input_vals = (inputs[[0, i]], inputs[[1, i]]);
let prediction = final_output[[0, i]];
let target_val = targets[[0, i]];
println!(" {:?} -> {:.3} (target: {:.1})", input_vals, prediction, target_val);
}
println!();
}

/// Example 4: Demonstrating different initialization methods
fn initialization_example() {
println!("=== Initialization Methods Example ===");

// Method 1: Default random initialization (Xavier/Glorot)
let layer_random = LinearLayer::new(3, 2);
println!("Random initialization:");
println!(" Weight range: [{:.3}, {:.3}]",
layer_random.weight.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(),
layer_random.weight.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap());

// Method 2: Zero initialization
let layer_zeros = LinearLayer::new_zeros(3, 2);
println!("Zero initialization:");
println!(" All weights: {}", layer_zeros.weight.iter().all(|&x| x == 0.0));

// Method 3: Custom initialization
let custom_weights = arr2(&[[1.0, 0.5, -0.2], [0.8, -0.1, 0.3]]);
let custom_bias = arr2(&[[0.1], [-0.05]]);
let layer_custom = LinearLayer::from_weights(custom_weights.clone(), custom_bias.clone());
println!("Custom initialization:");
println!(" Custom weights shape: {:?}", layer_custom.weight.shape());
println!(" Custom bias shape: {:?}", layer_custom.bias.shape());

// Show layer information
println!("Layer dimensions: {:?}", layer_custom.dimensions());
println!("Number of parameters: {}", layer_custom.num_parameters());
println!();
}

fn main() {
println!("LinearLayer Examples");
println!("===================\n");

basic_classification_example();
lstm_with_linear_example();
multilayer_perceptron_example();
initialization_example();

println!("All examples completed successfully! 🎉");
println!("\nKey takeaways:");
println!("- LinearLayer enables standard neural network architectures");
println!("- Works seamlessly with LSTM networks for classification");
println!("- Supports multiple initialization methods");
println!("- Integrates with all existing optimizers");
println!("- Essential for text generation and classification tasks");
}
Loading