diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index 74244322840..d6b492dc694 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -5,7 +5,7 @@ BUILDTYPE="opt" VERB= FILTER= LOGFILE=/tmp/pytorch_cpp_test.log -XLA_EXPERIMENTAL="nonzero:masked_select" +XLA_EXPERIMENTAL="nonzero:masked_select:nms" BAZEL_REMOTE_CACHE="0" BAZEL_VERB="test" diff --git a/test/run_tests.sh b/test/run_tests.sh index 4d4bd530e27..8926318dc38 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -104,7 +104,7 @@ function run_xla_hlo_debug { function run_dynamic { echo "Running in DynamicShape mode: $@" - XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@" + XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@" } function run_eager_debug { diff --git a/test/test_operations.py b/test/test_operations.py index 7fb9f5bc3e3..ff32c268927 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -88,6 +88,12 @@ def onlyOnCUDA(fn): return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) +def onlyIfXLAExperimentalContains(feat): + experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":") + return unittest.skipIf(feat not in experimental, + f"XLA_EXPERIMENTAL={feat} required") + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -2454,6 +2460,7 @@ def test_dropout(self): # These tests were extracted and adapted from torchvision. # Source: vision/test/test_ops.py +@onlyIfXLAExperimentalContains("nms") class TestNMS(test_utils.XlaTestCase): def _reference_nms(self, boxes, scores, iou_threshold): diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 413951854d6..dc2f4e96dba 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -11,8 +11,8 @@ python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py -XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v -XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v +XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v +XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v python3 test/test_autocast.py python3 test/dynamo/test_dynamo.py python3 test/spmd/test_spmd_debugging.py diff --git a/torch_xla/csrc/xla_manual_registration.cpp b/torch_xla/csrc/xla_manual_registration.cpp index dc7df436ec7..6020ef6bc04 100644 --- a/torch_xla/csrc/xla_manual_registration.cpp +++ b/torch_xla/csrc/xla_manual_registration.cpp @@ -1,7 +1,9 @@ #include #include +#include "torch_xla/csrc/aten_cpu_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" +#include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/ops/nms.h" #include "torch_xla/csrc/ops/ops.h" #include "torch_xla/csrc/tensor_methods.h" @@ -11,10 +13,22 @@ namespace torch_xla { namespace manual { namespace { +struct NmsOp { + using schema = at::Tensor(const at::Tensor&, const at::Tensor&, double); + using ptr_schema = schema*; + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "torchvision::nms") + STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "") +}; + at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores, double iou_threshold) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if (!DebugUtil::ExperimentEnabled("nms")) { + return at::native::call_fallback_fn<&xla_cpu_fallback, NmsOp>::call( + boxes, scores, iou_threshold); + } + XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor."; XLA_CHECK_EQ(boxes.size(1), 4) << "nms(): boxes should be a 2D tensor of shape [N, 4].";