Skip to content

fardinayar/ssl_depth

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SSL Depth

SSL Depth is a deep learning framework for self-supervised depth estimation, focusing on robust, type-safe implementation with PyTorch.

Features

  • Self-supervised depth estimation from monocular or stereo videos
  • Modular, extendable codebase with registry pattern
  • Type-safe interfaces using dataclasses
  • Multi-frame depth estimation support
  • PyTorch Lightning integration for training
  • Comprehensive documentation and examples

Installation

git clone https://github.com/yourusername/ssl_depth.git
cd ssl_depth
pip install -e .

Quick Start

import torch
from ssl_depth.models import DepthModelWrapper
from ssl_depth.utils.data_types import DepthBatch

# Load a pretrained model
model = DepthModelWrapper.load_from_checkpoint("path/to/checkpoint.pth")
model.eval()

# Create a batch of images (type-safe interface with dataclasses)
batch = DepthBatch(
    rgb=torch.randn(1, 3, 3, 256, 256),  # [B, N, C, H, W]
    intrinsics=torch.eye(3).unsqueeze(0)  # [B, 3, 3]
)

# Get depth predictions
predictions = model(batch)
depth_map = predictions.depth  # [B, 1, H, W]

Training

  1. Prepare your configuration:
# configs/custom_video.yaml
model:
  name: SimpleDepthModel
  backbone: resnet18
  pretrained: true
  freeze_backbone: false
  num_frames: 3

train_dataset:
  name: MultiFrameDataset
  root_dir: /path/to/dataset
  split: train
  frame_idxs: [-1, 0, 1]
  img_height: 192
  img_width: 640

# More configuration options...
  1. Run training:
python scripts/train.py --config configs/custom_video.yaml --output_dir outputs/custom_video

Evaluation

python scripts/evaluate.py --config configs/custom_video.yaml --checkpoint outputs/custom_video/model_final.pth

Project Structure

ssl_depth/
├── configs/            # Configuration files
├── scripts/            # Training and evaluation scripts
├── ssl_depth/          # Main package
│   ├── datasets/       # Dataset implementations
│   ├── losses/         # Loss functions
│   ├── models/         # Model architectures
│   ├── trainers/       # Training logic
│   ├── utils/          # Utility functions
│   │   ├── data_types.py     # Dataclass definitions
│   │   └── evaluation.py     # Evaluation utilities
│   └── registry.py     # Component registry
├── tests/              # Unit tests
└── docs/               # Documentation

Type-Safe Interface

The framework uses Python dataclasses to provide a structured, type-safe interface:

# Input batch structure
batch = DepthBatch(
    rgb=torch.tensor(...),        # RGB images [B, N, 3, H, W]
    intrinsics=torch.tensor(...), # Camera intrinsics [B, 3, 3]
    depth_gt=torch.tensor(...)    # Optional ground truth [B, 1, H, W]
)

# Model prediction structure
predictions = DepthPredictions(
    depth=torch.tensor(...),     # Predicted depth [B, 1, H, W]
    uncertainty=torch.tensor(...) # Optional uncertainty [B, 1, H, W]
)

Documentation

See the documentation for more detailed guides and API references.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgements

This project incorporates ideas and code from several open-source projects, including:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages