From f07a1c854debcf32775300b2d9df3f22fba3fab2 Mon Sep 17 00:00:00 2001 From: Haris Mahmood Date: Fri, 6 Sep 2024 17:12:24 +0500 Subject: [PATCH] fix: Fix torch backend inplace update --- ivy/functional/backends/torch/general.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/torch/general.py b/ivy/functional/backends/torch/general.py index 55187d9f2ceb7..aeb8b2ed45a7a 100644 --- a/ivy/functional/backends/torch/general.py +++ b/ivy/functional/backends/torch/general.py @@ -20,7 +20,7 @@ from ivy.func_wrapper import _update_torch_views, with_unsupported_dtypes from ...ivy.general import _broadcast_to -from . import backend_version, is_variable +from . import backend_version torch_scatter = None @@ -378,10 +378,8 @@ def inplace_update( if keep_input_dtype: val = ivy.astype(val, x.dtype) (x_native, val_native), _ = ivy.args_to_native(x, val) - if is_variable(x_native): - x_native.copy_ = val_native - else: - x_native[()] = val_native + with torch.no_grad(): + x_native.copy_(val_native) x_native = x_native.to(val_native.device) if ivy.is_native_array(x): return x_native