From eda64539142fd0e4b50d69d99c939aca8c46cc62 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Thu, 12 Sep 2024 16:25:11 +0000 Subject: [PATCH] feat decorator_utils: adding a couple more decorators to detect convolution blocks at the functional level instead of just the stateful level. --- ivy/utils/decorator_utils.py | 41 +++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/ivy/utils/decorator_utils.py b/ivy/utils/decorator_utils.py index 7f2785ede136..b36bee9244d9 100644 --- a/ivy/utils/decorator_utils.py +++ b/ivy/utils/decorator_utils.py @@ -332,7 +332,10 @@ def apply_transpose(input, transpose, pt_to_tf=True): # -> TF: (kernel[0], kernel[1], kernel[2], num_in_channel, num_out_channel) axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3) - input = ivy.permute_dims(input, axes=axes).data + if ivy.is_array(input): + input = ivy.permute_dims(input, axes=axes).data + else: + input = tuple(input[i] for i in axes) return input @@ -419,6 +422,42 @@ def transpose_wrapper(self, *args, **kwargs): return transpose_wrapper +def handle_transpose_in_input_and_output_for_functions(fn): + @functools.wraps(fn) + def transpose_wrapper(*args, **kwargs): + DATA_FORMAT = os.environ.get("DATA_FORMAT", "channels_first") + if DATA_FORMAT == "channels_first": + value_map = {"channel_last": "channel_first", "NHWC": "NCHW", "NSC": "NCS"} + if "data_format" in kwargs and kwargs["data_format"] in value_map: + kwargs["data_format"] = value_map[kwargs["data_format"]] + if "filter_format" in kwargs and kwargs["filter_format"] in value_map: + kwargs["filter_format"] = value_map[kwargs["filter_format"]] + + return fn(*args, **kwargs) + + return transpose_wrapper + + +def handle_transpose_in_pad(fn): + @functools.wraps(fn) + def transpose_wrapper(input, pad_width, *args, **kwargs): + DATA_FORMAT = os.environ.get("DATA_FORMAT", "channels_first") + if DATA_FORMAT == "channels_last": + if len(input.shape) > 4: + transpose = TransposeType.CONV3D + elif len(input.shape) > 3: + transpose = TransposeType.CONV2D + elif len(input.shape) > 2: + transpose = TransposeType.CONV1D + else: + transpose = TransposeType.NO_TRANSPOSE + pad_width = apply_transpose(pad_width, transpose=transpose, pt_to_tf=True) + + return fn(input, pad_width, *args, **kwargs) + + return transpose_wrapper + + # TODO: temp fix for `ivy.inplace_update`. Dont quite understand the way this function # has been implemented in the backends as it seems to also have ivy.Array specific logic # In the case where both x, and val are arrays, it simply returns x (why??)