diff --git a/requirements.txt b/requirements.txt index 6d62e21..6ab5049 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,6 @@ statsmodels tifffile torch zarr -torchmetrics \ No newline at end of file +rich +perlin-noise +torchmetrics diff --git a/sslt/models/nets/deeplabv3.py b/sslt/models/nets/deeplabv3.py new file mode 100644 index 0000000..de3d091 --- /dev/null +++ b/sslt/models/nets/deeplabv3.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import lightning as L +from torchvision import models + +from .base import SimpleSupervisedModel + +class Resnet50Backbone(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.resnet50 = models.resnet50() + self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-2]) + + def forward(self, x): + return self.resnet50(x) + +class DeepLabV3_Head(nn.Module): + + def __init__(self) -> None: + super().__init__() + raise NotImplementedError("DeepLabV3's head has not yet been implemented") + + def forward(self, x): + raise NotImplementedError("DeepLabV3's head has not yet been implemented") + + +class DeepLabV3(SimpleSupervisedModel): + + """A DeeplabV3 with a ResNet50 backbone + + References + ---------- + Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017 + """ + + def __init__(self, learning_rate: float = 1e-3,loss_fn: torch.nn.Module = None): + """Wrapper implementation of the DeepLabv3 model. + + Parameters + ---------- + learning_rate : float, optional + The learning rate to Adam optimizer, by default 1e-3 + loss_fn : torch.nn.Module, optional + The function used to compute the loss. If `None`, it will be used + the MSELoss, by default None. + """ + super().__init__( + backbone=Resnet50Backbone(), + fc=DeepLabV3_Head(), + loss_fn=loss_fn or torch.nn.MSELoss(), + learning_rate=learning_rate, + ) diff --git a/sslt/transforms/transform.py b/sslt/transforms/transform.py index 2fb3776..107a72e 100644 --- a/sslt/transforms/transform.py +++ b/sslt/transforms/transform.py @@ -1,5 +1,7 @@ from typing import Any, List, Sequence - +from perlin_noise import PerlinNoise +from itertools import product +import torch import numpy as np @@ -76,3 +78,40 @@ def __call__(self, x: np.ndarray) -> np.ndarray: x = np.flip(x, axis=axis) return x + + +class PerlinMasker(_Transform): + """Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint""" + + def __init__(self, octaves: int, scale: float = 1): + """Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint + + Parameters + ---------- + octaves: int + Level of detail for the Perlin noise generator + scale: float = 1 + Optionally rescale the Perlin noise. Default is 1 (no rescaling) + """ + if octaves <= 0: raise ValueError(f"Number of octaves must be positive, but got {octaves=}") + if scale == 0: raise ValueError(f"Scale can't be 0") + self.octaves = octaves + self.scale = scale + + def __call__(self, x: np.ndarray) -> np.ndarray: + """Zeroes entries of a tensor according to the sign of Perlin noise. + + Parameters + ---------- + x: np.ndarray + The tensor whose entries to zero. + """ + + mask = np.empty_like(x, dtype=bool) + noise = PerlinNoise(self.octaves, torch.randint(0, 2**32, (1,)).item()) + denom = self.scale * max(x.shape) + + for pos in product(*[range(i) for i in mask.shape]): + mask[pos] = (noise([i/denom for i in pos]) < 0) + + return x * mask diff --git a/tests/transforms/test_transform.py b/tests/transforms/test_transform.py index bbb9d5a..6b6916a 100644 --- a/tests/transforms/test_transform.py +++ b/tests/transforms/test_transform.py @@ -1,9 +1,9 @@ -import pytest from typing import Sequence import numpy as np +import pytest -from sslt.transforms import Flip, TransformPipeline, _Transform +from sslt.transforms import Flip, PerlinMasker, TransformPipeline, _Transform def test_transform_pipeline(): @@ -59,3 +59,40 @@ def test_flip_invalid_axes(): # Check if an AssertionError is raised when applying the transform with pytest.raises(AssertionError): flipped_x = flip_transform(x) + + +def test_perlin_masker(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Apply the PerlinMasker transform + perlin_masker = PerlinMasker(octaves=3, scale=2) + masked_x = perlin_masker(x) + + # Check if the masked data has the same shape as the input + assert masked_x.shape == x.shape + + # Check if the masked data has zeros at the positions where Perlin noise is negative + noise = perlin_masker.noise + denom = perlin_masker.scale * max(x.shape) + for pos in np.ndindex(*x.shape): + expected_value = (noise([i/denom for i in pos]) < 0) + assert masked_x[pos] == pytest.approx(x[pos] * expected_value) + +def test_perlin_masker_invalid_octaves(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Check if a ValueError is raised when using invalid octaves + with pytest.raises(ValueError): + perlin_masker = PerlinMasker(octaves=-1) + masked_x = perlin_masker(x) + +def test_perlin_masker_invalid_scale(): + # Create a dummy input + x = np.random.rand(10, 20) + + # Check if a ValueError is raised when using invalid scale + with pytest.raises(ValueError): + perlin_masker = PerlinMasker(octaves=3, scale=0) + masked_x = perlin_masker(x) \ No newline at end of file