Skip to content

Commit

Permalink
Fix runtime error when run dynamo with a profiler scope (#6913)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Apr 11, 2024
1 parent a170ffe commit 7dfbef0
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 13 deletions.
15 changes: 15 additions & 0 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch_xla.utils.utils as xu
import torch_xla.debug.metrics as met
from torch_xla import runtime as xr
import torch_xla.debug.profiler as xp
import torch.optim as optim
import torch.nn as nn
import torch._dynamo as dynamo
Expand Down Expand Up @@ -91,6 +92,20 @@ def test_mark_step_after_dynamo(self):
self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0])


class DynamoProfilerTest(unittest.TestCase):

def dummy_fn(self, a):
return torch.sin(a) + a

def test_dynamo_with_trace(self):
dynamo_dummy = torch.compile(
self.dummy_fn, backend="openxla", fullgraph=True)
t = torch.randn(2, 3, 4, device=xm.xla_device())
for i in range(10):
with xp.Trace('build_graph'):
t = dynamo_dummy(t)


class DynamoInferenceBasicTest(unittest.TestCase):

@classmethod
Expand Down
8 changes: 5 additions & 3 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ def extract_internal(xla_model: torch.fx.GraphModule):
for xla_arg in xla_model.xla_args:
if isinstance(xla_arg, torch.Tensor):
print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
xm.mark_step()
# Don't reset the scope as we might be under some profiler trace scope.
xm.mark_step(reset_scope=False)
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
Expand Down Expand Up @@ -614,8 +615,9 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
if isinstance(a, torch.Tensor) and torch._is_functional_tensor(a):
torch._functionalize_sync(a)

# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
xm.mark_step()
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids.
# Don't reset the scope as we might be under some profiler trace scope.
xm.mark_step(reset_scope=False)

# Find tensor constructor nodes that create CPU tensors, and make
# them create XLA tensors, where possible, instead. i.e. replace the
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ def _run_step_closures():
return devctx


def mark_step(wait=False):
def mark_step(wait=False, reset_scope=True):
if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
print(
'torch_xla.core.xla_model::mark_step\n',
Expand All @@ -1054,7 +1054,8 @@ def mark_step(wait=False):
flush=True)
torch_xla._XLAC._xla_step_marker(
torch_xla._XLAC._xla_get_default_device(), [],
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait),
reset_scope=reset_scope)
# Only emit metrics from the first local device index, to avoid emitting the
# same values from different threads.
if is_master_ordinal():
Expand Down
12 changes: 7 additions & 5 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,13 @@ void SyncLiveTensors(const std::string& device_str,
}

void StepMarker(const std::string& device_str,
const std::vector<std::string>& devices, bool wait) {
const std::vector<std::string>& devices, bool wait,
bool reset_scope) {
tsl::profiler::TraceMe activity("StepMarker",
tsl::profiler::TraceMeLevel::kInfo);
torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str);
XLAGraphExecutor::Get()->SyncLiveTensorsGraph(&device, devices, wait);
XLAGraphExecutor::Get()->MarkStep(device);
XLAGraphExecutor::Get()->MarkStep(device, reset_scope);
bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false);
if (TF_PREDICT_FALSE(debug_mode)) {
std::string report = runtime::metrics::CreatePerformanceReport(
Expand Down Expand Up @@ -1649,11 +1650,12 @@ void InitXlaModuleBindings(py::module m) {
m.def(
"_xla_step_marker",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait) {
bool wait, bool reset_scope) {
NoGilSection nogil;
StepMarker(device, devices, wait);
StepMarker(device, devices, wait, reset_scope);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
py::arg("reset_scope") = true);
m.def("_get_stablehlo",
[](const std::vector<at::Tensor>& tensors, const std::string& device,
const std::vector<std::string>& devices,
Expand Down
7 changes: 5 additions & 2 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,17 @@ void XLAGraphExecutor::SyncLiveTensorsGraph(
SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
}

void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device,
bool reset_scope) {
// TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to
// remain as XLA_COUNTER to support
// runtime::metrics::CreatePerformanceReport(). For more information, see
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
XLA_COUNTER("MarkStep", 1);
DeviceContextArena::Get()->MarkStep(device);
torch::lazy::ScopePusher::ResetScopes();
if (reset_scope) {
torch::lazy::ScopePusher::ResetScopes();
}
ResetTrimCounter();
}

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
// Marks an execution step, which allows the tensor framework to understand
// the computation boundaries.
// Override to use our own DeviceContextArena.
void MarkStep(const torch::lazy::BackendDevice& device) final;
void MarkStep(const torch::lazy::BackendDevice& device, bool reset_scope);

// Waits for all the outstanding operations on all the supplied devices.
// If devices is empty, the wait will happen for all local devices.
Expand Down

0 comments on commit 7dfbef0

Please sign in to comment.