diff --git a/requirements.txt b/requirements.txt index 6d14a1f..6d62e21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,11 @@ -torch lightning -git+https://github.com/discovery-unicamp/hiaac-librep.git@0.0.4-dev -scipy -plotly numpy pandas +plotly PyYAML +scipy statsmodels -jsonargparse[all] +tifffile +torch zarr -rich \ No newline at end of file +torchmetrics \ No newline at end of file diff --git a/sslt/models/nets/__init__.py b/sslt/models/nets/__init__.py index e69de29..dd6875f 100644 --- a/sslt/models/nets/__init__.py +++ b/sslt/models/nets/__init__.py @@ -0,0 +1 @@ +from .setr import _SetR_PUP diff --git a/sslt/models/nets/base.py b/sslt/models/nets/base.py index 89dc4bc..9436c9c 100644 --- a/sslt/models/nets/base.py +++ b/sslt/models/nets/base.py @@ -1,6 +1,7 @@ from typing import Dict -import torch + import lightning as L +import torch class SimpleSupervisedModel(L.LightningModule): @@ -18,7 +19,7 @@ class SimpleSupervisedModel(L.LightningModule): easier to implement new models by only changing the backbone model. More complex models, that does not follow this pipeline, should not inherit from this class. - + Note that, for this class the input data is a tuple of tensors, where the first tensor is the input data and the second tensor is the mask or label. """ @@ -38,7 +39,7 @@ def __init__( backbone : torch.nn.Module The backbone model. Usually the encoder/decoder part of the model. fc : torch.nn.Module - The fully connected model, usually used to classification tasks. + The fully connected model, usually used to classification tasks. Use `torch.nn.Identity()` if no FC model is needed. loss_fn : torch.nn.Module The function used to compute the loss. diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py new file mode 100644 index 0000000..44fe6f5 --- /dev/null +++ b/sslt/models/nets/setr.py @@ -0,0 +1,537 @@ +import warnings +from typing import Optional, Tuple + +import lightning as L +import torch +from torch import nn + +from sslt.models.nets.vit import _VisionTransformerBackbone +from sslt.utils.upsample import Upsample, resize + + +class _SETRUPHead(nn.Module): + """Naive upsampling head and Progressive upsampling head of SETR. + + Naive or PUP head of `SETR `_. + + """ + + def __init__( + self, + channels: int, + in_channels: int, + num_classes: int, + norm_layer: nn.Module, + conv_norm: nn.Module, + conv_act: nn.Module, + num_convs: int, + up_scale: int, + kernel_size: int, + align_corners: bool, + dropout: float, + interpolate_mode: str, + ): + """ + Initializes the SETR model. + + Parameters + ---------- + channels : int + Number of output channels. + in_channels : int + Number of input channels. + num_classes : int + Number of output classes. + norm_layer : nn.Module + Normalization layer. + conv_norm : nn.Module + Convolutional normalization layer. + conv_act : nn.Module + Convolutional activation layer. + num_convs : int + Number of convolutional layers. + up_scale : int + Upsampling scale factor. + kernel_size : int + Kernel size for convolutional layers. + align_corners : bool + Whether to align corners during upsampling. + dropout : float + Dropout rate. + interpolate_mode : str + Interpolation mode for upsampling. + + Raises + ------ + AssertionError + If kernel_size is not 1 or 3. + """ + assert kernel_size in [1, 3], "kernel_size must be 1 or 3." + + super().__init__() + + self.num_classes = num_classes + self.out_channels = channels + self.cls_seg = nn.Conv2d(channels, self.num_classes, 1) + self.norm = norm_layer + conv_norm = conv_norm + conv_act = conv_act + self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None + + self.up_convs = nn.ModuleList() + + for _ in range(num_convs): + self.up_convs.append( + nn.Sequential( + nn.Conv2d( + in_channels, + self.out_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ), + conv_norm, + conv_act, + Upsample( + scale_factor=up_scale, + mode=interpolate_mode, + align_corners=align_corners, + ), + ) + ) + in_channels = self.out_channels + + def forward(self, x): + n, c, h, w = x.shape + + x = x.reshape(n, c, h * w).transpose(1, 2).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + + for up_conv in self.up_convs: + x = up_conv(x) + + if self.dropout is not None: + x = self.dropout(x) + out = self.cls_seg(x) + + return out + + +class _SETRMLAHead(nn.Module): + """Multi level feature aggretation head of SETR. + + MLA head of `SETR `_. + """ + + def __init__( + self, + channels: int, + conv_norm: Optional[nn.Module], + conv_act: Optional[nn.Module], + in_channels: list[int], + out_channels: int, + num_classes: int, + mla_channels: int = 128, + up_scale: int = 4, + kernel_size: int = 3, + align_corners: bool = True, + dropout: float = 0.1, + threshold: Optional[float] = None, + ): + super().__init__() + + if out_channels is None: + if num_classes == 2: + warnings.warn( + "For binary segmentation, we suggest using" + "`out_channels = 1` to define the output" + "channels of segmentor, and use `threshold`" + "to convert `seg_logits` into a prediction" + "applying a threshold" + ) + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + "out_channels should be equal to num_classes," + "except binary segmentation set out_channels == 1 and" + f"num_classes == 2, but got out_channels={out_channels}" + f"and num_classes={num_classes}" + ) + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn("threshold is not defined for binary, and defaults to 0.3") + + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + conv_norm = ( + conv_norm if conv_norm is not None else nn.SyncBatchNorm(mla_channels) + ) + conv_act = conv_act if conv_act is not None else nn.ReLU() + self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None + self.cls_seg = nn.Conv2d(channels, out_channels, 1) + + num_inputs = len(in_channels) + + self.up_convs = nn.ModuleList() + for i in range(num_inputs): + self.up_convs.append( + nn.Sequential( + nn.Conv2d( + in_channels[i], + mla_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ), + conv_norm, + conv_act, + nn.Conv2d( + mla_channels, + mla_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ), + conv_norm, + conv_act, + Upsample( + scale_factor=up_scale, + mode="bilinear", + align_corners=align_corners, + ), + ) + ) + + def forward(self, x): + outs = [] + for x, up_conv in zip(x, self.up_convs): + outs.append(up_conv(x)) + out = torch.cat(outs, dim=1) + if self.dropout is not None: + out = self.dropout(out) + out = self.cls_seg(out) + return out + + +class _SetR_PUP(nn.Module): + + def __init__( + self, + image_size: int | tuple[int, int], + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + num_convs: int, + num_classes: int, + decoder_channels: int, + up_scale: int, + encoder_dropout: float, + kernel_size: int, + decoder_dropout: float, + norm_layer: nn.Module, + interpolate_mode: str, + conv_norm: nn.Module, + conv_act: nn.Module, + align_corners: bool, + ): + """ + Initializes the SETR PUP model. + + Parameters + ---------- + image_size : int or tuple[int, int] + The size of the input image. + patch_size : int + The size of each patch in the input image. + num_layers : int + The number of layers in the transformer encoder. + num_heads : int + The number of attention heads in the transformer encoder. + hidden_dim : int + The hidden dimension of the transformer encoder. + mlp_dim : int + The dimension of the feed-forward network in the transformer encoder. + num_convs : int + The number of convolutional layers in the decoder. + num_classes : int + The number of output classes. + decoder_channels : int + The number of channels in the decoder. + up_scale : int + The scale factor for upsampling in the decoder. + encoder_dropout : float + The dropout rate for the transformer encoder. + kernel_size : int + The kernel size for the convolutional layers in the decoder. + decoder_dropout : float + The dropout rate for the decoder. + norm_layer : nn.Module + The normalization layer to be used. + interpolate_mode : str + The mode for interpolation during upsampling. + conv_norm : nn.Module + The normalization layer to be used in the decoder convolutional layers. + conv_act : nn.Module + The activation function to be used in the decoder convolutional layers. + align_corners : bool + Whether to align corners during upsampling. + + """ + super().__init__() + self.encoder = _VisionTransformerBackbone( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + num_classes=num_classes, + dropout=encoder_dropout, + ) + + self.decoder = _SETRUPHead( + channels=decoder_channels, + in_channels=hidden_dim, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, + norm_layer=norm_layer, + ) + + self.aux_head1 = _SETRUPHead( + channels=decoder_channels, + in_channels=hidden_dim, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, + norm_layer=norm_layer, + ) + + self.aux_head2 = _SETRUPHead( + channels=decoder_channels, + in_channels=hidden_dim, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, + norm_layer=norm_layer, + ) + + self.aux_head3 = _SETRUPHead( + channels=decoder_channels, + in_channels=hidden_dim, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, + norm_layer=norm_layer, + ) + + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.encoder(x) + # x_aux1 = self.aux_head1(x) + # x_aux2 = self.aux_head2(x) + # x_aux3 = self.aux_head3(x) + x = self.decoder(x) + return x, torch.zeros(1), torch.zeros(1), torch.zeros(1) + + +class SETR_PUP(L.LightningModule): + + def __init__( + self, + image_size: int | tuple[int, int] = 512, + patch_size: int = 16, + num_layers: int = 24, + num_heads: int = 16, + hidden_dim: int = 1024, + mlp_dim: int = 4096, + encoder_dropout: float = 0.1, + num_classes: int = 1000, + norm_layer: Optional[nn.Module] = None, + decoder_channels: int = 256, + num_convs: int = 4, + up_scale: int = 2, + kernel_size: int = 3, + align_corners: bool = False, + decoder_dropout: float = 0.1, + conv_norm: Optional[nn.Module] = None, + conv_act: Optional[nn.Module] = None, + interpolate_mode: str = "bilinear", + loss_fn: Optional[nn.Module] = None, + ): + """ + Initializes the SetR model. + + Parameters + ---------- + image_size : int or tuple[int, int] + The input image size. Defaults to 512. + patch_size : int + The size of each patch. Defaults to 16. + num_layers : int + The number of layers in the transformer encoder. Defaults to 24. + num_heads : int + The number of attention heads in the transformer encoder. Defaults to 16. + hidden_dim : int + The hidden dimension of the transformer encoder. Defaults to 1024. + mlp_dim : int + The dimension of the MLP layers in the transformer encoder. Defaults to 4096. + encoder_dropout : float + The dropout rate for the transformer encoder. Defaults to 0.1. + num_classes : int + The number of output classes. Defaults to 1000. + norm_layer : nn.Module, optional + The normalization layer to be used in the decoder. Defaults to None. + decoder_channels : int + The number of channels in the decoder. Defaults to 256. + num_convs : int + The number of convolutional layers in the decoder. Defaults to 4. + up_scale : int + The scale factor for upsampling in the decoder. Defaults to 2. + kernel_size : int + The kernel size for convolutional layers in the decoder. Defaults to 3. + align_corners : bool + Whether to align corners during interpolation in the decoder. Defaults to False. + decoder_dropout : float + The dropout rate for the decoder. Defaults to 0.1. + conv_norm : nn.Module, optional + The normalization layer to be used in the convolutional layers of the decoder. Defaults to None. + conv_act : nn.Module, optional + The activation function to be used in the convolutional layers of the decoder. Defaults to None. + interpolate_mode : str + The interpolation mode for upsampling in the decoder. Defaults to "bilinear". + loss_fn : nn.Module, optional + The loss function to be used during training. Defaults to None. + + """ + super().__init__() + self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss() + norm_layer = norm_layer if norm_layer is not None else nn.LayerNorm(hidden_dim) + conv_norm = ( + conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels) + ) + conv_act = conv_act if conv_act is not None else nn.ReLU() + self.num_classes = num_classes + + self.model = _SetR_PUP( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + conv_norm=conv_norm, + conv_act=conv_act, + decoder_channels=decoder_channels, + encoder_dropout=encoder_dropout, + decoder_dropout=decoder_dropout, + norm_layer=norm_layer, + interpolate_mode=interpolate_mode, + align_corners=align_corners, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calculate the loss between the output and the input data. + + Parameters + ---------- + y_hat : torch.Tensor + The output data from the forward pass. + y : torch.Tensor + The input data/label. + + Returns + ------- + torch.Tensor + The loss value. + """ + loss = self.loss_fn(y_hat, y) + return loss + + def _single_step( + self, batch: torch.Tensor, batch_idx: int, step_name: str + ) -> torch.Tensor: + """Perform a single step of the training/validation loop. + + Parameters + ---------- + batch : torch.Tensor + The input data. + batch_idx : int + The index of the batch. + step_name : str + The name of the step, either "train" or "val". + + Returns + ------- + torch.Tensor + The loss value. + """ + x, y = batch + y_hat = self.model(x) + loss = self._loss_func(y_hat[0], y.squeeze(1)) + self.log( + f"{step_name}_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "train") + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "val") + + def test_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "test") + + def predict_step( + self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None + ): + x, _ = batch + return self.model(x) + + def configure_optimizers(self): + return torch.optim.Adam(self.model.parameters(), lr=1e-3) diff --git a/sslt/models/nets/unet.py b/sslt/models/nets/unet.py index a7803eb..ee5decf 100644 --- a/sslt/models/nets/unet.py +++ b/sslt/models/nets/unet.py @@ -1,18 +1,17 @@ """ Full assembly of the parts to form the complete network """ -from typing import Dict -import lightning as L -import torch.optim as optim -from torch.optim.lr_scheduler import CyclicLR -from torch.optim.lr_scheduler import StepLR import time +from typing import Dict, Optional + +import lightning as L import torch import torch.nn as nn import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import CyclicLR, StepLR from sslt.models.nets.base import SimpleSupervisedModel - """ -------------- Parts of the U-Net model --------------""" @@ -44,15 +43,11 @@ def __init__(self, in_channels, out_channels, mid_channels=None): nn.Conv2d( in_channels, mid_channels, kernel_size=3, padding=1, bias=False ), # no need to add bias since BatchNorm2d will do that - nn.BatchNorm2d( - mid_channels - ), # normalize the output of the previous layer + nn.BatchNorm2d(mid_channels), # normalize the output of the previous layer nn.ReLU( inplace=True ), # inplace=True will modify the input directly instead of allocating new memory - nn.Conv2d( - mid_channels, out_channels, kernel_size=3, padding=1, bias=False - ), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) @@ -82,9 +77,7 @@ def __init__(self, in_channels, out_channels, bilinear=True): # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: - self.up = nn.Upsample( - scale_factor=2, mode="bilinear", align_corners=True - ) + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = _DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d( @@ -99,9 +92,7 @@ def forward(self, x1, x2): diffX = x2.size()[3] - x1.size()[3] # pad the input tensor on all sides with the given "pad" value - x1 = F.pad( - x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2] - ) + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd @@ -209,7 +200,7 @@ def __init__( n_channels: int = 1, bilinear: bool = False, learning_rate: float = 1e-3, - loss_fn: torch.nn.Module = None, + loss_fn: Optional[torch.nn.Module] = None, ): """Wrapper implementation of the U-Net model. @@ -231,5 +222,5 @@ def __init__( fc=torch.nn.Identity(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, - flatten=False + flatten=False, ) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py new file mode 100644 index 0000000..bb613e4 --- /dev/null +++ b/sslt/models/nets/vit.py @@ -0,0 +1,237 @@ +import math +from collections import OrderedDict +from functools import partial +from typing import Callable, List, Optional + +import torch +from torch import nn +from torchvision.models.vision_transformer import ( + Conv2dNormActivation, + ConvStemConfig, + Encoder, + _log_api_usage_once, +) + + +class _VisionTransformerBackbone(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" + + def __init__( + self, + image_size: int | tuple[int, int], + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): + """ + Initializes a Vision Transformer (ViT) model. + + Parameters + ---------- + image_size : int or tuple[int, int] + The size of the input image. If an int is provided, it is assumed + to be a square image. If a tuple of ints is provided, it represents the height and width of the image. + patch_size : int + The size of each patch in the image. + num_layers : int + The number of transformer layers in the model. + num_heads : int + The number of attention heads in the transformer layers. + hidden_dim : int + The dimensionality of the hidden layers in the transformer. + mlp_dim : int + The dimensionality of the feed-forward MLP layers in the transformer. + dropout : float, optional + The dropout rate to apply. Defaults to 0.0. + attention_dropout : float, optional + The dropout rate to apply to the attention weights. Defaults to 0.0. + num_classes : int, optional + The number of output classes. Defaults to 1000. + norm_layer : Callable[..., torch.nn.Module], optional + The normalization layer to use. Defaults to nn.LayerNorm with epsilon=1e-6. + conv_stem_configs : List[ConvStemConfig], optional + The configuration for the convolutional stem layers. + If provided, the input image will be processed by these convolutional layers before being passed to + the transformer. Defaults to None. + + """ + super().__init__() + _log_api_usage_once(self) + + if isinstance(image_size, int): + torch._assert( + image_size % patch_size == 0, "Input shape indivisible by patch size!" + ) + elif isinstance(image_size, tuple): + torch._assert( + image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0, + "Input shape indivisible by patch size!", + ) + + self.image_size = image_size + self.patch_size = patch_size + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout + self.num_classes = num_classes + self.norm_layer = norm_layer + + if conv_stem_configs is not None: + # As per https://arxiv.org/abs/2106.14881 + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_bn_relu_{i}", + Conv2dNormActivation( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + norm_layer=conv_stem_layer_config.norm_layer, + activation_layer=conv_stem_layer_config.activation_layer, + ), + ) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", + nn.Conv2d( + in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1 + ), + ) + self.conv_proj: nn.Module = seq_proj + else: + self.conv_proj = nn.Conv2d( + in_channels=3, + out_channels=hidden_dim, + kernel_size=patch_size, + stride=patch_size, + ) + + if isinstance(image_size, int): + seq_length = (image_size // patch_size) ** 2 + elif isinstance(image_size, tuple): + seq_length = (image_size[0] // patch_size) * (image_size[1] // patch_size) + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + ) + self.seq_length = seq_length + + if isinstance(self.conv_proj, nn.Conv2d): + # Init the patchify stem + fan_in = ( + self.conv_proj.in_channels + * self.conv_proj.kernel_size[0] + * self.conv_proj.kernel_size[1] + ) + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + if self.conv_proj.bias is not None: + nn.init.zeros_(self.conv_proj.bias) + elif self.conv_proj.conv_last is not None and isinstance( + self.conv_proj.conv_last, nn.Conv2d + ): + # Init the last 1x1 conv of the conv stem + nn.init.normal_( + self.conv_proj.conv_last.weight, + mean=0.0, + std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels), + ) + if self.conv_proj.conv_last.bias is not None: + nn.init.zeros_(self.conv_proj.conv_last.bias) + + def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: + """Process the input tensor and return the reshaped tensor and dimensions. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns. + """ + n, c, h, w = x.shape + p = self.patch_size + + if isinstance(self.image_size, int): + torch._assert( + h == self.image_size, + f"Wrong image height! Expected {self.image_size} but got {h}!", + ) + torch._assert( + w == self.image_size, + f"Wrong image width! Expected {self.image_size} but got {w}!", + ) + elif isinstance(self.image_size, tuple): + torch._assert( + h == self.image_size[0], + f"Wrong image height! Expected {self.image_size[0]} but got {h}!", + ) + torch._assert( + w == self.image_size[1], + f"Wrong image width! Expected {self.image_size[1]} but got {w}!", + ) + else: + raise ValueError("Invalid image size type!") + + n_h = h // p + n_w = w // p + + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) + + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) + + return x, n_h, n_w + + def forward(self, x: torch.Tensor): + """Forward pass of the Vision Transformer Backbone. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + # Reshape and permute the input tensor + x, n_h, n_w = self._process_input(x) + n = x.shape[0] + + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) + + x = self.encoder(x) + + # Classifier "token" as used by standard language architectures + x = x[:, 1:] + + B, _, C = x.shape + + x = x.reshape(B, n_h, n_w, C).permute(0, 3, 1, 2).contiguous() + + return x diff --git a/sslt/models/nets/wisenet.py b/sslt/models/nets/wisenet.py index 4e8d186..ee4575f 100644 --- a/sslt/models/nets/wisenet.py +++ b/sslt/models/nets/wisenet.py @@ -1,4 +1,7 @@ +from typing import Optional + import torch + from sslt.models.nets.base import SimpleSupervisedModel @@ -98,13 +101,11 @@ def __init__( self, in_channels: int = 1, out_channels: int = 1, - loss_fn: torch.nn.Module = None, + loss_fn: Optional[torch.nn.Module] = None, learning_rate: float = 1e-3, ): super().__init__( - backbone=_WiseNet( - in_channels=in_channels, out_channels=out_channels - ), + backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels), fc=torch.nn.Identity(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, @@ -129,9 +130,8 @@ def _single_step( ) return loss - def predict_step(self, batch, batch_idx, dataloader_idx=None): x, y = batch y_hat = self.forward(x) y_hat = y_hat[:, :, : y.size(2), : y.size(3)] - return y_hat \ No newline at end of file + return y_hat diff --git a/sslt/utils/upsample.py b/sslt/utils/upsample.py new file mode 100644 index 0000000..522bc95 --- /dev/null +++ b/sslt/utils/upsample.py @@ -0,0 +1,54 @@ +import warnings +from typing import Optional, Tuple + +import torch.nn as nn +import torch.nn.functional as F + + +def resize( + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=True, +): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__( + self, size=None, scale_factor=None, mode="nearest", align_corners=None + ): + super().__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py new file mode 100644 index 0000000..3506f3f --- /dev/null +++ b/tests/models/nets/test_setr.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from sslt.models.nets.setr import SETR_PUP + + +def test_wisenet_loss(): + model = SETR_PUP() + batch_size = 2 + x = torch.rand(2, 3, 512, 512) + mask = torch.rand(2, 1, 512, 512).long() + + # Do the training step + loss = model.training_step((x, mask), 0).item() + assert loss is not None + assert loss >= 0, f"Expected non-negative loss, but got {loss}" + + +def test_wisenet_predict(): + model = SETR_PUP() + batch_size = 2 + mask_shape = (batch_size, 1000, 512, 512) # (2, 1, 500, 500) + x = torch.rand(2, 3, 512, 512) + mask = torch.rand(2, 1, 512, 512).long() + + # Do the prediction step + preds = model.predict_step((x, mask), 0) + assert preds is not None + assert ( + preds[0].shape == mask_shape + ), f"Expected shape {mask_shape}, but got {preds[0].shape}"