This directory contains the source code for the project.
core
: Package containing core library components such asMonoDepthModule
,MonoDepthEvaluator
orMonoDepthPredictor
. Can depend on any custom package. Only API scripts should depend on them.datasets
: Package containing PyTorch datasets.devkits
: Package containing basic loading tools for datasets.external_libs
: Package containing external libraries from other developers.losses
: Package containing training losses.networks
: Package containing network architectures (including contribution decoders).regularizers
: Package containing regularizer losses.tools
: Package containing more advanced utilities. They should depend only on each other andutils
.utils
: Package containing basic utilities. They should not have any custom dependencies!__init__.py
:src
package init.paths.py
: File containing path management tools.registry.py
: File containing the tools for registering models & datasets for training.typing.py
: File containing custom type hints.
Please take into account the notes regarding dependencies when deciding where to incorporate custom code.
Contains the core library components required to train and evaluate a Monocular depth estimation network.
aspect_ratio
: Code to generate the proposed aspect ratio augmentation.evaluator
: Tools for evaluating pre-computed predictions.handlers
: Handlers that wrap multi-scale loss computation during training.heavy_logger
: PyTorch Lightning callback for logging images after each epoch.metrics
: Functions for computing the various sets of evaluation metrics, such aseigen
,benchmark
,pointcloud
andibims
.predictors
: Tools computing dataset predictions using custom or external models.trainer
: Main PyTorch Lightning module for training,MonoDepthModule
.
The MonoDepthModule
used for training is implemented using PyTorch Lightning, which wraps the optimization procedure and provides hooks to various steps.
See their docs for background info about how the code is organized and what hooks are available.
Overall, the module forward pass is split into:
forward
: Computes the network predictions.forward_postprocess
: Prepares the predictions for loss computation. E.g. upsampling to common resolution & converting to depth.forward_loss
: Computes the optimization loss and produces auxiliary outputs for logging.compute_metrics
: Computes the metrics for logging and validation performance tracking.log_dict
: Logs scalars everyn
steps.image_logger
: Logs images at the end of each epoch.
To add a new network/loss to the training procedure:
- Implement it in the respective module.
- Add it to the
registry
. - Add a new
if
block to the corresponding forward step based on theregistry
key. - If adding a loss, add the corresponding wrapper to
handlers
. - Add auxiliary inputs to
fwd
orloss_dict
for logging. - Add logging to
image_logger
based on the auxiliary inputs.
Configs consist of: networks, losses, datasets, loaders, optimizers, schedulers and trainer. For an example covering most of the avilable options see this file.
- Networks and losses use dictionaries, where the keys correspond to the
registry
keys. Remaining parameters arekwargs
to the respective class. - Losses must add an additional parameter
weight
, which controls the scaling factor in the total loss. - Datasets, optimizers and schedulers add a
type
argument corresponding to theregistry
keys. - Datasets/loaders allow for different configs based on the
train
&val
mode, overriding the original parameters.
# -----------------------------------------------------------------------------
net:
# Depth estimation network.
depth:
enc_name: 'convnext_base' # Choose from `timm` encoders.
pretrained: True
dec_name: 'monodepth' # Choose from custom decoders.
# Pose estimation network (for use with purely monocular models).
pose:
enc_name: 'resnet18' # Typically ResNet18 for efficiency.
pretrained: True
# -----------------------------------------------------------------------------
loss:
# Image-based reconstruction loss.
img_recon:
weight: 1
loss_name: 'ssim'
# -----------------------------------------------------------------------------
dataset:
kitti_lmdb:
split: 'eigen_benchmark'
datum: 'image support depth K'
shape: [ 192, 640 ]
supp_idxs: [ -1, 1 ]
max_len: 10000
randomize: True
train: { mode: 'train', use_aug: True }
val: { mode: 'val', use_aug: False }
mannequin_lmdb:
datum: 'image support K'
shape: [ 384, 640 ]
supp_idxs: [ -1, 1 ]
max_len: 10000
randomize: True
train: { mode: 'train', use_aug: True }
val: { mode: 'val', use_aug: False }
# -----------------------------------------------------------------------------
loader:
batch_size: 8
num_workers: 8
drop_last: True
train: { shuffle: True }
val: { shuffle: False }
# -----------------------------------------------------------------------------
optimizer:
type: 'adam' # Choose from any optimizer available from `timm`.
lr: 0.0001
# -----------------------------------------------------------------------------
scheduler:
steplr:
step_size: 15
gamma: 0.1
# -----------------------------------------------------------------------------
trainer:
max_epochs: 30
resume_training: True # Will begin training from scratch if no checkpoints are found. Otherwise resume.
monitor: 'AbsRel' # Monitor metric to save `best` checkpoint.
min_depth: 0.1 # Min depth to scale sigmoid disparity.
max_depth: 100 # Max depth to scale sigmoid disparity.
benchmark: True # Pytorch cudnn benchmark.
# -----------------------------------------------------------------------------
Contains PyTorch datasets required for training and/or evaluating.
base
:BaseDataset
that all other datasets should inherit from, provides utilities for logging, loading and visualizing.base_mde
:MdeBaseDataset
that provides a few extra utilities for depth estimation datasets.
All datasets should inherit from BaseDataset
or MdeBaseDataset
and implement/override the following methods.
class MyDataset(MdeBaseDataset):
# Data types that can be loaded by the dataset.
# Can be provided as either a list or a single string.
# Each datum must have a corresponding `load_{datum}` and `_load_{datum}' method.
VALID_DATUM = 'image support depth K'
# Determines the full-resolution size of the images.
# SIZE, H and W are determined automatically based on this.
SHAPE = 376, 1242
def log_args(self):
"""(OVERRIDE) Log additional input arguments. Should call `super().log_args()` at the end."""
def validate_args(self):
"""(OVERRIDE) Sanity check for input arguments. Should call `super().validate_args()` at the end."""
def parse_items(self) -> tuple[Path, Sequence[Item]]:
"""(REQUIRED) Get file containing split items and the list of item data."""
def add_metadata(self, data: Item, batch: BatchData) -> BatchData:
"""(OVERRIDE) Add item info to the batch metadata."""
def get_supp_scale(self, data: Item) -> float:
"""(OVERRIDE) Return a random scaling factor for loading the support image."""
def augment(self, x: dict, y: dict, m: dict) -> BatchData:
"""(OVERRIDE) Augment a loaded item. Default is a no-op."""
return x, y, m
def transform(self, x: dict, y: dict, m: dict) -> BatchData:
"""(OVERRIDE) Transform a loaded item. Default is a no-op."""
return x, y, m
def to_torch(self, x: dict, y: dict, m: dict) -> BatchData:
"""Convert (x, y, m) to torch Tensors. Default converts to torch and permutes >=3D tensors."""
@classmethod
def collate_fn(cls, batch: Sequence[BatchData]):
"""(OVERRIDE) Function to collate multiple dataset items. By default uses the PyTorch collator."""
def load_datum(self, data: Item) -> NDArray:
"""(OVERRIDE) Load a single datum from the item data and places it in the corresponding batch dict.
Should call the corresponding `_load_datum`. Implement for each datum.
"""
def _load_datum(self, data: Item) -> NDArray:
"""(REQUIRED) Load a single datum from the item data. Implement for each datum."""
def create_axs(self) -> Axes:
"""(OVERRIDE) Create axes for visualization."""
def show(self, batch: BatchData, axs: Axes) -> None:
"""(OVERRIDE) Visualize a single dataset item."""
Datasets must return batches as three dictionaries:
x
: Contains data required for the network forward pass. E.g. images, indexes of support frames.y
: Contains auxiliary data required for loss/metric computation. E.g. depth, edges, non-augmented images.m
: Contains metadata about the loaded batch. E.g. loaded indexes, augmentations applied or errors while loading.
Utilities provided by BaseDataset
:
logger
: Used for logging instance arguments.max_len
: Sets the max number of items to load.randomize
: Randomize the order of items to load. (Useful when combined withmax_len
)play
: Display each item in the dataset. (Useful for debugging)timer
:MultiLevelTimer
that logs the time required to load each datum and each processing step.
Utilities provided by MdeBaseDataset
:
- Augmentations: Photometric, horizontal flipping...
collate
: Transposes the support frames & keeps one set of support indexes.load_datum
: Provides default implementation for common datums. Each dataset still needs to implement the specific_load_datum
.show
: Provides default displaying for common setups.
Contains low-level tools for loading and interacting with the available datasets. It should be self-evident which dataset each devkit corresponds to.
Contains libraries from other developers.
- Databases: Tools for creating LMDB datasets.
- DGP: Tools for loading DDAD dataset.
- MiDaS: Pre-trained supervised scaleless depth estimation model.
- NeWCRFs: Pre-trained supervised metric depth estimation model.
The main available losses are:
ReconstructionLoss
: Base view synthesis loss. Additionally used for feature-based view synthesis and autoencoder image reconstruction.RegressionLoss
: Proxy depth regression loss. Additionally used for virtual stereo consistency.
NOTE: Each of these incorporates multiple different contributions based on the available input configuration. Check out the respective documentation for additional details.
New losses should be added as per the instructions in the registry. Losses must return a tuple consisting of
"""
:return (tuple) (
loss: (Tensor) (,) Scalar loss value.
loss_dict: (TensorDict) Dictionary containing intermediate loss outputs used for TensorBoard logging.
)
"""
The main available networks are:
depth
: Predicts a dense disparity map from a single image.pose
: Predicts the relative pose between two images in axis-angle format.autoencoder
: Converts the input image into a compact feature representation, which can be used to reconstruct the image. Used primarily to learn a feature representation complementary to the image-based reconstruction loss.
These networks use any of the pretrained encoders available in timm
.
New networks and decoders should be added as per the instructions in the registry.
Networks producing dense outputs (depth & autoencoder) additionally require a dense decoder:
cadepth
: Adds self-attention and channel-wise skip connections. From CA-Depth.ddvnet
. Predicts depth as a discrete disparity volume. From Johnston.diffnet
. Adds self-attention and channel-wise attention skip-connections. From DiffNet.hrdepth
. Adds progressive skip connections & SqueezeExcitation. From HRDepth.monodepth
. Default Conv+ELU+BilinearUpsample. From Monodepth.superdepth
. Conv+ELU+PixelShuffle. From SuperDepth.
Currently, all decoders are required to have roughly the same argument structure.
This could probably be improved by using additional **kwargs
in the main network initializers.
"""
:param num_ch_enc: (Sequence[int]) List of channels per encoder stage.
:param enc_sc: (Sequence[int]) List of downsampling factor per encoder stage.
:param upsample_mode: (str) Torch upsampling mode. {'nearest', 'bilinear'...}
:param use_skip: (bool) If `True`, add skip connections from corresponding encoder stage.
:param out_sc: (Sequence[int]) List of multi-scale output downsampling factor as 2**s.
:param out_ch: (int) Number of output channels.
:param out_act: (str) Activation to apply to each output stage.
"""
Regularizers are meant to prevent suboptimal or degenerate representations, rather than driving the optimization. The main available regularizers are:
MaskReg
: Explainability mask regularization. From SfM-Learner.OccReg
: Disparity occlusion regularization. From DVSO.SmoothReg
: Disparity smoothness regularization. From multiple contributions.FeatPeakReg
: First-order feature peakiness regularization. From FeatDepth.FeatSmoothReg
: Second-order feature smoothness regularization. From FeatDepth.
New regularizers should be added as per the instructions in the registry. They must also follow the output format required by the losses.
A collection of more advanced utilities only depend on each other or on utils
.
geometry
: Depth scaling/conversion and view synthesis tools, such asextract_edges
,to_scaled
,to_inv
,T_from_AAt
,ViewSynth
...ops
: Collection of PyTorch operations, such asto_torch
,to_numpy
,allow_np
,interpolate_like
,expand_dim
...parsers
: Tools for instantiating classes from config dicts.table_formatter
:TableFormatter
to convert dataframes into LaTeX/MarkDown tables.viz
: Visualizations toolsrgb_from_disp
&rgb_from_feat
.
A collection of basic utilities that do not depend on any other custom code from this library.
callbacks
: Custom PyTorch Lighning callbacks, incliding progress bars and anomaly detection.collate
:default_collate
from PyTorch, modified to acceptMultiLevelTimer
.deco
: Custom decorators, includingopt_args_deco
,delegates
, map_container&
retry_new_on_error`.io
: YAML loading/writing tools and image conversion.loader
:ConcatDataLoader
that implements a round-robin loading for multiple datasets.metrics
: PyTorch Lightning metrics for use during training.misc
: Collection of random utilities,flatten_dict
,sort_dict
,get_logger
&apply_cmap
.timers
:MultiLevelTimer
to allow for nested timing blocks.
Path management for datasets and storing/loading checkpoints is done based on predefined locations in DATA_ROOTS
& MODEL_ROOTS
.
This alleviates the need to provide long and repeated paths, remaining flexible to datasets being stored in different locations (e.g. local scratch spaces).
Instructions for setting up custom roots can be found in the main README.
This file additionally provides some utilities for finding dataset & model paths within the available roots: find_data_dir
& find_model_file
.
These functions will return the input path if it is an absolute path to an existing file/directory.
Otherwise, they will search the available roots and return the first existing path.
New network, losses or datasets should be added to the registry via the register
decorator.
This makes these classes accessible to the parsers
, and in turn to the config files.
import torch.nn as nn
from src import register
@register(name='awesome', type='loss')
class MyAwesomeLoss(nn.Module):
def forward(self, pred, target):
err = (pred - target).abs().mean(dim=1, keepdim=True)
loss = err.mean()
return loss, {'l1_error': err}
type
selects the relevant registry, but can typically be omitted and guessed from the class name.name
represents the identifier used in the configs and module forward pass. Multiple aliases can be registered by providing atuple
, useful when losses share the same underlying computations but require different inputs or preprocessing inMonoDepthModule
. An example is the baseReconstructionLoss
, which can be used with either images or dense feature maps.