Skip to content

Commit

Permalink
fix: fixing tests for Shape __mul__ and __rmul__ methods (ivy-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
fnhirwa authored Feb 21, 2024
1 parent af48650 commit c7e98b3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 30 deletions.
13 changes: 11 additions & 2 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,20 @@ def __radd__(self, other):
return self

def __mul__(self, other):
self._shape = self._shape * other
if ivy.current_backend_str() == "tensorflow":
shape_tup = builtins.tuple(self._shape) * other
self._shape = ivy.to_native_shape(shape_tup)
else:
self._shape = self._shape * other
return self

def __rmul__(self, other):
self._shape = other * self._shape
# handle tensorflow case as tf.TensorShape doesn't support multiplications
if ivy.current_backend_str() == "tensorflow":
shape_tup = other * builtins.tuple(self._shape)
self._shape = ivy.to_native_shape(shape_tup)
else:
self._shape = other * self._shape
return self

def __bool__(self):
Expand Down
46 changes: 18 additions & 28 deletions ivy_tests/test_ivy/test_misc/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,18 +424,14 @@ def test_shape__mod__(


@handle_method(
init_tree=CLASS_TREE,
method_tree="Shape.__mul__",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
large_abs_safety_factor=2.5,
small_abs_safety_factor=2.5,
safety_factor_scale="log",
shared_dtype=True,
),
shape=helpers.get_shape(),
other=st.integers(min_value=1, max_value=10),
)
def test_shape__mul__(
dtype_and_x,
shape,
other,
method_name,
class_name,
backend_fw,
Expand All @@ -444,17 +440,16 @@ def test_shape__mul__(
method_flags,
on_device,
):
dtype, x = dtype_and_x
helpers.test_method(
on_device=on_device,
ground_truth_backend=ground_truth_backend,
backend_to_test=backend_fw,
init_flags=init_flags,
method_flags=method_flags,
init_all_as_kwargs_np={"data": x[0]},
init_input_dtypes=dtype,
method_input_dtypes=dtype,
method_all_as_kwargs_np={"other": x[1]},
init_all_as_kwargs_np={"shape_tup": shape},
init_input_dtypes=DUMMY_DTYPE,
method_input_dtypes=DUMMY_DTYPE,
method_all_as_kwargs_np={"other": other},
class_name=class_name,
method_name=method_name,
)
Expand Down Expand Up @@ -569,18 +564,14 @@ def test_shape__rmod__(


@handle_method(
init_tree=CLASS_TREE,
method_tree="Shape.__rmul__",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("numeric"),
num_arrays=2,
large_abs_safety_factor=2.5,
small_abs_safety_factor=2.5,
safety_factor_scale="log",
shared_dtype=True,
),
shape=helpers.get_shape(),
other=st.integers(min_value=1, max_value=10),
)
def test_shape__rmul__(
dtype_and_x,
shape,
other,
method_name,
class_name,
ground_truth_backend,
Expand All @@ -589,17 +580,16 @@ def test_shape__rmul__(
method_flags,
on_device,
):
dtype, x = dtype_and_x
helpers.test_method(
on_device=on_device,
ground_truth_backend=ground_truth_backend,
backend_to_test=backend_fw,
init_flags=init_flags,
method_flags=method_flags,
init_all_as_kwargs_np={"data": x[0]},
init_input_dtypes=dtype,
method_input_dtypes=dtype,
method_all_as_kwargs_np={"other": x[1]},
init_all_as_kwargs_np={"shape_tup": shape},
init_input_dtypes=DUMMY_DTYPE,
method_input_dtypes=DUMMY_DTYPE,
method_all_as_kwargs_np={"other": other},
class_name=class_name,
method_name=method_name,
)
Expand Down

0 comments on commit c7e98b3

Please sign in to comment.