diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index f0aee024f75..64141a92f45 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -7,6 +7,7 @@ import torch_xla.utils.utils as xu import torch_xla.debug.metrics as met from torch_xla import runtime as xr +import torch_xla.debug.profiler as xp import torch.optim as optim import torch.nn as nn import torch._dynamo as dynamo @@ -91,6 +92,20 @@ def test_mark_step_after_dynamo(self): self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0]) +class DynamoProfilerTest(unittest.TestCase): + + def dummy_fn(self, a): + return torch.sin(a) + a + + def test_dynamo_with_trace(self): + dynamo_dummy = torch.compile( + self.dummy_fn, backend="openxla", fullgraph=True) + t = torch.randn(2, 3, 4, device=xm.xla_device()) + for i in range(10): + with xp.Trace('build_graph'): + t = dynamo_dummy(t) + + class DynamoInferenceBasicTest(unittest.TestCase): @classmethod diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index 379416ec73f..624acb9cb6f 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -429,7 +429,8 @@ def extract_internal(xla_model: torch.fx.GraphModule): for xla_arg in xla_model.xla_args: if isinstance(xla_arg, torch.Tensor): print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg)) - xm.mark_step() + # Don't reset the scope as we might be under some profiler trace scope. + xm.mark_step(reset_scope=False) (xla_args_sharding_spec, args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model) @@ -614,8 +615,9 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a): torch._functionalize_sync(a) - # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids - xm.mark_step() + # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids. + # Don't reset the scope as we might be under some profiler trace scope. + xm.mark_step(reset_scope=False) # Find tensor constructor nodes that create CPU tensors, and make # them create XLA tensors, where possible, instead. i.e. replace the diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index e70409e885b..7591e13af29 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1045,7 +1045,7 @@ def _run_step_closures(): return devctx -def mark_step(wait=False): +def mark_step(wait=False, reset_scope=True): if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False): print( 'torch_xla.core.xla_model::mark_step\n', @@ -1054,7 +1054,8 @@ def mark_step(wait=False): flush=True) torch_xla._XLAC._xla_step_marker( torch_xla._XLAC._xla_get_default_device(), [], - wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait)) + wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait), + reset_scope=reset_scope) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if is_master_ordinal(): diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 7a4a52ad1d2..805e467e44d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -458,12 +458,13 @@ void SyncLiveTensors(const std::string& device_str, } void StepMarker(const std::string& device_str, - const std::vector& devices, bool wait) { + const std::vector& devices, bool wait, + bool reset_scope) { tsl::profiler::TraceMe activity("StepMarker", tsl::profiler::TraceMeLevel::kInfo); torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str); XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait); - XLAGraphExecutor::Get()->MarkStep(device); + XLAGraphExecutor::Get()->MarkStep(device, reset_scope); bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false); if (TF_PREDICT_FALSE(debug_mode)) { std::string report = runtime::metrics::CreatePerformanceReport( @@ -1649,11 +1650,12 @@ void InitXlaModuleBindings(py::module m) { m.def( "_xla_step_marker", [](const std::string& device, const std::vector& devices, - bool wait) { + bool wait, bool reset_scope) { NoGilSection nogil; - StepMarker(device, devices, wait); + StepMarker(device, devices, wait, reset_scope); }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true, + py::arg("reset_scope") = true); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices, diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 0eddefc39f3..fe12e392ea4 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -404,14 +404,17 @@ void XLAGraphExecutor::SyncLiveTensorsGraph( SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true); } -void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) { +void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device, + bool reset_scope) { // TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to // remain as XLA_COUNTER to support // runtime::metrics::CreatePerformanceReport(). For more information, see // NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER]. XLA_COUNTER("MarkStep", 1); DeviceContextArena::Get()->MarkStep(device); - torch::lazy::ScopePusher::ResetScopes(); + if (reset_scope) { + torch::lazy::ScopePusher::ResetScopes(); + } ResetTrimCounter(); } diff --git a/torch_xla/csrc/xla_graph_executor.h b/torch_xla/csrc/xla_graph_executor.h index ca874274a98..b2b76b8ae33 100644 --- a/torch_xla/csrc/xla_graph_executor.h +++ b/torch_xla/csrc/xla_graph_executor.h @@ -134,7 +134,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor { // Marks an execution step, which allows the tensor framework to understand // the computation boundaries. // Override to use our own DeviceContextArena. - void MarkStep(const torch::lazy::BackendDevice& device) final; + void MarkStep(const torch::lazy::BackendDevice& device, bool reset_scope); // Waits for all the outstanding operations on all the supplied devices. // If devices is empty, the wait will happen for all local devices.