From f50e951960816bdd3063fddac70aed4c47eeb8ec Mon Sep 17 00:00:00 2001 From: Milad Mohammadi Date: Mon, 17 Jun 2024 14:04:59 -0700 Subject: [PATCH] `select` op support in torchxla2 (#7293) --- experimental/torch_xla2/test/test_ops.py | 1 - experimental/torch_xla2/torch_xla2/ops/jaten.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 5e8f3ea10b9..ad07275ef52 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -290,7 +290,6 @@ "scatter", "scatter_reduce", "searchsorted", - "select", "select_scatter", "signbit", "softmax", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index b28eed951f1..4da33b58da2 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -103,6 +103,9 @@ def _aten_index_copy(x, dim, indexes, source): @op(torch.ops.aten.select) +def _aten_select(x, dim, indexes): + return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) + @op(torch.ops.aten.index_select) @op(torch.ops.aten.select_copy) def _aten_index_select(x, dim, indexes):