Skip to content

Commit

Permalink
Fix in GaussNoise (#1801)
Browse files Browse the repository at this point in the history
* Fix in GaussNoise

* Added tests for GaussNoise
  • Loading branch information
ternaus authored Jun 19, 2024
1 parent a6ceade commit a209d6a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
7 changes: 6 additions & 1 deletion albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import cv2
import numpy as np
import skimage
from albucore.functions import add, add_weighted, multiply, multiply_add
from albucore.functions import add, add_array, add_weighted, multiply, multiply_add
from albucore.utils import (
MAX_VALUES_BY_DTYPE,
clip,
Expand Down Expand Up @@ -1536,3 +1536,8 @@ def generate_approx_gaussian_noise(
# Upsample the noise to the original shape using OpenCV
result = cv2.resize(low_res_noise, (shape[1], shape[0]), interpolation=cv2.INTER_LINEAR)
return result.reshape(shape)


@clipped
def add_noise(img: np.ndarray, noise: np.ndarray) -> np.ndarray:
return add_array(img, noise)
4 changes: 2 additions & 2 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ def __init__(
self.noise_scale_factor = noise_scale_factor

def apply(self, img: np.ndarray, gauss: np.ndarray, **params: Any) -> np.ndarray:
return albucore.add_array(img, gauss)
return fmain.add_noise(img, gauss)

def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, float]:
image = params["image"]
Expand All @@ -1673,7 +1673,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, f
if image.ndim > MONO_CHANNEL_DIMENSIONS:
gauss = np.expand_dims(gauss, -1)

return {"gauss": gauss.astype(image.dtype)}
return {"gauss": gauss}

@property
def targets_as_params(self) -> List[str]:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,4 +1748,18 @@ def test_random_fog_initialization(params, expected):
])
def test_random_fog_invalid_input(params):
with pytest.raises(Exception):
img_fog = A.RandomFog(**params)
A.RandomFog(**params)


@pytest.mark.parametrize("image", IMAGES + [np.full((10, 10), 128, dtype=np.uint8)])
@pytest.mark.parametrize("mean", (0, 10, -10))
def test_gauss_noise(mean, image):
set_seed(42)
aug = A.GaussNoise(p=1, noise_scale_factor=1.0, mean=mean)

apply_params = aug.get_params_dependent_on_targets(params = {"image":image })

assert np.abs(mean - apply_params["gauss"].mean()) < 0.5
result = A.Compose([aug])(image=image)

assert not (result["image"] >= image).all()

0 comments on commit a209d6a

Please sign in to comment.