From 6718ded65fec9ab313d393ef7c247dc7c012786e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 24 May 2024 20:47:49 -0300 Subject: [PATCH 1/5] Add test. --- test/test_metrics.py | 48 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/test_metrics.py b/test/test_metrics.py index 87c45949b32..d006fbf622b 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -205,6 +205,54 @@ def test_pybind_increment_counter(self): torch_xla._XLAC._xla_increment_counter('FakeCounter', 2) self.assertEqual(met.counter_value('FakeCounter'), 2) + def test_get_fallback_ops(self): + + def getAndAssertFallbackOpsLenEquals(count): + fallback_ops = met.executed_fallback_ops() + fallback_ops_number = len(fallback_ops) + self.assertEqual( + fallback_ops_number, + count, + msg=f"found {fallback_ops_number}: {fallback_ops}") + return fallback_ops + + # Reset all metrics, and make sure we don't start with any fallback ops. + met.clear_all() + getAndAssertFallbackOpsLenEquals(0) + + # Create N boxes in the format XYXY. + # This should not run any fallback ops. + N = 10 + x = torch.rand(N, 1).to(xm.xla_device()) + y = torch.rand(N, 1).to(xm.xla_device()) + width = torch.rand(N, 1).to(xm.xla_device()) + height = torch.rand(N, 1).to(xm.xla_device()) + xys = torch.cat((x, x + width, y, y - height), dim=1) + getAndAssertFallbackOpsLenEquals(0) + + # tensor.item() is a fallback operation. + xys[0, 0].item() + ops = getAndAssertFallbackOpsLenEquals(1) + self.assertEqual(ops[0], "aten::_local_scalar_dense") + + # Reset all metrics, and make sure we also don't retrieve any + # fallback operations. + met.clear_all() + getAndAssertFallbackOpsLenEquals(0) + + # Run torchvision operations as fallback. + import torchvision + scores = torch.rand(N).to(xm.xla_device()) + # NMS doesn't have a PyTorch/XLA implementation without dynamic shapes. + torchvision.ops.nms(xys, scores, 0.5) + # remove_small_boxes is not implemented in C++. It calls other PyTorch + # operations. One of them, nonzero, is a fallback operation. + torchvision.ops.remove_small_boxes( + xys, torch.median(torch.stack((width, height)))) + ops = getAndAssertFallbackOpsLenEquals(3) + self.assertEqual( + set(ops), {"aten::nonzero", "aten::median", "torchvision::nms"}) + if __name__ == '__main__': test = unittest.main() From 56743cdad537cf9571467f7f884b30a3927f6c52 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 24 May 2024 20:48:00 -0300 Subject: [PATCH 2/5] Retrieve fallbacks from counters. --- torch_xla/csrc/aten_cpu_fallback.cpp | 12 ++++++++++++ torch_xla/csrc/aten_cpu_fallback.h | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/aten_cpu_fallback.cpp b/torch_xla/csrc/aten_cpu_fallback.cpp index d664c60114f..5e84e61ba1a 100644 --- a/torch_xla/csrc/aten_cpu_fallback.cpp +++ b/torch_xla/csrc/aten_cpu_fallback.cpp @@ -16,6 +16,18 @@ namespace torch_xla { static std::unordered_map _cpu_fallback_counters; +// Get all the executed fallback operations. +// In other words, get all of them whose counters are not zero. +std::vector GetFallbackOperations() { + std::vector fallback; + for (auto const& pair : _cpu_fallback_counters) { + if (pair.second->Value() != 0) { + fallback.push_back(pair.first); + } + } + return fallback; +} + void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { XLA_FN_TRACK(3); const auto name = c10::toString(op.operator_name()); diff --git a/torch_xla/csrc/aten_cpu_fallback.h b/torch_xla/csrc/aten_cpu_fallback.h index 572d4e1009a..706c7aa40a5 100644 --- a/torch_xla/csrc/aten_cpu_fallback.h +++ b/torch_xla/csrc/aten_cpu_fallback.h @@ -7,6 +7,8 @@ namespace torch_xla { void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); +std::vector GetFallbackOperations(); + } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_ATEN_CPU_FALLBACK_H_ From 9e9c6e8bd26e8a387c9fbf560ee1765865bfc652 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 24 May 2024 20:48:13 -0300 Subject: [PATCH 3/5] Add Python bindings. --- torch_xla/csrc/init_python_bindings.cpp | 2 ++ torch_xla/debug/metrics.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index a9240db692a..5d7707d5699 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -33,6 +33,7 @@ #include "pybind11/stl_bind.h" #include "torch_xla/csrc/XLANativeFunctions.h" #include "torch_xla/csrc/aten_autograd_ops.h" +#include "torch_xla/csrc/aten_cpu_fallback.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/dl_convertor.h" @@ -1781,6 +1782,7 @@ void InitXlaModuleBindings(py::module m) { } }, py::arg("devices")); + m.def("_get_executed_fallback_ops", []() { return GetFallbackOperations(); }); m.def("_xla_counter_names", []() { auto counter_names = torch::lazy::GetCounterNames(); auto xla_counter_names = runtime::metrics::GetCounterNames(); diff --git a/torch_xla/debug/metrics.py b/torch_xla/debug/metrics.py index 363c52a80da..11718e8376b 100644 --- a/torch_xla/debug/metrics.py +++ b/torch_xla/debug/metrics.py @@ -79,3 +79,8 @@ def short_metrics_report(counter_names: list = None, metric_names: list = None): 'TransferToDeviceTime', 'TransferFromDeviceTime' ] return torch_xla._XLAC._short_xla_metrics_report(counter_names, metric_names) + + +def executed_fallback_ops(): + """Retrieves a list of operations that were run in fallback mode.""" + return torch_xla._XLAC._get_executed_fallback_ops() From ba10278b1abce29d162718c78654e4d5afc049d7 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 24 May 2024 20:49:11 -0300 Subject: [PATCH 4/5] Make dynamo bridge call the new function. --- torch_xla/core/dynamo_bridge.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 624acb9cb6f..f6d7a3b6e00 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -100,15 +100,7 @@ def __call__(self, args): def get_fallback_ops(): - fallback_ops = [] - for opname in metrics.counter_names(): - if "aten::" not in opname: - continue - val = int(metrics.counter_value(opname)) - if val > 0: - fallback_ops.append(f"{opname}={val}") - - return fallback_ops + return metrics.executed_fallback_ops() # Checks that all input args that are tensors are on the same device. From 5f985b34f8cb3362f2793f82d098c79bbe4e7f0b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 28 May 2024 20:40:36 -0300 Subject: [PATCH 5/5] Check if `nms` dynamic shapes is enabled --- test/test_metrics.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/test/test_metrics.py b/test/test_metrics.py index d006fbf622b..409876d8d9d 100644 --- a/test/test_metrics.py +++ b/test/test_metrics.py @@ -9,6 +9,11 @@ import unittest +def XLAExperimentalContains(feat): + experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":") + return feat in experimental + + class MetricsTest(unittest.TestCase): def test_clear_counters(self): @@ -240,18 +245,19 @@ def getAndAssertFallbackOpsLenEquals(count): met.clear_all() getAndAssertFallbackOpsLenEquals(0) - # Run torchvision operations as fallback. - import torchvision - scores = torch.rand(N).to(xm.xla_device()) - # NMS doesn't have a PyTorch/XLA implementation without dynamic shapes. - torchvision.ops.nms(xys, scores, 0.5) - # remove_small_boxes is not implemented in C++. It calls other PyTorch - # operations. One of them, nonzero, is a fallback operation. - torchvision.ops.remove_small_boxes( - xys, torch.median(torch.stack((width, height)))) - ops = getAndAssertFallbackOpsLenEquals(3) - self.assertEqual( - set(ops), {"aten::nonzero", "aten::median", "torchvision::nms"}) + if not XLAExperimentalContains("nms"): + # Run torchvision operations as fallback. + import torchvision + scores = torch.rand(N).to(xm.xla_device()) + # NMS doesn't have a PyTorch/XLA implementation without dynamic shapes. + torchvision.ops.nms(xys, scores, 0.5) + # remove_small_boxes is not implemented in C++. It calls other PyTorch + # operations. One of them, nonzero, is a fallback operation. + torchvision.ops.remove_small_boxes( + xys, torch.median(torch.stack((width, height)))) + ops = getAndAssertFallbackOpsLenEquals(3) + self.assertEqual( + set(ops), {"aten::nonzero", "aten::median", "torchvision::nms"}) if __name__ == '__main__':