Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tooling to explain why a graph execution happens #5723

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ function run_save_tensor_hlo {
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
}

function run_pt_xla_debug {
echo "Running in save tensor file mode: $@"
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_stablehlo_compile {
echo "Running in StableHlo Compile mode: $@"
XLA_STABLEHLO_COMPILE=1 run_test "$@"
Expand Down Expand Up @@ -156,6 +161,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_grad_checkpoint.py"
run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_pt_xla_debug "$CDIR/test_pt_xla_debug.py"
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/pjrt/test_runtime.py"
Expand Down
119 changes: 119 additions & 0 deletions test/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
import unittest


def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def extract_execution_cause(lines):
causes = []
for i in range(len(lines)):
if 'Execution Cause' in lines[i].decode():
causes.append(lines[i + 1].decode())
return causes


class PtXLADebugTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
if not check_env_flag('PT_XLA_DEBUG'):
assert False, "This test should be run with PT_XLA_DEBUG"
cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
if not cls.debug_file_name:
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
open(cls.debug_file_name, 'w').close()

def test_user_mark_step(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user mark_step', causes[0])
open(self.debug_file_name, 'w').close()

def test_step_trace(self):
device = xm.xla_device()
with xp.StepTrace('train_pt_xla_debug'):
t1 = torch.randn(2, 2, device=device)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
causes[0])
open(self.debug_file_name, 'w').close()

def test_dynamo(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)

def toy_program(t1):
return t1 + t1

compiled = torch.compile(toy_program, backend="openxla")
res = compiled(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 3)
self.assertIn('mark_step when dynamo processing input graphs', causes[0])
self.assertIn('mark_step when dynamo processing input graphs', causes[1])
self.assertIn('dynamo compiles FX graph to HLO', causes[2])
open(self.debug_file_name, 'w').close()

def test_parallel_loader(self):
device = xm.xla_device()

train_dataset_len = 100
batch_size = 10
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 3, 128,
128), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // 10)

train_device_loader = pl.MpDeviceLoader(
train_loader,
device,
loader_prefetch_size=8,
device_prefetch_size=4,
host_to_device_transfer_threads=1)

for step, (data, target) in enumerate(train_device_loader):
pass

with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), batch_size + 2)
for cause in causes:
self.assertIn('mark_step in parallel loader at step end', cause)
open(self.debug_file_name, 'w').close()

def test_print(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
print(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user code trying to access tensor value', causes[0])
open(self.debug_file_name, 'w').close()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
83 changes: 83 additions & 0 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <torch/csrc/lazy/python/python_util.h>

#include <fstream>
#include <iostream>
#include <mutex>
#include <sstream>
#include <unordered_set>
Expand Down Expand Up @@ -209,4 +210,86 @@ bool DebugUtil::ExperimentEnabled(const std::string& name) {
return xset->find(name) != xset->end();
}

// helper function until we move to C++ 20
static bool endsWith(const std::string& str, const std::string& suffix) {
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

void DebugUtil::analyze_graph_execution_python_frame() {
static bool is_master_process =
(runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0);
static std::string debug_file_name =
runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", "");
static std::string debug_output_prefix = "Execution Analysis: ";
// TODO: Make this configurable.
if (!is_master_process) {
return;
}
std::vector<torch::lazy::SourceLocation> frames =
torch::lazy::GetPythonFrames();
// python frame must be > 1
XLA_CHECK_GE(frames.size(), 1);
std::stringstream ss;
ss << "\n"
<< debug_output_prefix
<< "======================================================================"
"=========="
<< "\n";
ss << debug_output_prefix << "Execution Cause\n";
if (frames[0].function == "mark_step") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if it worths logging frames[0].function; for the cases other than mark_step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I want to expand this a bit later to cover most of the common cases of the execution.

if (frames[1].function == "next" &&
endsWith(frames[1].file, "parallel_loader.py")) {
ss << debug_output_prefix
<< " mark_step in parallel loader at step end\n";
} else if (frames[1].function == "__exit__" &&
endsWith(frames[1].file, "profiler.py")) {
ss << debug_output_prefix
<< " mark_step when exiting a profiler StepTrace region\n";
} else if ((frames[1].function == "extract_compiled_graph" ||
frames[1].function == "extract_internal") &&
endsWith(frames[1].file, "dynamo_bridge.py")) {
ss << debug_output_prefix
<< " mark_step when dynamo processing input graphs\n";
} else {
ss << debug_output_prefix << " user mark_step\n";
}
} else if (frames[0].function == "extract_graph_helper" &&
endsWith(frames[0].file, "dynamo_bridge.py")) {
ss << debug_output_prefix << " dynamo compiles FX graph to HLO\n";
} else {
// TODO(JackCaoG): be more specific about exeuction caused by printing
// tensor or fallback or some weird indexing.
ss << debug_output_prefix
<< " most likely user code trying to access tensor value before "
"mark_step\n";
}

// TODO(JackCaoG): make number of frames printed configurable
ss << debug_output_prefix << "Python Frame Triggered Execution: \n";
for (auto& location : frames) {
ss << debug_output_prefix << " " << location.function << " ("
<< location.file << ":" << location.line << ")\n";
}
ss << debug_output_prefix
<< "----------------------------------------------------------------------"
"----------"
<< "\n";
ss << debug_output_prefix
<< "======================================================================"
"=========="
<< "\n";

// TODO(JackCaoG): print more information about the graph that is about to get
// executed.
if (debug_file_name == "") {
// print to stderr by default
std::cerr << ss.str();
} else {
std::ofstream outFile;
outFile.open(debug_file_name, std::ios_base::app);
outFile << ss.rdbuf();
}
}

} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/debug_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class DebugUtil {
absl::Span<const size_t> indices);

static bool ExperimentEnabled(const std::string& name);

// warning, this function should only be called when a graph execution is
// about to happen.
static void analyze_graph_execution_python_frame();
};

} // namespace torch_xla
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,9 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
const SyncTensorsConfig& config, bool warm_up_cache_only) {
tsl::profiler::TraceMe activity("SyncTensorsGraphInternal",
tsl::profiler::TraceMeLevel::kInfo);
if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) {
DebugUtil::analyze_graph_execution_python_frame();
}
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
if (coll.indices.empty()) {
// Enure previous execution is complete before exiting this
Expand Down