diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 5e8f3ea10b9..ad07275ef52 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -290,7 +290,6 @@ "scatter", "scatter_reduce", "searchsorted", - "select", "select_scatter", "signbit", "softmax", diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index b28eed951f1..4da33b58da2 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -103,6 +103,9 @@ def _aten_index_copy(x, dim, indexes, source): @op(torch.ops.aten.select) +def _aten_select(x, dim, indexes): + return jax.lax.index_in_dim(x, index=indexes, axis=dim, keepdims=False) + @op(torch.ops.aten.index_select) @op(torch.ops.aten.select_copy) def _aten_index_select(x, dim, indexes):