Skip to content

Commit

Permalink
feat decorator_utils: adding a couple more decorators to detect convo…
Browse files Browse the repository at this point in the history
…lution blocks at the functional level instead of just the stateful level.
  • Loading branch information
YushaArif99 committed Sep 12, 2024
1 parent fe60be9 commit eda6453
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion ivy/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ def apply_transpose(input, transpose, pt_to_tf=True):
# -> TF: (kernel[0], kernel[1], kernel[2], num_in_channel, num_out_channel)
axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3)

input = ivy.permute_dims(input, axes=axes).data
if ivy.is_array(input):
input = ivy.permute_dims(input, axes=axes).data
else:
input = tuple(input[i] for i in axes)
return input


Expand Down Expand Up @@ -419,6 +422,42 @@ def transpose_wrapper(self, *args, **kwargs):
return transpose_wrapper


def handle_transpose_in_input_and_output_for_functions(fn):
@functools.wraps(fn)
def transpose_wrapper(*args, **kwargs):
DATA_FORMAT = os.environ.get("DATA_FORMAT", "channels_first")
if DATA_FORMAT == "channels_first":
value_map = {"channel_last": "channel_first", "NHWC": "NCHW", "NSC": "NCS"}
if "data_format" in kwargs and kwargs["data_format"] in value_map:
kwargs["data_format"] = value_map[kwargs["data_format"]]
if "filter_format" in kwargs and kwargs["filter_format"] in value_map:
kwargs["filter_format"] = value_map[kwargs["filter_format"]]

return fn(*args, **kwargs)

return transpose_wrapper


def handle_transpose_in_pad(fn):
@functools.wraps(fn)
def transpose_wrapper(input, pad_width, *args, **kwargs):
DATA_FORMAT = os.environ.get("DATA_FORMAT", "channels_first")
if DATA_FORMAT == "channels_last":
if len(input.shape) > 4:
transpose = TransposeType.CONV3D
elif len(input.shape) > 3:
transpose = TransposeType.CONV2D
elif len(input.shape) > 2:
transpose = TransposeType.CONV1D
else:
transpose = TransposeType.NO_TRANSPOSE
pad_width = apply_transpose(pad_width, transpose=transpose, pt_to_tf=True)

return fn(input, pad_width, *args, **kwargs)

return transpose_wrapper


# TODO: temp fix for `ivy.inplace_update`. Dont quite understand the way this function
# has been implemented in the backends as it seems to also have ivy.Array specific logic
# In the case where both x, and val are arrays, it simply returns x (why??)
Expand Down

0 comments on commit eda6453

Please sign in to comment.