A PyTorch implementation of U-Net using a DenseNet-121 backbone for the encoding and deconding path.
The DenseNet blocks are based on the implementation available in torchvision.
The input is restricted to RGB images and has shape . The output has shape , where is the number of output classes.
If the downsaple
option is set to False the stride in conv0
is set to 1 and pool0
is removed.
Optionally a pretrained model can be used to initalize the encoder.
- pytorch
- torchvision
from dense_unet import DenseUNet
pretrained_encoder_uri = 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
#
# for a local file use
#
# from pathlib import Path
# pretrained_encoder_uri = Path('/path/to/local/model.pth').resolve().as_uri()
#
num_output_classes = 3
model = DenseUNet(num_output_classes, downsample=True, pretrained_encoder_uri=pretrained_encoder_uri)