From a81c50fece5f62bbd7e7f9a6948cc16e0ef662a9 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Thu, 22 Feb 2024 17:46:53 -0300 Subject: [PATCH 01/24] vit --- requirements.txt | 11 +- sslt/models/nets/vit.py | 406 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 411 insertions(+), 6 deletions(-) create mode 100644 sslt/models/nets/vit.py diff --git a/requirements.txt b/requirements.txt index 6d14a1f..6d62e21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,11 @@ -torch lightning -git+https://github.com/discovery-unicamp/hiaac-librep.git@0.0.4-dev -scipy -plotly numpy pandas +plotly PyYAML +scipy statsmodels -jsonargparse[all] +tifffile +torch zarr -rich \ No newline at end of file +torchmetrics \ No newline at end of file diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py new file mode 100644 index 0000000..965d22a --- /dev/null +++ b/sslt/models/nets/vit.py @@ -0,0 +1,406 @@ +import math +from functools import partial + +import torch +import torch.nn as nn + +# This implementation is based and addapted from Fudan Zhang Vision Group SETR implementation. +# You can find the original implementation here: https://github.com/fudan-zvg/SETR/blob/main/mmseg/models/backbones/vit.py#L3 + + +class DropPath(nn.Module): + + def __init__(self, drop_prob: float = 0) -> None: + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0 or not self.training: + return x + + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.dim() - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + output = x.div(keep_prob) * random_tensor + return output + + +class Mlp(nn.Module): + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop: float = 0.0, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale=None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer: partial[nn.LayerNorm] | nn.LayerNorm = partial(nn.LayerNorm), + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768) -> None: + super().__init__() + img_size = (img_size, img_size) + patch_size = (patch_size, patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, C, H, W = x.shape + + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class HybridEmbed(nn.Module): + + def __init__( + self, + backbone: nn.Module, + img_size=224, + feature_size=None, + in_chans=3, + embed_dim=768, + ) -> None: + super().__init__() + assert isinstance(backbone, nn.Module), "backbone must be nn.Module" + self.backbone = backbone + self.img_size = (img_size, img_size) + + # FIXME (from original code) this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + if feature_size is None: + with torch.no_grad(): + training = self.backbone.training + if training: + self.backbone.eval() + o = self.backbone( + torch.zeros(1, in_chans, self.img_size[0], self.img_size[1]) + )[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = (feature_size, feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class VisionTransformer(nn.Module): + """ + Vision Transformer model implementation. + + Parameters + ---------- + img_size: int + Size of the input image. Default is 384. + patch_size: int + Size of the image patch. Default is 16. + in_chans: int + Number of input channels. Default is 3. + embed_dim: int + Dimensionality of the token embeddings. Default is 1024. + depth: int + Number of transformer blocks. Default is 24. + num_heads: int + Number of attention heads. Default is 16. + num_classes: int + Number of output classes. Default is 19. + mlp_ratio: float + Ratio of MLP hidden dimension to embedding dimension. Default is 4.0. + qkv_bias: bool + Whether to include bias in the query, key, and value projections. Default is True. + qk_scale: float + Scale factor for query and key. Default is None. + drop_rate: float + Dropout rate. Default is 0.1. + attn_drop_rate: float + Dropout rate for attention weights. Default is 0.0. + drop_path_rate: float + Dropout rate for stochastic depth. Default is 0.0. + hybrid_backbone: None | nn.Module + Hybrid backbone module. Default is None. + norm_layer: nn.Module + Normalization layer. Default is nn.LayerNorm with eps=1e-6. + norm_cfg: None | dict + Normalization configuration. Default is None. + pos_embed_interp: bool + Whether to interpolate positional embeddings. Default is False. + random_init: bool + Whether to initialize weights randomly. Default is False. + align_corners: bool + Whether to align corners in positional embeddings. Default is False. + + References + ---------- + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale + https://arxiv.org/abs/2010.11929 + + """ + + def __init__( + self, + img_size=384, + patch_size=16, + in_chans=3, + embed_dim=1024, + depth=24, + num_heads=16, + num_classes=19, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.1, + attn_drop_rate=0.0, + drop_path_rate=0.0, + hybrid_backbone=None, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + norm_cfg=None, + pos_embed_interp=False, + random_init=False, + align_corners=False, + **kwargs, + ) -> None: + super(VisionTransformer, self).__init__(**kwargs) + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + self.depth = depth + self.num_heads = num_heads + self.num_classes = num_classes + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_scale = qk_scale + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.hybrid_backbone = hybrid_backbone + self.norm_layer = norm_layer + self.norm_cfg = norm_cfg + self.pos_embed_interp = pos_embed_interp + self.random_init = random_init + self.align_corners = align_corners + + self.num_stages = self.depth + self.out_indices = tuple(range(self.num_stages)) + + if self.hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + self.hybrid_backbone, + img_size=self.img_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + ) + else: + self.patch_embed = PatchEmbed( + img_size=self.img_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + ) + self.num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, self.embed_dim) + ) + self.pos_drop = nn.Dropout(p=self.drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + Block( + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=self.norm_layer, + ) + for i in range(self.depth) + ] + ) + + # NOTE (from original code) as per official impl, we could have a pre-logits representation dense layer + tanh here + # self.repr = nn.Linear(embed_dim, representation_size) + # self.repr_act = nn.Tanh() + + nn.init.trunc_normal_(self.pos_embed, std=0.02) + nn.init.trunc_normal_(self.cls_token, std=0.02) + + def init_weights(self, pretrained=None) -> None: + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if not self.random_init: + raise NotImplementedError("Pretrained model is not supported yet") + else: + print("Initialize weight randomly") + + def _conv_filter(self, state_dict, patch_size=16) -> dict: + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + def to_2D(self, x: torch.Tensor) -> torch.Tensor: + n, hw, c = x.shape + h = w = int(math.sqrt(hw)) + x = x.transpose(1, 2).reshape(n, c, h, w) + return x + + def to_1D(self, x: torch.Tensor) -> torch.Tensor: + n, c, h, w = x.shape + x = x.reshape(n, c, -1).transpose(1, 2) + return x + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: + B = x.shape[0] + x = self.patch_embed(x) + + x = x.flatten(2).transpose(1, 2) + + # originaly credited to Phil Wang + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) From 1395c6cd05757c8a72ece5f852619d81162fd6f3 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Fri, 23 Feb 2024 13:51:56 -0300 Subject: [PATCH 02/24] MLAHead in progress --- sslt/models/nets/vit.py | 103 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 965d22a..a304b90 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -404,3 +404,106 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: if i in self.out_indices: outs.append(x) return tuple(outs) + + +class MLAHead(nn.Module): + + def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None): + super(MLAHead, self).__init__() + self.head2 = nn.Sequential( + nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + ) + self.head3 = nn.Sequential( + nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + ) + self.head4 = nn.Sequential( + nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + ) + self.head5 = nn.Sequential( + nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), + build_norm_layer(norm_cfg, mlahead_channels)[1], + nn.ReLU(), + ) + + +""" +TODO figure wtf is going on here +def build_norm_layer( + cfg: Dict, num_features: int, postfix: Union[int, str] = "" +) -> Tuple[str, nn.Module]: + Build normalization layer. + + Args: + cfg (dict): The norm layer config, which should contain: + + - type (str): Layer type. + - layer args: Args needed to instantiate a norm layer. + - requires_grad (bool, optional): Whether stop gradient updates. + num_features (int): Number of input channels. + postfix (int | str): The postfix to be appended into norm abbreviation + to create named layer. + + Returns: + tuple[str, nn.Module]: The first element is the layer name consisting + of abbreviation and postfix, e.g., bn1, gn. The second element is the + created norm layer. + + if not isinstance(cfg, dict): + raise TypeError("cfg must be a dict") + if "type" not in cfg: + raise KeyError('the cfg dict must contain the key "type"') + cfg_ = cfg.copy() + + layer_type = cfg_.pop("type") + + if inspect.isclass(layer_type): + norm_layer = layer_type + else: + # Switch registry to the target scope. If `norm_layer` cannot be found + # in the registry, fallback to search `norm_layer` in the + # mmengine.MODELS. + with MODELS.switch_scope_and_registry(None) as registry: + norm_layer = registry.get(layer_type) + if norm_layer is None: + raise KeyError( + f"Cannot find {norm_layer} in registry under " + f"scope name {registry.scope}" + ) + abbr = infer_abbr(norm_layer) + + assert isinstance(postfix, (int, str)) + name = abbr + str(postfix) + + requires_grad = cfg_.pop("requires_grad", True) + cfg_.setdefault("eps", 1e-5) + if norm_layer is not nn.GroupNorm: + layer = norm_layer(num_features, **cfg_) + if layer_type == "SyncBN" and hasattr(layer, "_specify_ddp_gpu_num"): + layer._specify_ddp_gpu_num(1) + else: + assert "num_groups" in cfg_ + layer = norm_layer(num_channels=num_features, **cfg_) + + for param in layer.parameters(): + param.requires_grad = requires_grad + + return name, layer +""" From 742cde17213bd0eac10ebda4386f64b7d36174b1 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Mon, 26 Feb 2024 15:18:11 -0300 Subject: [PATCH 03/24] Add SyncBatchNorm and interpolate in MLAHead --- sslt/models/nets/vit.py | 113 +++++++++++++++------------------------- 1 file changed, 41 insertions(+), 72 deletions(-) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index a304b90..dd298b9 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F # This implementation is based and addapted from Fudan Zhang Vision Group SETR implementation. # You can find the original implementation here: https://github.com/fudan-zvg/SETR/blob/main/mmseg/models/backbones/vit.py#L3 @@ -408,102 +409,70 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: class MLAHead(nn.Module): + def build_norm_layer(self, mlahead_channels): + layer = nn.SyncBatchNorm(mlahead_channels, eps=1e-5) + for param in layer.parameters(): + param.requires_grad = True + return layer + def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None): super(MLAHead, self).__init__() self.head2 = nn.Sequential( nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), ) self.head3 = nn.Sequential( nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), ) self.head4 = nn.Sequential( nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), ) self.head5 = nn.Sequential( nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - build_norm_layer(norm_cfg, mlahead_channels)[1], + self.build_norm_layer(mlahead_channels), nn.ReLU(), ) - -""" -TODO figure wtf is going on here -def build_norm_layer( - cfg: Dict, num_features: int, postfix: Union[int, str] = "" -) -> Tuple[str, nn.Module]: - Build normalization layer. - - Args: - cfg (dict): The norm layer config, which should contain: - - - type (str): Layer type. - - layer args: Args needed to instantiate a norm layer. - - requires_grad (bool, optional): Whether stop gradient updates. - num_features (int): Number of input channels. - postfix (int | str): The postfix to be appended into norm abbreviation - to create named layer. - - Returns: - tuple[str, nn.Module]: The first element is the layer name consisting - of abbreviation and postfix, e.g., bn1, gn. The second element is the - created norm layer. - - if not isinstance(cfg, dict): - raise TypeError("cfg must be a dict") - if "type" not in cfg: - raise KeyError('the cfg dict must contain the key "type"') - cfg_ = cfg.copy() - - layer_type = cfg_.pop("type") - - if inspect.isclass(layer_type): - norm_layer = layer_type - else: - # Switch registry to the target scope. If `norm_layer` cannot be found - # in the registry, fallback to search `norm_layer` in the - # mmengine.MODELS. - with MODELS.switch_scope_and_registry(None) as registry: - norm_layer = registry.get(layer_type) - if norm_layer is None: - raise KeyError( - f"Cannot find {norm_layer} in registry under " - f"scope name {registry.scope}" - ) - abbr = infer_abbr(norm_layer) - - assert isinstance(postfix, (int, str)) - name = abbr + str(postfix) - - requires_grad = cfg_.pop("requires_grad", True) - cfg_.setdefault("eps", 1e-5) - if norm_layer is not nn.GroupNorm: - layer = norm_layer(num_features, **cfg_) - if layer_type == "SyncBN" and hasattr(layer, "_specify_ddp_gpu_num"): - layer._specify_ddp_gpu_num(1) - else: - assert "num_groups" in cfg_ - layer = norm_layer(num_channels=num_features, **cfg_) - - for param in layer.parameters(): - param.requires_grad = requires_grad - - return name, layer -""" + def forward(self, x2, x3, x4, x5): + x2 = F.interpolate( + self.head2(x2), + 4 * x2.shape[-1], + mode="bilinear", + align_corners=True, + ) + x3 = F.interpolate( + self.head3(x3), + 8 * x3.shape[-1], + mode="bilinear", + align_corners=True, + ) + x4 = F.interpolate( + self.head4(x4), + 16 * x4.shape[-1], + mode="bilinear", + align_corners=True, + ) + x5 = F.interpolate( + self.head5(x5), + 32 * x5.shape[-1], + mode="bilinear", + align_corners=True, + ) + return torch.cat([x2, x3, x4, x5], dim=1) From 09a70df95a31e792dfa1d8f4de3e4c4569cb2801 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 27 Feb 2024 12:33:06 -0300 Subject: [PATCH 04/24] aux heads --- sslt/models/nets/base.py | 230 ++++++++++++++++++++++++++++++++++++++- sslt/models/nets/vit.py | 47 +++++++- 2 files changed, 270 insertions(+), 7 deletions(-) diff --git a/sslt/models/nets/base.py b/sslt/models/nets/base.py index 89dc4bc..9153b49 100644 --- a/sslt/models/nets/base.py +++ b/sslt/models/nets/base.py @@ -1,6 +1,11 @@ -from typing import Dict -import torch +import warnings +from typing import Dict, Iterable + import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchmetrics import Accuracy class SimpleSupervisedModel(L.LightningModule): @@ -18,7 +23,7 @@ class SimpleSupervisedModel(L.LightningModule): easier to implement new models by only changing the backbone model. More complex models, that does not follow this pipeline, should not inherit from this class. - + Note that, for this class the input data is a tuple of tensors, where the first tensor is the input data and the second tensor is the mask or label. """ @@ -38,7 +43,7 @@ def __init__( backbone : torch.nn.Module The backbone model. Usually the encoder/decoder part of the model. fc : torch.nn.Module - The fully connected model, usually used to classification tasks. + The fully connected model, usually used to classification tasks. Use `torch.nn.Identity()` if no FC model is needed. loss_fn : torch.nn.Module The function used to compute the loss. @@ -149,3 +154,220 @@ def configure_optimizers(self): lr=self.learning_rate, ) return optimizer + + +class BaseDecodeHead(nn.Module): + """Base class for BaseDecodeHead. + + Parameters + ---------- + in_channels (int|Sequence[int]): Input channels. + channels (int): Channels after modules, before conv_seg. + num_classes (int): Number of classes. + dropout_ratio (float): Ratio of dropout layer. Default: 0.1. + conv_cfg (dict|None): Config of conv layers. Default: None. + norm_cfg (dict|None): Config of norm layers. Default: None. + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU') + in_index (int|Sequence[int]): Input feature index. Default: -1 + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + Default: None. + loss_decode (dict): Config of decode loss. + Default: dict(type='CrossEntropyLoss'). + ignore_index (int): The label index to be ignored. Default: 255 + sampler (dict|None): The config of segmentation map sampler. + Default: None. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + """ + + def __init__( + self, + in_channels, + channels, + *, + num_classes, + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type="ReLU"), + in_index: int | Iterable = -1, + input_transform=None, + loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ignore_index=255, + align_corners=False, + ): + self._init_inputs(in_channels, in_index, input_transform) + self.channels = channels + self.num_classes = num_classes + self.dropout_ratio = dropout_ratio + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.in_index = in_index + self.loss_decode = loss_decode + self.ignore_index = ignore_index + self.align_corners = align_corners + + self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) + if dropout_ratio > 0: + self.dropout = nn.Dropout2d(dropout_ratio) + else: + self.dropout = None + self.fp16_enabled = False + + def _normal_init( + self, module: nn.Module, mean: float = 0, std: float = 1, bias: float = 0 + ) -> None: + if hasattr(module, "weight") and module.weight is not None: + nn.init.normal_(module.weight, mean, std) + if hasattr(module, "bias") and module.bias is not None: + nn.init.constant_(module.bias, bias) + + def init_weights(self): + self._normal_init(self.conv_seg, std=0.01) + + def _transform_inputs(self, inputs: list[torch.Tensor]) -> torch.Tensor: + """Transform inputs for decoder. + + Args: + inputs (list[Tensor]): List of multi-level img features. + + Returns: + Tensor: The transformed inputs + """ + + if not isinstance(self.in_index, int): + if self.input_transform == "resize_concat": + inputs = [inputs[i] for i in self.in_index] + upsampled_inputs = [ + resize( + input=x, + size=inputs[0].shape[2:], + mode="bilinear", + align_corners=self.align_corners, + ) + for x in inputs + ] + inputs = torch.cat(upsampled_inputs, dim=1) + elif self.input_transform == "multiple_select": + inputs = [inputs[i] for i in self.in_index] + else: + inputs = inputs[self.in_index] + + return inputs + + def extra_repr(self): + """Extra repr.""" + s = ( + f"input_transform={self.input_transform}, " + f"ignore_index={self.ignore_index}, " + f"align_corners={self.align_corners}" + ) + return s + + def _init_inputs(self, in_channels, in_index, input_transform): + """Check and initialize input transforms. + + The in_channels, in_index and input_transform must match. + Specifically, when input_transform is None, only single feature map + will be selected. So in_channels and in_index must be of type int. + When input_transform + + Args: + in_channels (int|Sequence[int]): Input channels. + in_index (int|Sequence[int]): Input feature index. + input_transform (str|None): Transformation type of input features. + Options: 'resize_concat', 'multiple_select', None. + 'resize_concat': Multiple feature maps will be resize to the + same size as first one and than concat together. + Usually used in FCN head of HRNet. + 'multiple_select': Multiple feature maps will be bundle into + a list and passed into decode head. + None: Only one select feature map is allowed. + """ + + if input_transform is not None: + assert input_transform in ["resize_concat", "multiple_select"] + self.input_transform = input_transform + self.in_index = in_index + if input_transform is not None: + assert isinstance(in_channels, (list, tuple)) + assert isinstance(in_index, (list, tuple)) + assert len(in_channels) == len(in_index) + if input_transform == "resize_concat": + self.in_channels = sum(in_channels) + else: + self.in_channels = in_channels + else: + assert isinstance(in_channels, int) + assert isinstance(in_index, int) + self.in_channels = in_channels + + def forward(self, inputs): + raise NotImplementedError("forward method must be implemented.") + + def cls_seg(self, feat): + """Classify each pixel.""" + if self.dropout is not None: + feat = self.dropout(feat) + output = self.conv_seg(feat) + return output + + def losses(self, seg_logit, seg_label): + """Compute segmentation loss.""" + loss = dict() + seg_logit = resize( + input=seg_logit, + size=seg_label.shape[2:], + mode="bilinear", + align_corners=self.align_corners, + ) + if self.sampler is not None: + seg_weight = self.sampler.sample(seg_logit, seg_label) + else: + seg_weight = None + seg_label = seg_label.squeeze(1) + loss["loss_seg"] = self.loss_decode( + seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index + ) + loss["acc_seg"] = Accuracy(task="multiclass", ignore_index=self.ignore_index)( + seg_logit, seg_label + ) + return loss + + +def resize( + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=True, +): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index dd298b9..503be9c 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.nn.functional as F +from sslt.models.nets.base import BaseDecodeHead + # This implementation is based and addapted from Fudan Zhang Vision Group SETR implementation. # You can find the original implementation here: https://github.com/fudan-zvg/SETR/blob/main/mmseg/models/backbones/vit.py#L3 @@ -407,15 +409,52 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: return tuple(outs) +class MLA_Aux_Head(BaseDecodeHead): + """Vision Transformer with support for patch or hybrid CNN input stage""" + + def __init__(self, img_size=768, **kwargs): + super(MLA_Aux_Head, self).__init__(**kwargs) + self.img_size = img_size + if self.in_channels == 1024: + self.aux_0 = nn.Conv2d(self.in_channels, 256, kernel_size=1, bias=False) + self.aux_1 = nn.Conv2d(256, self.num_classes, kernel_size=1, bias=False) + elif self.in_channels == 256: + self.aux = nn.Conv2d( + self.in_channels, self.num_classes, kernel_size=1, bias=False + ) + + def to_2D(self, x): + n, hw, c = x.shape + h = w = int(math.sqrt(hw)) + x = x.transpose(1, 2).reshape(n, c, h, w) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._transform_inputs(x) + if x.dim() == 3: + x = x[:, 1:] + x = self.to_2D(x) + + if self.in_channels == 1024: + x = self.aux_0(x) + x = self.aux_1(x) + elif self.in_channels == 256: + x = self.aux(x) + x = F.interpolate( + x, size=self.img_size, mode="bilinear", align_corners=self.align_corners + ) + return x + + class MLAHead(nn.Module): - def build_norm_layer(self, mlahead_channels): + def build_norm_layer(self, mlahead_channels: int) -> nn.SyncBatchNorm: layer = nn.SyncBatchNorm(mlahead_channels, eps=1e-5) for param in layer.parameters(): param.requires_grad = True return layer - def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None): + def __init__(self, mla_channels=256, mlahead_channels=128): super(MLAHead, self).__init__() self.head2 = nn.Sequential( nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), @@ -450,7 +489,9 @@ def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None): nn.ReLU(), ) - def forward(self, x2, x3, x4, x5): + def forward( + self, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, x5: torch.Tensor + ) -> torch.Tensor: x2 = F.interpolate( self.head2(x2), 4 * x2.shape[-1], From b771f167c812fd4ba3ae86b80236ff139b39b263 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 27 Feb 2024 17:47:56 -0300 Subject: [PATCH 05/24] agora o setr sai --- sslt/models/nets/setr.py | 139 +++++++++++++++++++++++++++++++++++++++ sslt/utils/upsample.py | 54 +++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 sslt/models/nets/setr.py create mode 100644 sslt/utils/upsample.py diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py new file mode 100644 index 0000000..7a95e42 --- /dev/null +++ b/sslt/models/nets/setr.py @@ -0,0 +1,139 @@ +from typing import Optional, Tuple + +import torch +from torch import nn + +from sslt.utils.upsample import Upsample + + +class SETRUPHead(nn.Module): + """Naive upsampling head and Progressive upsampling head of SETR. + + Naive or PUP head of `SETR `_. + + """ + + def __init__( + self, + norm_layer: nn.Module, + conv_norm: nn.Module, + conv_act: nn.Module, + in_channels: int, + out_channels: int, + size: Optional[Tuple[int, int]] = None, + num_convs: int = 1, + up_scale: int = 4, + kernel_size: int = 3, + align_corners: bool = False, + ): + + assert kernel_size in [1, 3], "kernel_size must be 1 or 3." + + super().__init__() + + self.size = size + self.norm = norm_layer + self.conv_norm = conv_norm + self.conv_act = conv_act + self.in_channels = in_channels + self.channels = out_channels + self.align_corners = align_corners + + self.up_convs = nn.ModuleList() + in_channels = self.in_channels + out_channels = self.channels + for _ in range(num_convs): + self.up_convs.append( + nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, + ), + self.conv_norm, + self.conv_act, + Upsample( + scale_factor=up_scale, + mode="bilinear", + align_corners=self.align_corners, + ), + ) + ) + in_channels = out_channels + + def forward(self, x): + x = self._transform_inputs(x) + + n, c, h, w = x.shape + x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() + x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + + for up_conv in self.up_convs: + x = up_conv(x) + out = self.cls_seg(x) + return out + + +class SETRMLAHead(nn.Module): + """Multi level feature aggretation head of SETR. + + MLA head of `SETR `_. + + Args: + mlahead_channels (int): Channels of conv-conv-4x of multi-level feature + aggregation. Default: 128. + up_scale (int): The scale factor of interpolate. Default:4. + """ + + def __init__( + self, + mla_channels=128, + up_scale=4, + ): + super().__init__(input_transform="multiple_select") + self.mla_channels = mla_channels + + num_inputs = len(self.in_channels) + + # Refer to self.cls_seg settings of BaseDecodeHead + assert self.channels == num_inputs * mla_channels + + self.up_convs = nn.ModuleList() + for i in range(num_inputs): + self.up_convs.append( + nn.Sequential( + ConvModule( + in_channels=self.in_channels[i], + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ), + ConvModule( + in_channels=mla_channels, + out_channels=mla_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + ), + Upsample( + scale_factor=up_scale, + mode="bilinear", + align_corners=self.align_corners, + ), + ) + ) + + def forward(self, inputs): + inputs = self._transform_inputs(inputs) + outs = [] + for x, up_conv in zip(inputs, self.up_convs): + outs.append(up_conv(x)) + out = torch.cat(outs, dim=1) + out = self.cls_seg(out) + return out diff --git a/sslt/utils/upsample.py b/sslt/utils/upsample.py new file mode 100644 index 0000000..522bc95 --- /dev/null +++ b/sslt/utils/upsample.py @@ -0,0 +1,54 @@ +import warnings +from typing import Optional, Tuple + +import torch.nn as nn +import torch.nn.functional as F + + +def resize( + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + warning=True, +): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ( + (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) + and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1) + ): + warnings.warn( + f"When align_corners={align_corners}, " + "the output would more aligned if " + f"input size {(input_h, input_w)} is `x+1` and " + f"out size {(output_h, output_w)} is `nx+1`" + ) + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__( + self, size=None, scale_factor=None, mode="nearest", align_corners=None + ): + super().__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) From 77d43409dac0de8d0bace7ec4453e997d7627ea3 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Wed, 28 Feb 2024 15:18:48 -0300 Subject: [PATCH 06/24] bases done??? --- sslt/models/nets/base.py | 223 +------------ sslt/models/nets/setr.py | 173 +++++++--- sslt/models/nets/unet.py | 31 +- sslt/models/nets/vit.py | 630 +++++++++--------------------------- sslt/models/nets/wisenet.py | 12 +- 5 files changed, 286 insertions(+), 783 deletions(-) diff --git a/sslt/models/nets/base.py b/sslt/models/nets/base.py index 9153b49..9436c9c 100644 --- a/sslt/models/nets/base.py +++ b/sslt/models/nets/base.py @@ -1,11 +1,7 @@ -import warnings -from typing import Dict, Iterable +from typing import Dict import lightning as L import torch -import torch.nn as nn -import torch.nn.functional as F -from torchmetrics import Accuracy class SimpleSupervisedModel(L.LightningModule): @@ -154,220 +150,3 @@ def configure_optimizers(self): lr=self.learning_rate, ) return optimizer - - -class BaseDecodeHead(nn.Module): - """Base class for BaseDecodeHead. - - Parameters - ---------- - in_channels (int|Sequence[int]): Input channels. - channels (int): Channels after modules, before conv_seg. - num_classes (int): Number of classes. - dropout_ratio (float): Ratio of dropout layer. Default: 0.1. - conv_cfg (dict|None): Config of conv layers. Default: None. - norm_cfg (dict|None): Config of norm layers. Default: None. - act_cfg (dict): Config of activation layers. - Default: dict(type='ReLU') - in_index (int|Sequence[int]): Input feature index. Default: -1 - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - Default: None. - loss_decode (dict): Config of decode loss. - Default: dict(type='CrossEntropyLoss'). - ignore_index (int): The label index to be ignored. Default: 255 - sampler (dict|None): The config of segmentation map sampler. - Default: None. - align_corners (bool): align_corners argument of F.interpolate. - Default: False. - """ - - def __init__( - self, - in_channels, - channels, - *, - num_classes, - dropout_ratio=0.1, - conv_cfg=None, - norm_cfg=None, - act_cfg=dict(type="ReLU"), - in_index: int | Iterable = -1, - input_transform=None, - loss_decode=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), - ignore_index=255, - align_corners=False, - ): - self._init_inputs(in_channels, in_index, input_transform) - self.channels = channels - self.num_classes = num_classes - self.dropout_ratio = dropout_ratio - self.conv_cfg = conv_cfg - self.norm_cfg = norm_cfg - self.act_cfg = act_cfg - self.in_index = in_index - self.loss_decode = loss_decode - self.ignore_index = ignore_index - self.align_corners = align_corners - - self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) - if dropout_ratio > 0: - self.dropout = nn.Dropout2d(dropout_ratio) - else: - self.dropout = None - self.fp16_enabled = False - - def _normal_init( - self, module: nn.Module, mean: float = 0, std: float = 1, bias: float = 0 - ) -> None: - if hasattr(module, "weight") and module.weight is not None: - nn.init.normal_(module.weight, mean, std) - if hasattr(module, "bias") and module.bias is not None: - nn.init.constant_(module.bias, bias) - - def init_weights(self): - self._normal_init(self.conv_seg, std=0.01) - - def _transform_inputs(self, inputs: list[torch.Tensor]) -> torch.Tensor: - """Transform inputs for decoder. - - Args: - inputs (list[Tensor]): List of multi-level img features. - - Returns: - Tensor: The transformed inputs - """ - - if not isinstance(self.in_index, int): - if self.input_transform == "resize_concat": - inputs = [inputs[i] for i in self.in_index] - upsampled_inputs = [ - resize( - input=x, - size=inputs[0].shape[2:], - mode="bilinear", - align_corners=self.align_corners, - ) - for x in inputs - ] - inputs = torch.cat(upsampled_inputs, dim=1) - elif self.input_transform == "multiple_select": - inputs = [inputs[i] for i in self.in_index] - else: - inputs = inputs[self.in_index] - - return inputs - - def extra_repr(self): - """Extra repr.""" - s = ( - f"input_transform={self.input_transform}, " - f"ignore_index={self.ignore_index}, " - f"align_corners={self.align_corners}" - ) - return s - - def _init_inputs(self, in_channels, in_index, input_transform): - """Check and initialize input transforms. - - The in_channels, in_index and input_transform must match. - Specifically, when input_transform is None, only single feature map - will be selected. So in_channels and in_index must be of type int. - When input_transform - - Args: - in_channels (int|Sequence[int]): Input channels. - in_index (int|Sequence[int]): Input feature index. - input_transform (str|None): Transformation type of input features. - Options: 'resize_concat', 'multiple_select', None. - 'resize_concat': Multiple feature maps will be resize to the - same size as first one and than concat together. - Usually used in FCN head of HRNet. - 'multiple_select': Multiple feature maps will be bundle into - a list and passed into decode head. - None: Only one select feature map is allowed. - """ - - if input_transform is not None: - assert input_transform in ["resize_concat", "multiple_select"] - self.input_transform = input_transform - self.in_index = in_index - if input_transform is not None: - assert isinstance(in_channels, (list, tuple)) - assert isinstance(in_index, (list, tuple)) - assert len(in_channels) == len(in_index) - if input_transform == "resize_concat": - self.in_channels = sum(in_channels) - else: - self.in_channels = in_channels - else: - assert isinstance(in_channels, int) - assert isinstance(in_index, int) - self.in_channels = in_channels - - def forward(self, inputs): - raise NotImplementedError("forward method must be implemented.") - - def cls_seg(self, feat): - """Classify each pixel.""" - if self.dropout is not None: - feat = self.dropout(feat) - output = self.conv_seg(feat) - return output - - def losses(self, seg_logit, seg_label): - """Compute segmentation loss.""" - loss = dict() - seg_logit = resize( - input=seg_logit, - size=seg_label.shape[2:], - mode="bilinear", - align_corners=self.align_corners, - ) - if self.sampler is not None: - seg_weight = self.sampler.sample(seg_logit, seg_label) - else: - seg_weight = None - seg_label = seg_label.squeeze(1) - loss["loss_seg"] = self.loss_decode( - seg_logit, seg_label, weight=seg_weight, ignore_index=self.ignore_index - ) - loss["acc_seg"] = Accuracy(task="multiclass", ignore_index=self.ignore_index)( - seg_logit, seg_label - ) - return loss - - -def resize( - input, - size=None, - scale_factor=None, - mode="nearest", - align_corners=None, - warning=True, -): - if warning: - if size is not None and align_corners: - input_h, input_w = tuple(int(x) for x in input.shape[2:]) - output_h, output_w = tuple(int(x) for x in size) - if output_h > input_h or output_w > output_h: - if ( - (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) - and (output_h - 1) % (input_h - 1) - and (output_w - 1) % (input_w - 1) - ): - warnings.warn( - f"When align_corners={align_corners}, " - "the output would more aligned if " - f"input size {(input_h, input_w)} is `x+1` and " - f"out size {(output_h, output_w)} is `nx+1`" - ) - if isinstance(size, torch.Size): - size = tuple(int(x) for x in size) - return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 7a95e42..6977d24 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Tuple import torch @@ -6,7 +7,7 @@ from sslt.utils.upsample import Upsample -class SETRUPHead(nn.Module): +class _SETRUPHead(nn.Module): """Naive upsampling head and Progressive upsampling head of SETR. Naive or PUP head of `SETR `_. @@ -15,33 +16,60 @@ class SETRUPHead(nn.Module): def __init__( self, - norm_layer: nn.Module, - conv_norm: nn.Module, - conv_act: nn.Module, + channels: int, + norm_layer: Optional[nn.Module], + conv_norm: Optional[nn.Module], + conv_act: Optional[nn.Module], in_channels: int, out_channels: int, - size: Optional[Tuple[int, int]] = None, + num_classes: int, num_convs: int = 1, up_scale: int = 4, kernel_size: int = 3, - align_corners: bool = False, + align_corners: bool = True, + dropout: float = 0.1, + threshold: Optional[float] = None, ): assert kernel_size in [1, 3], "kernel_size must be 1 or 3." super().__init__() - self.size = size - self.norm = norm_layer - self.conv_norm = conv_norm - self.conv_act = conv_act - self.in_channels = in_channels - self.channels = out_channels - self.align_corners = align_corners + if out_channels is None: + if num_classes == 2: + warnings.warn( + "For binary segmentation, we suggest using" + "`out_channels = 1` to define the output" + "channels of segmentor, and use `threshold`" + "to convert `seg_logits` into a prediction" + "applying a threshold" + ) + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + "out_channels should be equal to num_classes," + "except binary segmentation set out_channels == 1 and" + f"num_classes == 2, but got out_channels={out_channels}" + f"and num_classes={num_classes}" + ) + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn("threshold is not defined for binary, and defaults" "to 0.3") + + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + self.norm = norm_layer if norm_layer is not None else nn.SyncBatchNorm(channels) + conv_norm = ( + conv_norm if conv_norm is not None else nn.SyncBatchNorm(out_channels) + ) + conv_act = conv_act if conv_act is not None else nn.ReLU() + self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None + self.cls_seg = nn.Conv2d(channels, out_channels, 1) self.up_convs = nn.ModuleList() - in_channels = self.in_channels - out_channels = self.channels + for _ in range(num_convs): self.up_convs.append( nn.Sequential( @@ -52,19 +80,18 @@ def __init__( padding=kernel_size // 2, bias=False, ), - self.conv_norm, - self.conv_act, + conv_norm, + conv_act, Upsample( scale_factor=up_scale, mode="bilinear", - align_corners=self.align_corners, + align_corners=align_corners, ), ) ) in_channels = out_channels - def forward(self, x): - x = self._transform_inputs(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() @@ -73,67 +100,111 @@ def forward(self, x): for up_conv in self.up_convs: x = up_conv(x) + + if self.dropout is not None: + x = self.dropout(x) out = self.cls_seg(x) + return out -class SETRMLAHead(nn.Module): +class _SETRMLAHead(nn.Module): """Multi level feature aggretation head of SETR. MLA head of `SETR `_. - - Args: - mlahead_channels (int): Channels of conv-conv-4x of multi-level feature - aggregation. Default: 128. - up_scale (int): The scale factor of interpolate. Default:4. """ def __init__( self, - mla_channels=128, - up_scale=4, + channels: int, + conv_norm: Optional[nn.Module], + conv_act: Optional[nn.Module], + in_channels: list[int], + out_channels: int, + num_classes: int, + mla_channels: int = 128, + up_scale: int = 4, + kernel_size: int = 3, + align_corners: bool = True, + dropout: float = 0.1, + threshold: Optional[float] = None, ): - super().__init__(input_transform="multiple_select") - self.mla_channels = mla_channels + super().__init__() - num_inputs = len(self.in_channels) + conv_norm = ( + conv_norm if conv_norm is not None else nn.SyncBatchNorm(mla_channels) + ) + conv_act = conv_act if conv_act is not None else nn.ReLU() - # Refer to self.cls_seg settings of BaseDecodeHead - assert self.channels == num_inputs * mla_channels + self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None + + self.cls_seg = nn.Conv2d(channels, out_channels, 1) + + if out_channels is None: + if num_classes == 2: + warnings.warn( + "For binary segmentation, we suggest using" + "`out_channels = 1` to define the output" + "channels of segmentor, and use `threshold`" + "to convert `seg_logits` into a prediction" + "applying a threshold" + ) + out_channels = num_classes + + if out_channels != num_classes and out_channels != 1: + raise ValueError( + "out_channels should be equal to num_classes," + "except binary segmentation set out_channels == 1 and" + f"num_classes == 2, but got out_channels={out_channels}" + f"and num_classes={num_classes}" + ) + + if out_channels == 1 and threshold is None: + threshold = 0.3 + warnings.warn("threshold is not defined for binary, and defaults" "to 0.3") + + self.num_classes = num_classes + self.out_channels = out_channels + self.threshold = threshold + + num_inputs = len(self.in_channels) self.up_convs = nn.ModuleList() for i in range(num_inputs): self.up_convs.append( nn.Sequential( - ConvModule( - in_channels=self.in_channels[i], - out_channels=mla_channels, - kernel_size=3, - padding=1, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, + nn.Conv2d( + in_channels[i], + mla_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, ), - ConvModule( - in_channels=mla_channels, - out_channels=mla_channels, - kernel_size=3, - padding=1, - norm_cfg=self.norm_cfg, - act_cfg=self.act_cfg, + conv_norm, + conv_act, + nn.Conv2d( + mla_channels, + mla_channels, + kernel_size, + padding=kernel_size // 2, + bias=False, ), + conv_norm, + conv_act, Upsample( scale_factor=up_scale, mode="bilinear", - align_corners=self.align_corners, + align_corners=align_corners, ), ) ) - def forward(self, inputs): - inputs = self._transform_inputs(inputs) + def forward(self, x): outs = [] - for x, up_conv in zip(inputs, self.up_convs): + for x, up_conv in zip(x, self.up_convs): outs.append(up_conv(x)) out = torch.cat(outs, dim=1) + if self.dropout is not None: + out = self.dropout(out) out = self.cls_seg(out) return out diff --git a/sslt/models/nets/unet.py b/sslt/models/nets/unet.py index a7803eb..ee5decf 100644 --- a/sslt/models/nets/unet.py +++ b/sslt/models/nets/unet.py @@ -1,18 +1,17 @@ """ Full assembly of the parts to form the complete network """ -from typing import Dict -import lightning as L -import torch.optim as optim -from torch.optim.lr_scheduler import CyclicLR -from torch.optim.lr_scheduler import StepLR import time +from typing import Dict, Optional + +import lightning as L import torch import torch.nn as nn import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import CyclicLR, StepLR from sslt.models.nets.base import SimpleSupervisedModel - """ -------------- Parts of the U-Net model --------------""" @@ -44,15 +43,11 @@ def __init__(self, in_channels, out_channels, mid_channels=None): nn.Conv2d( in_channels, mid_channels, kernel_size=3, padding=1, bias=False ), # no need to add bias since BatchNorm2d will do that - nn.BatchNorm2d( - mid_channels - ), # normalize the output of the previous layer + nn.BatchNorm2d(mid_channels), # normalize the output of the previous layer nn.ReLU( inplace=True ), # inplace=True will modify the input directly instead of allocating new memory - nn.Conv2d( - mid_channels, out_channels, kernel_size=3, padding=1, bias=False - ), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) @@ -82,9 +77,7 @@ def __init__(self, in_channels, out_channels, bilinear=True): # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: - self.up = nn.Upsample( - scale_factor=2, mode="bilinear", align_corners=True - ) + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = _DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d( @@ -99,9 +92,7 @@ def forward(self, x1, x2): diffX = x2.size()[3] - x1.size()[3] # pad the input tensor on all sides with the given "pad" value - x1 = F.pad( - x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2] - ) + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd @@ -209,7 +200,7 @@ def __init__( n_channels: int = 1, bilinear: bool = False, learning_rate: float = 1e-3, - loss_fn: torch.nn.Module = None, + loss_fn: Optional[torch.nn.Module] = None, ): """Wrapper implementation of the U-Net model. @@ -231,5 +222,5 @@ def __init__( fc=torch.nn.Identity(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, - flatten=False + flatten=False, ) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 503be9c..08b043b 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -1,519 +1,181 @@ import math +from collections import OrderedDict from functools import partial +from typing import Callable, List, Optional import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn +from torch.nn import functional as F +from torchvision.models.vision_transformer import ( + Conv2dNormActivation, + ConvStemConfig, + Encoder, + _log_api_usage_once, +) -from sslt.models.nets.base import BaseDecodeHead -# This implementation is based and addapted from Fudan Zhang Vision Group SETR implementation. -# You can find the original implementation here: https://github.com/fudan-zvg/SETR/blob/main/mmseg/models/backbones/vit.py#L3 - - -class DropPath(nn.Module): - - def __init__(self, drop_prob: float = 0) -> None: - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.drop_prob == 0.0 or not self.training: - return x - - keep_prob = 1 - self.drop_prob - shape = (x.shape[0],) + (1,) * (x.dim() - 1) - random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() - output = x.div(keep_prob) * random_tensor - return output - - -class Mlp(nn.Module): - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop: float = 0.0, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): +class _VisionTransformer(nn.Module): + """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" def __init__( self, - dim, - num_heads: int = 8, - qkv_bias: bool = False, - qk_scale=None, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - ) -> None: + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + heads: nn.Sequential, + dropout: float = 0.0, + attention_dropout: float = 0.0, + num_classes: int = 1000, + representation_size: Optional[int] = None, + norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), + conv_stem_configs: Optional[List[ConvStemConfig]] = None, + ): super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B, N, 3, self.num_heads, C // self.num_heads) - .permute(2, 0, 3, 1, 4) + _log_api_usage_once(self) + torch._assert( + image_size % patch_size == 0, "Input shape indivisible by patch size!" ) - q, k, v = qkv[0], qkv[1], qkv[2] - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - - def __init__( - self, - dim, - num_heads, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer: partial[nn.LayerNorm] | nn.LayerNorm = partial(nn.LayerNorm), - ) -> None: - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - -class PatchEmbed(nn.Module): - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768) -> None: - super().__init__() - img_size = (img_size, img_size) - patch_size = (patch_size, patch_size) - num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, C, H, W = x.shape - - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - -class HybridEmbed(nn.Module): - - def __init__( - self, - backbone: nn.Module, - img_size=224, - feature_size=None, - in_chans=3, - embed_dim=768, - ) -> None: - super().__init__() - assert isinstance(backbone, nn.Module), "backbone must be nn.Module" - self.backbone = backbone - self.img_size = (img_size, img_size) - - # FIXME (from original code) this is hacky, but most reliable way of determining the exact dim of the output feature - # map for all networks, the feature metadata has reliable channel and stride info, but using - # stride to calc feature dim requires info about padding of each stage that isn't captured. - if feature_size is None: - with torch.no_grad(): - training = self.backbone.training - if training: - self.backbone.eval() - o = self.backbone( - torch.zeros(1, in_chans, self.img_size[0], self.img_size[1]) - )[-1] - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - else: - feature_size = (feature_size, feature_size) - feature_dim = self.backbone.feature_info.channels()[-1] - - self.num_patches = feature_size[0] * feature_size[1] - self.proj = nn.Linear(feature_dim, embed_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.backbone(x)[-1] - x = x.flatten(2).transpose(1, 2) - x = self.proj(x) - return x - - -class VisionTransformer(nn.Module): - """ - Vision Transformer model implementation. - - Parameters - ---------- - img_size: int - Size of the input image. Default is 384. - patch_size: int - Size of the image patch. Default is 16. - in_chans: int - Number of input channels. Default is 3. - embed_dim: int - Dimensionality of the token embeddings. Default is 1024. - depth: int - Number of transformer blocks. Default is 24. - num_heads: int - Number of attention heads. Default is 16. - num_classes: int - Number of output classes. Default is 19. - mlp_ratio: float - Ratio of MLP hidden dimension to embedding dimension. Default is 4.0. - qkv_bias: bool - Whether to include bias in the query, key, and value projections. Default is True. - qk_scale: float - Scale factor for query and key. Default is None. - drop_rate: float - Dropout rate. Default is 0.1. - attn_drop_rate: float - Dropout rate for attention weights. Default is 0.0. - drop_path_rate: float - Dropout rate for stochastic depth. Default is 0.0. - hybrid_backbone: None | nn.Module - Hybrid backbone module. Default is None. - norm_layer: nn.Module - Normalization layer. Default is nn.LayerNorm with eps=1e-6. - norm_cfg: None | dict - Normalization configuration. Default is None. - pos_embed_interp: bool - Whether to interpolate positional embeddings. Default is False. - random_init: bool - Whether to initialize weights randomly. Default is False. - align_corners: bool - Whether to align corners in positional embeddings. Default is False. - - References - ---------- - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929 - - """ - - def __init__( - self, - img_size=384, - patch_size=16, - in_chans=3, - embed_dim=1024, - depth=24, - num_heads=16, - num_classes=19, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop_rate=0.1, - attn_drop_rate=0.0, - drop_path_rate=0.0, - hybrid_backbone=None, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - norm_cfg=None, - pos_embed_interp=False, - random_init=False, - align_corners=False, - **kwargs, - ) -> None: - super(VisionTransformer, self).__init__(**kwargs) - self.img_size = img_size + self.image_size = image_size self.patch_size = patch_size - self.in_chans = in_chans - self.embed_dim = embed_dim - self.depth = depth - self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.attention_dropout = attention_dropout + self.dropout = dropout self.num_classes = num_classes - self.mlp_ratio = mlp_ratio - self.qkv_bias = qkv_bias - self.qk_scale = qk_scale - self.drop_rate = drop_rate - self.attn_drop_rate = attn_drop_rate - self.drop_path_rate = drop_path_rate - self.hybrid_backbone = hybrid_backbone + self.representation_size = representation_size self.norm_layer = norm_layer - self.norm_cfg = norm_cfg - self.pos_embed_interp = pos_embed_interp - self.random_init = random_init - self.align_corners = align_corners - - self.num_stages = self.depth - self.out_indices = tuple(range(self.num_stages)) - if self.hybrid_backbone is not None: - self.patch_embed = HybridEmbed( - self.hybrid_backbone, - img_size=self.img_size, - in_chans=self.in_chans, - embed_dim=self.embed_dim, + if conv_stem_configs is not None: + # As per https://arxiv.org/abs/2106.14881 + seq_proj = nn.Sequential() + prev_channels = 3 + for i, conv_stem_layer_config in enumerate(conv_stem_configs): + seq_proj.add_module( + f"conv_bn_relu_{i}", + Conv2dNormActivation( + in_channels=prev_channels, + out_channels=conv_stem_layer_config.out_channels, + kernel_size=conv_stem_layer_config.kernel_size, + stride=conv_stem_layer_config.stride, + norm_layer=conv_stem_layer_config.norm_layer, + activation_layer=conv_stem_layer_config.activation_layer, + ), + ) + prev_channels = conv_stem_layer_config.out_channels + seq_proj.add_module( + "conv_last", + nn.Conv2d( + in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1 + ), ) + self.conv_proj: nn.Module = seq_proj else: - self.patch_embed = PatchEmbed( - img_size=self.img_size, - patch_size=self.patch_size, - in_chans=self.in_chans, - embed_dim=self.embed_dim, + self.conv_proj = nn.Conv2d( + in_channels=3, + out_channels=hidden_dim, + kernel_size=patch_size, + stride=patch_size, ) - self.num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - self.pos_embed = nn.Parameter( - torch.zeros(1, self.num_patches + 1, self.embed_dim) - ) - self.pos_drop = nn.Dropout(p=self.drop_rate) - - dpr = [ - x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth) - ] # stochastic depth decay rule - self.blocks = nn.ModuleList( - [ - Block( - dim=self.embed_dim, - num_heads=self.num_heads, - mlp_ratio=self.mlp_ratio, - qkv_bias=self.qkv_bias, - qk_scale=self.qk_scale, - drop=self.drop_rate, - attn_drop=self.attn_drop_rate, - drop_path=dpr[i], - norm_layer=self.norm_layer, - ) - for i in range(self.depth) - ] + seq_length = (image_size // patch_size) ** 2 + + # Add a class token + self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + seq_length += 1 + + self.encoder = Encoder( + seq_length, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + norm_layer, ) + self.seq_length = seq_length - # NOTE (from original code) as per official impl, we could have a pre-logits representation dense layer + tanh here - # self.repr = nn.Linear(embed_dim, representation_size) - # self.repr_act = nn.Tanh() + self.heads = heads - nn.init.trunc_normal_(self.pos_embed, std=0.02) - nn.init.trunc_normal_(self.cls_token, std=0.02) - - def init_weights(self, pretrained=None) -> None: - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - if not self.random_init: - raise NotImplementedError("Pretrained model is not supported yet") - else: - print("Initialize weight randomly") - - def _conv_filter(self, state_dict, patch_size=16) -> dict: - """convert patch embedding weight from manual patchify + linear proj to conv""" - out_dict = {} - for k, v in state_dict.items(): - if "patch_embed.proj.weight" in k: - v = v.reshape((v.shape[0], 3, patch_size, patch_size)) - out_dict[k] = v - return out_dict + if isinstance(self.conv_proj, nn.Conv2d): + # Init the patchify stem + fan_in = ( + self.conv_proj.in_channels + * self.conv_proj.kernel_size[0] + * self.conv_proj.kernel_size[1] + ) + nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) + if self.conv_proj.bias is not None: + nn.init.zeros_(self.conv_proj.bias) + elif self.conv_proj.conv_last is not None and isinstance( + self.conv_proj.conv_last, nn.Conv2d + ): + # Init the last 1x1 conv of the conv stem + nn.init.normal_( + self.conv_proj.conv_last.weight, + mean=0.0, + std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels), + ) + if self.conv_proj.conv_last.bias is not None: + nn.init.zeros_(self.conv_proj.conv_last.bias) + + if hasattr(self.heads, "pre_logits") and isinstance( + self.heads.pre_logits, nn.Linear + ): + fan_in = self.heads.pre_logits.in_features + nn.init.trunc_normal_( + self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in) + ) + nn.init.zeros_(self.heads.pre_logits.bias) - def to_2D(self, x: torch.Tensor) -> torch.Tensor: - n, hw, c = x.shape - h = w = int(math.sqrt(hw)) - x = x.transpose(1, 2).reshape(n, c, h, w) - return x + if isinstance(self.heads.head, nn.Linear): + nn.init.zeros_(self.heads.head.weight) + nn.init.zeros_(self.heads.head.bias) - def to_1D(self, x: torch.Tensor) -> torch.Tensor: + def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape - x = x.reshape(n, c, -1).transpose(1, 2) - return x - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: - B = x.shape[0] - x = self.patch_embed(x) - - x = x.flatten(2).transpose(1, 2) - - # originaly credited to Phil Wang - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embed - x = self.pos_drop(x) - - outs = [] - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in self.out_indices: - outs.append(x) - return tuple(outs) - + p = self.patch_size + torch._assert( + h == self.image_size, + f"Wrong image height! Expected {self.image_size} but got {h}!", + ) + torch._assert( + w == self.image_size, + f"Wrong image width! Expected {self.image_size} but got {w}!", + ) + n_h = h // p + n_w = w // p -class MLA_Aux_Head(BaseDecodeHead): - """Vision Transformer with support for patch or hybrid CNN input stage""" + # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) + x = self.conv_proj(x) + # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) + x = x.reshape(n, self.hidden_dim, n_h * n_w) - def __init__(self, img_size=768, **kwargs): - super(MLA_Aux_Head, self).__init__(**kwargs) - self.img_size = img_size - if self.in_channels == 1024: - self.aux_0 = nn.Conv2d(self.in_channels, 256, kernel_size=1, bias=False) - self.aux_1 = nn.Conv2d(256, self.num_classes, kernel_size=1, bias=False) - elif self.in_channels == 256: - self.aux = nn.Conv2d( - self.in_channels, self.num_classes, kernel_size=1, bias=False - ) + # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) + # The self attention layer expects inputs in the format (N, S, E) + # where S is the source sequence length, N is the batch size, E is the + # embedding dimension + x = x.permute(0, 2, 1) - def to_2D(self, x): - n, hw, c = x.shape - h = w = int(math.sqrt(hw)) - x = x.transpose(1, 2).reshape(n, c, h, w) return x - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self._transform_inputs(x) - if x.dim() == 3: - x = x[:, 1:] - x = self.to_2D(x) - - if self.in_channels == 1024: - x = self.aux_0(x) - x = self.aux_1(x) - elif self.in_channels == 256: - x = self.aux(x) - x = F.interpolate( - x, size=self.img_size, mode="bilinear", align_corners=self.align_corners - ) - return x + def forward(self, x: torch.Tensor): + # Reshape and permute the input tensor + x = self._process_input(x) + n = x.shape[0] + # Expand the class token to the full batch + batch_class_token = self.class_token.expand(n, -1, -1) + x = torch.cat([batch_class_token, x], dim=1) -class MLAHead(nn.Module): + x = self.encoder(x) - def build_norm_layer(self, mlahead_channels: int) -> nn.SyncBatchNorm: - layer = nn.SyncBatchNorm(mlahead_channels, eps=1e-5) - for param in layer.parameters(): - param.requires_grad = True - return layer + # Classifier "token" as used by standard language architectures + x = x[:, 0] - def __init__(self, mla_channels=256, mlahead_channels=128): - super(MLAHead, self).__init__() - self.head2 = nn.Sequential( - nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - ) - self.head3 = nn.Sequential( - nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - ) - self.head4 = nn.Sequential( - nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - ) - self.head5 = nn.Sequential( - nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - nn.Conv2d(mlahead_channels, mlahead_channels, 3, padding=1, bias=False), - self.build_norm_layer(mlahead_channels), - nn.ReLU(), - ) + x = self.heads(x) - def forward( - self, x2: torch.Tensor, x3: torch.Tensor, x4: torch.Tensor, x5: torch.Tensor - ) -> torch.Tensor: - x2 = F.interpolate( - self.head2(x2), - 4 * x2.shape[-1], - mode="bilinear", - align_corners=True, - ) - x3 = F.interpolate( - self.head3(x3), - 8 * x3.shape[-1], - mode="bilinear", - align_corners=True, - ) - x4 = F.interpolate( - self.head4(x4), - 16 * x4.shape[-1], - mode="bilinear", - align_corners=True, - ) - x5 = F.interpolate( - self.head5(x5), - 32 * x5.shape[-1], - mode="bilinear", - align_corners=True, - ) - return torch.cat([x2, x3, x4, x5], dim=1) + return x diff --git a/sslt/models/nets/wisenet.py b/sslt/models/nets/wisenet.py index 4e8d186..ee4575f 100644 --- a/sslt/models/nets/wisenet.py +++ b/sslt/models/nets/wisenet.py @@ -1,4 +1,7 @@ +from typing import Optional + import torch + from sslt.models.nets.base import SimpleSupervisedModel @@ -98,13 +101,11 @@ def __init__( self, in_channels: int = 1, out_channels: int = 1, - loss_fn: torch.nn.Module = None, + loss_fn: Optional[torch.nn.Module] = None, learning_rate: float = 1e-3, ): super().__init__( - backbone=_WiseNet( - in_channels=in_channels, out_channels=out_channels - ), + backbone=_WiseNet(in_channels=in_channels, out_channels=out_channels), fc=torch.nn.Identity(), loss_fn=loss_fn or torch.nn.MSELoss(), learning_rate=learning_rate, @@ -129,9 +130,8 @@ def _single_step( ) return loss - def predict_step(self, batch, batch_idx, dataloader_idx=None): x, y = batch y_hat = self.forward(x) y_hat = y_hat[:, :, : y.size(2), : y.size(3)] - return y_hat \ No newline at end of file + return y_hat From 3ecab055a770c6a7fcccd4a25f875f2ed3651774 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Thu, 29 Feb 2024 07:29:15 -0300 Subject: [PATCH 07/24] Add SETRUPHead and SETRMLAHead classes to setr.py --- sslt/models/nets/setr.py | 88 +++++++++++++++++++++++++--------------- sslt/models/nets/vit.py | 23 +---------- 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 6977d24..ddf4883 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -4,6 +4,8 @@ import torch from torch import nn +from sslt.models.nets.base import SimpleSupervisedModel +from sslt.models.nets.vit import _VisionTransformerBackbone from sslt.utils.upsample import Upsample @@ -35,29 +37,6 @@ def __init__( super().__init__() - if out_channels is None: - if num_classes == 2: - warnings.warn( - "For binary segmentation, we suggest using" - "`out_channels = 1` to define the output" - "channels of segmentor, and use `threshold`" - "to convert `seg_logits` into a prediction" - "applying a threshold" - ) - out_channels = num_classes - - if out_channels != num_classes and out_channels != 1: - raise ValueError( - "out_channels should be equal to num_classes," - "except binary segmentation set out_channels == 1 and" - f"num_classes == 2, but got out_channels={out_channels}" - f"and num_classes={num_classes}" - ) - - if out_channels == 1 and threshold is None: - threshold = 0.3 - warnings.warn("threshold is not defined for binary, and defaults" "to 0.3") - self.num_classes = num_classes self.out_channels = out_channels self.threshold = threshold @@ -70,6 +49,8 @@ def __init__( self.cls_seg = nn.Conv2d(channels, out_channels, 1) self.up_convs = nn.ModuleList() + out_channels = channels + for _ in range(num_convs): self.up_convs.append( nn.Sequential( @@ -131,15 +112,6 @@ def __init__( ): super().__init__() - conv_norm = ( - conv_norm if conv_norm is not None else nn.SyncBatchNorm(mla_channels) - ) - conv_act = conv_act if conv_act is not None else nn.ReLU() - - self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None - - self.cls_seg = nn.Conv2d(channels, out_channels, 1) - if out_channels is None: if num_classes == 2: warnings.warn( @@ -161,11 +133,17 @@ def __init__( if out_channels == 1 and threshold is None: threshold = 0.3 - warnings.warn("threshold is not defined for binary, and defaults" "to 0.3") + warnings.warn("threshold is not defined for binary, and defaults to 0.3") self.num_classes = num_classes self.out_channels = out_channels self.threshold = threshold + conv_norm = ( + conv_norm if conv_norm is not None else nn.SyncBatchNorm(mla_channels) + ) + conv_act = conv_act if conv_act is not None else nn.ReLU() + self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None + self.cls_seg = nn.Conv2d(channels, out_channels, 1) num_inputs = len(self.in_channels) @@ -208,3 +186,47 @@ def forward(self, x): out = self.dropout(out) out = self.cls_seg(out) return out + + +class _SetR_PUP(nn.Module): + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + num_classes: int, + dropout: float = 0.1, + attention_dropout: float = 0.1, + norm_layer: Optional[nn.Module] = None, + interpolate_mode: str = "bilinear", + ): + super().__init__() + self.encoder = _VisionTransformerBackbone( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + num_classes=num_classes, + dropout=dropout, + ) + + self.aux_head1 = _SETRUPHead( + channels=1024, + in_channels=hidden_dim, + out_channels=hidden_dim, + num_classes=6, + num_convs=4, + up_scale=2, + kernel_size=3, + align_corners=False, + dropout=0, + norm_layer=norm_layer, + conv_norm=None, # Add default value for conv_norm + conv_act=None, # Add default value for conv_act + ) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 08b043b..4dce6c8 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -5,7 +5,6 @@ import torch from torch import nn -from torch.nn import functional as F from torchvision.models.vision_transformer import ( Conv2dNormActivation, ConvStemConfig, @@ -14,7 +13,7 @@ ) -class _VisionTransformer(nn.Module): +class _VisionTransformerBackbone(nn.Module): """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" def __init__( @@ -25,11 +24,9 @@ def __init__( num_heads: int, hidden_dim: int, mlp_dim: int, - heads: nn.Sequential, dropout: float = 0.0, attention_dropout: float = 0.0, num_classes: int = 1000, - representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, ): @@ -45,7 +42,6 @@ def __init__( self.attention_dropout = attention_dropout self.dropout = dropout self.num_classes = num_classes - self.representation_size = representation_size self.norm_layer = norm_layer if conv_stem_configs is not None: @@ -98,8 +94,6 @@ def __init__( ) self.seq_length = seq_length - self.heads = heads - if isinstance(self.conv_proj, nn.Conv2d): # Init the patchify stem fan_in = ( @@ -122,19 +116,6 @@ def __init__( if self.conv_proj.conv_last.bias is not None: nn.init.zeros_(self.conv_proj.conv_last.bias) - if hasattr(self.heads, "pre_logits") and isinstance( - self.heads.pre_logits, nn.Linear - ): - fan_in = self.heads.pre_logits.in_features - nn.init.trunc_normal_( - self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in) - ) - nn.init.zeros_(self.heads.pre_logits.bias) - - if isinstance(self.heads.head, nn.Linear): - nn.init.zeros_(self.heads.head.weight) - nn.init.zeros_(self.heads.head.bias) - def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape p = self.patch_size @@ -176,6 +157,4 @@ def forward(self, x: torch.Tensor): # Classifier "token" as used by standard language architectures x = x[:, 0] - x = self.heads(x) - return x From 9de7ac4de06a59e11bd2e69ef6385e9390d5f01c Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Thu, 29 Feb 2024 12:36:57 -0300 Subject: [PATCH 08/24] progress --- sslt/models/nets/setr.py | 64 ++++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index ddf4883..8d08ae6 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -23,7 +23,6 @@ def __init__( conv_norm: Optional[nn.Module], conv_act: Optional[nn.Module], in_channels: int, - out_channels: int, num_classes: int, num_convs: int = 1, up_scale: int = 4, @@ -38,25 +37,23 @@ def __init__( super().__init__() self.num_classes = num_classes - self.out_channels = out_channels + self.out_channels = channels self.threshold = threshold self.norm = norm_layer if norm_layer is not None else nn.SyncBatchNorm(channels) conv_norm = ( - conv_norm if conv_norm is not None else nn.SyncBatchNorm(out_channels) + conv_norm if conv_norm is not None else nn.SyncBatchNorm(self.out_channels) ) conv_act = conv_act if conv_act is not None else nn.ReLU() self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None - self.cls_seg = nn.Conv2d(channels, out_channels, 1) + self.cls_seg = nn.Conv2d(channels, self.out_channels, 1) self.up_convs = nn.ModuleList() - out_channels = channels - for _ in range(num_convs): self.up_convs.append( nn.Sequential( nn.Conv2d( in_channels, - out_channels, + self.out_channels, kernel_size, padding=kernel_size // 2, bias=False, @@ -70,7 +67,7 @@ def __init__( ), ) ) - in_channels = out_channels + in_channels = self.out_channels def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -219,7 +216,6 @@ def __init__( self.aux_head1 = _SETRUPHead( channels=1024, in_channels=hidden_dim, - out_channels=hidden_dim, num_classes=6, num_convs=4, up_scale=2, @@ -230,3 +226,53 @@ def __init__( conv_norm=None, # Add default value for conv_norm conv_act=None, # Add default value for conv_act ) + + self.aux_head2 = _SETRUPHead( + channels=1024, + in_channels=hidden_dim, + num_classes=6, + num_convs=4, + up_scale=2, + kernel_size=3, + align_corners=False, + dropout=0, + norm_layer=norm_layer, + conv_norm=None, # Add default value for conv_norm + conv_act=None, # Add default value for conv_act + ) + + self.aux_head3 = _SETRUPHead( + channels=1024, + in_channels=hidden_dim, + num_classes=6, + num_convs=4, + up_scale=2, + kernel_size=3, + align_corners=False, + dropout=0, + norm_layer=norm_layer, + conv_norm=None, # Add default value for conv_norm + conv_act=None, # Add default value for conv_act + ) + + self.decoder = _SETRUPHead( + channels=1024, + in_channels=hidden_dim, + num_classes=6, + num_convs=4, + up_scale=2, + kernel_size=3, + align_corners=False, + dropout=0, + norm_layer=norm_layer, + conv_norm=None, # Add default value for conv_norm + conv_act=None, # Add default value for conv_act + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.encoder(x) + # x = self.auto_head1(x) + # x = self.auto_head2(x) + # x = self.auto_head3(x) + # x = self.decoder(x) + return x From 1a469e88bc3c9c5cb73fa631e975be9ee134406d Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 5 Mar 2024 14:27:03 -0300 Subject: [PATCH 09/24] testing setr --- tests/models/nets/test_setr.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 tests/models/nets/test_setr.py diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py new file mode 100644 index 0000000..141c920 --- /dev/null +++ b/tests/models/nets/test_setr.py @@ -0,0 +1,10 @@ +import torch + +from sslt.models.nets.setr import _SetR_PUP + +if __name__ == "__main__": + model = _SetR_PUP(2, 3, 4, 5, 6, 7, 8, 9, 10) + print(model) + result = model.forward(torch.zeros(1, 2, 3, 4)) + print(result.shape) + print(result) From c2b7885a733d395a6a5c0e29fc6c41a4878b87c9 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 5 Mar 2024 14:32:57 -0300 Subject: [PATCH 10/24] tests setr --- tests/models/nets/test_setr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py index 141c920..9de22dd 100644 --- a/tests/models/nets/test_setr.py +++ b/tests/models/nets/test_setr.py @@ -3,7 +3,7 @@ from sslt.models.nets.setr import _SetR_PUP if __name__ == "__main__": - model = _SetR_PUP(2, 3, 4, 5, 6, 7, 8, 9, 10) + model = _SetR_PUP(512, 16, 24, 16, 1, 1, 3) print(model) result = model.forward(torch.zeros(1, 2, 3, 4)) print(result.shape) From 5a39e8629dc82cc06649adfbbe4c2401c84d8ff9 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 5 Mar 2024 14:37:24 -0300 Subject: [PATCH 11/24] testes setr --- tests/models/nets/test_setr.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py index 9de22dd..17bfb4a 100644 --- a/tests/models/nets/test_setr.py +++ b/tests/models/nets/test_setr.py @@ -3,7 +3,15 @@ from sslt.models.nets.setr import _SetR_PUP if __name__ == "__main__": - model = _SetR_PUP(512, 16, 24, 16, 1, 1, 3) + model = _SetR_PUP( + image_size=512, + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=2, + mlp_dim=1, + num_classes=3, + ) print(model) result = model.forward(torch.zeros(1, 2, 3, 4)) print(result.shape) From 04ca35d353b94ea6a735f4891bed90cf796cfdde Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 5 Mar 2024 21:18:08 +0000 Subject: [PATCH 12/24] sert integration in progress --- sslt/models/nets/setr.py | 20 ++++++++++---------- {tests => sslt}/models/nets/test_setr.py | 6 ++---- sslt/models/nets/vit.py | 16 ++++++++-------- 3 files changed, 20 insertions(+), 22 deletions(-) rename {tests => sslt}/models/nets/test_setr.py (73%) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 8d08ae6..5c5e51f 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -44,7 +44,7 @@ def __init__( conv_norm if conv_norm is not None else nn.SyncBatchNorm(self.out_channels) ) conv_act = conv_act if conv_act is not None else nn.ReLU() - self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None + self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None self.cls_seg = nn.Conv2d(channels, self.out_channels, 1) self.up_convs = nn.ModuleList() @@ -139,10 +139,10 @@ def __init__( conv_norm if conv_norm is not None else nn.SyncBatchNorm(mla_channels) ) conv_act = conv_act if conv_act is not None else nn.ReLU() - self.dropout = nn.Dropout2d(dropout) if dropout > 0 is not None else None + self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None self.cls_seg = nn.Conv2d(channels, out_channels, 1) - num_inputs = len(self.in_channels) + num_inputs = len(in_channels) self.up_convs = nn.ModuleList() for i in range(num_inputs): @@ -269,10 +269,10 @@ def __init__( conv_act=None, # Add default value for conv_act ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.encoder(x) - # x = self.auto_head1(x) - # x = self.auto_head2(x) - # x = self.auto_head3(x) - # x = self.decoder(x) - return x + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.encoder(x) + # x = self.auto_head1(x) + # x = self.auto_head2(x) + # x = self.auto_head3(x) + # x = self.decoder(x) + return x diff --git a/tests/models/nets/test_setr.py b/sslt/models/nets/test_setr.py similarity index 73% rename from tests/models/nets/test_setr.py rename to sslt/models/nets/test_setr.py index 17bfb4a..4d08348 100644 --- a/tests/models/nets/test_setr.py +++ b/sslt/models/nets/test_setr.py @@ -1,6 +1,6 @@ import torch -from sslt.models.nets.setr import _SetR_PUP +from .setr import _SetR_PUP if __name__ == "__main__": model = _SetR_PUP( @@ -8,11 +8,9 @@ patch_size=16, num_layers=24, num_heads=16, - hidden_dim=2, + hidden_dim=768, mlp_dim=1, num_classes=3, ) - print(model) result = model.forward(torch.zeros(1, 2, 3, 4)) print(result.shape) - print(result) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 4dce6c8..0ca9f43 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -83,14 +83,14 @@ def __init__( seq_length += 1 self.encoder = Encoder( - seq_length, - num_layers, - num_heads, - hidden_dim, - mlp_dim, - dropout, - attention_dropout, - norm_layer, + seq_length=seq_length, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, ) self.seq_length = seq_length From c5141e07a07a3b9cbe7fa1d6fff30b7bcd60f3ae Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Thu, 7 Mar 2024 10:26:09 -0300 Subject: [PATCH 13/24] progress i guess --- setr net graph.txt | 706 ++++++++++++++++++++++++++++++++++ sslt/models/nets/__init__.py | 1 + sslt/models/nets/setr.py | 11 +- sslt/models/nets/test_setr.py | 16 - sslt/models/nets/vit.py | 6 +- test_setr.py | 18 + 6 files changed, 734 insertions(+), 24 deletions(-) create mode 100644 setr net graph.txt delete mode 100644 sslt/models/nets/test_setr.py create mode 100644 test_setr.py diff --git a/setr net graph.txt b/setr net graph.txt new file mode 100644 index 0000000..16a4b5e --- /dev/null +++ b/setr net graph.txt @@ -0,0 +1,706 @@ +EncoderDecoder( + (backbone): VisionTransformer( + (patch_embed): PatchEmbed( + (adap_padding): AdaptivePadding() + (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16)) + ) + (drop_after_pos): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (1): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (2): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (3): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (4): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (5): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (6): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (7): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (8): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (9): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (10): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (11): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (12): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (13): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (14): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (15): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (16): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (17): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (18): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (19): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (20): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (21): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (22): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + (23): TransformerEncoderLayer( + (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (attn): MultiheadAttention( + (attn): MultiheadAttention( + (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) + ) + (proj_drop): Dropout(p=0.0, inplace=False) + (dropout_layer): Dropout(p=0.0, inplace=False) + ) + (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (ffn): FFN( + (activate): GELU() + (layers): Sequential( + (0): Sequential( + (0): Linear(in_features=1024, out_features=4096, bias=True) + (1): GELU() + (2): Dropout(p=0.0, inplace=False) + ) + (1): Linear(in_features=4096, out_features=1024, bias=True) + (2): Dropout(p=0.0, inplace=False) + ) + (dropout_layer): Identity() + ) + ) + ) + ) + init_cfg={'type': 'Pretrained', 'checkpoint': 'pretrain/vit_large_p16.pth'} + (decode_head): SETRUPHead( + input_transform=None, ignore_index=255, align_corners=False + (loss_decode): CrossEntropyLoss(avg_non_ignore=False) + (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) + (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (up_convs): ModuleList( + (0): Sequential( + (0): ConvModule( + (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + (1): Sequential( + (0): ConvModule( + (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + (2): Sequential( + (0): ConvModule( + (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + (3): Sequential( + (0): ConvModule( + (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + ) + ) + init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] + (auxiliary_head): ModuleList( + (0): SETRUPHead( + input_transform=None, ignore_index=255, align_corners=False + (loss_decode): CrossEntropyLoss(avg_non_ignore=False) + (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) + (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (up_convs): ModuleList( + (0): Sequential( + (0): ConvModule( + (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + (1): Sequential( + (0): ConvModule( + (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + ) + ) + init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] + (1): SETRUPHead( + input_transform=None, ignore_index=255, align_corners=False + (loss_decode): CrossEntropyLoss(avg_non_ignore=False) + (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) + (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (up_convs): ModuleList( + (0): Sequential( + (0): ConvModule( + (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + (1): Sequential( + (0): ConvModule( + (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + ) + ) + init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] + (2): SETRUPHead( + input_transform=None, ignore_index=255, align_corners=False + (loss_decode): CrossEntropyLoss(avg_non_ignore=False) + (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) + (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) + (up_convs): ModuleList( + (0): Sequential( + (0): ConvModule( + (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + (1): Sequential( + (0): ConvModule( + (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (activate): ReLU(inplace=True) + ) + (1): Upsample() + ) + ) + ) + init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] + ) +) \ No newline at end of file diff --git a/sslt/models/nets/__init__.py b/sslt/models/nets/__init__.py index e69de29..dd6875f 100644 --- a/sslt/models/nets/__init__.py +++ b/sslt/models/nets/__init__.py @@ -0,0 +1 @@ +from .setr import _SetR_PUP diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 5c5e51f..509620a 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -6,7 +6,7 @@ from sslt.models.nets.base import SimpleSupervisedModel from sslt.models.nets.vit import _VisionTransformerBackbone -from sslt.utils.upsample import Upsample +from sslt.utils.upsample import Upsample, resize class _SETRUPHead(nn.Module): @@ -69,12 +69,9 @@ def __init__( ) in_channels = self.out_channels - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x): - n, c, h, w = x.shape - x = x.reshape(n, c, h * w).transpose(2, 1).contiguous() x = self.norm(x) - x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() for up_conv in self.up_convs: x = up_conv(x) @@ -214,7 +211,7 @@ def __init__( ) self.aux_head1 = _SETRUPHead( - channels=1024, + channels=16, in_channels=hidden_dim, num_classes=6, num_convs=4, @@ -271,7 +268,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.encoder(x) - # x = self.auto_head1(x) + x = self.aux_head1(x) # x = self.auto_head2(x) # x = self.auto_head3(x) # x = self.decoder(x) diff --git a/sslt/models/nets/test_setr.py b/sslt/models/nets/test_setr.py deleted file mode 100644 index 4d08348..0000000 --- a/sslt/models/nets/test_setr.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - -from .setr import _SetR_PUP - -if __name__ == "__main__": - model = _SetR_PUP( - image_size=512, - patch_size=16, - num_layers=24, - num_heads=16, - hidden_dim=768, - mlp_dim=1, - num_classes=3, - ) - result = model.forward(torch.zeros(1, 2, 3, 4)) - print(result.shape) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 0ca9f43..cb6c015 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -155,6 +155,10 @@ def forward(self, x: torch.Tensor): x = self.encoder(x) # Classifier "token" as used by standard language architectures - x = x[:, 0] + x = x[:, 1:] + + x = x.permute(1, 0, 2) + + print(x.shape) return x diff --git a/test_setr.py b/test_setr.py new file mode 100644 index 0000000..b5a5e3d --- /dev/null +++ b/test_setr.py @@ -0,0 +1,18 @@ +import torch + +from sslt.models.nets.setr import _SetR_PUP + +if __name__ == "__main__": + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = _SetR_PUP( + image_size=16, + patch_size=16, + num_layers=24, + num_heads=16, + hidden_dim=16, + mlp_dim=1, + num_classes=3, + ) + model.to(device) + result = model.forward(torch.zeros(16, 3, 16, 16).to(device)) + print(result.shape) From caf20c3c283f85146b544ab99723840116a85c6e Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Fri, 8 Mar 2024 11:40:19 -0300 Subject: [PATCH 14/24] new netgraph and modifications --- new setr net graph.txt | Bin 0 -> 41638 bytes sslt/models/nets/setr.py | 32 +++++++++++++++++--------------- sslt/models/nets/vit.py | 4 ---- test_setr.py | 6 +++--- 4 files changed, 20 insertions(+), 22 deletions(-) create mode 100644 new setr net graph.txt diff --git a/new setr net graph.txt b/new setr net graph.txt new file mode 100644 index 0000000000000000000000000000000000000000..1193795260b751d067e3771e4eaf0299227f4798 GIT binary patch literal 41638 zcmeHQ+iu%N5S`}&{Rh952I?Y~6ZZn>q0Xg!aDvpAz6!EUJFX>Kk>n)y$G1JR!_jIf zQm$>Kppyea+9kQWoS8jyI9zhcfBx>tzT`5MJ?Y6?Im9oMZ{-_V!;)c7A;)qlC&-Ou zU4FsVp8P52GQsZ>dk?U^k8~oZSRczUl-$9#W4Vx{P$r*i)7sTG3;BSyMyPWsf0b=N zFSmDa#7ws2dmL|5o=8WYhWtfny}VUw2yYt8dmEhNI!zqS2&v4=S8wDgN9vb@B<`Nsm{^n2-2&4RIXRZ4Ipcd+HUvQ zzLvYjzA;Kz>)KoIaWw6jDO%Ioi;|giY47#IQLVKeB(nBW9d}UXGW1-1o7V0C$L>KN zYhiaL;~Ugf1Z1HlMx&_^xNLyM=+r zI>2%*f8eOlwR#>vA6V{veN&#~odsX1G?- ztK`phof(hAc-%)Cmc^8e$2B&uzZU)4<8i-!A&-;ComV|2k8@nbv5=-2M!4-tW%XkX zdL$i7_t9F8tC;PV$KyWsLmnrOr@pG*#CV+JDpgb%&yvU6HLmikJnxam$>UXzC66;6 zuVTX18IMm9<=p?t6}M{rNO=9$Z=@A;d6&oMEE*+ylp)G6M39GP9DbwT6Z^+|&ledq_Ij56z^Q-TIaGxcU zlgS+mX-Xzn4B@WVOH-+;%$>dEw%b1)o`QkZ_Om6QQ zZX1(xh9{YvOuoohjLDl&K_+h}lXH$InY_to8Iv<6Up&W=$=k@}9nSIOS34KJ66@Ye zCMT0S7Sfb=Jll51vorg8GM^{s^W>?wa%K=`22~N^>tyowF}dG^CX$lS6qczhh4y`?F+RH$$xfj@}Cr zoRu}7qoqPVKoc$B!Tcwa+x{=YF$(+-LcMdjLB1QbSDp2(Vwz5%+ZfuYOzGYzx#7Af zE$Cm4y^m#|J=n4ttUZE8YDd+6Qd;A6uo?ndE9kCK#d5pr-cdloNkG|IrC)UaQIJ?M`zlDT$V~f7B{%8) z{c-#7gaUTkL5(TWcXK`w(NpyW>#-rM>z;6hTo-n9@2>X;l~R$_$B{13U+N_V7WJF6 zfXHc(Lv^V4p{Ri~>R#yEL-bRVt{$@;29NGT(osM}U5&*ylo>W=-$veo2A}YII8vf@ zV~#j6-h6C0iKvOW3`o8LR(kL$<8Sx5Rk8a#*O1?r^ZQ0Px9QgbJ1g_xdHc`jAEy`3 zYw)%5ee~X#Tl8Z;^Q_8`BmdKjmwT~AnGZkr&SO1V|8e?ll>ck>+G<4m&=@{iH0Fqn tFF7*vD+Dt7-9|Kc-OcOnx~phoWcRB3Uw5zSd>?r{=E&$@2BTY~{x1x%gSh|z literal 0 HcmV?d00001 diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 509620a..1006da0 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -39,13 +39,14 @@ def __init__( self.num_classes = num_classes self.out_channels = channels self.threshold = threshold - self.norm = norm_layer if norm_layer is not None else nn.SyncBatchNorm(channels) + self.cls_seg = nn.Conv2d(channels, self.num_classes, 1) + self.norm = norm_layer if norm_layer is not None else nn.LayerNorm(in_channels) conv_norm = ( conv_norm if conv_norm is not None else nn.SyncBatchNorm(self.out_channels) ) conv_act = conv_act if conv_act is not None else nn.ReLU() self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None - self.cls_seg = nn.Conv2d(channels, self.out_channels, 1) + self.up_convs = nn.ModuleList() for _ in range(num_convs): @@ -74,6 +75,7 @@ def forward(self, x): x = self.norm(x) for up_conv in self.up_convs: + print(x.shape) x = up_conv(x) if self.dropout is not None: @@ -210,8 +212,8 @@ def __init__( dropout=dropout, ) - self.aux_head1 = _SETRUPHead( - channels=16, + self.decoder = _SETRUPHead( + channels=256, in_channels=hidden_dim, num_classes=6, num_convs=4, @@ -224,11 +226,11 @@ def __init__( conv_act=None, # Add default value for conv_act ) - self.aux_head2 = _SETRUPHead( + self.aux_head1 = _SETRUPHead( channels=1024, in_channels=hidden_dim, num_classes=6, - num_convs=4, + num_convs=2, up_scale=2, kernel_size=3, align_corners=False, @@ -238,11 +240,11 @@ def __init__( conv_act=None, # Add default value for conv_act ) - self.aux_head3 = _SETRUPHead( - channels=1024, + self.aux_head2 = _SETRUPHead( + channels=256, in_channels=hidden_dim, num_classes=6, - num_convs=4, + num_convs=2, up_scale=2, kernel_size=3, align_corners=False, @@ -252,11 +254,11 @@ def __init__( conv_act=None, # Add default value for conv_act ) - self.decoder = _SETRUPHead( - channels=1024, + self.aux_head3 = _SETRUPHead( + channels=256, in_channels=hidden_dim, num_classes=6, - num_convs=4, + num_convs=2, up_scale=2, kernel_size=3, align_corners=False, @@ -269,7 +271,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.encoder(x) x = self.aux_head1(x) - # x = self.auto_head2(x) - # x = self.auto_head3(x) - # x = self.decoder(x) + x = self.auto_head2(x) + x = self.auto_head3(x) + x = self.decoder(x) return x diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index cb6c015..cc1b288 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -157,8 +157,4 @@ def forward(self, x: torch.Tensor): # Classifier "token" as used by standard language architectures x = x[:, 1:] - x = x.permute(1, 0, 2) - - print(x.shape) - return x diff --git a/test_setr.py b/test_setr.py index b5a5e3d..99ea1dd 100644 --- a/test_setr.py +++ b/test_setr.py @@ -5,14 +5,14 @@ if __name__ == "__main__": device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = _SetR_PUP( - image_size=16, + image_size=512, patch_size=16, num_layers=24, num_heads=16, - hidden_dim=16, + hidden_dim=1024, mlp_dim=1, num_classes=3, ) model.to(device) - result = model.forward(torch.zeros(16, 3, 16, 16).to(device)) + result = model.forward(torch.zeros(1, 3, 512, 512).to(device)) print(result.shape) From de62da28ed3fa620d72a62510ad1e3098310ea17 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Fri, 8 Mar 2024 14:41:10 -0300 Subject: [PATCH 15/24] aux_heads fix --- sslt/models/nets/setr.py | 12 +++++++----- test_setr.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 1006da0..59e91f4 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -268,10 +268,12 @@ def __init__( conv_act=None, # Add default value for conv_act ) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x = self.encoder(x) - x = self.aux_head1(x) - x = self.auto_head2(x) - x = self.auto_head3(x) + x_aux1 = self.aux_head1(x) + x_aux2 = self.aux_head2(x) + x_aux3 = self.aux_head3(x) x = self.decoder(x) - return x + return x, x_aux1, x_aux2, x_aux3 diff --git a/test_setr.py b/test_setr.py index 99ea1dd..c1063fe 100644 --- a/test_setr.py +++ b/test_setr.py @@ -15,4 +15,4 @@ ) model.to(device) result = model.forward(torch.zeros(1, 3, 512, 512).to(device)) - print(result.shape) + print(result[0].shape) From 8b8012c3f44626a10df7ef1eb300cc949901059e Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Sat, 9 Mar 2024 17:29:03 -0300 Subject: [PATCH 16/24] proregress i gues?? --- sslt/models/nets/setr.py | 10 +++++----- sslt/models/nets/vit.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 59e91f4..ca7f1a6 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -213,7 +213,7 @@ def __init__( ) self.decoder = _SETRUPHead( - channels=256, + channels=1024, in_channels=hidden_dim, num_classes=6, num_convs=4, @@ -272,8 +272,8 @@ def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x = self.encoder(x) - x_aux1 = self.aux_head1(x) - x_aux2 = self.aux_head2(x) - x_aux3 = self.aux_head3(x) + # x_aux1 = self.aux_head1(x) + # x_aux2 = self.aux_head2(x) + # x_aux3 = self.aux_head3(x) x = self.decoder(x) - return x, x_aux1, x_aux2, x_aux3 + return x, torch.zeros(1), torch.zeros(1), torch.zeros(1) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index cc1b288..ea4c2ec 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -116,7 +116,7 @@ def __init__( if self.conv_proj.conv_last.bias is not None: nn.init.zeros_(self.conv_proj.conv_last.bias) - def _process_input(self, x: torch.Tensor) -> torch.Tensor: + def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: n, c, h, w = x.shape p = self.patch_size torch._assert( @@ -141,11 +141,11 @@ def _process_input(self, x: torch.Tensor) -> torch.Tensor: # embedding dimension x = x.permute(0, 2, 1) - return x + return x, n_h, n_w def forward(self, x: torch.Tensor): # Reshape and permute the input tensor - x = self._process_input(x) + x, n_h, n_w = self._process_input(x) n = x.shape[0] # Expand the class token to the full batch From e660417ef2d91650362f6159d7b17ff2bc52bb97 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Mar 2024 11:35:42 -0300 Subject: [PATCH 17/24] VAI [redacted]!!!! FORWARD PRONTO --- sslt/models/nets/setr.py | 5 ++++- sslt/models/nets/vit.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index ca7f1a6..71433ad 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -72,10 +72,13 @@ def __init__( def forward(self, x): + n, c, h, w = x.shape + + x = x.reshape(n, c, h * w).transpose(1, 2).contiguous() x = self.norm(x) + x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() for up_conv in self.up_convs: - print(x.shape) x = up_conv(x) if self.dropout is not None: diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index ea4c2ec..ee5bc03 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -157,4 +157,8 @@ def forward(self, x: torch.Tensor): # Classifier "token" as used by standard language architectures x = x[:, 1:] + B, _, C = x.shape + + x = x.reshape(B, n_h, n_w, C).permute(0, 3, 1, 2).contiguous() + return x From 0aadc12fecc95999175a041b8f3d68187aa7885b Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Mar 2024 14:25:15 -0300 Subject: [PATCH 18/24] add independent height and width option --- sslt/models/nets/setr.py | 42 ++++++++++++++++++++++++++++++---- sslt/models/nets/vit.py | 49 +++++++++++++++++++++++++++++----------- test_setr.py | 4 ++-- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 71433ad..8582977 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -1,6 +1,7 @@ import warnings from typing import Optional, Tuple +import lightning as L import torch from torch import nn @@ -191,7 +192,7 @@ class _SetR_PUP(nn.Module): def __init__( self, - image_size: int, + image_size: int | tuple[int, int], patch_size: int, num_layers: int, num_heads: int, @@ -275,8 +276,39 @@ def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x = self.encoder(x) - # x_aux1 = self.aux_head1(x) - # x_aux2 = self.aux_head2(x) - # x_aux3 = self.aux_head3(x) + x_aux1 = self.aux_head1(x) + x_aux2 = self.aux_head2(x) + x_aux3 = self.aux_head3(x) x = self.decoder(x) - return x, torch.zeros(1), torch.zeros(1), torch.zeros(1) + return x, x_aux1, x_aux2, x_aux3 + + +class SETR_PUP(L.LightningDataModule): + + def __init__( + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + num_classes: int, + dropout: float = 0.1, + attention_dropout: float = 0.1, + norm_layer: Optional[nn.Module] = None, + interpolate_mode: str = "bilinear", + ): + self.model = _SetR_PUP( + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + num_classes=num_classes, + dropout=dropout, + attention_dropout=attention_dropout, + norm_layer=norm_layer, + interpolate_mode=interpolate_mode, + ) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index ee5bc03..94acbd6 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -18,7 +18,7 @@ class _VisionTransformerBackbone(nn.Module): def __init__( self, - image_size: int, + image_size: int | tuple[int, int], patch_size: int, num_layers: int, num_heads: int, @@ -32,9 +32,17 @@ def __init__( ): super().__init__() _log_api_usage_once(self) - torch._assert( - image_size % patch_size == 0, "Input shape indivisible by patch size!" - ) + + if isinstance(image_size, int): + torch._assert( + image_size % patch_size == 0, "Input shape indivisible by patch size!" + ) + elif isinstance(image_size, tuple): + torch._assert( + image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0, + "Input shape indivisible by patch size!", + ) + self.image_size = image_size self.patch_size = patch_size self.hidden_dim = hidden_dim @@ -76,7 +84,10 @@ def __init__( stride=patch_size, ) - seq_length = (image_size // patch_size) ** 2 + if isinstance(image_size, int): + seq_length = (image_size // patch_size) ** 2 + elif isinstance(image_size, tuple): + seq_length = (image_size[0] // patch_size) * (image_size[1] // patch_size) # Add a class token self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) @@ -119,14 +130,26 @@ def __init__( def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: n, c, h, w = x.shape p = self.patch_size - torch._assert( - h == self.image_size, - f"Wrong image height! Expected {self.image_size} but got {h}!", - ) - torch._assert( - w == self.image_size, - f"Wrong image width! Expected {self.image_size} but got {w}!", - ) + + if isinstance(self.image_size, int): + torch._assert( + h == self.image_size, + f"Wrong image height! Expected {self.image_size} but got {h}!", + ) + torch._assert( + w == self.image_size, + f"Wrong image width! Expected {self.image_size} but got {w}!", + ) + elif isinstance(self.image_size, tuple): + torch._assert( + h == self.image_size[0], + f"Wrong image height! Expected {self.image_size[0]} but got {h}!", + ) + torch._assert( + w == self.image_size[1], + f"Wrong image width! Expected {self.image_size[1]} but got {w}!", + ) + n_h = h // p n_w = w // p diff --git a/test_setr.py b/test_setr.py index c1063fe..6123e17 100644 --- a/test_setr.py +++ b/test_setr.py @@ -5,7 +5,7 @@ if __name__ == "__main__": device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = _SetR_PUP( - image_size=512, + image_size=(16, 32), patch_size=16, num_layers=24, num_heads=16, @@ -14,5 +14,5 @@ num_classes=3, ) model.to(device) - result = model.forward(torch.zeros(1, 3, 512, 512).to(device)) + result = model.forward(torch.zeros(1, 3, 16, 32).to(device)) print(result[0].shape) From aa148543d216c2f42e630f537f7644e2f42a21d1 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Mar 2024 17:55:41 -0300 Subject: [PATCH 19/24] Refactor SETR_PUP model and fix image size validation --- sslt/models/nets/setr.py | 131 ++++++++++++++++++++++++++++++--------- sslt/models/nets/vit.py | 2 + 2 files changed, 103 insertions(+), 30 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 8582977..1474d7a 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -5,7 +5,6 @@ import torch from torch import nn -from sslt.models.nets.base import SimpleSupervisedModel from sslt.models.nets.vit import _VisionTransformerBackbone from sslt.utils.upsample import Upsample, resize @@ -20,11 +19,11 @@ class _SETRUPHead(nn.Module): def __init__( self, channels: int, - norm_layer: Optional[nn.Module], - conv_norm: Optional[nn.Module], - conv_act: Optional[nn.Module], in_channels: int, num_classes: int, + norm_layer: Optional[nn.Module] = None, + conv_norm: Optional[nn.Module] = None, + conv_act: Optional[nn.Module] = None, num_convs: int = 1, up_scale: int = 4, kernel_size: int = 3, @@ -198,9 +197,13 @@ def __init__( num_heads: int, hidden_dim: int, mlp_dim: int, + num_convs: int, num_classes: int, - dropout: float = 0.1, - attention_dropout: float = 0.1, + decoder_channels: int, + up_scale: int = 2, + encoder_dropout: float = 0.1, + kernel_size: int = 3, + decoder_dropout: float = 0.1, norm_layer: Optional[nn.Module] = None, interpolate_mode: str = "bilinear", ): @@ -213,21 +216,19 @@ def __init__( hidden_dim=hidden_dim, mlp_dim=mlp_dim, num_classes=num_classes, - dropout=dropout, + dropout=encoder_dropout, ) self.decoder = _SETRUPHead( - channels=1024, + channels=decoder_channels, in_channels=hidden_dim, - num_classes=6, - num_convs=4, - up_scale=2, - kernel_size=3, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, align_corners=False, - dropout=0, + dropout=decoder_dropout, norm_layer=norm_layer, - conv_norm=None, # Add default value for conv_norm - conv_act=None, # Add default value for conv_act ) self.aux_head1 = _SETRUPHead( @@ -240,8 +241,6 @@ def __init__( align_corners=False, dropout=0, norm_layer=norm_layer, - conv_norm=None, # Add default value for conv_norm - conv_act=None, # Add default value for conv_act ) self.aux_head2 = _SETRUPHead( @@ -254,8 +253,6 @@ def __init__( align_corners=False, dropout=0, norm_layer=norm_layer, - conv_norm=None, # Add default value for conv_norm - conv_act=None, # Add default value for conv_act ) self.aux_head3 = _SETRUPHead( @@ -268,22 +265,20 @@ def __init__( align_corners=False, dropout=0, norm_layer=norm_layer, - conv_norm=None, # Add default value for conv_norm - conv_act=None, # Add default value for conv_act ) def forward( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x = self.encoder(x) - x_aux1 = self.aux_head1(x) - x_aux2 = self.aux_head2(x) - x_aux3 = self.aux_head3(x) + # x_aux1 = self.aux_head1(x) + # x_aux2 = self.aux_head2(x) + # x_aux3 = self.aux_head3(x) x = self.decoder(x) - return x, x_aux1, x_aux2, x_aux3 + return x, torch.zeros(1), torch.zeros(1), torch.zeros(1) -class SETR_PUP(L.LightningDataModule): +class SETR_PUP(L.LightningModule): def __init__( self, @@ -294,11 +289,16 @@ def __init__( hidden_dim: int, mlp_dim: int, num_classes: int, - dropout: float = 0.1, - attention_dropout: float = 0.1, + num_convs: int, + encoder_dropout: float = 0.1, + decoder_dropout: float = 0.1, + decoder_channels: int = 1024, norm_layer: Optional[nn.Module] = None, interpolate_mode: str = "bilinear", + loss_fn: Optional[nn.Module] = None, ): + super().__init__() + self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss() self.model = _SetR_PUP( image_size=image_size, patch_size=patch_size, @@ -307,8 +307,79 @@ def __init__( hidden_dim=hidden_dim, mlp_dim=mlp_dim, num_classes=num_classes, - dropout=dropout, - attention_dropout=attention_dropout, + num_convs=num_convs, + decoder_channels=decoder_channels, + encoder_dropout=encoder_dropout, + decoder_dropout=decoder_dropout, norm_layer=norm_layer, interpolate_mode=interpolate_mode, ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + def _loss_func(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calculate the loss between the output and the input data. + + Parameters + ---------- + y_hat : torch.Tensor + The output data from the forward pass. + y : torch.Tensor + The input data/label. + + Returns + ------- + torch.Tensor + The loss value. + """ + loss = self.loss_fn(y_hat, y) + return loss + + def _single_step( + self, batch: torch.Tensor, batch_idx: int, step_name: str + ) -> torch.Tensor: + """Perform a single step of the training/validation loop. + + Parameters + ---------- + batch : torch.Tensor + The input data. + batch_idx : int + The index of the batch. + step_name : str + The name of the step, either "train" or "val". + + Returns + ------- + torch.Tensor + The loss value. + """ + x, y = batch + y_hat = self.model(x) + loss = self._loss_func(y_hat, y) + self.log( + f"{step_name}_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def training_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "train") + + def validation_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "val") + + def test_step(self, batch: torch.Tensor, batch_idx: int): + return self._single_step(batch, batch_idx, "test") + + def predict_step(self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int): + x, _ = batch + return self.model(x) + + def configure_optimizers(self): + return torch.optim.Adam(self.model.parameters(), lr=1e-3) diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 94acbd6..9fc85e4 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -149,6 +149,8 @@ def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: w == self.image_size[1], f"Wrong image width! Expected {self.image_size[1]} but got {w}!", ) + else: + raise ValueError("Invalid image size type!") n_h = h // p n_w = w // p From 2e5eeaefbd9ac8b1a6ad5b0c0724c05f3a083e8d Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Mar 2024 22:53:03 -0300 Subject: [PATCH 20/24] Update SETR_PUP model and add docstrings to Vision Transformer Backbone class --- sslt/models/nets/setr.py | 260 +++++++++++++++++++++++++++++++-------- sslt/models/nets/vit.py | 48 ++++++++ test_setr.py | 11 +- 3 files changed, 263 insertions(+), 56 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 1474d7a..034298d 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -21,30 +21,61 @@ def __init__( channels: int, in_channels: int, num_classes: int, - norm_layer: Optional[nn.Module] = None, - conv_norm: Optional[nn.Module] = None, - conv_act: Optional[nn.Module] = None, - num_convs: int = 1, - up_scale: int = 4, - kernel_size: int = 3, - align_corners: bool = True, - dropout: float = 0.1, - threshold: Optional[float] = None, + norm_layer: nn.Module, + conv_norm: nn.Module, + conv_act: nn.Module, + num_convs: int, + up_scale: int, + kernel_size: int, + align_corners: bool, + dropout: float, + interpolate_mode: str, ): + """ + Initializes the SETR model. + Parameters + ---------- + channels : int + Number of output channels. + in_channels : int + Number of input channels. + num_classes : int + Number of output classes. + norm_layer : nn.Module + Normalization layer. + conv_norm : nn.Module + Convolutional normalization layer. + conv_act : nn.Module + Convolutional activation layer. + num_convs : int + Number of convolutional layers. + up_scale : int + Upsampling scale factor. + kernel_size : int + Kernel size for convolutional layers. + align_corners : bool + Whether to align corners during upsampling. + dropout : float + Dropout rate. + interpolate_mode : str + Interpolation mode for upsampling. + + Raises + ------ + AssertionError + If kernel_size is not 1 or 3. + """ assert kernel_size in [1, 3], "kernel_size must be 1 or 3." super().__init__() self.num_classes = num_classes self.out_channels = channels - self.threshold = threshold self.cls_seg = nn.Conv2d(channels, self.num_classes, 1) - self.norm = norm_layer if norm_layer is not None else nn.LayerNorm(in_channels) - conv_norm = ( - conv_norm if conv_norm is not None else nn.SyncBatchNorm(self.out_channels) - ) - conv_act = conv_act if conv_act is not None else nn.ReLU() + self.norm = norm_layer + conv_norm = conv_norm + conv_act = conv_act self.dropout = nn.Dropout2d(dropout) if dropout > 0 != None else None self.up_convs = nn.ModuleList() @@ -63,7 +94,7 @@ def __init__( conv_act, Upsample( scale_factor=up_scale, - mode="bilinear", + mode=interpolate_mode, align_corners=align_corners, ), ) @@ -200,13 +231,59 @@ def __init__( num_convs: int, num_classes: int, decoder_channels: int, - up_scale: int = 2, - encoder_dropout: float = 0.1, - kernel_size: int = 3, - decoder_dropout: float = 0.1, - norm_layer: Optional[nn.Module] = None, - interpolate_mode: str = "bilinear", + up_scale: int, + encoder_dropout: float, + kernel_size: int, + decoder_dropout: float, + norm_layer: nn.Module, + interpolate_mode: str, + conv_norm: nn.Module, + conv_act: nn.Module, + align_corners: bool, ): + """ + Initializes the SETR PUP model. + + Parameters + ---------- + image_size : int or tuple[int, int] + The size of the input image. + patch_size : int + The size of each patch in the input image. + num_layers : int + The number of layers in the transformer encoder. + num_heads : int + The number of attention heads in the transformer encoder. + hidden_dim : int + The hidden dimension of the transformer encoder. + mlp_dim : int + The dimension of the feed-forward network in the transformer encoder. + num_convs : int + The number of convolutional layers in the decoder. + num_classes : int + The number of output classes. + decoder_channels : int + The number of channels in the decoder. + up_scale : int + The scale factor for upsampling in the decoder. + encoder_dropout : float + The dropout rate for the transformer encoder. + kernel_size : int + The kernel size for the convolutional layers in the decoder. + decoder_dropout : float + The dropout rate for the decoder. + norm_layer : nn.Module + The normalization layer to be used. + interpolate_mode : str + The mode for interpolation during upsampling. + conv_norm : nn.Module + The normalization layer to be used in the decoder convolutional layers. + conv_act : nn.Module + The activation function to be used in the decoder convolutional layers. + align_corners : bool + Whether to align corners during upsampling. + + """ super().__init__() self.encoder = _VisionTransformerBackbone( image_size=image_size, @@ -226,44 +303,56 @@ def __init__( num_convs=num_convs, up_scale=up_scale, kernel_size=kernel_size, - align_corners=False, + align_corners=align_corners, dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, norm_layer=norm_layer, ) self.aux_head1 = _SETRUPHead( - channels=1024, + channels=decoder_channels, in_channels=hidden_dim, - num_classes=6, - num_convs=2, - up_scale=2, - kernel_size=3, - align_corners=False, - dropout=0, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, norm_layer=norm_layer, ) self.aux_head2 = _SETRUPHead( - channels=256, + channels=decoder_channels, in_channels=hidden_dim, - num_classes=6, - num_convs=2, - up_scale=2, - kernel_size=3, - align_corners=False, - dropout=0, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, norm_layer=norm_layer, ) self.aux_head3 = _SETRUPHead( - channels=256, + channels=decoder_channels, in_channels=hidden_dim, - num_classes=6, - num_convs=2, - up_scale=2, - kernel_size=3, - align_corners=False, - dropout=0, + num_classes=num_classes, + num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + align_corners=align_corners, + dropout=decoder_dropout, + conv_norm=conv_norm, + conv_act=conv_act, + interpolate_mode=interpolate_mode, norm_layer=norm_layer, ) @@ -282,23 +371,81 @@ class SETR_PUP(L.LightningModule): def __init__( self, - image_size: int, - patch_size: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - num_classes: int, - num_convs: int, + image_size: int | tuple[int, int] = 512, + patch_size: int = 16, + num_layers: int = 24, + num_heads: int = 16, + hidden_dim: int = 1024, + mlp_dim: int = 4096, encoder_dropout: float = 0.1, - decoder_dropout: float = 0.1, - decoder_channels: int = 1024, + num_classes: int = 1000, norm_layer: Optional[nn.Module] = None, + decoder_channels: int = 256, + num_convs: int = 4, + up_scale: int = 2, + kernel_size: int = 3, + align_corners: bool = False, + decoder_dropout: float = 0.1, + conv_norm: Optional[nn.Module] = None, + conv_act: Optional[nn.Module] = None, interpolate_mode: str = "bilinear", loss_fn: Optional[nn.Module] = None, ): + """ + Initializes the SetR model. + + Parameters + ---------- + image_size : int or tuple[int, int] + The input image size. Defaults to 512. + patch_size : int + The size of each patch. Defaults to 16. + num_layers : int + The number of layers in the transformer encoder. Defaults to 24. + num_heads : int + The number of attention heads in the transformer encoder. Defaults to 16. + hidden_dim : int + The hidden dimension of the transformer encoder. Defaults to 1024. + mlp_dim : int + The dimension of the MLP layers in the transformer encoder. Defaults to 4096. + encoder_dropout : float + The dropout rate for the transformer encoder. Defaults to 0.1. + num_classes : int + The number of output classes. Defaults to 1000. + norm_layer : nn.Module, optional + The normalization layer to be used in the decoder. Defaults to None. + decoder_channels : int + The number of channels in the decoder. Defaults to 256. + num_convs : int + The number of convolutional layers in the decoder. Defaults to 4. + up_scale : int + The scale factor for upsampling in the decoder. Defaults to 2. + kernel_size : int + The kernel size for convolutional layers in the decoder. Defaults to 3. + align_corners : bool + Whether to align corners during interpolation in the decoder. Defaults to False. + decoder_dropout : float + The dropout rate for the decoder. Defaults to 0.1. + conv_norm : nn.Module, optional + The normalization layer to be used in the convolutional layers of the decoder. Defaults to None. + conv_act : nn.Module, optional + The activation function to be used in the convolutional layers of the decoder. Defaults to None. + interpolate_mode : str + The interpolation mode for upsampling in the decoder. Defaults to "bilinear". + loss_fn : nn.Module, optional + The loss function to be used during training. Defaults to None. + + """ super().__init__() self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss() + norm_layer = ( + norm_layer if norm_layer is not None else nn.LayerNorm(decoder_channels) + ) + conv_norm = ( + conv_norm if conv_norm is not None else nn.SyncBatchNorm(self.out_channels) + ) + conv_act = conv_act if conv_act is not None else nn.ReLU() + self.model = _SetR_PUP( image_size=image_size, patch_size=patch_size, @@ -308,11 +455,16 @@ def __init__( mlp_dim=mlp_dim, num_classes=num_classes, num_convs=num_convs, + up_scale=up_scale, + kernel_size=kernel_size, + conv_norm=conv_norm, + conv_act=conv_act, decoder_channels=decoder_channels, encoder_dropout=encoder_dropout, decoder_dropout=decoder_dropout, norm_layer=norm_layer, interpolate_mode=interpolate_mode, + align_corners=align_corners, ) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/sslt/models/nets/vit.py b/sslt/models/nets/vit.py index 9fc85e4..bb613e4 100644 --- a/sslt/models/nets/vit.py +++ b/sslt/models/nets/vit.py @@ -30,6 +30,38 @@ def __init__( norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, ): + """ + Initializes a Vision Transformer (ViT) model. + + Parameters + ---------- + image_size : int or tuple[int, int] + The size of the input image. If an int is provided, it is assumed + to be a square image. If a tuple of ints is provided, it represents the height and width of the image. + patch_size : int + The size of each patch in the image. + num_layers : int + The number of transformer layers in the model. + num_heads : int + The number of attention heads in the transformer layers. + hidden_dim : int + The dimensionality of the hidden layers in the transformer. + mlp_dim : int + The dimensionality of the feed-forward MLP layers in the transformer. + dropout : float, optional + The dropout rate to apply. Defaults to 0.0. + attention_dropout : float, optional + The dropout rate to apply to the attention weights. Defaults to 0.0. + num_classes : int, optional + The number of output classes. Defaults to 1000. + norm_layer : Callable[..., torch.nn.Module], optional + The normalization layer to use. Defaults to nn.LayerNorm with epsilon=1e-6. + conv_stem_configs : List[ConvStemConfig], optional + The configuration for the convolutional stem layers. + If provided, the input image will be processed by these convolutional layers before being passed to + the transformer. Defaults to None. + + """ super().__init__() _log_api_usage_once(self) @@ -128,6 +160,14 @@ def __init__( nn.init.zeros_(self.conv_proj.conv_last.bias) def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: + """Process the input tensor and return the reshaped tensor and dimensions. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns. + """ n, c, h, w = x.shape p = self.patch_size @@ -169,6 +209,14 @@ def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: return x, n_h, n_w def forward(self, x: torch.Tensor): + """Forward pass of the Vision Transformer Backbone. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ # Reshape and permute the input tensor x, n_h, n_w = self._process_input(x) n = x.shape[0] diff --git a/test_setr.py b/test_setr.py index 6123e17..372070f 100644 --- a/test_setr.py +++ b/test_setr.py @@ -1,10 +1,10 @@ import torch -from sslt.models.nets.setr import _SetR_PUP +from sslt.models.nets.setr import SETR_PUP if __name__ == "__main__": device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = _SetR_PUP( + model = SETR_PUP( image_size=(16, 32), patch_size=16, num_layers=24, @@ -12,6 +12,13 @@ hidden_dim=1024, mlp_dim=1, num_classes=3, + num_convs=4, + decoder_channels=256, + up_scale=4, + encoder_dropout=0.1, + kernel_size=3, + decoder_dropout=0.1, + conv_act=torch.nn.ReLU(inplace=True), ) model.to(device) result = model.forward(torch.zeros(1, 3, 16, 32).to(device)) From 091afbec8d71beb39b4da17dbdc2ce657195c376 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Mar 2024 23:09:59 -0300 Subject: [PATCH 21/24] Refactor SETR_PUP model and update test_setr.py --- sslt/models/nets/setr.py | 7 +++++-- test_setr.py | 19 ++----------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 034298d..b2ee018 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -102,12 +102,15 @@ def __init__( in_channels = self.out_channels def forward(self, x): - n, c, h, w = x.shape + print(x.shape) + x = x.reshape(n, c, h * w).transpose(1, 2).contiguous() + print(x.shape) x = self.norm(x) x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() + print(x.shape) for up_conv in self.up_convs: x = up_conv(x) @@ -442,7 +445,7 @@ def __init__( norm_layer if norm_layer is not None else nn.LayerNorm(decoder_channels) ) conv_norm = ( - conv_norm if conv_norm is not None else nn.SyncBatchNorm(self.out_channels) + conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels) ) conv_act = conv_act if conv_act is not None else nn.ReLU() diff --git a/test_setr.py b/test_setr.py index 372070f..f59a8ad 100644 --- a/test_setr.py +++ b/test_setr.py @@ -4,22 +4,7 @@ if __name__ == "__main__": device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = SETR_PUP( - image_size=(16, 32), - patch_size=16, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=1, - num_classes=3, - num_convs=4, - decoder_channels=256, - up_scale=4, - encoder_dropout=0.1, - kernel_size=3, - decoder_dropout=0.1, - conv_act=torch.nn.ReLU(inplace=True), - ) + model = SETR_PUP() model.to(device) - result = model.forward(torch.zeros(1, 3, 16, 32).to(device)) + result = model.forward(torch.zeros(1, 3, 512, 512).to(device)) print(result[0].shape) From 10e013c321ffc512f193f177386eb6200d469ef4 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Tue, 12 Mar 2024 23:24:38 -0300 Subject: [PATCH 22/24] bug fixes --- sslt/models/nets/setr.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index b2ee018..8b36b13 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -104,13 +104,9 @@ def __init__( def forward(self, x): n, c, h, w = x.shape - print(x.shape) - x = x.reshape(n, c, h * w).transpose(1, 2).contiguous() - print(x.shape) x = self.norm(x) x = x.transpose(1, 2).reshape(n, c, h, w).contiguous() - print(x.shape) for up_conv in self.up_convs: x = up_conv(x) @@ -441,9 +437,7 @@ def __init__( """ super().__init__() self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss() - norm_layer = ( - norm_layer if norm_layer is not None else nn.LayerNorm(decoder_channels) - ) + norm_layer = norm_layer if norm_layer is not None else nn.LayerNorm(hidden_dim) conv_norm = ( conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels) ) From 68860af165a66d5b9be7135ef5897b8c2b304c5a Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Wed, 13 Mar 2024 04:31:11 -0300 Subject: [PATCH 23/24] Tests and cleanup for PR --- new setr net graph.txt | Bin 41638 -> 0 bytes setr net graph.txt | 706 --------------------------------- sslt/models/nets/setr.py | 3 +- test_setr.py | 10 - tests/models/nets/test_setr.py | 120 ++++++ 5 files changed, 122 insertions(+), 717 deletions(-) delete mode 100644 new setr net graph.txt delete mode 100644 setr net graph.txt delete mode 100644 test_setr.py create mode 100644 tests/models/nets/test_setr.py diff --git a/new setr net graph.txt b/new setr net graph.txt deleted file mode 100644 index 1193795260b751d067e3771e4eaf0299227f4798..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 41638 zcmeHQ+iu%N5S`}&{Rh952I?Y~6ZZn>q0Xg!aDvpAz6!EUJFX>Kk>n)y$G1JR!_jIf zQm$>Kppyea+9kQWoS8jyI9zhcfBx>tzT`5MJ?Y6?Im9oMZ{-_V!;)c7A;)qlC&-Ou zU4FsVp8P52GQsZ>dk?U^k8~oZSRczUl-$9#W4Vx{P$r*i)7sTG3;BSyMyPWsf0b=N zFSmDa#7ws2dmL|5o=8WYhWtfny}VUw2yYt8dmEhNI!zqS2&v4=S8wDgN9vb@B<`Nsm{^n2-2&4RIXRZ4Ipcd+HUvQ zzLvYjzA;Kz>)KoIaWw6jDO%Ioi;|giY47#IQLVKeB(nBW9d}UXGW1-1o7V0C$L>KN zYhiaL;~Ugf1Z1HlMx&_^xNLyM=+r zI>2%*f8eOlwR#>vA6V{veN&#~odsX1G?- ztK`phof(hAc-%)Cmc^8e$2B&uzZU)4<8i-!A&-;ComV|2k8@nbv5=-2M!4-tW%XkX zdL$i7_t9F8tC;PV$KyWsLmnrOr@pG*#CV+JDpgb%&yvU6HLmikJnxam$>UXzC66;6 zuVTX18IMm9<=p?t6}M{rNO=9$Z=@A;d6&oMEE*+ylp)G6M39GP9DbwT6Z^+|&ledq_Ij56z^Q-TIaGxcU zlgS+mX-Xzn4B@WVOH-+;%$>dEw%b1)o`QkZ_Om6QQ zZX1(xh9{YvOuoohjLDl&K_+h}lXH$InY_to8Iv<6Up&W=$=k@}9nSIOS34KJ66@Ye zCMT0S7Sfb=Jll51vorg8GM^{s^W>?wa%K=`22~N^>tyowF}dG^CX$lS6qczhh4y`?F+RH$$xfj@}Cr zoRu}7qoqPVKoc$B!Tcwa+x{=YF$(+-LcMdjLB1QbSDp2(Vwz5%+ZfuYOzGYzx#7Af zE$Cm4y^m#|J=n4ttUZE8YDd+6Qd;A6uo?ndE9kCK#d5pr-cdloNkG|IrC)UaQIJ?M`zlDT$V~f7B{%8) z{c-#7gaUTkL5(TWcXK`w(NpyW>#-rM>z;6hTo-n9@2>X;l~R$_$B{13U+N_V7WJF6 zfXHc(Lv^V4p{Ri~>R#yEL-bRVt{$@;29NGT(osM}U5&*ylo>W=-$veo2A}YII8vf@ zV~#j6-h6C0iKvOW3`o8LR(kL$<8Sx5Rk8a#*O1?r^ZQ0Px9QgbJ1g_xdHc`jAEy`3 zYw)%5ee~X#Tl8Z;^Q_8`BmdKjmwT~AnGZkr&SO1V|8e?ll>ck>+G<4m&=@{iH0Fqn tFF7*vD+Dt7-9|Kc-OcOnx~phoWcRB3Uw5zSd>?r{=E&$@2BTY~{x1x%gSh|z diff --git a/setr net graph.txt b/setr net graph.txt deleted file mode 100644 index 16a4b5e..0000000 --- a/setr net graph.txt +++ /dev/null @@ -1,706 +0,0 @@ -EncoderDecoder( - (backbone): VisionTransformer( - (patch_embed): PatchEmbed( - (adap_padding): AdaptivePadding() - (projection): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16)) - ) - (drop_after_pos): Dropout(p=0.0, inplace=False) - (layers): ModuleList( - (0): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (1): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (2): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (3): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (4): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (5): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (6): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (7): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (8): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (9): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (10): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (11): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (12): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (13): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (14): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (15): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (16): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (17): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (18): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (19): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (20): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (21): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (22): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - (23): TransformerEncoderLayer( - (ln1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (attn): MultiheadAttention( - (attn): MultiheadAttention( - (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True) - ) - (proj_drop): Dropout(p=0.0, inplace=False) - (dropout_layer): Dropout(p=0.0, inplace=False) - ) - (ln2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (ffn): FFN( - (activate): GELU() - (layers): Sequential( - (0): Sequential( - (0): Linear(in_features=1024, out_features=4096, bias=True) - (1): GELU() - (2): Dropout(p=0.0, inplace=False) - ) - (1): Linear(in_features=4096, out_features=1024, bias=True) - (2): Dropout(p=0.0, inplace=False) - ) - (dropout_layer): Identity() - ) - ) - ) - ) - init_cfg={'type': 'Pretrained', 'checkpoint': 'pretrain/vit_large_p16.pth'} - (decode_head): SETRUPHead( - input_transform=None, ignore_index=255, align_corners=False - (loss_decode): CrossEntropyLoss(avg_non_ignore=False) - (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) - (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (up_convs): ModuleList( - (0): Sequential( - (0): ConvModule( - (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - (1): Sequential( - (0): ConvModule( - (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - (2): Sequential( - (0): ConvModule( - (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - (3): Sequential( - (0): ConvModule( - (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - ) - ) - init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] - (auxiliary_head): ModuleList( - (0): SETRUPHead( - input_transform=None, ignore_index=255, align_corners=False - (loss_decode): CrossEntropyLoss(avg_non_ignore=False) - (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) - (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (up_convs): ModuleList( - (0): Sequential( - (0): ConvModule( - (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - (1): Sequential( - (0): ConvModule( - (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - ) - ) - init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] - (1): SETRUPHead( - input_transform=None, ignore_index=255, align_corners=False - (loss_decode): CrossEntropyLoss(avg_non_ignore=False) - (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) - (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (up_convs): ModuleList( - (0): Sequential( - (0): ConvModule( - (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - (1): Sequential( - (0): ConvModule( - (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - ) - ) - init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] - (2): SETRUPHead( - input_transform=None, ignore_index=255, align_corners=False - (loss_decode): CrossEntropyLoss(avg_non_ignore=False) - (conv_seg): Conv2d(256, 6, kernel_size=(1, 1), stride=(1, 1)) - (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) - (up_convs): ModuleList( - (0): Sequential( - (0): ConvModule( - (conv): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - (1): Sequential( - (0): ConvModule( - (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - (bn): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - (activate): ReLU(inplace=True) - ) - (1): Upsample() - ) - ) - ) - init_cfg=[{'type': 'Constant', 'val': 1.0, 'bias': 0, 'layer': 'LayerNorm'}, {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}] - ) -) \ No newline at end of file diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 8b36b13..27a3ba6 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -442,6 +442,7 @@ def __init__( conv_norm if conv_norm is not None else nn.SyncBatchNorm(decoder_channels) ) conv_act = conv_act if conv_act is not None else nn.ReLU() + self.num_classes = num_classes self.model = _SetR_PUP( image_size=image_size, @@ -506,7 +507,7 @@ def _single_step( """ x, y = batch y_hat = self.model(x) - loss = self._loss_func(y_hat, y) + loss = self._loss_func(y_hat[0], y.squeeze(1)) self.log( f"{step_name}_loss", loss, diff --git a/test_setr.py b/test_setr.py deleted file mode 100644 index f59a8ad..0000000 --- a/test_setr.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch - -from sslt.models.nets.setr import SETR_PUP - -if __name__ == "__main__": - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = SETR_PUP() - model.to(device) - result = model.forward(torch.zeros(1, 3, 512, 512).to(device)) - print(result[0].shape) diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py new file mode 100644 index 0000000..9478b5e --- /dev/null +++ b/tests/models/nets/test_setr.py @@ -0,0 +1,120 @@ +import pytest +import torch + +from sslt.models.nets.setr import SETR_PUP + + +def test_setr_pup_forward(): + # Create a dummy input + x = torch.randn(1, 3, 512, 512) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Perform forward pass + output = model(x) + + # Check if the output has the expected shape + assert output.shape == (1, model.num_classes) + + +def test_setr_pup_loss_func(): + # Create dummy input and target tensors + y_hat = torch.randn(1, 1000) + y = torch.randint(0, 1000, (1,)) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Calculate the loss + loss = model._loss_func(y_hat, y) + + # Check if the loss is a scalar tensor + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 + + +def test_setr_pup_single_step(): + # Create dummy input and target tensors + x = torch.randn(1, 3, 512, 512) + y = torch.randint(0, 1000, (1,)) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Perform a single training step + loss = model._single_step((x, y), 0, "train") + + # Check if the loss is a scalar tensor + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 + + +def test_setr_pup_training_step(): + # Create dummy input and target tensors + x = torch.randn(1, 3, 512, 512) + y = torch.randint(0, 1000, (1,)) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Perform a training step + output = model.training_step((x, y), 0) + + # Check if the output is None + assert output is None + + +def test_setr_pup_validation_step(): + # Create dummy input and target tensors + x = torch.randn(1, 3, 512, 512) + y = torch.randint(0, 1000, (1,)) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Perform a validation step + output = model.validation_step((x, y), 0) + + # Check if the output is None + assert output is None + + +def test_setr_pup_test_step(): + # Create dummy input and target tensors + x = torch.randn(1, 3, 512, 512) + y = torch.randint(0, 1000, (1,)) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Perform a test step + output = model.test_step((x, y), 0) + + # Check if the output is None + assert output is None + + +def test_setr_pup_predict_step(): + # Create a dummy input + x = torch.randn(1, 3, 512, 512) + + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Perform a predict step + output = model.predict_step((x, None), 0, 0) + + # Check if the output has the expected shape + assert output.shape == (1, model.num_classes) + + +def test_setr_pup_configure_optimizers(): + # Initialize the SETR_PUP model + model = SETR_PUP() + + # Configure optimizers + optimizers = model.configure_optimizers() + + # Check if the optimizers are instances of torch.optim.Optimizer + assert isinstance(optimizers, torch.optim.Optimizer) From 1336318361ad4f13fe61d9ca6d052838874c77d0 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Wed, 13 Mar 2024 04:46:08 -0300 Subject: [PATCH 24/24] Refactor SETR_PUP test functions and minor fixes --- sslt/models/nets/setr.py | 4 +- tests/models/nets/test_setr.py | 127 +++++---------------------------- 2 files changed, 22 insertions(+), 109 deletions(-) diff --git a/sslt/models/nets/setr.py b/sslt/models/nets/setr.py index 27a3ba6..44fe6f5 100644 --- a/sslt/models/nets/setr.py +++ b/sslt/models/nets/setr.py @@ -527,7 +527,9 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int): def test_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, "test") - def predict_step(self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int): + def predict_step( + self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None + ): x, _ = batch return self.model(x) diff --git a/tests/models/nets/test_setr.py b/tests/models/nets/test_setr.py index 9478b5e..3506f3f 100644 --- a/tests/models/nets/test_setr.py +++ b/tests/models/nets/test_setr.py @@ -4,117 +4,28 @@ from sslt.models.nets.setr import SETR_PUP -def test_setr_pup_forward(): - # Create a dummy input - x = torch.randn(1, 3, 512, 512) - - # Initialize the SETR_PUP model - model = SETR_PUP() - - # Perform forward pass - output = model(x) - - # Check if the output has the expected shape - assert output.shape == (1, model.num_classes) - - -def test_setr_pup_loss_func(): - # Create dummy input and target tensors - y_hat = torch.randn(1, 1000) - y = torch.randint(0, 1000, (1,)) - - # Initialize the SETR_PUP model - model = SETR_PUP() - - # Calculate the loss - loss = model._loss_func(y_hat, y) - - # Check if the loss is a scalar tensor - assert isinstance(loss, torch.Tensor) - assert loss.dim() == 0 - - -def test_setr_pup_single_step(): - # Create dummy input and target tensors - x = torch.randn(1, 3, 512, 512) - y = torch.randint(0, 1000, (1,)) - - # Initialize the SETR_PUP model +def test_wisenet_loss(): model = SETR_PUP() + batch_size = 2 + x = torch.rand(2, 3, 512, 512) + mask = torch.rand(2, 1, 512, 512).long() - # Perform a single training step - loss = model._single_step((x, y), 0, "train") - - # Check if the loss is a scalar tensor - assert isinstance(loss, torch.Tensor) - assert loss.dim() == 0 - + # Do the training step + loss = model.training_step((x, mask), 0).item() + assert loss is not None + assert loss >= 0, f"Expected non-negative loss, but got {loss}" -def test_setr_pup_training_step(): - # Create dummy input and target tensors - x = torch.randn(1, 3, 512, 512) - y = torch.randint(0, 1000, (1,)) - # Initialize the SETR_PUP model +def test_wisenet_predict(): model = SETR_PUP() + batch_size = 2 + mask_shape = (batch_size, 1000, 512, 512) # (2, 1, 500, 500) + x = torch.rand(2, 3, 512, 512) + mask = torch.rand(2, 1, 512, 512).long() - # Perform a training step - output = model.training_step((x, y), 0) - - # Check if the output is None - assert output is None - - -def test_setr_pup_validation_step(): - # Create dummy input and target tensors - x = torch.randn(1, 3, 512, 512) - y = torch.randint(0, 1000, (1,)) - - # Initialize the SETR_PUP model - model = SETR_PUP() - - # Perform a validation step - output = model.validation_step((x, y), 0) - - # Check if the output is None - assert output is None - - -def test_setr_pup_test_step(): - # Create dummy input and target tensors - x = torch.randn(1, 3, 512, 512) - y = torch.randint(0, 1000, (1,)) - - # Initialize the SETR_PUP model - model = SETR_PUP() - - # Perform a test step - output = model.test_step((x, y), 0) - - # Check if the output is None - assert output is None - - -def test_setr_pup_predict_step(): - # Create a dummy input - x = torch.randn(1, 3, 512, 512) - - # Initialize the SETR_PUP model - model = SETR_PUP() - - # Perform a predict step - output = model.predict_step((x, None), 0, 0) - - # Check if the output has the expected shape - assert output.shape == (1, model.num_classes) - - -def test_setr_pup_configure_optimizers(): - # Initialize the SETR_PUP model - model = SETR_PUP() - - # Configure optimizers - optimizers = model.configure_optimizers() - - # Check if the optimizers are instances of torch.optim.Optimizer - assert isinstance(optimizers, torch.optim.Optimizer) + # Do the prediction step + preds = model.predict_step((x, mask), 0) + assert preds is not None + assert ( + preds[0].shape == mask_shape + ), f"Expected shape {mask_shape}, but got {preds[0].shape}"