diff --git a/test/test_pt_xla_debug.py b/test/test_pt_xla_debug.py index 72bea4e9ffb..14f0817a4c0 100644 --- a/test/test_pt_xla_debug.py +++ b/test/test_pt_xla_debug.py @@ -68,10 +68,11 @@ def toy_program(t1): with open(self.debug_file_name, 'rb') as f: lines = f.readlines() causes = extract_execution_cause(lines) - self.assertEqual(len(causes), 3) + self.assertEqual(len(causes), 4) self.assertIn('mark_step when dynamo processing input graphs', causes[0]) self.assertIn('mark_step when dynamo processing input graphs', causes[1]) - self.assertIn('dynamo compiles FX graph to HLO', causes[2]) + self.assertIn('dynamo is compiling a FX graph to HLO', causes[2]) + self.assertIn('dynamo is executing a compiled program', causes[3]) open(self.debug_file_name, 'w').close() def test_parallel_loader(self): diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index ece9ac15ff1..7723d6d95d9 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -216,7 +216,8 @@ static bool endsWith(const std::string& str, const std::string& suffix) { 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } -void DebugUtil::analyze_graph_execution_python_frame() { +void DebugUtil::analyze_graph_execution_python_frame( + bool from_dynamo_executation) { static bool is_master_process = (runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0); static std::string debug_file_name = @@ -237,7 +238,13 @@ void DebugUtil::analyze_graph_execution_python_frame() { "==========" << "\n"; ss << debug_output_prefix << "Execution Cause\n"; - if (frames[0].function == "mark_step") { + if (from_dynamo_executation) { + // when executation is from dynamo compiled graph, the python stack will not + // show any dynamo related python file since frame is already replaced. We + // can either analyze the C++ call stack or rely on caller to pass a boolean + // variable. + ss << debug_output_prefix << " dynamo is executing a compiled program\n"; + } else if (frames[0].function == "mark_step") { if (frames[1].function == "next" && endsWith(frames[1].file, "parallel_loader.py")) { ss << debug_output_prefix @@ -256,7 +263,7 @@ void DebugUtil::analyze_graph_execution_python_frame() { } } else if (frames[0].function == "extract_graph_helper" && endsWith(frames[0].file, "dynamo_bridge.py")) { - ss << debug_output_prefix << " dynamo compiles FX graph to HLO\n"; + ss << debug_output_prefix << " dynamo is compiling a FX graph to HLO\n"; } else { // TODO(JackCaoG): be more specific about exeuction caused by printing // tensor or fallback or some weird indexing. diff --git a/torch_xla/csrc/debug_util.h b/torch_xla/csrc/debug_util.h index ce063195d93..530a45fc83a 100644 --- a/torch_xla/csrc/debug_util.h +++ b/torch_xla/csrc/debug_util.h @@ -49,7 +49,8 @@ class DebugUtil { // warning, this function should only be called when a graph execution is // about to happen. - static void analyze_graph_execution_python_frame(); + static void analyze_graph_execution_python_frame( + bool from_dynamo_executation = false); }; } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 5e51aec1d3e..39d866358ac 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -629,6 +629,10 @@ XLAGraphExecutor::ExecuteComputationWithBarrier( tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier", tsl::profiler::TraceMeLevel::kInfo); MaybeDumpGraph("dynamo", hash); + if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) { + DebugUtil::analyze_graph_execution_python_frame( + /*from_dynamo_executation=*/true); + } auto cachedComputation = XLAGraphExecutor::Get()->GetComputationCache()->Get(hash); TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash)