From 834268ba6e277e12b8a46532549775fb8500e304 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Mon, 17 Jun 2024 16:10:08 -0700 Subject: [PATCH] Pr 1789 (#1790) * PlanckianJitter implementation * Fix pre-commit errors * Fixes on top of pr_1789 * Docstring * Fix in const dict * Fix * Merge * refactoring * Added PlankianJitter * Added PlankianJitter to Readme * pre-commit fixes * Fix in readme * Fix in readme * Fix in readme * Fix in readme --------- Co-authored-by: Jamil Zakirov --- .pre-commit-config.yaml | 2 +- README.md | 75 +++++------ albumentations/augmentations/functional.py | 82 ++++++++++++ albumentations/augmentations/transforms.py | 139 +++++++++++++++++++-- albumentations/core/types.py | 2 + requirements-dev.txt | 2 +- tests/test_augmentations.py | 5 + tests/test_core.py | 6 +- tests/test_functional.py | 38 ++++++ tests/test_serialization.py | 3 +- tests/test_transforms.py | 4 +- 11 files changed, 307 insertions(+), 51 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c9a51c32c..2e64b998f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: types: [python] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.8 + rev: v0.4.9 hooks: # Run the linter. - id: ruff diff --git a/README.md b/README.md index 29bb438c8..e9b57ecdd 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,7 @@ Pixel-level transforms will change just an input image and will leave any additi - [MultiplicativeNoise](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.MultiplicativeNoise) - [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) - [PixelDistributionAdaptation](https://albumentations.ai/docs/api_reference/augmentations/domain_adaptation/#albumentations.augmentations.domain_adaptation.PixelDistributionAdaptation) +- [PlanckianJitter](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.PlanckianJitter) - [Posterize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Posterize) - [RGBShift](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RGBShift) - [RandomBrightnessContrast](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RandomBrightnessContrast) @@ -288,43 +289,43 @@ To run the benchmark yourself, follow the instructions in [benchmark/README.md]( Results for running the benchmark on the first 2000 images from the ImageNet validation set using an AMD Ryzen Threadripper 3970X CPU. The table shows how many images per second can be processed on a single core; higher is better. -| Library | Version | -|---------|---------| -| Python | 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] | -| albumentations | 1.4.8 | -| imgaug | 0.4.0 | -| torchvision | 0.18.1+rocm6.0 | -| numpy | 1.26.4 | -| opencv-python-headless | 4.10.0.82 | -| scikit-image | 0.23.2 | -| scipy | 1.13.1 | -| pillow | 10.3.0 | -| kornia | 0.7.2 | -| augly | 1.0.0 | - -| |albumentations
1.4.8|torchvision
0.18.1+rocm6.0|kornia
0.7.2|augly
1.0.0|imgaug
0.4.0| -|-----------------|--------------------------------------|--------------------------------------------|------------------------------|-----------------------------|------------------------------| -|HorizontalFlip |**8084 ± 30** |2422 ± 16 |940 ± 10 |3633 ± 7 |4869 ± 10 | -|VerticalFlip |7330 ± 11 |2541 ± 2 |945 ± 4 |4807 ± 4 |**8400 ± 11** | -|Rotate |535 ± 5 |144 ± 2 |202 ± 1 |**572 ± 2** |494 ± 1 | -|Affine |**1504 ± 46** |153 ± 1 |197 ± 1 |- |671 ± 2 | -|Equalize |1005 ± 1 |328 ± 2 |76 ± 1 |- |**1165 ± 1** | -|RandomCrop64 |20880 ± 170 |15792 ± 23 |833 ± 1 |**21313 ± 603** |5547 ± 2 | -|RandomResizedCrop|**2272 ± 6** |1113 ± 5 |189 ± 1 |- |- | -|ShiftRGB |**1708 ± 2** |- |425 ± 1 |- |1480 ± 11 | -|Resize |**2209 ± 1** |1285 ± 3 |200 ± 1 |430 ± 1 |1690 ± 1 | -|RandomGamma |**3638 ± 7** |229 ± 1 |213 ± 1 |- |2307 ± 3 | -|Grayscale |**7234 ± 4** |1628 ± 6 |447 ± 1 |2535 ± 1 |1052 ± 6 | -|ColorJitter |**438 ± 1** |50 ± 1 |47 ± 1 |214 ± 1 |- | -|RandomPerspective|457 ± 2 |122 ± 1 |115 ± 1 |- |**460 ± 2** | -|GaussianBlur |**2073 ± 1** |110 ± 2 |74 ± 2 |162 ± 1 |1271 ± 1 | -|MedianBlur |536 ± 1 |- |3 ± 0 |- |**564 ± 1** | -|MotionBlur |**2156 ± 8** |- |98 ± 1 |- |503 ± 1 | -|Posterize |**3435 ± 2** |2574 ± 2 |312 ± 9 |- |1894 ± 1 | -|JpegCompression |**805 ± 1** |- |- |679 ± 16 |427 ± 1 | -|GaussianNoise |**239 ± 1** |- |- |67 ± 1 |124 ± 1 | -|Elastic |126 ± 1 |4 ± 0 |1 ± 0 |- |**128 ± 1** | -|Normalize |**1056 ± 1** |429 ± 1 |398 ± 1 |- |- | +| Library | Version | +| ---------------------- | -------------------------------------------------- | +| Python | 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] | +| albumentations | 1.4.8 | +| imgaug | 0.4.0 | +| torchvision | 0.18.1+rocm6.0 | +| numpy | 1.26.4 | +| opencv-python-headless | 4.10.0.82 | +| scikit-image | 0.23.2 | +| scipy | 1.13.1 | +| pillow | 10.3.0 | +| kornia | 0.7.2 | +| augly | 1.0.0 | + +| | albumentations
1.4.8 | torchvision
0.18.1+rocm6.0 | kornia
0.7.2 | augly
1.0.0 | imgaug
0.4.0 | +| ----------------- | -------------------------------------- | -------------------------------------------- | ------------------------------ | ----------------------------- | ------------------------------ | +| HorizontalFlip | **8084 ± 30** | 2422 ± 16 | 940 ± 10 | 3633 ± 7 | 4869 ± 10 | +| VerticalFlip | 7330 ± 11 | 2541 ± 2 | 945 ± 4 | 4807 ± 4 | **8400 ± 11** | +| Rotate | 535 ± 5 | 144 ± 2 | 202 ± 1 | **572 ± 2** | 494 ± 1 | +| Affine | **1504 ± 46** | 153 ± 1 | 197 ± 1 | - | 671 ± 2 | +| Equalize | 1005 ± 1 | 328 ± 2 | 76 ± 1 | - | **1165 ± 1** | +| RandomCrop64 | 20880 ± 170 | 15792 ± 23 | 833 ± 1 | **21313 ± 603** | 5547 ± 2 | +| RandomResizedCrop | **2272 ± 6** | 1113 ± 5 | 189 ± 1 | - | - | +| ShiftRGB | **1708 ± 2** | - | 425 ± 1 | - | 1480 ± 11 | +| Resize | **2209 ± 1** | 1285 ± 3 | 200 ± 1 | 430 ± 1 | 1690 ± 1 | +| RandomGamma | **3638 ± 7** | 229 ± 1 | 213 ± 1 | - | 2307 ± 3 | +| Grayscale | **7234 ± 4** | 1628 ± 6 | 447 ± 1 | 2535 ± 1 | 1052 ± 6 | +| ColorJitter | **438 ± 1** | 50 ± 1 | 47 ± 1 | 214 ± 1 | - | +| RandomPerspective | 457 ± 2 | 122 ± 1 | 115 ± 1 | - | **460 ± 2** | +| GaussianBlur | **2073 ± 1** | 110 ± 2 | 74 ± 2 | 162 ± 1 | 1271 ± 1 | +| MedianBlur | 536 ± 1 | - | 3 ± 0 | - | **564 ± 1** | +| MotionBlur | **2156 ± 8** | - | 98 ± 1 | - | 503 ± 1 | +| Posterize | **3435 ± 2** | 2574 ± 2 | 312 ± 9 | - | 1894 ± 1 | +| JpegCompression | **805 ± 1** | - | - | 679 ± 16 | 427 ± 1 | +| GaussianNoise | **239 ± 1** | - | - | 67 ± 1 | 124 ± 1 | +| Elastic | 126 ± 1 | 4 ± 0 | 1 ± 0 | - | **128 ± 1** | +| Normalize | **1056 ± 1** | 429 ± 1 | 398 ± 1 | - | - | ## Contributing diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 234fb2429..f192cf4ea 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -29,6 +29,7 @@ ColorType, ImageMode, NumericType, + PlanckianJitterMode, SizeType, SpatterMode, ) @@ -1436,6 +1437,87 @@ def center(width: NumericType, height: NumericType) -> Tuple[float, float]: return width / 2 - 0.5, height / 2 - 0.5 +PLANCKIAN_COEFFS = { + "blackbody": { + 3_000: [0.6743, 0.4029, 0.0013], + 3_500: [0.6281, 0.4241, 0.1665], + 4_000: [0.5919, 0.4372, 0.2513], + 4_500: [0.5623, 0.4457, 0.3154], + 5_000: [0.5376, 0.4515, 0.3672], + 5_500: [0.5163, 0.4555, 0.4103], + 6_000: [0.4979, 0.4584, 0.4468], + 6_500: [0.4816, 0.4604, 0.4782], + 7_000: [0.4672, 0.4619, 0.5053], + 7_500: [0.4542, 0.4630, 0.5289], + 8_000: [0.4426, 0.4638, 0.5497], + 8_500: [0.4320, 0.4644, 0.5681], + 9_000: [0.4223, 0.4648, 0.5844], + 9_500: [0.4135, 0.4651, 0.5990], + 10_000: [0.4054, 0.4653, 0.6121], + 10_500: [0.3980, 0.4654, 0.6239], + 11_000: [0.3911, 0.4655, 0.6346], + 11_500: [0.3847, 0.4656, 0.6444], + 12_000: [0.3787, 0.4656, 0.6532], + 12_500: [0.3732, 0.4656, 0.6613], + 13_000: [0.3680, 0.4655, 0.6688], + 13_500: [0.3632, 0.4655, 0.6756], + 14_000: [0.3586, 0.4655, 0.6820], + 14_500: [0.3544, 0.4654, 0.6878], + 15_000: [0.3503, 0.4653, 0.6933], + }, + "cied": { + 4_000: [0.5829, 0.4421, 0.2288], + 4_500: [0.5510, 0.4514, 0.2948], + 5_000: [0.5246, 0.4576, 0.3488], + 5_500: [0.5021, 0.4618, 0.3941], + 6_000: [0.4826, 0.4646, 0.4325], + 6_500: [0.4654, 0.4667, 0.4654], + 7_000: [0.4502, 0.4681, 0.4938], + 7_500: [0.4364, 0.4692, 0.5186], + 8_000: [0.4240, 0.4700, 0.5403], + 8_500: [0.4127, 0.4705, 0.5594], + 9_000: [0.4023, 0.4709, 0.5763], + 9_500: [0.3928, 0.4713, 0.5914], + 10_000: [0.3839, 0.4715, 0.6049], + 10_500: [0.3757, 0.4716, 0.6171], + 11_000: [0.3681, 0.4717, 0.6281], + 11_500: [0.3609, 0.4718, 0.6380], + 12_000: [0.3543, 0.4719, 0.6472], + 12_500: [0.3480, 0.4719, 0.6555], + 13_000: [0.3421, 0.4719, 0.6631], + 13_500: [0.3365, 0.4719, 0.6702], + 14_000: [0.3313, 0.4719, 0.6766], + 14_500: [0.3263, 0.4719, 0.6826], + 15_000: [0.3217, 0.4719, 0.6882], + }, +} + + +@clipped +def planckian_jitter(img: np.ndarray, temperature: int, mode: PlanckianJitterMode = "blackbody") -> np.ndarray: + img = img.copy() + # Linearly interpolate between 2 closest temperatures + step = 500 + t_left = (temperature // step) * step + t_right = (temperature // step + 1) * step + + w_left = (t_right - temperature) / step + w_right = (temperature - t_left) / step + + coeffs = w_left * np.array(PLANCKIAN_COEFFS[mode][t_left]) + w_right * np.array(PLANCKIAN_COEFFS[mode][t_right]) + + image = img / 255.0 if img.dtype == np.uint8 else img + + image[:, :, 0] = image[:, :, 0] * (coeffs[0] / coeffs[1]) + image[:, :, 2] = image[:, :, 2] * (coeffs[2] / coeffs[1]) + image[image > 1] = 1 + + if img.dtype == np.uint8: + return image * 255.0 + + return image + + def generate_approx_gaussian_noise( shape: SizeType, mean: float = 0, diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index de383da0e..0431086b0 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -52,6 +52,7 @@ ImageMode, KeypointInternalType, MorphologyMode, + PlanckianJitterMode, RainMode, ScaleFloatType, ScaleIntType, @@ -106,6 +107,7 @@ "Spatter", "ChromaticAberration", "Morphological", + "PlanckianJitter", ] NUM_BITS_ARRAY_LENGTH = 3 @@ -2089,7 +2091,6 @@ class Downscale(ImageOnlyTransform): downscaling and upscaling. Should include keys 'downscale' and 'upscale' with cv2 interpolation flags as values. Example: {"downscale": cv2.INTER_NEAREST, "upscale": cv2.INTER_LINEAR}. - always_apply (bool): Deprecated. Defaults to None. Targets: image @@ -2097,10 +2098,6 @@ class Downscale(ImageOnlyTransform): Image types: uint8, float32 - Note: - Previous parameters `scale_min`, `scale_max`, and `interpolation` are deprecated. Use `scale_range` - and `interpolation_pair` for specifying scaling bounds and interpolation methods respectively. - Example: >>> transform = Downscale(scale_range=(0.5, 0.9), interpolation_pair={"downscale": cv2.INTER_AREA, "upscale": cv2.INTER_CUBIC}) @@ -2208,7 +2205,6 @@ class Lambda(NoOp): keypoint: Keypoint transformation function. bbox: BBox transformation function. global_label: Global label transformation function. - always_apply: Deprecated. Defaults to None. p: probability of applying the transform. Default: 1.0. Targets: @@ -3418,7 +3414,6 @@ class Morphological(DualTransform): and maximum sizes for the dilation kernel. operation (str, optional): The morphological operation to apply. Options are 'dilation' or 'erosion'. Default is 'dilation'. - always_apply (bool, optional): Deprecated. Default is None. p (float, optional): The probability of applying this transformation. Default is 0.5. Targets: @@ -3476,3 +3471,133 @@ def targets(self) -> Dict[str, Callable[..., Any]]: "mask": self.apply_to_mask, "masks": self.apply_to_masks, } + + +PLANKIAN_JITTER_CONST = { + "MAX_TEMP": 15000, + "MIN_BLACKBODY_TEMP": 3000, + "MIN_CIED_TEMP": 4000, + "WHITE_TEMP": 6000, + "SAMPLING_TEMP_PROB": 0.4, +} + + +class PlanckianJitter(ImageOnlyTransform): + r"""Randomly jitter the image illuminant along the Planckian locus. + + Physics-based color augmentation creates realistic variations in chromaticity, simulating illumination changes + in a scene. + + Args: + mode (Literal["blackbody", "cied"]): The mode of the transformation. `blackbody` simulates blackbody radiation, + and `cied` uses the CIED illuminant series. + temperature_limit (Tuple[int, int]): Temperature range to sample from. For `blackbody` mode, the range should + be within [3000K, 15000K]. For "cied" mode, the range should be within [4000K, 15000K]. + Higher temperatures produce cooler (bluish) images. + sampling_method (Literal["uniform", "gaussian"]): Method to sample the temperature. + "uniform" samples uniformly across the range, while "gaussian" samples from a Gaussian distribution. + p (float): Probability of applying the transform. Defaults to 0.5. + + Targets: + image + + Image types: + uint8, float32 + + References: + - https://github.com/TheZino/PlanckianJitter + - https://arxiv.org/pdf/2202.07993.pdf + + """ + + class InitSchema(BaseTransformInitSchema): + mode: PlanckianJitterMode = "blackbody" + temperature_limit: Annotated[Tuple[int, int], AfterValidator(nondecreasing)] = (3000, 15000) + sampling_method: Literal["uniform", "gaussian"] = "uniform" + + @model_validator(mode="after") + def validate_temperature(self) -> Self: + max_temp = PLANKIAN_JITTER_CONST["MAX_TEMP"] + + if self.mode == "blackbody" and ( + min(self.temperature_limit) < PLANKIAN_JITTER_CONST["MIN_BLACKBODY_TEMP"] + or max(self.temperature_limit) > max_temp + ): + raise ValueError("Temperature limits for blackbody should be in [3000, 15000] range") + if self.mode == "cied" and ( + min(self.temperature_limit) < PLANKIAN_JITTER_CONST["MIN_CIED_TEMP"] + or max(self.temperature_limit) > max_temp + ): + raise ValueError("Temperature limits for CIED should be in [4000, 15000] range") + + if not self.temperature_limit[0] <= PLANKIAN_JITTER_CONST["WHITE_TEMP"] <= self.temperature_limit[1]: + raise ValueError("White temperature should be within the temperature limits") + + return self + + def __init__( + self, + mode: PlanckianJitterMode = "blackbody", + temperature_limit: Tuple[int, int] = (3000, 15000), + sampling_method: Literal["uniform", "gaussian"] = "uniform", + always_apply: Optional[bool] = None, + p: float = 0.5, + ) -> None: + super().__init__(always_apply=always_apply, p=p) + + self.mode = mode + self.temperature_limit = temperature_limit + self.sampling_method = sampling_method + + def apply(self, img: np.ndarray, temperature: int, **params: Any) -> np.ndarray: + if not is_rgb_image(img): + raise TypeError("PlanckianJitter transformation expects 3-channel images.") + return fmain.planckian_jitter(img, temperature, mode=self.mode) + + def get_params(self) -> Dict[str, Any]: + sampling_prob_boundary = PLANKIAN_JITTER_CONST["SAMPLING_TEMP_PROB"] + sampling_temp_boundary = PLANKIAN_JITTER_CONST["WHITE_TEMP"] + + if self.sampling_method == "uniform": + # Split into 2 cases to avoid selecting cold temperatures (>6000) too often + if random.random() < sampling_prob_boundary: + temperature = ( + random.uniform( + self.temperature_limit[0], + sampling_temp_boundary, + ), + ) + else: + temperature = ( + random.uniform( + sampling_temp_boundary, + self.temperature_limit[1], + ), + ) + elif self.sampling_method == "gaussian": + # Sample values from asymmetric gaussian distribution + if random.random() < sampling_prob_boundary: + # Left side + shift = np.abs( + random.gauss( + 0, + np.abs(sampling_temp_boundary - self.temperature_limit[0]) / 3, + ), + ) + else: + # Right side + shift = -np.abs( + random.gauss( + 0, + np.abs(self.temperature_limit[1] - sampling_temp_boundary) / 3, + ), + ) + + temperature = sampling_temp_boundary - shift + else: + raise ValueError(f"Unknown sampling method: {self.sampling_method}") + + return {"temperature": int(np.clip(temperature, self.temperature_limit[0], self.temperature_limit[1]))} + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return "mode", "temperature_limit", "sampling_method" diff --git a/albumentations/core/types.py b/albumentations/core/types.py index 105126f77..ccb76b566 100644 --- a/albumentations/core/types.py +++ b/albumentations/core/types.py @@ -35,6 +35,8 @@ MorphologyMode = Literal["erosion", "dilation"] +PlanckianJitterMode = Literal["blackbody", "cied"] + d4_group_elements = ["e", "r90", "r180", "r270", "v", "hvt", "h", "t"] D4Type = Literal["e", "r90", "r180", "r270", "v", "hvt", "h", "t"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 211862a54..2abca87a4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,7 +5,7 @@ pytest>=8.2.0 pytest_cov>=4.1.0 pytest_mock>=3.14.0 requests>=2.31.0 -ruff>=0.4.8 +ruff>=0.4.9 tomli>=2.0.1 types-pkg-resources types-PyYAML diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index e537c70d5..aabc34173 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -371,6 +371,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params): A.RandomCropFromBorders, A.Spatter, A.ChromaticAberration, + A.PlanckianJitter }, ), ) @@ -555,6 +556,7 @@ def test_mask_fill_value(augmentation_cls, params): A.PixelDistributionAdaptation, A.Spatter, A.ChromaticAberration, + A.PlanckianJitter }, ), ) @@ -634,6 +636,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params): A.PixelDistributionAdaptation, A.Spatter, A.ChromaticAberration, + A.PlanckianJitter }, ), ) @@ -704,6 +707,7 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params): A.PixelDistributionAdaptation, A.Spatter, A.ChromaticAberration, + A.PlanckianJitter }, ), ) @@ -779,6 +783,7 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params A.PixelDistributionAdaptation, A.Spatter, A.ChromaticAberration, + A.PlanckianJitter }, ), ) diff --git a/tests/test_core.py b/tests/test_core.py index 89d2ef4da..1f7826b99 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -658,15 +658,15 @@ def test_compose_additional_targets_in_available_keys() -> None: image = np.ones((8, 8)) # non-empty `transforms` - augmentation = Compose([first, second], p=1, + augmentation = Compose([first, second], p=1, additional_targets={"additional_target_1": "image", "additional_target_2": "image"}) augmentation(image=image, additional_target_1=image, additional_target_2=image) # will raise exception if not # empty `transforms` - augmentation = Compose([], p=1, + augmentation = Compose([], p=1, additional_targets={"additional_target_1": "image", "additional_target_2": "image"}) augmentation(image=image, additional_target_1=image, additional_target_2=image) # will raise exception if not - + def test_transform_always_apply_warning() -> None: """Check that warning is raised if always_apply argument is used""" diff --git a/tests/test_functional.py b/tests/test_functional.py index 3596690b0..b0470b86c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1061,3 +1061,41 @@ def test_transpose(shape): assert np.array_equal(FGeometric.transpose(img), expected_main) transposed_axis1 = FGeometric.transpose(FGeometric.rot90(img, 2)) assert np.array_equal(transposed_axis1, expected_second) + + +def test_planckian_jitter_blackbody(): + img = np.array([[ + [0.4963, 0.6977, 0.1759], [0.7682, 0.8 , 0.2698], [0.0885, 0.161 , 0.1507], [0.132 , 0.2823, 0.0317]], + [[0.3074, 0.6816, 0.2081], [0.6341, 0.9152, 0.9298], [0.4901, 0.3971, 0.7231],[0.8964, 0.8742, 0.7423]], + [[0.4556, 0.4194, 0.5263], [0.6323, 0.5529, 0.2437], [0.3489, 0.9527, 0.5846], [0.4017, 0.0362, 0.0332]], + [[0.0223, 0.1852, 0.1387], [0.1689, 0.3734, 0.2422], [0.2939, 0.3051, 0.8155], [0.5185, 0.932 , 0.7932]]] + ) + + expected_blackbody_plankian_jitter = np.array([ + [[0.735 , 0.6977, 0.0691], [1. , 0.8 , 0.1059], [0.1311, 0.161 , 0.0592], [0.1955, 0.2823, 0.0124]], + [[0.4553, 0.6816, 0.0817], [0.9391, 0.9152, 0.365 ], [0.7258, 0.3971, 0.2839], [1. , 0.8742, 0.2914]], + [[0.6748, 0.4194, 0.2066], [0.9364, 0.5529, 0.0957], [0.5167, 0.9527, 0.2295], [0.5949, 0.0362, 0.013 ]], + [[0.033 , 0.1852, 0.0545], [0.2501, 0.3734, 0.0951], [0.4353, 0.3051, 0.3202], [0.7679, 0.932 , 0.3114]]] + ) + + blackbody_plankian_jitter = F.planckian_jitter(img, temperature=3500, mode="blackbody") + assert np.allclose(blackbody_plankian_jitter, expected_blackbody_plankian_jitter, atol=1e-4) + + +def test_planckian_jitter_cied(): + img = np.array([ + [[0.4963, 0.6977, 0.1759], [0.7682, 0.8 , 0.2698], [0.0885, 0.161 , 0.1507], [0.132 , 0.2823, 0.0317]], + [[0.3074, 0.6816, 0.2081], [0.6341, 0.9152, 0.9298], [0.4901, 0.3971, 0.7231], [0.8964, 0.8742, 0.7423]], + [[0.4556, 0.4194, 0.5263], [0.6323, 0.5529, 0.2437], [0.3489, 0.9527, 0.5846], [0.4017, 0.0362, 0.0332]], + [[0.0223, 0.1852, 0.1387], [0.1689, 0.3734, 0.2422], [0.2939, 0.3051, 0.8155], [0.5185, 0.932 , 0.7932]]] + ) + + expected_cied_plankian_jitter = np.array([ + [[0.6058, 0.6977, 0.1149], [0.9377, 0.8000, 0.1762], [0.1080, 0.1610, 0.0984], [0.1611, 0.2823, 0.0207]], + [[0.3752, 0.6816, 0.1359], [0.7740, 0.9152, 0.6072], [0.5982, 0.3971, 0.4722], [1.0000, 0.8742, 0.4848]], + [[0.5561, 0.4194, 0.3437], [0.7718, 0.5529, 0.1592], [0.4259, 0.9527, 0.3818], [0.4903, 0.0362, 0.0217]], + [[0.0272, 0.1852, 0.0906], [0.2062, 0.3734, 0.1582], [0.3587, 0.3051, 0.5326], [0.6329, 0.9320, 0.5180]]] + ) + + cied_plankian_jitter = F.planckian_jitter(img, temperature=4500, mode="cied") + assert np.allclose(cied_plankian_jitter, expected_cied_plankian_jitter, atol=1e-4) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ace546c26..f7acb812e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -458,7 +458,8 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image): ) ], [A.Morphological, {}], - [A.D4, {}] + [A.D4, {}], + [A.PlanckianJitter, {}] ] AUGMENTATION_CLS_EXCEPT = { diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1fefef847..d4b959c35 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1362,6 +1362,7 @@ def test_coarse_dropout_invalid_input(params): "n_segments": (10, 10), "max_size": 10 }, + A.ZoomBlur: {"max_factor": (1.05, 3)}, }, except_augmentations={ A.RandomCropNearBBox, @@ -1438,7 +1439,8 @@ def test_change_image(augmentation_cls, params): A.ChannelShuffle, A.ChromaticAberration, A.RandomRotate90, - A.FancyPCA + A.FancyPCA, + A.PlanckianJitter }, ), )