Skip to content

Barebones Vision Transformer (ViT) implementation from scratch

Notifications You must be signed in to change notification settings

Kamugg/vit-from-scratch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformer (ViT) Implementation on MNIST

This repository contains a PyTorch implementation of a Vision Transformer (ViT) trained on the MNIST dataset. The project demonstrates a modular implementation of the Transformer architecture from scratch, including custom components for multi-head self-attention, patch encoding, and transformer blocks.

Project Overview

The goal of this project is to implement the Vision Transformer architecture for image classification without relying on pre-built Transformer layers from the standard library. The model treats images as sequences of patches, processing them through self-attention mechanisms to capture global dependencies.

Key implementation details include:

  • Patch Embedding: Images are divided into fixed-size patches and projected into linear embeddings using a convolutional layer.
  • Positional Embeddings: Learnable parameters are added to patch embeddings to retain positional information.
  • Multi-Head Self-Attention (MHSA): A custom implementation of the attention mechanism, calculating Query, Key, and Value projections and attention scores manually.
  • Transformer Blocks: A pre-norm architecture combining Layer Normalization, MHSA, and Multi-Layer Perceptrons (MLP) with residual connections.

Project Structure

The codebase is organized into a training script and a modular package for the model architecture:

  • train_mnist.py: The main entry point. Handles argument parsing, data loading (MNIST), the training loop, validation metrics, and checkpoint saving.
  • modules/:
    • vision_transformer.py: Assembles the full model, including the class token (CLS) and classification head.
    • transformer_block.py: Defines a single transformer layer with skip connections.
    • attention_module.py: Implements the Multi-Head Self-Attention logic.
    • image_encoder.py: Handles the conversion of input images into flattened patch embeddings.

Dependencies

The project requires Python and the libraries listed in requirements.txt, including PyTorch, TorchVision, TorchMetrics, and TQDM.

To install the dependencies:

pip install -r requirements.txt

Usage

The training script train_mnist.py accepts various command-line arguments to configure hyperparameters and system settings. The only required argument is --out, which specifies the directory for saving the model checkpoint.

Basic Execution

To train the model with default hyperparameters and save the best checkpoint to a local folder:

python train_mnist.py --out ./checkpoints

Enable CUDA

To utilize GPU acceleration:

python train_mnist.py --out ./checkpoints --cuda

Verbose Logging

To print detailed validation metrics (Accuracy, Precision, Recall, F1) at the end of every epoch:

python train_mnist.py --out ./checkpoints --verbose

Configuration

The script supports the following arguments:

General Configuration

  • --out: (Required) Path where the model will be saved.
  • --seed: Random seed for reproducibility (default: 42).
  • --cuda: Boolean flag to enable CUDA.
  • --verbose: Boolean flag to print validation results per epoch.

Training Hyperparameters

  • --epochs: Number of training epochs (default: 20).
  • --bsize: Batch size (default: 128).
  • --lr: Learning rate (default: 0.001).

Model Architecture

  • --patch_size: Size of the image patches (default: 7).
  • --emb_dim: Internal embedding dimension (default: 64).
  • --num_heads: Number of attention heads (default: 2).
  • --num_layers: Number of transformer layers (default: 2).
  • --mlp_ratio: Expansion ratio for the internal MLP (default: 4).

Regularization (Dropout)

  • --attn_drop: Dropout rate for attention weights (default: 0.1).
  • --mlp_drop: Dropout rate for the MLP layers (default: 0.1).
  • --out_drop: Dropout rate for the projection layers (default: 0.1).

Metrics

The training loop tracks the CrossEntropyLoss. The validation loop utilizes torchmetrics to calculate and report the following metrics:

  • Average Loss
  • Accuracy
  • Precision (Macro average)
  • Recall (Macro average)
  • F1 Score

About

Barebones Vision Transformer (ViT) implementation from scratch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages