From 03dab2ba70e2750ad5c2ef30ca5ace9e054f5e27 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Thu, 5 Sep 2024 02:01:44 +0500 Subject: [PATCH] fix: Fixed the frontend torch.Tensor class such that all inplace methods now follow a consistent policy for inplace updating the underlying ivy array attr --- ivy/functional/frontends/torch/tensor.py | 333 ++++++++++++++--------- 1 file changed, 207 insertions(+), 126 deletions(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 1af5260ec5d4..942a427eece0 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -185,7 +185,8 @@ def all(self, dim=None, keepdim=False): @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def add_(self, other, *, alpha=1): - self.ivy_array = self.add(other, alpha=alpha).ivy_array + ret = self.add(other, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -194,7 +195,8 @@ def addmm(self, mat1, mat2, *, beta=1, alpha=1): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def addmm_(self, mat1, mat2, *, beta=1, alpha=1): - self.ivy_array = self.addmm(mat1, mat2, beta=beta, alpha=alpha).ivy_array + ret = self.addmm(mat1, mat2, beta=beta, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -203,9 +205,8 @@ def addmv(self, mat, vec, *, beta=1, alpha=1): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def addmv_(self, mat, vec, *, beta=1, alpha=1): - self.ivy_array = torch_frontend.addmv( - self, mat, vec, beta=beta, alpha=alpha - ).ivy_array + ret = torch_frontend.addmv(self, mat, vec, beta=beta, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -214,12 +215,14 @@ def addbmm(self, batch1, batch2, *, beta=1, alpha=1): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def addbmm_(self, batch1, batch2, *, beta=1, alpha=1): - self.ivy_array = self.addbmm(batch1, batch2, beta=beta, alpha=alpha).ivy_array + ret = self.addbmm(batch1, batch2, beta=beta, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def subtract_(self, other, *, alpha=1): - self.ivy_array = self.sub(other, alpha=alpha).ivy_array + ret = self.sub(other, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -228,7 +231,8 @@ def asin(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def asin_(self): - self.ivy_array = self.asin().ivy_array + ret = self.asin() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def float_power(self, exponent): @@ -245,7 +249,8 @@ def sin(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def sin_(self): - self.ivy_array = self.sin().ivy_array + ret = self.sin() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -254,7 +259,8 @@ def sinh(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def sinh_(self): - self.ivy_array = self.sinh().ivy_array + ret = self.sinh() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -263,7 +269,8 @@ def cos(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def cos_(self): - self.ivy_array = self.cos().ivy_array + ret = self.cos() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -272,7 +279,8 @@ def cosh(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def cosh_(self): - self.ivy_array = self.cosh().ivy_array + ret = self.cosh() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -281,7 +289,8 @@ def atan(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def atan_(self): - self.ivy_array = self.atan().ivy_array + ret = self.atan() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") @@ -336,7 +345,8 @@ def asinh(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def asinh_(self): - self.ivy_array = self.asinh().ivy_array + ret = self.asinh() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -345,7 +355,8 @@ def tan(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def tan_(self): - self.ivy_array = self.tan().ivy_array + ret = self.tan() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -354,7 +365,8 @@ def tanh(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def tanh_(self): - self.ivy_array = self.tanh().ivy_array + ret = self.tanh() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -363,7 +375,8 @@ def atanh(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def atanh_(self): - self.ivy_array = self.atanh().ivy_array + ret = self.atanh() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -372,7 +385,8 @@ def log(self): @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") def log2_(self): - self.ivy_array = self.log2().ivy_array + ret = self.log2() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") @@ -381,12 +395,14 @@ def logit(self): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "uint16")}, "torch") def copy_(self, other, non_blocking=False): - self._ivy_array = torch_frontend.tensor(other).ivy_array + ret = torch_frontend.tensor(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def log_(self): - self.ivy_array = self.log().ivy_array + ret = self.log() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -416,7 +432,8 @@ def abs(self): return torch_frontend.abs(self) def abs_(self): - self.ivy_array = self.abs().ivy_array + ret = self.abs() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") @@ -427,7 +444,10 @@ def logical_not(self, *, out=None): return torch_frontend.logical_not(self, out=out) def logical_not_(self): - self.ivy_array = ivy.astype(self.logical_not().ivy_array, self.dtype) + ret = self.logical_not() + self.ivy_array = ivy.inplace_update( + self.ivy_array, ivy.astype(ret.ivy_array, self.dtype) + ) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") @@ -453,7 +473,8 @@ def bitwise_left_shift(self, other): @with_supported_dtypes({"2.2 and below": ("integer",)}, "torch") def bitwise_or_(self, other): - self.ivy_array = self.bitwise_or(other).ivy_array + ret = self.bitwise_or(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def contiguous(self, memory_format=None): @@ -519,7 +540,8 @@ def not_equal(self, other, *, out=None): "torch", ) def not_equal_(self, other, *, out=None): - self.ivy_array = self.not_equal(other).ivy_array + ret = self.not_equal(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def eq(self, other): @@ -536,7 +558,8 @@ def erf(self, *, out=None): {"2.2 and below": ("float32", "float64", "bfloat16")}, "torch" ) def erf_(self, *, out=None): - self.ivy_array = self.erf(out=out).ivy_array + ret = self.erf(out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_supported_device_and_dtypes( @@ -551,7 +574,8 @@ def erfc(self, *, out=None): "torch", ) def erfc_(self, *, out=None): - self.ivy_array = self.erfc(out=out).ivy_array + ret = self.erfc(out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def new_zeros( @@ -645,7 +669,8 @@ def acos(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def acos_(self): - self.ivy_array = self.acos().ivy_array + ret = self.acos() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def new_tensor( @@ -692,7 +717,8 @@ def detach(self): ) def detach_(self): - self.ivy_array = self.detach().ivy_array + ret = self.detach() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("uint16",)}, "torch") @@ -702,7 +728,8 @@ def unsqueeze(self, dim): @numpy_to_torch_style_args def unsqueeze_(self, dim): - self.ivy_array = self.unsqueeze(dim).ivy_array + ret = self.unsqueeze(dim) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def ravel(self): @@ -836,7 +863,8 @@ def unflatten(self, dim, sizes): @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def pow_(self, exponent): - self.ivy_array = self.pow(exponent).ivy_array + ret = self.pow(exponent) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def size(self, dim=None): @@ -923,14 +951,16 @@ def transpose(self, dim0, dim1): return torch_frontend.transpose(self, dim0=dim0, dim1=dim1) def transpose_(self, dim0, dim1): - self.ivy_array = self.transpose(dim0, dim1).ivy_array + ret = self.transpose(dim0, dim1) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def t(self): return torch_frontend.t(self) def t_(self): - self.ivy_array = self.t().ivy_array + ret = self.t() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def flatten(self, start_dim=0, end_dim=-1): @@ -944,7 +974,8 @@ def cumsum(self, dim, *, dtype=None): @numpy_to_torch_style_args @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def cumsum_(self, dim, *, dtype=None): - self.ivy_array = self.cumsum(dim, dtype=dtype).ivy_array + ret = self.cumsum(dim, dtype=dtype) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") @@ -957,7 +988,8 @@ def neg(self): @with_unsupported_dtypes({"2.2 and below": ("bool",)}, "torch") def neg_(self): - self.ivy_array = torch_frontend.negative(self).ivy_array + ret = torch_frontend.negative(self) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self __neg__ = neg @@ -968,7 +1000,8 @@ def negative(self): @with_unsupported_dtypes({"2.0.1 and below": ("bool", "bfloat16")}, "torch") def negative_(self): - self.ivy_array = torch_frontend.negative(self).ivy_array + ret = torch_frontend.negative(self) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def int(self, memory_format=None): @@ -1007,7 +1040,8 @@ def squeeze(self, dim=None): @numpy_to_torch_style_args @with_unsupported_dtypes({"2.2 and below": ("uint16",)}, "torch") def squeeze_(self, dim=None): - self.ivy_array = self.squeeze(dim).ivy_array + ret = self.squeeze(dim) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def flip(self, dims): @@ -1023,7 +1057,8 @@ def tril(self, diagonal=0): return torch_frontend.tril(self, diagonal=diagonal) def tril_(self, diagonal=0): - self.ivy_array = self.tril(diagonal=diagonal).ivy_array + ret = self.tril(diagonal=diagonal) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def index_select(self, dim, index): @@ -1035,7 +1070,8 @@ def clamp(self, min=None, max=None): @with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") def clamp_(self, min=None, max=None): - self.ivy_array = self.clamp(min=min, max=max).ivy_array + ret = self.clamp(min=min, max=max) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes( @@ -1045,7 +1081,8 @@ def clamp_min(self, min=None): return torch_frontend.clamp(self, min=min) def clamp_min_(self, min=None): - self.ivy_array = self.clamp_min(min).ivy_array + ret = self.clamp_min(min) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") @@ -1058,12 +1095,14 @@ def rsqrt(self): @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def rsqrt_(self): - self.ivy_array = self.rsqrt().ivy_array + ret = self.rsqrt() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def sqrt_(self): - self.ivy_array = self.sqrt().ivy_array + ret = self.sqrt() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def where(self, condition, other): @@ -1083,7 +1122,8 @@ def masked_fill(self, mask, value): ) def masked_fill_(self, mask, value): - self.ivy_array = self.masked_fill(mask, value).ivy_array + ret = self.masked_fill(mask, value) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def masked_select(self, mask): @@ -1103,25 +1143,24 @@ def masked_scatter_(self, mask, source): flat_source = torch_frontend.flatten(source) indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1) flat_self.scatter_(0, indices, flat_source[: indices.shape[0]]) - self.ivy_array = flat_self.reshape(self.shape).ivy_array + ret = flat_self.reshape(self.shape) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add_(self, dim, index, source, *, alpha=1): - self.ivy_array = torch_frontend.index_add( - self, dim, index, source, alpha=alpha - ).ivy_array + ret = torch_frontend.index_add(self, dim, index, source, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def index_add(self, dim, index, source, *, alpha=1): - return torch_frontend.index_add( - self._ivy_array, dim, index, source, alpha=alpha - ) + return torch_frontend.index_add(self.ivy_array, dim, index, source, alpha=alpha) @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def acosh_(self): - self.ivy_array = self.acosh().ivy_array + ret = self.acosh() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") @@ -1134,7 +1173,8 @@ def sigmoid(self): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def sigmoid_(self): - self.ivy_array = self.sigmoid().ivy_array + ret = self.sigmoid() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -1172,24 +1212,29 @@ def remainder(self, other, *, out=None): {"2.2 and below": ("float16", "float32", "float64", "bfloat16")}, "torch" ) def reciprocal_(self): - self.ivy_array = torch_frontend.reciprocal(self).ivy_array + ret = torch_frontend.reciprocal(self) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def remainder_(self, other, *, out=None): - self.ivy_array = torch_frontend.remainder(self, other, out=out).ivy_array + ret = torch_frontend.remainder(self, other, out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def bitwise_not_(self): - self.ivy_array = self.bitwise_not().ivy_array + ret = self.bitwise_not() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def bitwise_and_(self, other): - self.ivy_array = self.bitwise_and(other).ivy_array + ret = self.bitwise_and(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def atan2_(self, other): - self.ivy_array = self.atan2(other).ivy_array + ret = self.atan2(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") @@ -1215,7 +1260,8 @@ def trunc(self): @with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") def trunc_(self): - self.ivy_array = self.trunc().ivy_array + ret = self.trunc() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") @@ -1224,28 +1270,30 @@ def fix(self): @with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") def fix_(self): - self.ivy_array = self.fix().ivy_array + ret = self.fix() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def isinf(self): - return torch_frontend.isinf(self._ivy_array) + return torch_frontend.isinf(self.ivy_array) def is_complex(self): - return torch_frontend.is_complex(self._ivy_array) + return torch_frontend.is_complex(self.ivy_array) @with_unsupported_dtypes({"2.2 and below": ("uint16", "bfloat16")}, "torch") def is_floating_point(self): - return torch_frontend.is_floating_point(self._ivy_array) + return torch_frontend.is_floating_point(self.ivy_array) @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def isreal(self): - return torch_frontend.isreal(self._ivy_array) + return torch_frontend.isreal(self.ivy_array) def addr(self, vec1, vec2, *, beta=1, alpha=1, out=None): return torch_frontend.addr(self, vec1, vec2, beta=beta, alpha=alpha, out=out) def addr_(self, vec1, vec2, *, beta=1, alpha=1): - self.ivy_array = self.addr(vec1, vec2, beta=beta, alpha=alpha).ivy_array + ret = self.addr(vec1, vec2, beta=beta, alpha=alpha) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") @@ -1254,15 +1302,16 @@ def dot(self, tensor): @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") def bernoulli(self, *, generator=None, out=None): - return torch_frontend.bernoulli(self._ivy_array, generator=generator, out=out) + return torch_frontend.bernoulli(self.ivy_array, generator=generator, out=out) @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") def bernoulli_(self, p, *, generator=None, out=None): - self.ivy_array = torch_frontend.bernoulli( + ret = torch_frontend.bernoulli( torch_frontend.full(self.shape, p, dtype=torch_frontend.float64), generator=generator, out=out, - ).ivy_array + ) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def numel(self): @@ -1466,7 +1515,8 @@ def bitwise_xor(self, other): return torch_frontend.bitwise_xor(self, other) def bitwise_xor_(self, other): - self.ivy_array = self.bitwise_xor(other).ivy_array + ret = self.bitwise_xor(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def item(self): @@ -1511,13 +1561,15 @@ def expm1(self): {"2.2 and below": ("bfloat16", "float16", "complex")}, "torch" ) def expm1_(self): - self.ivy_array = torch_frontend.expm1(self).ivy_array + ret = torch_frontend.expm1(self) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self # fmt: off @with_unsupported_dtypes({"2.2 and below": ("int8", "int16", "int32", "int64", "uint8", "bool", "float16",)},"torch",) # noqa def exp_(self): - self.ivy_array = self.exp().ivy_array + ret = self.exp() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self # fmt: on @@ -1526,14 +1578,14 @@ def mul(self, other): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def ceil_(self): - self.ivy_array = torch_frontend.ceil(self).ivy_array + ret = torch_frontend.ceil(self) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def mul_(self, other): - self.ivy_array = self.mul(other).ivy_array - # the return dtype is the same as the input dtype - self.ivy_array = self.to(self.dtype).ivy_array + ret = self.mul(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, "torch") @@ -1542,7 +1594,8 @@ def round(self, *, decimals=0): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, "torch") def round_(self, *, decimals=0): - self.ivy_array = self.round(decimals=decimals).ivy_array + ret = self.round(decimals=decimals) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @numpy_to_torch_style_args @@ -1561,7 +1614,7 @@ def fill_(self, value): ret = torch_frontend.full_like( self, value, dtype=self.dtype, device=self.device ) - self.ivy_array = ivy.inplace_update(self.ivy_array, ret) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def nonzero(self, as_tuple=False): @@ -1572,7 +1625,7 @@ def mm(self, mat2): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "float16")}, "torch") def square(self): - return torch_frontend.square(self._ivy_array) + return torch_frontend.square(self.ivy_array) @with_supported_dtypes( { @@ -1592,22 +1645,24 @@ def square(self): "torch", ) def square_(self): - self.ivy_array = torch_frontend.square(self._ivy_array).ivy_array + ret = torch_frontend.square(self.ivy_array) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def log10(self): - return torch_frontend.log10(self._ivy_array) + return torch_frontend.log10(self.ivy_array) @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def log10_(self): - self.ivy_array = self.log10().ivy_array + ret = self.log10() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("uint16",)}, "torch") def zero_(self): ret = torch_frontend.zeros_like(self) - self.ivy_array = ivy.inplace_update(self.ivy_array, ret) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def short(self, memory_format=None): @@ -1623,24 +1678,27 @@ def div(self, other, *, rounding_mode=None): return torch_frontend.div(self, other, rounding_mode=rounding_mode) def div_(self, other, *, rounding_mode=None): - self.ivy_array = self.div(other, rounding_mode=rounding_mode).ivy_array + ret = self.div(other, rounding_mode=rounding_mode) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_supported_dtypes( {"2.2 and below": ("float16", "float32", "float64", "bfloat16")}, "torch" ) def true_divide_(self, other): - self.ivy_array = self.div(other, rounding_mode=None).ivy_array + ret = self.div(other, rounding_mode=None) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def normal_(self, mean=0, std=1, *, generator=None): - self.ivy_array = ivy.random_normal( + ret = ivy.random_normal( mean=mean, std=std, shape=self.ivy_array.shape, dtype=self.dtype, device=self.device, ) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") @@ -1653,18 +1711,20 @@ def addcmul(self, tensor1, tensor2, *, value=1): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def addcmul_(self, tensor1, tensor2, *, value=1): - self.ivy_array = self.addcmul(tensor1, tensor2, value=value).ivy_array + ret = self.addcmul(tensor1, tensor2, value=value) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self sign_decorator_dtypes = ("float16", "complex", "bool") @with_unsupported_dtypes({"2.2 and below": sign_decorator_dtypes}, "torch") def sign(self): - return torch_frontend.sign(self._ivy_array) + return torch_frontend.sign(self.ivy_array) @with_unsupported_dtypes({"2.2 and below": sign_decorator_dtypes}, "torch") def sign_(self): - self.ivy_array = self.sign().ivy_array + ret = self.sign() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @numpy_to_torch_style_args @@ -1679,14 +1739,15 @@ def fmod(self, other, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") def fmod_(self, other): - self.ivy_array = self.fmod(other).ivy_array + ret = self.fmod(other) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def norm(self, p="fro", dim=None, keepdim=False, dtype=None): return torch_frontend.norm(self, p=p, dim=dim, keepdim=keepdim, dtype=dtype) def tolist(self): - return self._ivy_array.to_list() + return self.ivy_array.to_list() @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def multiply(self, other, *, out=None): @@ -1694,7 +1755,8 @@ def multiply(self, other, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def multiply_(self, other, *, out=None): - self.ivy_array = torch_frontend.multiply(self, other, out=out).ivy_array + ret = torch_frontend.multiply(self, other, out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @numpy_to_torch_style_args @@ -1706,13 +1768,14 @@ def topk(self, k, dim=None, largest=True, sorted=True): @with_unsupported_dtypes({"2.2 and below": rshift_dtypes}, "torch") def bitwise_right_shift(self, other, *, out=None): - return torch_frontend.bitwise_right_shift(self._ivy_array, other) + return torch_frontend.bitwise_right_shift(self.ivy_array, other) @with_supported_dtypes( {"2.2 and below": ("uint8", "int8", "int32", "int64")}, "torch" ) def bitwise_right_shift_(self, other, *, out=None): - self.ivy_array = self.bitwise_right_shift(other, out=out).ivy_array + ret = self.bitwise_right_shift(other, out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("float16", "bfloat16")}, "torch") @@ -1730,7 +1793,8 @@ def copysign(self, other, *, out=None): {"2.2 and below": ("float16", "float32", "float64")}, "torch" ) def copysign_(self, other, *, out=None): - self.ivy_array = self.copysign(other, out=out).ivy_array + ret = self.copysign(other, out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes( @@ -1741,7 +1805,8 @@ def greater(self, other, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "bool")}, "torch") def greater_(self, other): - self.ivy_array = ivy.astype(self.greater(other).ivy_array, self.dtype) + ret = ivy.astype(self.greater(other).ivy_array, self.dtype) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self @with_unsupported_dtypes( @@ -1752,7 +1817,10 @@ def greater_equal(self, other, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "bool")}, "torch") def greater_equal_(self, other): - self.ivy_array = ivy.astype(self.greater_equal(other).ivy_array, self.dtype) + ret = self.greater_equal(other) + self.ivy_array = ivy.inplace_update( + self.ivy_array, ivy.astype(ret.ivy_array, self.dtype) + ) return self @with_unsupported_dtypes( @@ -1763,7 +1831,10 @@ def less(self, other, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "bool")}, "torch") def less_(self, other): - self.ivy_array = ivy.astype(self.less(other).ivy_array, self.dtype) + ret = self.less(other) + self.ivy_array = ivy.inplace_update( + self.ivy_array, ivy.astype(ret.ivy_array, self.dtype) + ) return self @with_unsupported_dtypes( @@ -1774,13 +1845,17 @@ def less_equal(self, other, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("bfloat16", "bool")}, "torch") def less_equal_(self, other): - self.ivy_array = ivy.astype(self.less_equal(other).ivy_array, self.dtype) + ret = self.less_equal(other) + self.ivy_array = ivy.inplace_update( + self.ivy_array, ivy.astype(ret.ivy_array, self.dtype) + ) return self @with_unsupported_dtypes({"2.2 and below": ("bfloat16",)}, "torch") def eq_(self, other): - self.ivy_array = ivy.astype( - torch_frontend.eq(self, other).ivy_array, self.dtype + ret = torch_frontend.eq(self, other) + self.ivy_array = ivy.inplace_update( + self.ivy_array, ivy.astype(ret.ivy_array, self.dtype) ) return self @@ -1816,8 +1891,9 @@ def log1p(self): @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") def log1p_(self): promoted_type = ivy.promote_types(self.dtype, "float32") - res = torch_frontend.log1p(self) - self.ivy_array = res.to(promoted_type).ivy_array + ret = torch_frontend.log1p(self) + ret = ret.to(promoted_type) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def baddbmm(self, batch1, batch2, *, beta=1, alpha=1): @@ -1826,9 +1902,10 @@ def baddbmm(self, batch1, batch2, *, beta=1, alpha=1): ) def baddbmm_(self, batch1, batch2, *, beta=1, alpha=1): - self.ivy_array = torch_frontend.baddbmm( + ret = torch_frontend.baddbmm( self, batch1=batch1, batch2=batch2, beta=beta, alpha=alpha - ).ivy_array + ) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self def bmm(self, mat2): @@ -1836,7 +1913,8 @@ def bmm(self, mat2): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def floor_(self): - self.ivy_array = self.floor().ivy_array + ret = self.floor() + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes( @@ -1868,7 +1946,8 @@ def gather(self, dim, index): {"2.2 and below": ("float32", "float64", "int32", "int64")}, "torch" ) def scatter_add_(self, dim, index, src): - self.ivy_array = ivy.put_along_axis(self.ivy_array, index, src, dim, mode="sum") + ret = ivy.put_along_axis(self.ivy_array, index, src, dim, mode="sum") + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self @with_supported_dtypes( @@ -1883,9 +1962,8 @@ def scatter_(self, dim, index, src, *, reduce=None): "multiply": "mul", } reduce = mode_mappings.get(reduce, reduce) - self.ivy_array = ivy.put_along_axis( - self.ivy_array, index, src, dim, mode=reduce - ) + ret = ivy.put_along_axis(self.ivy_array, index, src, dim, mode=reduce) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self @with_supported_dtypes( @@ -1894,9 +1972,8 @@ def scatter_(self, dim, index, src, *, reduce=None): def scatter_reduce_(self, dim, index, src, reduce, *, include_self=True): if reduce == "prod": reduce = "mul" - self.ivy_array = ivy.put_along_axis( - self.ivy_array, index, src, dim, mode=reduce - ) + ret = ivy.put_along_axis(self.ivy_array, index, src, dim, mode=reduce) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self @with_supported_dtypes( @@ -1925,9 +2002,8 @@ def movedim(self, source, destination): @with_unsupported_dtypes({"2.2 and below": ("float16",)}, "torch") def addcdiv_(self, tensor1, tensor2, *, value=1): - self.ivy_array = self.addcdiv( - tensor1=tensor1, tensor2=tensor2, value=value - ).ivy_array + ret = self.addcdiv(tensor1=tensor1, tensor2=tensor2, value=value) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_supported_dtypes( @@ -1957,7 +2033,8 @@ def tile(self, *reps): def apply_(self, callable, /): if self.device != "cpu": raise ValueError("apply_ is only supported on cpu tensors") - self.ivy_array = callable(self.ivy_array) + ret = callable(self.ivy_array) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self def requires_grad_(self, requires_grad=True): @@ -2079,7 +2156,8 @@ def lcm(self, other, *, out=None): "torch", ) def lcm_(self, other, *, out=None): - self.ivy_array = self.lcm(other, out=out).ivy_array + ret = self.lcm(other, out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes( @@ -2097,7 +2175,8 @@ def lcm_(self, other, *, out=None): "torch", ) def triu_(self, diagonal=0): - self.ivy_array = torch_frontend.triu(self, diagonal).ivy_array + ret = torch_frontend.triu(self, diagonal) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes( @@ -2136,9 +2215,10 @@ def random_( to = ivy.finfo(self.dtype).max else: to = ivy.iinfo(self.dtype).max - self.ivy_array = ivy.random_uniform( + ret = ivy.random_uniform( low=from_, high=to, shape=self.size(), dtype=self.dtype ) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self.ivy_array @with_unsupported_dtypes( @@ -2157,14 +2237,12 @@ def uniform_(self, from_=0, to=1, *, generator=None): ret = ivy.random_uniform( low=from_, high=to, shape=self.shape, dtype=self.dtype, seed=generator ) - self._ivy_array = ivy.inplace_update( - self._ivy_array, ivy.astype(ret, self._ivy_array.dtype) - ) + self.ivy_array = ivy.inplace_update(self.ivy_array, ivy.astype(ret, self.dtype)) return self @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") def frac(self, name=None): - return torch_frontend.frac(self._ivy_array) + return torch_frontend.frac(self.ivy_array) @with_unsupported_dtypes( { @@ -2189,7 +2267,8 @@ def sinc(self): "torch", ) def sinc_(self): - self.ivy_array = torch_frontend.sinc(self).ivy_array + ret = torch_frontend.sinc(self) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes({"2.2 and below": ("uint8",)}, "torch") @@ -2263,7 +2342,8 @@ def triu(self, diagonal=0): "torch", ) def xlogy_(self, *, other, out=None): - self.ivy_array = torch_frontend.xlogy(self, other, out=out).ivy_array + ret = torch_frontend.xlogy(self, other, out=out) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array) return self @with_unsupported_dtypes( @@ -2341,7 +2421,8 @@ def rad2deg(self, *, out=None): return torch_frontend.rad2deg(self, out=out) def fill_diagonal_(self, fill_value, wrap=False): - self._ivy_array = ivy.fill_diagonal(self._ivy_array, fill_value, wrap=wrap) + ret = ivy.fill_diagonal(self.ivy_array, fill_value, wrap=wrap) + self.ivy_array = ivy.inplace_update(self.ivy_array, ret) return self @with_supported_dtypes( @@ -2380,8 +2461,8 @@ def erfinv(self, *, out=None): @with_unsupported_dtypes({"2.2 and below": ("float16", "complex")}, "torch") def erfinv_(self, *, out=None): ret = self.erfinv(out=out) - self._ivy_array = ivy.inplace_update( - self._ivy_array, ivy.astype(ret.ivy_array, self._ivy_array.dtype) + self.ivy_array = ivy.inplace_update( + self.ivy_array, ivy.astype(ret.ivy_array, self.dtype) ) return self