From 65fd11c3702af55465d730de28610296254d3640 Mon Sep 17 00:00:00 2001 From: Ogban Ugot Date: Tue, 13 Feb 2024 23:29:04 +0100 Subject: [PATCH] fix torch frontend unique --- ivy/functional/backends/torch/set.py | 2 +- ivy/functional/frontends/torch/tensor.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/torch/set.py b/ivy/functional/backends/torch/set.py index 03e0a6077bb7c..a7eee8289918f 100644 --- a/ivy/functional/backends/torch/set.py +++ b/ivy/functional/backends/torch/set.py @@ -66,7 +66,7 @@ def unique_all( [i[0] for i in sorted(enumerate(values_), key=lambda x: tuple(x[1]))] ) ivy_torch = ivy.current_backend() - values = ivy_torch.gather(values, sort_idx, axis=axis) + values = values.index_select(dim=axis, index=sort_idx) counts = ivy_torch.gather(counts, sort_idx) indices = ivy_torch.gather(indices, sort_idx) inv_sort_idx = ivy_torch.invert_permutation(sort_idx) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 91c1f5b86db55..faef7db5d4e1c 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -2242,6 +2242,7 @@ def ne_(self, other): "float16", "complex128", "complex64", + "bool", ) }, "torch",