Skip to content

Commit

Permalink
Merge pull request #44 from discovery-unicamp/perlin-noise
Browse files Browse the repository at this point in the history
Perlin noise
  • Loading branch information
GabrielBG0 authored Apr 11, 2024
2 parents b45dee2 + edc1bb4 commit 7927a74
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ statsmodels
tifffile
torch
zarr
torchmetrics
rich
perlin-noise
torchmetrics
54 changes: 54 additions & 0 deletions sslt/models/nets/deeplabv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L
from torchvision import models

from .base import SimpleSupervisedModel

class Resnet50Backbone(nn.Module):

def __init__(self) -> None:
super().__init__()
self.resnet50 = models.resnet50()
self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-2])

def forward(self, x):
return self.resnet50(x)

class DeepLabV3_Head(nn.Module):

def __init__(self) -> None:
super().__init__()
raise NotImplementedError("DeepLabV3's head has not yet been implemented")

def forward(self, x):
raise NotImplementedError("DeepLabV3's head has not yet been implemented")


class DeepLabV3(SimpleSupervisedModel):

"""A DeeplabV3 with a ResNet50 backbone
References
----------
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam. "Rethinking Atrous Convolution for Semantic Image Segmentation", 2017
"""

def __init__(self, learning_rate: float = 1e-3,loss_fn: torch.nn.Module = None):
"""Wrapper implementation of the DeepLabv3 model.
Parameters
----------
learning_rate : float, optional
The learning rate to Adam optimizer, by default 1e-3
loss_fn : torch.nn.Module, optional
The function used to compute the loss. If `None`, it will be used
the MSELoss, by default None.
"""
super().__init__(
backbone=Resnet50Backbone(),
fc=DeepLabV3_Head(),
loss_fn=loss_fn or torch.nn.MSELoss(),
learning_rate=learning_rate,
)
41 changes: 40 additions & 1 deletion sslt/transforms/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, List, Sequence

from perlin_noise import PerlinNoise
from itertools import product
import torch
import numpy as np


Expand Down Expand Up @@ -76,3 +78,40 @@ def __call__(self, x: np.ndarray) -> np.ndarray:
x = np.flip(x, axis=axis)

return x


class PerlinMasker(_Transform):
"""Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint"""

def __init__(self, octaves: int, scale: float = 1):
"""Zeroes entries of a tensor according to the sign of Perlin noise. Seed for the noise generator given by torch.randint
Parameters
----------
octaves: int
Level of detail for the Perlin noise generator
scale: float = 1
Optionally rescale the Perlin noise. Default is 1 (no rescaling)
"""
if octaves <= 0: raise ValueError(f"Number of octaves must be positive, but got {octaves=}")
if scale == 0: raise ValueError(f"Scale can't be 0")
self.octaves = octaves
self.scale = scale

def __call__(self, x: np.ndarray) -> np.ndarray:
"""Zeroes entries of a tensor according to the sign of Perlin noise.
Parameters
----------
x: np.ndarray
The tensor whose entries to zero.
"""

mask = np.empty_like(x, dtype=bool)
noise = PerlinNoise(self.octaves, torch.randint(0, 2**32, (1,)).item())
denom = self.scale * max(x.shape)

for pos in product(*[range(i) for i in mask.shape]):
mask[pos] = (noise([i/denom for i in pos]) < 0)

return x * mask
41 changes: 39 additions & 2 deletions tests/transforms/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
from typing import Sequence

import numpy as np
import pytest

from sslt.transforms import Flip, TransformPipeline, _Transform
from sslt.transforms import Flip, PerlinMasker, TransformPipeline, _Transform


def test_transform_pipeline():
Expand Down Expand Up @@ -59,3 +59,40 @@ def test_flip_invalid_axes():
# Check if an AssertionError is raised when applying the transform
with pytest.raises(AssertionError):
flipped_x = flip_transform(x)


def test_perlin_masker():
# Create a dummy input
x = np.random.rand(10, 20)

# Apply the PerlinMasker transform
perlin_masker = PerlinMasker(octaves=3, scale=2)
masked_x = perlin_masker(x)

# Check if the masked data has the same shape as the input
assert masked_x.shape == x.shape

# Check if the masked data has zeros at the positions where Perlin noise is negative
noise = perlin_masker.noise
denom = perlin_masker.scale * max(x.shape)
for pos in np.ndindex(*x.shape):
expected_value = (noise([i/denom for i in pos]) < 0)
assert masked_x[pos] == pytest.approx(x[pos] * expected_value)

def test_perlin_masker_invalid_octaves():
# Create a dummy input
x = np.random.rand(10, 20)

# Check if a ValueError is raised when using invalid octaves
with pytest.raises(ValueError):
perlin_masker = PerlinMasker(octaves=-1)
masked_x = perlin_masker(x)

def test_perlin_masker_invalid_scale():
# Create a dummy input
x = np.random.rand(10, 20)

# Check if a ValueError is raised when using invalid scale
with pytest.raises(ValueError):
perlin_masker = PerlinMasker(octaves=3, scale=0)
masked_x = perlin_masker(x)

0 comments on commit 7927a74

Please sign in to comment.