Skip to content

Commit

Permalink
Make nms fallback by default. (#6933)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Apr 20, 2024
1 parent 9ba844a commit b06c9c7
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 4 deletions.
2 changes: 1 addition & 1 deletion test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ BUILDTYPE="opt"
VERB=
FILTER=
LOGFILE=/tmp/pytorch_cpp_test.log
XLA_EXPERIMENTAL="nonzero:masked_select"
XLA_EXPERIMENTAL="nonzero:masked_select:nms"
BAZEL_REMOTE_CACHE="0"
BAZEL_VERB="test"

Expand Down
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function run_xla_hlo_debug {

function run_dynamic {
echo "Running in DynamicShape mode: $@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" run_test "$@"
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter:nms" run_test "$@"
}

function run_eager_debug {
Expand Down
7 changes: 7 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def onlyOnCUDA(fn):
return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn)


def onlyIfXLAExperimentalContains(feat):
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
return unittest.skipIf(feat not in experimental,
f"XLA_EXPERIMENTAL={feat} required")


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -2454,6 +2460,7 @@ def test_dropout(self):

# These tests were extracted and adapted from torchvision.
# Source: vision/test/test_ops.py
@onlyIfXLAExperimentalContains("nms")
class TestNMS(test_utils.XlaTestCase):

def _reference_nms(self, boxes, scores, iou_threshold):
Expand Down
4 changes: 2 additions & 2 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select python3 test/ds/test_dynamic_shapes.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/dynamo/test_dynamo.py
python3 test/spmd/test_spmd_debugging.py
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/xla_manual_registration.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "torch_xla/csrc/aten_cpu_fallback.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/debug_util.h"
#include "torch_xla/csrc/ops/nms.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/tensor_methods.h"
Expand All @@ -11,10 +13,22 @@ namespace torch_xla {
namespace manual {
namespace {

struct NmsOp {
using schema = at::Tensor(const at::Tensor&, const at::Tensor&, double);
using ptr_schema = schema*;
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "torchvision::nms")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "")
};

at::Tensor nms_kernel(const at::Tensor& boxes, const at::Tensor& scores,
double iou_threshold) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

if (!DebugUtil::ExperimentEnabled("nms")) {
return at::native::call_fallback_fn<&xla_cpu_fallback, NmsOp>::call(
boxes, scores, iou_threshold);
}

XLA_CHECK_EQ(boxes.dim(), 2) << "nms(): boxes should be a 2D tensor.";
XLA_CHECK_EQ(boxes.size(1), 4)
<< "nms(): boxes should be a 2D tensor of shape [N, 4].";
Expand Down

0 comments on commit b06c9c7

Please sign in to comment.