From c7e98b388fc3d3e82f2350c627623c60139d5401 Mon Sep 17 00:00:00 2001 From: Felix Hirwa Nshuti Date: Wed, 21 Feb 2024 12:12:15 +0530 Subject: [PATCH] fix: fixing tests for `Shape` `__mul__` and `__rmul__` methods (#28337) --- ivy/__init__.py | 13 +++++- ivy_tests/test_ivy/test_misc/test_shape.py | 46 +++++++++------------- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/ivy/__init__.py b/ivy/__init__.py index 00e5039fd4099..d40ab10bfb12d 100644 --- a/ivy/__init__.py +++ b/ivy/__init__.py @@ -287,11 +287,20 @@ def __radd__(self, other): return self def __mul__(self, other): - self._shape = self._shape * other + if ivy.current_backend_str() == "tensorflow": + shape_tup = builtins.tuple(self._shape) * other + self._shape = ivy.to_native_shape(shape_tup) + else: + self._shape = self._shape * other return self def __rmul__(self, other): - self._shape = other * self._shape + # handle tensorflow case as tf.TensorShape doesn't support multiplications + if ivy.current_backend_str() == "tensorflow": + shape_tup = other * builtins.tuple(self._shape) + self._shape = ivy.to_native_shape(shape_tup) + else: + self._shape = other * self._shape return self def __bool__(self): diff --git a/ivy_tests/test_ivy/test_misc/test_shape.py b/ivy_tests/test_ivy/test_misc/test_shape.py index c42f74180293f..72f60d5209cc9 100644 --- a/ivy_tests/test_ivy/test_misc/test_shape.py +++ b/ivy_tests/test_ivy/test_misc/test_shape.py @@ -424,18 +424,14 @@ def test_shape__mod__( @handle_method( + init_tree=CLASS_TREE, method_tree="Shape.__mul__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, - ), + shape=helpers.get_shape(), + other=st.integers(min_value=1, max_value=10), ) def test_shape__mul__( - dtype_and_x, + shape, + other, method_name, class_name, backend_fw, @@ -444,17 +440,16 @@ def test_shape__mul__( method_flags, on_device, ): - dtype, x = dtype_and_x helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"shape_tup": shape}, + init_input_dtypes=DUMMY_DTYPE, + method_input_dtypes=DUMMY_DTYPE, + method_all_as_kwargs_np={"other": other}, class_name=class_name, method_name=method_name, ) @@ -569,18 +564,14 @@ def test_shape__rmod__( @handle_method( + init_tree=CLASS_TREE, method_tree="Shape.__rmul__", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("numeric"), - num_arrays=2, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - shared_dtype=True, - ), + shape=helpers.get_shape(), + other=st.integers(min_value=1, max_value=10), ) def test_shape__rmul__( - dtype_and_x, + shape, + other, method_name, class_name, ground_truth_backend, @@ -589,17 +580,16 @@ def test_shape__rmul__( method_flags, on_device, ): - dtype, x = dtype_and_x helpers.test_method( on_device=on_device, ground_truth_backend=ground_truth_backend, backend_to_test=backend_fw, init_flags=init_flags, method_flags=method_flags, - init_all_as_kwargs_np={"data": x[0]}, - init_input_dtypes=dtype, - method_input_dtypes=dtype, - method_all_as_kwargs_np={"other": x[1]}, + init_all_as_kwargs_np={"shape_tup": shape}, + init_input_dtypes=DUMMY_DTYPE, + method_input_dtypes=DUMMY_DTYPE, + method_all_as_kwargs_np={"other": other}, class_name=class_name, method_name=method_name, )