Skip to content

Commit

Permalink
fix: removed code modifying ndarray/torch tensor methods (caused an i…
Browse files Browse the repository at this point in the history
…ssue with torch.compile recognising dunder methods). but added the same functionality using `__torch_function__`/`__array_ufunc__` which are the proper mechanisms to enable this behaviour. also simplified the test (ivy-llc#28394)
  • Loading branch information
mattbarrett98 authored and Kacper-W-Kozdon committed Feb 27, 2024
1 parent e96c7e9 commit 9dc2371
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 127 deletions.
4 changes: 2 additions & 2 deletions ivy/data_classes/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def data(self, data):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs={}):
args, kwargs = args_to_native(*args, **kwargs)
return func(*args, **kwargs)
return to_ivy(func(*args, **kwargs))

def __ivy_array_function__(self, func, types, args, kwargs):
# Cannot handle items that have __ivy_array_function__ other than those of
Expand Down Expand Up @@ -382,7 +382,7 @@ def __array_prepare__(self, *args, **kwargs):

def __array_ufunc__(self, *args, **kwargs):
args, kwargs = args_to_native(*args, **kwargs)
return self._data.__array_ufunc__(*args, **kwargs)
return to_ivy(self._data.__array_ufunc__(*args, **kwargs))

def __array_wrap__(self, *args, **kwargs):
args, kwargs = args_to_native(*args, **kwargs)
Expand Down
30 changes: 0 additions & 30 deletions ivy/functional/backends/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,6 @@

use = ivy.utils.backend.ContextManager(_module_in_memory)

# wrap __array_ufunc__ method of ivy.Array to prioritize Ivy array methods when using numpu backend


def wrap__array_ufunc__(func):
def rep_method(self, ufunc, method, *inputs, **kwargs):
methods = {
"not_equal": "not_equal",
"greater": "greater",
"less": "less",
"greater_equal": "greater_equal",
"less_equal": "less_equal",
"multiply": "multiply",
"divide": "divide",
"remainder": "remainder",
"equal": "equal",
"bitwise_and": "bitwise_and",
"matmul": "matmul",
"power": "pow",
"subtract": "subtract",
"add": "add",
}
if ufunc.__name__ in methods:
return eval("ivy." + methods[ufunc.__name__] + "(*inputs, **kwargs)")
return func(self, ufunc, method, *inputs, **kwargs)

return rep_method


ivy.Array.__array_ufunc__ = wrap__array_ufunc__(ivy.Array.__array_ufunc__)

NativeArray = np.ndarray
NativeDevice = str
NativeDtype = np.dtype
Expand Down
75 changes: 0 additions & 75 deletions ivy/functional/backends/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,81 +20,6 @@

use = ivy.utils.backend.ContextManager(_module_in_memory)


# wrap dunder methods of native tensors to return NotImplemented to prioritize Ivy array methods.
def dunder_wrapper(func):
def rep_method(*args, **kwargs):
for arg in args:
if ivy.is_ivy_array(arg):
return NotImplemented
return func(*args, **kwargs)

return rep_method


# check for previously imported torch module
modules_to_patch = []
tensors_to_patch = []
tmp_globals = dict(globals())
for name, value in tmp_globals.items():
if value == "torch.Tensor":
tensors_to_patch.append(name)
try:
if value.__name__ == "torch":
modules_to_patch.append(name)
except AttributeError:
pass

methods_to_patch = [
"__add__",
"__and__",
"__div__",
"__eq__",
"__floordiv__",
"__ge__",
"__gt__",
"__iadd__",
"__iand__",
"__idiv__",
"__ifloordiv__",
"__ilshift__",
"__imul__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lshift__",
"__lt__",
"__matmul__",
"__mul__",
"__or__",
"__pow__",
"__truediv__",
"__xor__",
"__ne__",
"__mod__",
]

for module in modules_to_patch:
for method in methods_to_patch:
exec(
module
+ ".Tensor."
+ method
+ " = dunder_wrapper("
+ module
+ ".Tensor."
+ method
+ ")"
)

for tensor in tensors_to_patch:
for method in methods_to_patch:
exec(tensor + "." + method + " = dunder_wrapper(" + tensor + "." + method + ")")

NativeArray = torch.Tensor
NativeDevice = torch.device
NativeDtype = torch.dtype
Expand Down
24 changes: 4 additions & 20 deletions ivy_tests/test_ivy/test_misc/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2543,30 +2543,14 @@ def test_array_property_strides(dtype_x, backend_fw):

@handle_test(
fn_tree="functional.ivy.native_array", # dummy fn_tree
dtype_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("integer"),
min_dim_size=3,
max_dim_size=3,
min_num_dims=3,
max_num_dims=3,
num_arrays=2,
min_value=3.0,
max_value=10.0,
),
op=st.sampled_from(
["!=", ">", "<", ">=", "<=", "*", "/", "%", "==", "&", "@", "**", "/"]
["!=", ">", "<", ">=", "<=", "*", "/", "%", "==", "&", "@", "**", "/"],
),
)
def test_dunder_wrapping(
dtype_x,
backend_fw,
test_flags,
op,
):
_, data = dtype_x
def test_dunder_wrapping(backend_fw, op):
ivy.set_backend(backend_fw)
x = ivy.to_native(ivy.array(data[0]))
y = ivy.array(data[1])
x = ivy.native_array([1])
y = ivy.array([1])
assert ivy.is_ivy_array(y)
assert ivy.is_native_array(x)
res = eval(f"x {op} y")
Expand Down

0 comments on commit 9dc2371

Please sign in to comment.