Skip to content

Commit

Permalink
fix (torch frontends)(tensor.py): fixing the implementation of `torch…
Browse files Browse the repository at this point in the history
….Tensor.masked_scatter` and `torch.Tensor.masked_scatter_`. The previous implementation was not handling broadcasting of the inputs and was thus was producing incorrect results.
  • Loading branch information
YushaArif99 committed Oct 7, 2024
1 parent 3772d1a commit 6f9bf51
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,19 +1140,23 @@ def masked_select(self, mask):
return torch_frontend.masked_select(self, mask)

def masked_scatter(self, mask, source):
self = torch_frontend.broadcast_to(self, source.shape)
mask = torch_frontend.broadcast_to(mask, self.shape)
flat_self = torch_frontend.flatten(self.clone())
flat_mask = torch_frontend.flatten(mask)
flat_source = torch_frontend.flatten(source)
indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1)
flat_self.scatter_(0, indices, flat_source[: indices.shape[0]])
flat_self[indices] = flat_source[:indices.numel()]
return flat_self.reshape(self.shape)

def masked_scatter_(self, mask, source):
self = torch_frontend.broadcast_to(self, source.shape)
mask = torch_frontend.broadcast_to(mask, self.shape)
flat_self = torch_frontend.flatten(self.clone())
flat_mask = torch_frontend.flatten(mask)
flat_source = torch_frontend.flatten(source)
indices = torch_frontend.squeeze(torch_frontend.nonzero(flat_mask), -1)
flat_self.scatter_(0, indices, flat_source[: indices.shape[0]])
flat_self[indices] = flat_source[:indices.numel()]
ret = flat_self.reshape(self.shape)
self.ivy_array = ivy.inplace_update(self.ivy_array, ret.ivy_array)
return self
Expand Down

0 comments on commit 6f9bf51

Please sign in to comment.