Skip to content

Commit

Permalink
fix: Fix torch backend inplace update
Browse files Browse the repository at this point in the history
  • Loading branch information
hmahmood24 committed Sep 6, 2024
1 parent 03dab2b commit f07a1c8
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions ivy/functional/backends/torch/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ivy.func_wrapper import _update_torch_views, with_unsupported_dtypes

from ...ivy.general import _broadcast_to
from . import backend_version, is_variable
from . import backend_version

torch_scatter = None

Expand Down Expand Up @@ -378,10 +378,8 @@ def inplace_update(
if keep_input_dtype:
val = ivy.astype(val, x.dtype)
(x_native, val_native), _ = ivy.args_to_native(x, val)
if is_variable(x_native):
x_native.copy_ = val_native
else:
x_native[()] = val_native
with torch.no_grad():
x_native.copy_(val_native)
x_native = x_native.to(val_native.device)
if ivy.is_native_array(x):
return x_native
Expand Down

0 comments on commit f07a1c8

Please sign in to comment.