Skip to content

Commit

Permalink
Add tooling to explain why a graph execution happens (#5723)
Browse files Browse the repository at this point in the history
* Initial commit for debugging tool

* minor format tweak

* Only master process should print the execution frame info

* add execution cause

* handle dynamo and everything else

* add test

* linter

* add test to the script
  • Loading branch information
JackCaoG authored and bhavya01 committed Apr 22, 2024
1 parent 54308d1 commit 4882bbf
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ function run_save_tensor_hlo {
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
}

function run_pt_xla_debug {
echo "Running in save tensor file mode: $@"
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_stablehlo_compile {
echo "Running in StableHlo Compile mode: $@"
XLA_STABLEHLO_COMPILE=1 run_test "$@"
Expand Down Expand Up @@ -156,6 +161,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_pt_xla_debug "$CDIR/test_pt_xla_debug.py"
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/pjrt/test_runtime.py"
Expand Down
119 changes: 119 additions & 0 deletions test/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
import unittest


def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def extract_execution_cause(lines):
causes = []
for i in range(len(lines)):
if 'Execution Cause' in lines[i].decode():
causes.append(lines[i + 1].decode())
return causes


class PtXLADebugTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
if not check_env_flag('PT_XLA_DEBUG'):
assert False, "This test should be run with PT_XLA_DEBUG"
cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
if not cls.debug_file_name:
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
open(cls.debug_file_name, 'w').close()

def test_user_mark_step(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user mark_step', causes[0])
open(self.debug_file_name, 'w').close()

def test_step_trace(self):
device = xm.xla_device()
with xp.StepTrace('train_pt_xla_debug'):
t1 = torch.randn(2, 2, device=device)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
causes[0])
open(self.debug_file_name, 'w').close()

def test_dynamo(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)

def toy_program(t1):
return t1 + t1

compiled = torch.compile(toy_program, backend="openxla")
res = compiled(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 3)
self.assertIn('mark_step when dynamo processing input graphs', causes[0])
self.assertIn('mark_step when dynamo processing input graphs', causes[1])
self.assertIn('dynamo compiles FX graph to HLO', causes[2])
open(self.debug_file_name, 'w').close()

def test_parallel_loader(self):
device = xm.xla_device()

train_dataset_len = 100
batch_size = 10
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 3, 128,
128), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // 10)

train_device_loader = pl.MpDeviceLoader(
train_loader,
device,
loader_prefetch_size=8,
device_prefetch_size=4,
host_to_device_transfer_threads=1)

for step, (data, target) in enumerate(train_device_loader):
pass

with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), batch_size + 2)
for cause in causes:
self.assertIn('mark_step in parallel loader at step end', cause)
open(self.debug_file_name, 'w').close()

def test_print(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
print(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user code trying to access tensor value', causes[0])
open(self.debug_file_name, 'w').close()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
83 changes: 83 additions & 0 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/lazy/python/python_util.h>

#include <fstream>
#include <iostream>
#include <mutex>
#include <sstream>
#include <unordered_set>
Expand Down Expand Up @@ -209,4 +210,86 @@ bool DebugUtil::ExperimentEnabled(const std::string& name) {
return xset->find(name) != xset->end();
}

// helper function until we move to C++ 20
static bool endsWith(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

void DebugUtil::analyze_graph_execution_python_frame() {
static bool is_master_process =
(runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0);
static std::string debug_file_name =
runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", "");
static std::string debug_output_prefix = "Execution Analysis: ";
// TODO: Make this configurable.
if (!is_master_process) {
return;
}
std::vector<torch::lazy::SourceLocation> frames =
torch::lazy::GetPythonFrames();
// python frame must be > 1
XLA_CHECK_GE(frames.size(), 1);
std::stringstream ss;
ss << "\n"
<< debug_output_prefix
<< "======================================================================"
"=========="
<< "\n";
ss << debug_output_prefix << "Execution Cause\n";
if (frames[0].function == "mark_step") {
if (frames[1].function == "next" &&
endsWith(frames[1].file, "parallel_loader.py")) {
ss << debug_output_prefix
<< " mark_step in parallel loader at step end\n";
} else if (frames[1].function == "__exit__" &&
endsWith(frames[1].file, "profiler.py")) {
ss << debug_output_prefix
<< " mark_step when exiting a profiler StepTrace region\n";
} else if ((frames[1].function == "extract_compiled_graph" ||
frames[1].function == "extract_internal") &&
endsWith(frames[1].file, "dynamo_bridge.py")) {
ss << debug_output_prefix
<< " mark_step when dynamo processing input graphs\n";
} else {
ss << debug_output_prefix << " user mark_step\n";
}
} else if (frames[0].function == "extract_graph_helper" &&
endsWith(frames[0].file, "dynamo_bridge.py")) {
ss << debug_output_prefix << " dynamo compiles FX graph to HLO\n";
} else {
// TODO(JackCaoG): be more specific about exeuction caused by printing
// tensor or fallback or some weird indexing.
ss << debug_output_prefix
<< " most likely user code trying to access tensor value before "
"mark_step\n";
}

// TODO(JackCaoG): make number of frames printed configurable
ss << debug_output_prefix << "Python Frame Triggered Execution: \n";
for (auto& location : frames) {
ss << debug_output_prefix << " " << location.function << " ("
<< location.file << ":" << location.line << ")\n";
}
ss << debug_output_prefix
<< "----------------------------------------------------------------------"
"----------"
<< "\n";
ss << debug_output_prefix
<< "======================================================================"
"=========="
<< "\n";

// TODO(JackCaoG): print more information about the graph that is about to get
// executed.
if (debug_file_name == "") {
// print to stderr by default
std::cerr << ss.str();
} else {
std::ofstream outFile;
outFile.open(debug_file_name, std::ios_base::app);
outFile << ss.rdbuf();
}
}

} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/debug_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class DebugUtil {
absl::Span<const size_t> indices);

static bool ExperimentEnabled(const std::string& name);

// warning, this function should only be called when a graph execution is
// about to happen.
static void analyze_graph_execution_python_frame();
};

} // namespace torch_xla
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,9 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
const SyncTensorsConfig& config, bool warm_up_cache_only) {
tsl::profiler::TraceMe activity("SyncTensorsGraphInternal",
tsl::profiler::TraceMeLevel::kInfo);
if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) {
DebugUtil::analyze_graph_execution_python_frame();
}
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
if (coll.indices.empty()) {
// Enure previous execution is complete before exiting this
Expand Down

0 comments on commit 4882bbf

Please sign in to comment.