Skip to content

Commit

Permalink
Check if nms dynamic shapes is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed May 28, 2024
1 parent ba10278 commit 5f985b3
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import unittest


def XLAExperimentalContains(feat):
experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":")
return feat in experimental


class MetricsTest(unittest.TestCase):

def test_clear_counters(self):
Expand Down Expand Up @@ -240,18 +245,19 @@ def getAndAssertFallbackOpsLenEquals(count):
met.clear_all()
getAndAssertFallbackOpsLenEquals(0)

# Run torchvision operations as fallback.
import torchvision
scores = torch.rand(N).to(xm.xla_device())
# NMS doesn't have a PyTorch/XLA implementation without dynamic shapes.
torchvision.ops.nms(xys, scores, 0.5)
# remove_small_boxes is not implemented in C++. It calls other PyTorch
# operations. One of them, nonzero, is a fallback operation.
torchvision.ops.remove_small_boxes(
xys, torch.median(torch.stack((width, height))))
ops = getAndAssertFallbackOpsLenEquals(3)
self.assertEqual(
set(ops), {"aten::nonzero", "aten::median", "torchvision::nms"})
if not XLAExperimentalContains("nms"):
# Run torchvision operations as fallback.
import torchvision
scores = torch.rand(N).to(xm.xla_device())
# NMS doesn't have a PyTorch/XLA implementation without dynamic shapes.
torchvision.ops.nms(xys, scores, 0.5)
# remove_small_boxes is not implemented in C++. It calls other PyTorch
# operations. One of them, nonzero, is a fallback operation.
torchvision.ops.remove_small_boxes(
xys, torch.median(torch.stack((width, height))))
ops = getAndAssertFallbackOpsLenEquals(3)
self.assertEqual(
set(ops), {"aten::nonzero", "aten::median", "torchvision::nms"})


if __name__ == '__main__':
Expand Down

0 comments on commit 5f985b3

Please sign in to comment.