SSL Depth is a deep learning framework for self-supervised depth estimation, focusing on robust, type-safe implementation with PyTorch.
- Self-supervised depth estimation from monocular or stereo videos
- Modular, extendable codebase with registry pattern
- Type-safe interfaces using dataclasses
- Multi-frame depth estimation support
- PyTorch Lightning integration for training
- Comprehensive documentation and examples
git clone https://github.com/yourusername/ssl_depth.git
cd ssl_depth
pip install -e .import torch
from ssl_depth.models import DepthModelWrapper
from ssl_depth.utils.data_types import DepthBatch
# Load a pretrained model
model = DepthModelWrapper.load_from_checkpoint("path/to/checkpoint.pth")
model.eval()
# Create a batch of images (type-safe interface with dataclasses)
batch = DepthBatch(
rgb=torch.randn(1, 3, 3, 256, 256), # [B, N, C, H, W]
intrinsics=torch.eye(3).unsqueeze(0) # [B, 3, 3]
)
# Get depth predictions
predictions = model(batch)
depth_map = predictions.depth # [B, 1, H, W]- Prepare your configuration:
# configs/custom_video.yaml
model:
name: SimpleDepthModel
backbone: resnet18
pretrained: true
freeze_backbone: false
num_frames: 3
train_dataset:
name: MultiFrameDataset
root_dir: /path/to/dataset
split: train
frame_idxs: [-1, 0, 1]
img_height: 192
img_width: 640
# More configuration options...- Run training:
python scripts/train.py --config configs/custom_video.yaml --output_dir outputs/custom_videopython scripts/evaluate.py --config configs/custom_video.yaml --checkpoint outputs/custom_video/model_final.pthssl_depth/
├── configs/ # Configuration files
├── scripts/ # Training and evaluation scripts
├── ssl_depth/ # Main package
│ ├── datasets/ # Dataset implementations
│ ├── losses/ # Loss functions
│ ├── models/ # Model architectures
│ ├── trainers/ # Training logic
│ ├── utils/ # Utility functions
│ │ ├── data_types.py # Dataclass definitions
│ │ └── evaluation.py # Evaluation utilities
│ └── registry.py # Component registry
├── tests/ # Unit tests
└── docs/ # Documentation
The framework uses Python dataclasses to provide a structured, type-safe interface:
# Input batch structure
batch = DepthBatch(
rgb=torch.tensor(...), # RGB images [B, N, 3, H, W]
intrinsics=torch.tensor(...), # Camera intrinsics [B, 3, 3]
depth_gt=torch.tensor(...) # Optional ground truth [B, 1, H, W]
)
# Model prediction structure
predictions = DepthPredictions(
depth=torch.tensor(...), # Predicted depth [B, 1, H, W]
uncertainty=torch.tensor(...) # Optional uncertainty [B, 1, H, W]
)See the documentation for more detailed guides and API references.
This project is licensed under the MIT License - see the LICENSE file for details.
This project incorporates ideas and code from several open-source projects, including: