Skip to content

Latest commit

 

History

History
372 lines (292 loc) · 18.5 KB

README.md

File metadata and controls

372 lines (292 loc) · 18.5 KB

Source Code

This directory contains the source code for the project.


Structure

  • core: Package containing core library components such as MonoDepthModule, MonoDepthEvaluator or MonoDepthPredictor. Can depend on any custom package. Only API scripts should depend on them.
  • datasets: Package containing PyTorch datasets.
  • devkits: Package containing basic loading tools for datasets.
  • external_libs: Package containing external libraries from other developers.
  • losses: Package containing training losses.
  • networks: Package containing network architectures (including contribution decoders).
  • regularizers: Package containing regularizer losses.
  • tools: Package containing more advanced utilities. They should depend only on each other and utils.
  • utils: Package containing basic utilities. They should not have any custom dependencies!
  • __init__.py: src package init.
  • paths.py: File containing path management tools.
  • registry.py: File containing the tools for registering models & datasets for training.
  • typing.py: File containing custom type hints.

Please take into account the notes regarding dependencies when deciding where to incorporate custom code.


Core

Contains the core library components required to train and evaluate a Monocular depth estimation network.

  • aspect_ratio: Code to generate the proposed aspect ratio augmentation.
  • evaluator: Tools for evaluating pre-computed predictions.
  • handlers: Handlers that wrap multi-scale loss computation during training.
  • heavy_logger: PyTorch Lightning callback for logging images after each epoch.
  • metrics: Functions for computing the various sets of evaluation metrics, such as eigen, benchmark, pointcloud and ibims.
  • predictors: Tools computing dataset predictions using custom or external models.
  • trainer: Main PyTorch Lightning module for training, MonoDepthModule.

MonoDepthModule Structure

The MonoDepthModule used for training is implemented using PyTorch Lightning, which wraps the optimization procedure and provides hooks to various steps. See their docs for background info about how the code is organized and what hooks are available. Overall, the module forward pass is split into:

  1. forward: Computes the network predictions.
  2. forward_postprocess: Prepares the predictions for loss computation. E.g. upsampling to common resolution & converting to depth.
  3. forward_loss: Computes the optimization loss and produces auxiliary outputs for logging.
  4. compute_metrics: Computes the metrics for logging and validation performance tracking.
  5. log_dict: Logs scalars every n steps.
  6. image_logger: Logs images at the end of each epoch.

To add a new network/loss to the training procedure:

  1. Implement it in the respective module.
  2. Add it to the registry.
  3. Add a new if block to the corresponding forward step based on the registry key.
  4. If adding a loss, add the corresponding wrapper to handlers.
  5. Add auxiliary inputs to fwd or loss_dict for logging.
  6. Add logging to image_logger based on the auxiliary inputs.

Configs consist of: networks, losses, datasets, loaders, optimizers, schedulers and trainer. For an example covering most of the avilable options see this file.

  • Networks and losses use dictionaries, where the keys correspond to the registry keys. Remaining parameters are kwargs to the respective class.
  • Losses must add an additional parameter weight, which controls the scaling factor in the total loss.
  • Datasets, optimizers and schedulers add a type argument corresponding to the registry keys.
  • Datasets/loaders allow for different configs based on the train & val mode, overriding the original parameters.
# -----------------------------------------------------------------------------
net:
  # Depth estimation network.
  depth:
    enc_name: 'convnext_base'  # Choose from `timm` encoders.
    pretrained: True
    dec_name: 'monodepth'  # Choose from custom decoders.
    
  # Pose estimation network (for use with purely monocular models).
  pose:
    enc_name: 'resnet18'  # Typically ResNet18 for efficiency.
    pretrained: True
# -----------------------------------------------------------------------------
loss:
  # Image-based reconstruction loss.
  img_recon:
    weight: 1
    loss_name: 'ssim'
# -----------------------------------------------------------------------------
dataset:
  kitti_lmdb:
    split: 'eigen_benchmark'
    datum: 'image support depth K'
    shape: [ 192, 640 ]
    supp_idxs: [ -1, 1 ]
    max_len: 10000
    randomize: True
  
    train: { mode: 'train', use_aug: True }
    val: { mode: 'val', use_aug: False }
    
  mannequin_lmdb:
    datum: 'image support K'
    shape: [ 384, 640 ]
    supp_idxs: [ -1, 1 ]
    max_len: 10000
    randomize: True
    
    train: { mode: 'train', use_aug: True }
    val: { mode: 'val', use_aug: False }
# -----------------------------------------------------------------------------
loader:
  batch_size: 8
  num_workers: 8
  drop_last: True

  train: { shuffle: True }
  val: { shuffle: False }
# -----------------------------------------------------------------------------
optimizer:
  type: 'adam'  # Choose from any optimizer available from `timm`.
  lr: 0.0001
# -----------------------------------------------------------------------------
scheduler:
  steplr: 
    step_size: 15
    gamma: 0.1
# -----------------------------------------------------------------------------
trainer:
  max_epochs: 30
  resume_training: True  # Will begin training from scratch if no checkpoints are found. Otherwise resume.
  monitor: 'AbsRel'  # Monitor metric to save `best` checkpoint.

  min_depth: 0.1  # Min depth to scale sigmoid disparity.
  max_depth: 100  # Max depth to scale sigmoid disparity.

  benchmark: True  # Pytorch cudnn benchmark.
# -----------------------------------------------------------------------------

Datasets

Contains PyTorch datasets required for training and/or evaluating.

  • base: BaseDataset that all other datasets should inherit from, provides utilities for logging, loading and visualizing.
  • base_mde: MdeBaseDataset that provides a few extra utilities for depth estimation datasets.

All datasets should inherit from BaseDataset or MdeBaseDataset and implement/override the following methods.

class MyDataset(MdeBaseDataset):
    # Data types that can be loaded by the dataset.
    # Can be provided as either a list or a single string. 
    # Each datum must have a corresponding `load_{datum}` and `_load_{datum}' method.
    VALID_DATUM = 'image support depth K'  
    
    # Determines the full-resolution size of the images.
    # SIZE, H and W are determined automatically based on this. 
    SHAPE = 376, 1242
    
    def log_args(self):
        """(OVERRIDE) Log additional input arguments. Should call `super().log_args()` at the end."""
        
    def validate_args(self):
        """(OVERRIDE) Sanity check for input arguments. Should call `super().validate_args()` at the end."""
        
    def parse_items(self) -> tuple[Path, Sequence[Item]]:
        """(REQUIRED) Get file containing split items and the list of item data."""
        
    def add_metadata(self, data: Item, batch: BatchData) -> BatchData:
        """(OVERRIDE) Add item info to the batch metadata."""
        
    def get_supp_scale(self, data: Item) -> float:
        """(OVERRIDE) Return a random scaling factor for loading the support image."""
        
    def augment(self, x: dict, y: dict, m: dict) -> BatchData:
        """(OVERRIDE) Augment a loaded item. Default is a no-op."""
        return x, y, m

    def transform(self, x: dict, y: dict, m: dict) -> BatchData:
        """(OVERRIDE) Transform a loaded item. Default is a no-op."""
        return x, y, m

    def to_torch(self, x: dict, y: dict, m: dict) -> BatchData:
        """Convert (x, y, m) to torch Tensors. Default converts to torch and permutes >=3D tensors."""

    @classmethod
    def collate_fn(cls, batch: Sequence[BatchData]):
        """(OVERRIDE) Function to collate multiple dataset items. By default uses the PyTorch collator."""

    def load_datum(self, data: Item) -> NDArray:
        """(OVERRIDE) Load a single datum from the item data and places it in the corresponding batch dict. 
        Should call the corresponding `_load_datum`. Implement for each datum.
        """
        
    def _load_datum(self, data: Item) -> NDArray:
        """(REQUIRED) Load a single datum from the item data. Implement for each datum."""
        
    def create_axs(self) -> Axes:
        """(OVERRIDE) Create axes for visualization."""
        
    def show(self, batch: BatchData, axs: Axes) -> None:
        """(OVERRIDE) Visualize a single dataset item."""

Datasets must return batches as three dictionaries:

  • x: Contains data required for the network forward pass. E.g. images, indexes of support frames.
  • y: Contains auxiliary data required for loss/metric computation. E.g. depth, edges, non-augmented images.
  • m: Contains metadata about the loaded batch. E.g. loaded indexes, augmentations applied or errors while loading.

Utilities provided by BaseDataset:

  • logger: Used for logging instance arguments.
  • max_len: Sets the max number of items to load.
  • randomize: Randomize the order of items to load. (Useful when combined with max_len)
  • play: Display each item in the dataset. (Useful for debugging)
  • timer: MultiLevelTimer that logs the time required to load each datum and each processing step.

Utilities provided by MdeBaseDataset:

  • Augmentations: Photometric, horizontal flipping...
  • collate: Transposes the support frames & keeps one set of support indexes.
  • load_datum: Provides default implementation for common datums. Each dataset still needs to implement the specific _load_datum.
  • show: Provides default displaying for common setups.

Devkits

Contains low-level tools for loading and interacting with the available datasets. It should be self-evident which dataset each devkit corresponds to.


External Libs

Contains libraries from other developers.

  • Databases: Tools for creating LMDB datasets.
  • DGP: Tools for loading DDAD dataset.
  • MiDaS: Pre-trained supervised scaleless depth estimation model.
  • NeWCRFs: Pre-trained supervised metric depth estimation model.

Losses

The main available losses are:

  • ReconstructionLoss: Base view synthesis loss. Additionally used for feature-based view synthesis and autoencoder image reconstruction.
  • RegressionLoss: Proxy depth regression loss. Additionally used for virtual stereo consistency.

NOTE: Each of these incorporates multiple different contributions based on the available input configuration. Check out the respective documentation for additional details.

New losses should be added as per the instructions in the registry. Losses must return a tuple consisting of

"""
:return (tuple) (
    loss: (Tensor) (,) Scalar loss value.
    loss_dict: (TensorDict) Dictionary containing intermediate loss outputs used for TensorBoard logging.
)
"""

Networks

The main available networks are:

  • depth: Predicts a dense disparity map from a single image.
  • pose: Predicts the relative pose between two images in axis-angle format.
  • autoencoder: Converts the input image into a compact feature representation, which can be used to reconstruct the image. Used primarily to learn a feature representation complementary to the image-based reconstruction loss.

These networks use any of the pretrained encoders available in timm. New networks and decoders should be added as per the instructions in the registry.

Networks producing dense outputs (depth & autoencoder) additionally require a dense decoder:

Currently, all decoders are required to have roughly the same argument structure. This could probably be improved by using additional **kwargs in the main network initializers.

"""
:param num_ch_enc: (Sequence[int]) List of channels per encoder stage.
:param enc_sc: (Sequence[int]) List of downsampling factor per encoder stage.
:param upsample_mode: (str) Torch upsampling mode. {'nearest', 'bilinear'...}
:param use_skip: (bool) If `True`, add skip connections from corresponding encoder stage.
:param out_sc: (Sequence[int]) List of multi-scale output downsampling factor as 2**s.
:param out_ch: (int) Number of output channels.
:param out_act: (str) Activation to apply to each output stage.
"""

Regularizers

Regularizers are meant to prevent suboptimal or degenerate representations, rather than driving the optimization. The main available regularizers are:

New regularizers should be added as per the instructions in the registry. They must also follow the output format required by the losses.


Tools

A collection of more advanced utilities only depend on each other or on utils.

  • geometry: Depth scaling/conversion and view synthesis tools, such as extract_edges, to_scaled, to_inv, T_from_AAt, ViewSynth...
  • ops: Collection of PyTorch operations, such as to_torch, to_numpy, allow_np , interpolate_like, expand_dim...
  • parsers: Tools for instantiating classes from config dicts.
  • table_formatter: TableFormatter to convert dataframes into LaTeX/MarkDown tables.
  • viz: Visualizations tools rgb_from_disp & rgb_from_feat.

Utils

A collection of basic utilities that do not depend on any other custom code from this library.

  • callbacks: Custom PyTorch Lighning callbacks, incliding progress bars and anomaly detection.
  • collate: default_collate from PyTorch, modified to accept MultiLevelTimer.
  • deco: Custom decorators, including opt_args_deco, delegates, map_container&retry_new_on_error`.
  • io: YAML loading/writing tools and image conversion.
  • loader: ConcatDataLoader that implements a round-robin loading for multiple datasets.
  • metrics: PyTorch Lightning metrics for use during training.
  • misc: Collection of random utilities, flatten_dict, sort_dict, get_logger & apply_cmap.
  • timers: MultiLevelTimer to allow for nested timing blocks.

Paths

Path management for datasets and storing/loading checkpoints is done based on predefined locations in DATA_ROOTS & MODEL_ROOTS. This alleviates the need to provide long and repeated paths, remaining flexible to datasets being stored in different locations (e.g. local scratch spaces). Instructions for setting up custom roots can be found in the main README.

This file additionally provides some utilities for finding dataset & model paths within the available roots: find_data_dir & find_model_file. These functions will return the input path if it is an absolute path to an existing file/directory. Otherwise, they will search the available roots and return the first existing path.


Registry

New network, losses or datasets should be added to the registry via the register decorator. This makes these classes accessible to the parsers, and in turn to the config files.

import torch.nn as nn

from src import register

@register(name='awesome', type='loss')
class MyAwesomeLoss(nn.Module):
    def forward(self, pred, target):
        err = (pred - target).abs().mean(dim=1, keepdim=True)
        loss = err.mean()
        return loss, {'l1_error': err}
  • type selects the relevant registry, but can typically be omitted and guessed from the class name.
  • name represents the identifier used in the configs and module forward pass. Multiple aliases can be registered by providing a tuple, useful when losses share the same underlying computations but require different inputs or preprocessing in MonoDepthModule. An example is the base ReconstructionLoss, which can be used with either images or dense feature maps.