diff --git a/test/run_tests.sh b/test/run_tests.sh index f60c7e0f72c1..fe2346d358ef 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -108,6 +108,11 @@ function run_save_tensor_hlo { XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@" } +function run_pt_xla_debug { + echo "Running in save tensor file mode: $@" + PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@" +} + function run_stablehlo_compile { echo "Running in StableHlo Compile mode: $@" XLA_STABLEHLO_COMPILE=1 run_test "$@" @@ -156,6 +161,7 @@ function run_xla_op_tests1 { run_test "$CDIR/test_grad_checkpoint.py" run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_pt_xla_debug "$CDIR/test_pt_xla_debug.py" run_test "$CDIR/test_async_closures.py" run_test "$CDIR/test_profiler.py" run_test "$CDIR/pjrt/test_runtime.py" diff --git a/test/test_pt_xla_debug.py b/test/test_pt_xla_debug.py new file mode 100644 index 000000000000..72bea4e9ffbb --- /dev/null +++ b/test/test_pt_xla_debug.py @@ -0,0 +1,119 @@ +import os + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.utils.utils as xu +import torch_xla.debug.profiler as xp +import torch_xla.utils.utils as xu +import torch_xla.distributed.parallel_loader as pl +import unittest + + +def check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + + +def extract_execution_cause(lines): + causes = [] + for i in range(len(lines)): + if 'Execution Cause' in lines[i].decode(): + causes.append(lines[i + 1].decode()) + return causes + + +class PtXLADebugTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if not check_env_flag('PT_XLA_DEBUG'): + assert False, "This test should be run with PT_XLA_DEBUG" + cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE') + if not cls.debug_file_name: + assert False, "This test should be run with PT_XLA_DEBUG_FILE" + open(cls.debug_file_name, 'w').close() + + def test_user_mark_step(self): + device = xm.xla_device() + t1 = torch.randn(2, 2, device=device) + xm.mark_step() + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 1) + self.assertIn('user mark_step', causes[0]) + open(self.debug_file_name, 'w').close() + + def test_step_trace(self): + device = xm.xla_device() + with xp.StepTrace('train_pt_xla_debug'): + t1 = torch.randn(2, 2, device=device) + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 1) + self.assertIn('mark_step when exiting a profiler StepTrace region', + causes[0]) + open(self.debug_file_name, 'w').close() + + def test_dynamo(self): + device = xm.xla_device() + t1 = torch.randn(2, 2, device=device) + + def toy_program(t1): + return t1 + t1 + + compiled = torch.compile(toy_program, backend="openxla") + res = compiled(t1) + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 3) + self.assertIn('mark_step when dynamo processing input graphs', causes[0]) + self.assertIn('mark_step when dynamo processing input graphs', causes[1]) + self.assertIn('dynamo compiles FX graph to HLO', causes[2]) + open(self.debug_file_name, 'w').close() + + def test_parallel_loader(self): + device = xm.xla_device() + + train_dataset_len = 100 + batch_size = 10 + train_loader = xu.SampleGenerator( + data=(torch.zeros(batch_size, 3, 128, + 128), torch.zeros(batch_size, dtype=torch.int64)), + sample_count=train_dataset_len // 10) + + train_device_loader = pl.MpDeviceLoader( + train_loader, + device, + loader_prefetch_size=8, + device_prefetch_size=4, + host_to_device_transfer_threads=1) + + for step, (data, target) in enumerate(train_device_loader): + pass + + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), batch_size + 2) + for cause in causes: + self.assertIn('mark_step in parallel loader at step end', cause) + open(self.debug_file_name, 'w').close() + + def test_print(self): + device = xm.xla_device() + t1 = torch.randn(2, 2, device=device) + print(t1) + with open(self.debug_file_name, 'rb') as f: + lines = f.readlines() + causes = extract_execution_cause(lines) + self.assertEqual(len(causes), 1) + self.assertIn('user code trying to access tensor value', causes[0]) + open(self.debug_file_name, 'w').close() + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 6b6c301b2a1b..ece9ac15ff1c 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -209,4 +210,86 @@ bool DebugUtil::ExperimentEnabled(const std::string& name) { return xset->find(name) != xset->end(); } +// helper function until we move to C++ 20 +static bool endsWith(const std::string& str, const std::string& suffix) { + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +void DebugUtil::analyze_graph_execution_python_frame() { + static bool is_master_process = + (runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0); + static std::string debug_file_name = + runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", ""); + static std::string debug_output_prefix = "Execution Analysis: "; + // TODO: Make this configurable. + if (!is_master_process) { + return; + } + std::vector frames = + torch::lazy::GetPythonFrames(); + // python frame must be > 1 + XLA_CHECK_GE(frames.size(), 1); + std::stringstream ss; + ss << "\n" + << debug_output_prefix + << "======================================================================" + "==========" + << "\n"; + ss << debug_output_prefix << "Execution Cause\n"; + if (frames[0].function == "mark_step") { + if (frames[1].function == "next" && + endsWith(frames[1].file, "parallel_loader.py")) { + ss << debug_output_prefix + << " mark_step in parallel loader at step end\n"; + } else if (frames[1].function == "__exit__" && + endsWith(frames[1].file, "profiler.py")) { + ss << debug_output_prefix + << " mark_step when exiting a profiler StepTrace region\n"; + } else if ((frames[1].function == "extract_compiled_graph" || + frames[1].function == "extract_internal") && + endsWith(frames[1].file, "dynamo_bridge.py")) { + ss << debug_output_prefix + << " mark_step when dynamo processing input graphs\n"; + } else { + ss << debug_output_prefix << " user mark_step\n"; + } + } else if (frames[0].function == "extract_graph_helper" && + endsWith(frames[0].file, "dynamo_bridge.py")) { + ss << debug_output_prefix << " dynamo compiles FX graph to HLO\n"; + } else { + // TODO(JackCaoG): be more specific about exeuction caused by printing + // tensor or fallback or some weird indexing. + ss << debug_output_prefix + << " most likely user code trying to access tensor value before " + "mark_step\n"; + } + + // TODO(JackCaoG): make number of frames printed configurable + ss << debug_output_prefix << "Python Frame Triggered Execution: \n"; + for (auto& location : frames) { + ss << debug_output_prefix << " " << location.function << " (" + << location.file << ":" << location.line << ")\n"; + } + ss << debug_output_prefix + << "----------------------------------------------------------------------" + "----------" + << "\n"; + ss << debug_output_prefix + << "======================================================================" + "==========" + << "\n"; + + // TODO(JackCaoG): print more information about the graph that is about to get + // executed. + if (debug_file_name == "") { + // print to stderr by default + std::cerr << ss.str(); + } else { + std::ofstream outFile; + outFile.open(debug_file_name, std::ios_base::app); + outFile << ss.rdbuf(); + } +} + } // namespace torch_xla diff --git a/torch_xla/csrc/debug_util.h b/torch_xla/csrc/debug_util.h index 2a687207b280..ce063195d938 100644 --- a/torch_xla/csrc/debug_util.h +++ b/torch_xla/csrc/debug_util.h @@ -46,6 +46,10 @@ class DebugUtil { absl::Span indices); static bool ExperimentEnabled(const std::string& name); + + // warning, this function should only be called when a graph execution is + // about to happen. + static void analyze_graph_execution_python_frame(); }; } // namespace torch_xla diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 45ed57b1b04c..5e51aec1d3ef 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -1322,6 +1322,9 @@ XLAGraphExecutor::SyncTensorsGraphInternal( const SyncTensorsConfig& config, bool warm_up_cache_only) { tsl::profiler::TraceMe activity("SyncTensorsGraphInternal", tsl::profiler::TraceMeLevel::kInfo); + if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) { + DebugUtil::analyze_graph_execution_python_frame(); + } SyncTensorCollection coll = CollectSyncTensors(*tensors, config); if (coll.indices.empty()) { // Enure previous execution is complete before exiting this