This repository contains a PyTorch implementation of a Vision Transformer (ViT) trained on the MNIST dataset. The project demonstrates a modular implementation of the Transformer architecture from scratch, including custom components for multi-head self-attention, patch encoding, and transformer blocks.
The goal of this project is to implement the Vision Transformer architecture for image classification without relying on pre-built Transformer layers from the standard library. The model treats images as sequences of patches, processing them through self-attention mechanisms to capture global dependencies.
Key implementation details include:
- Patch Embedding: Images are divided into fixed-size patches and projected into linear embeddings using a convolutional layer.
- Positional Embeddings: Learnable parameters are added to patch embeddings to retain positional information.
- Multi-Head Self-Attention (MHSA): A custom implementation of the attention mechanism, calculating Query, Key, and Value projections and attention scores manually.
- Transformer Blocks: A pre-norm architecture combining Layer Normalization, MHSA, and Multi-Layer Perceptrons (MLP) with residual connections.
The codebase is organized into a training script and a modular package for the model architecture:
train_mnist.py: The main entry point. Handles argument parsing, data loading (MNIST), the training loop, validation metrics, and checkpoint saving.modules/:vision_transformer.py: Assembles the full model, including the class token (CLS) and classification head.transformer_block.py: Defines a single transformer layer with skip connections.attention_module.py: Implements the Multi-Head Self-Attention logic.image_encoder.py: Handles the conversion of input images into flattened patch embeddings.
The project requires Python and the libraries listed in requirements.txt, including PyTorch, TorchVision, TorchMetrics, and TQDM.
To install the dependencies:
pip install -r requirements.txtThe training script train_mnist.py accepts various command-line arguments to configure hyperparameters and system settings. The only required argument is --out, which specifies the directory for saving the model checkpoint.
To train the model with default hyperparameters and save the best checkpoint to a local folder:
python train_mnist.py --out ./checkpointsTo utilize GPU acceleration:
python train_mnist.py --out ./checkpoints --cudaTo print detailed validation metrics (Accuracy, Precision, Recall, F1) at the end of every epoch:
python train_mnist.py --out ./checkpoints --verboseThe script supports the following arguments:
General Configuration
--out: (Required) Path where the model will be saved.--seed: Random seed for reproducibility (default: 42).--cuda: Boolean flag to enable CUDA.--verbose: Boolean flag to print validation results per epoch.
Training Hyperparameters
--epochs: Number of training epochs (default: 20).--bsize: Batch size (default: 128).--lr: Learning rate (default: 0.001).
Model Architecture
--patch_size: Size of the image patches (default: 7).--emb_dim: Internal embedding dimension (default: 64).--num_heads: Number of attention heads (default: 2).--num_layers: Number of transformer layers (default: 2).--mlp_ratio: Expansion ratio for the internal MLP (default: 4).
Regularization (Dropout)
--attn_drop: Dropout rate for attention weights (default: 0.1).--mlp_drop: Dropout rate for the MLP layers (default: 0.1).--out_drop: Dropout rate for the projection layers (default: 0.1).
The training loop tracks the CrossEntropyLoss. The validation loop utilizes torchmetrics to calculate and report the following metrics:
- Average Loss
- Accuracy
- Precision (Macro average)
- Recall (Macro average)
- F1 Score