Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
GabrielBG0 committed Apr 11, 2024
1 parent 97ab781 commit edc1bb4
Showing 1 changed file with 39 additions and 2 deletions.
41 changes: 39 additions & 2 deletions tests/transforms/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
from typing import Sequence

import numpy as np
import pytest

from sslt.transforms import Flip, TransformPipeline, _Transform
from sslt.transforms import Flip, PerlinMasker, TransformPipeline, _Transform


def test_transform_pipeline():
Expand Down Expand Up @@ -59,3 +59,40 @@ def test_flip_invalid_axes():
# Check if an AssertionError is raised when applying the transform
with pytest.raises(AssertionError):
flipped_x = flip_transform(x)


def test_perlin_masker():
# Create a dummy input
x = np.random.rand(10, 20)

# Apply the PerlinMasker transform
perlin_masker = PerlinMasker(octaves=3, scale=2)
masked_x = perlin_masker(x)

# Check if the masked data has the same shape as the input
assert masked_x.shape == x.shape

# Check if the masked data has zeros at the positions where Perlin noise is negative
noise = perlin_masker.noise
denom = perlin_masker.scale * max(x.shape)
for pos in np.ndindex(*x.shape):
expected_value = (noise([i/denom for i in pos]) < 0)
assert masked_x[pos] == pytest.approx(x[pos] * expected_value)

def test_perlin_masker_invalid_octaves():
# Create a dummy input
x = np.random.rand(10, 20)

# Check if a ValueError is raised when using invalid octaves
with pytest.raises(ValueError):
perlin_masker = PerlinMasker(octaves=-1)
masked_x = perlin_masker(x)

def test_perlin_masker_invalid_scale():
# Create a dummy input
x = np.random.rand(10, 20)

# Check if a ValueError is raised when using invalid scale
with pytest.raises(ValueError):
perlin_masker = PerlinMasker(octaves=3, scale=0)
masked_x = perlin_masker(x)

0 comments on commit edc1bb4

Please sign in to comment.