From 4d15c73b55dc3ea34ac5be3b18c4e9ff55057afb Mon Sep 17 00:00:00 2001 From: Sameerk22 Date: Wed, 6 Sep 2023 14:58:31 +0500 Subject: [PATCH] Final update --- .../frontends/paddle/nn/functional/vision.py | 39 ++++++++++++++++++ .../test_nn/test_functional/test_vision.py | 41 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/ivy/functional/frontends/paddle/nn/functional/vision.py b/ivy/functional/frontends/paddle/nn/functional/vision.py index 653ed76189da5..92e5ff6e56b0a 100644 --- a/ivy/functional/frontends/paddle/nn/functional/vision.py +++ b/ivy/functional/frontends/paddle/nn/functional/vision.py @@ -8,6 +8,45 @@ from ivy.utils.assertions import check_equal +@to_ivy_arrays_and_back +def pixel_unshuffle(x, downscale_factor, data_format="NCHW"): + input_shape = ivy.shape(x) + + if data_format == "NCHW": + b, c, h, w = input_shape + else: + b, h, w, c = input_shape + + check_equal( + c % (downscale_factor ** 2), + 0, + message=( + "pixel unshuffle expects input channel to be divisible by square of downscale" + " factor, but got input with size {}, downscale factor={}, and" + " self.size(1)={}, is not divisible by {}".format( + input_shape, downscale_factor, c, downscale_factor ** 2 + ) + ), + as_array=False, + ) + + oc = c // (downscale_factor ** 2) + oh = h // downscale_factor + ow = w // downscale_factor + + if data_format == "NCHW": + x_reshaped = ivy.reshape(x, (b, oc, downscale_factor, downscale_factor, oh, ow)) + else: + x_reshaped = ivy.reshape(x, (b, oh, ow, downscale_factor, downscale_factor, oc)) + + if data_format == "NCHW": + return ivy.reshape( + ivy.permute_dims(x_reshaped, (0, 1, 4, 2, 5, 3)), (b, oc, oh, ow) + ) + return ivy.reshape( + ivy.permute_dims(x_reshaped, (0, 4, 1, 5, 2, 3)), (b, oh, ow, oc) + ) + @to_ivy_arrays_and_back @with_unsupported_dtypes({"2.5.1 and below": ("float16", "bfloat16")}, "paddle") def affine_grid(theta, out_shape, align_corners=True): diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py index 0659e33a95bf2..889a8aa68e3be 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_vision.py @@ -10,6 +10,47 @@ # --- Helpers --- # # --------------- # +#pixel_unshuffle +@handle_frontend_test( + fn_tree="paddle.nn.functional.pixel_unshuffle", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=["float32", "float64"], + min_value=0, + min_num_dims=4, + max_num_dims=4, + min_dim_size=3, + ), + downscale_factor=helpers.ints(min_value=1), + data_format=st.sampled_from(["NCHW", "NHWC"]), +) + +def test_paddle_pixel_unshuffle( + *, + dtype_and_x, + downscale_factor, + data_format, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + if data_format == "NCHW": + assume(ivy.shape(x[0])[1] % (downscale_factor**2) == 0) + else: + assume(ivy.shape(x[0])[3] % (downscale_factor**2) == 0) + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + downscale_factor=downscale_factor, + data_format=data_format, + backend_to_test=backend_fw, + ) @st.composite def _affine_grid_helper(draw):