diff --git a/ivy/functional/frontends/paddle/vision/transforms.py b/ivy/functional/frontends/paddle/vision/transforms.py index 68cb4773981ee..b2a6147ba9f2f 100644 --- a/ivy/functional/frontends/paddle/vision/transforms.py +++ b/ivy/functional/frontends/paddle/vision/transforms.py @@ -18,6 +18,108 @@ def to_tensor(pic, data_format="CHW"): return Tensor(array) +# helpers +def _get_image_c_axis(data_format): + if data_format.lower() == "chw": + return -3 + elif data_format.lower() == "hwc": + return -1 + + +def _get_image_num_channels(img, data_format): + return ivy.shape(img)[_get_image_c_axis(data_format)] + + +def _rgb_to_hsv(img): + maxc = ivy.max(img, axis=-3) + minc = ivy.min(img, axis=-3) + + is_equal = ivy.equal(maxc, minc) + one_divisor = ivy.ones_like(maxc) + c_delta = maxc - minc + s = c_delta / ivy.where(is_equal, one_divisor, maxc) + + r, g, b = img[0], img[1], img[2] + c_delta_divisor = ivy.where(is_equal, one_divisor, c_delta) + + rc = (maxc - r) / c_delta_divisor + gc = (maxc - g) / c_delta_divisor + bc = (maxc - b) / c_delta_divisor + + hr = ivy.where((maxc == r), bc - gc, ivy.zeros_like(maxc)) + hg = ivy.where( + ((maxc == g) & (maxc != r)), + rc - bc + 2.0, + ivy.zeros_like(maxc), + ) + hb = ivy.where( + ((maxc != r) & (maxc != g)), + gc - rc + 4.0, + ivy.zeros_like(maxc), + ) + + h = (hr + hg + hb) / 6.0 + 1.0 + h = h - ivy.trunc(h) + + return ivy.stack([h, s, maxc], axis=-3) + + +def _hsv_to_rgb(img): + h, s, v = img[0], img[1], img[2] + f = h * 6.0 + i = ivy.floor(f) + f = f - i + i = ivy.astype(i, ivy.int32) % 6 + + p = ivy.clip(v * (1.0 - s), 0.0, 1.0) + q = ivy.clip(v * (1.0 - s * f), 0.0, 1.0) + t = ivy.clip(v * (1.0 - s * (1.0 - f)), 0.0, 1.0) + + mask = ivy.astype( + ivy.equal( + ivy.expand_dims(i, axis=-3), + ivy.reshape(ivy.arange(6, dtype=ivy.dtype(i)), (-1, 1, 1)), + ), + ivy.dtype(img), + ) + matrix = ivy.stack( + [ + ivy.stack([v, q, p, p, t, v], axis=-3), + ivy.stack([t, v, v, q, p, p], axis=-3), + ivy.stack([p, p, t, v, v, q], axis=-3), + ], + axis=-4, + ) + return ivy.einsum("...ijk, ...xijk -> ...xjk", mask, matrix) + + +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64", "uint8")}, "paddle") +@to_ivy_arrays_and_back +def adjust_hue(img, hue_factor): + assert -0.5 <= hue_factor <= 0.5, "hue_factor should be in range [-0.5, 0.5]" + + channels = _get_image_num_channels(img, "CHW") + + if channels == 1: + return img + elif channels == 3: + if ivy.dtype(img) == "uint8": + img = ivy.astype(img, "float32") / 255.0 + + img_hsv = _rgb_to_hsv(img) + h, s, v = img_hsv[0], img_hsv[1], img_hsv[2] + + h = h + hue_factor + h = h - ivy.floor(h) + + img_adjusted = _hsv_to_rgb(ivy.stack([h, s, v], axis=-3)) + + else: + raise ValueError("channels of input should be either 1 or 3.") + + return img_adjusted + + @with_unsupported_device_and_dtypes( { "2.5.1 and below": { diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py index c22194ace2884..418e0918d5856 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py @@ -31,6 +31,45 @@ def test_paddle_to_tensor( ) +# adjust_hue +@handle_frontend_test( + fn_tree="paddle.vision.transforms.adjust_hue", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=0, + max_value=255, + min_num_dims=3, + max_num_dims=3, + min_dim_size=3, + max_dim_size=3, + ), + hue_factor=helpers.floats(min_value=-0.5, max_value=0.5), +) +def test_paddle_adjust_hue( + *, + dtype_and_x, + hue_factor, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + rtol=1e-3, + atol=1e-3, + on_device=on_device, + img=x[0], + hue_factor=hue_factor, + ) + + @handle_frontend_test( fn_tree="paddle.vision.transforms.vflip", dtype_and_x=helpers.dtype_and_values(