Skip to content

Commit

Permalink
Merge pull request #15 from GabrielBG0/12-unet
Browse files Browse the repository at this point in the history
12 unet
  • Loading branch information
GabrielBG0 authored Feb 12, 2024
2 parents e2254d1 + 2caf28a commit dfefa1a
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 0 deletions.
Empty file added sslt/models/nets/__init__.py
Empty file.
136 changes: 136 additions & 0 deletions sslt/models/nets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Dict
import torch
import lightning as L


class SimpleReconstructionNet(L.LightningModule):
"""Simple autoencoder pipeline for reconstruction tasks
This class implements a very common pipeline for autoencoder models, which
are used to reconstruct the input data. It consists in:
1. Make a forward pass with the input data on the backbone model;
2. Compute the loss between the output and the input data;
3. Optimize the model parameters with respect to the loss.
This reduces the code duplication for autoencoder models, and makes it
easier to implement new models by only changing the backbone model. More
complex models, that does not follow this pipeline, should not inherit from
this class.
Note that this class assumes that input data is a single tensor and not a
tuple of tensors (e.g., data and label).
"""

def __init__(
self,
backbone: torch.nn.Module,
learning_rate: float = 1e-3,
loss_fn: torch.nn.Module = None,
):
"""Simple autoencoder pipeline for reconstruction tasks.
Parameters
----------
backbone : torch.nn.Module
The backbone model that will be used to make the forward pass and
will be optimized with respect to the loss.
learning_rate : float, optional
The learning rate to Adam optimizer, by default 1e-3
loss_fn : torch.nn.Module, optional
The function used to compute the loss. If `None`, it will be used
the MSELoss, by default None.
"""
super().__init__()
self.backbone = backbone
self.learning_rate = learning_rate
self.loss_fn = loss_fn or torch.nn.MSELoss()

def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calculate the loss between the output and the input data.
Parameters
----------
y_hat : torch.Tensor
The output data from the forward pass.
y : torch.Tensor
The input data/label.
Returns
-------
torch.Tensor
The loss value.
"""
loss = self.loss_fn(y_hat, y)
return loss

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass with the input data on the backbone model.
Parameters
----------
x : torch.Tensor
The input data.
Returns
-------
torch.Tensor
The output data from the forward pass.
"""
return self.backbone(x)

def _single_step(
self, batch: torch.Tensor, batch_idx: int, step_name: str
) -> torch.Tensor:
"""Perform a single train/validation/test step. It consists in making a
forward pass with the input data on the backbone model, computing the
loss between the output and the input data, and logging the loss.
Parameters
----------
batch : torch.Tensor
The input data. It must be a single tensor and not a tuple of
tensors (e.g., data and label).
batch_idx : int
The index of the batch.
step_name : str
The name of the step. It will be used to log the loss.
Returns
-------
torch.Tensor
A tensor with the loss value.
"""
x = batch
y_hat = self.forward(x)
loss = self._loss_func(y_hat, x)
self.log(
f"{step_name}_loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss

def training_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="train")

def validation_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="val")

def test_step(self, batch: torch.Tensor, batch_idx: int):
return self._single_step(batch, batch_idx, step_name="test")

def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
y_hat = self.forward(x)
return y_hat

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.learning_rate,
)
return optimizer
228 changes: 228 additions & 0 deletions sslt/models/nets/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
""" Full assembly of the parts to form the complete network """

from typing import Dict
import lightning as L
import torch.optim as optim
from torch.optim.lr_scheduler import CyclicLR
from torch.optim.lr_scheduler import StepLR
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from sslt.models.nets.base import SimpleReconstructionNet


""" -------------- Parts of the U-Net model --------------"""
class _DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

"""
Performs two convolutions with the same number
of input and output channels, followed by batch normalization and ReLU activation
"""

def __init__(self, in_channels, out_channels, mid_channels=None):
"""
Parameters
----------
in_channels : int
Number of input channels, i.e. the number of channels in the input image (1 for grayscale, 3 for RGB)
out_channels : int
Number of output channels, i.e. the number of channels produced by the convolution
mid_channels : int, optional
Number of channels in the middle, by default None
"""
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(
in_channels, mid_channels, kernel_size=3, padding=1, bias=False
), # no need to add bias since BatchNorm2d will do that
nn.BatchNorm2d(
mid_channels
), # normalize the output of the previous layer
nn.ReLU(
inplace=True
), # inplace=True will modify the input directly instead of allocating new memory
nn.Conv2d(
mid_channels, out_channels, kernel_size=3, padding=1, bias=False
),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)

def forward(self, x):
return self.double_conv(x)


class _Down(nn.Module):
"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2), _DoubleConv(in_channels, out_channels)
)

def forward(self, x):
return self.maxpool_conv(x)


class _Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=True
)
self.conv = _DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(
in_channels, in_channels // 2, kernel_size=2, stride=2
)
self.conv = _DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW (channel, height, width)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]

# pad the input tensor on all sides with the given "pad" value
x1 = F.pad(
x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]
)
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)


class _OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(_OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
return self.conv(x)


class _UNet(torch.nn.Module):
"""Implementation of U-Net model.
"""
def __init__(
self,
n_channels: int = 1,
bilinear: bool = False,
):
"""Implementation of U-Net model.
Parameters
----------
n_channels : int, optional
Number of input channels, by default 1
bilinear : bool, optional
If `True` use bilinear interpolation for upsampling, by default
False.
"""
super().__init__()
factor = 2 if bilinear else 1

self.n_channels = n_channels
self.bilinear = bilinear

self.inc = _DoubleConv(n_channels, 64)
self.down1 = _Down(64, 128)
self.down2 = _Down(128, 256)
self.down3 = _Down(256, 512)
self.down4 = _Down(512, 1024 // factor)
self.up1 = _Up(1024, 512 // factor, bilinear)
self.up2 = _Up(512, 256 // factor, bilinear)
self.up3 = _Up(256, 128 // factor, bilinear)
self.up4 = _Up(128, 64, bilinear)
# self.outc = (OutConv(64, n_classes))
self.outc = _OutConv(64, 1)

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits


class UNet(SimpleReconstructionNet):
""" This class is a simple implementation of the U-Net model, which is a
convolutional neural network used for image segmentation. The model consists
of a contracting path (encoder) and an expansive path (decoder). The
contracting path follows the typical architecture of a convolutional neural
network, with repeated applications of convolutions and max pooling layers.
The expansive path consists of up-convolutions and concatenation of feature
maps from the contracting path. The model also has skip connections, which
allows the expansive path to use information from the contracting path at
multiple resolutions. The U-Net model was originally proposed by
Ronneberger, Fischer, and Brox in 2015.
This architecture, handles arbitrary input sizes, and returns an output of
the same size as the input. The expected input size is (N, C, H, W), where N
is the batch size, C is the number of channels, H is the height of the input
image, and W is the width of the input image.
Note that, for this implementation, the input batch is a single tensor and
not a tuple of tensors (e.g., data and label).
Note that this class wrappers the `_UNet` class, which is the actual
implementation of the U-Net model, into a `SimpleReconstructionNet` class,
which is a simple autoencoder pipeline for reconstruction tasks.
References
----------
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional
networks for biomedical image segmentation." Medical Image Computing and
Computer-Assisted Intervention-MICCAI 2015: 18th International Conference,
Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer
International Publishing, 2015.
"""

def __init__(
self,
n_channels: int = 1,
bilinear: bool = False,
learning_rate: float = 1e-3,
loss_fn: torch.nn.Module = None,
):
"""Wrapper implementation of the U-Net model.
Parameters
----------
n_channels : int, optional
The number of channels of the input, by default 1
bilinear : bool, optional
If `True` use bilinear interpolation for upsampling, by default
False.
learning_rate : float, optional
The learning rate to Adam optimizer, by default 1e-3
loss_fn : torch.nn.Module, optional
The function used to compute the loss. If `None`, it will be used
the MSELoss, by default None.
"""
super().__init__(
backbone=_UNet(n_channels=n_channels, bilinear=bilinear),
learning_rate=learning_rate,
loss_fn=loss_fn,
)
Empty file added tests/__init__.py
Empty file.
Empty file added tests/models/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions tests/models/nets/test_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from sslt.models.nets.unet import UNet

def test_unet():
# Test the class instantiation
model = UNet()
assert model is not None

# Generate a random input tensor (B, C, H, W)
input_shape = (2, 1, 500, 500)
x = torch.rand(*input_shape)

# Test the forward method
output = model(x)
assert output.shape == input_shape

0 comments on commit dfefa1a

Please sign in to comment.