Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Oct 28, 2023
1 parent 659f1a0 commit 6dd0623
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 6 deletions.
5 changes: 5 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ function run_save_tensor_hlo {
XLA_SAVE_TENSORS_FILE="/tmp/xla_test_save_ir.txt" XLA_SAVE_TENSORS_FMT="hlo" run_test "$@"
}

function run_pt_xla_debug {
echo "Running in save tensor file mode: $@"
PT_XLA_DEBUG=1 PT_XLA_DEBUG_FILE="/tmp/pt_xla_debug.txt" run_test "$@"
}

function run_stablehlo_compile {
echo "Running in StableHlo Compile mode: $@"
XLA_STABLEHLO_COMPILE=1 run_test "$@"
Expand Down
119 changes: 119 additions & 0 deletions test/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.utils.utils as xu
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
import torch_xla.distributed.parallel_loader as pl
import unittest


def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']


def extract_execution_cause(lines):
causes = []
for i in range(len(lines)):
if 'Execution Cause' in lines[i].decode():
causes.append(lines[i + 1].decode())
return causes


class PtXLADebugTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
if not check_env_flag('PT_XLA_DEBUG'):
assert False, "This test should be run with PT_XLA_DEBUG"
cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE')
if not cls.debug_file_name:
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
open(cls.debug_file_name, 'w').close()

def test_user_mark_step(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
xm.mark_step()
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user mark_step', causes[0])
open(self.debug_file_name, 'w').close()

def test_step_trace(self):
device = xm.xla_device()
with xp.StepTrace('train_pt_xla_debug'):
t1 = torch.randn(2, 2, device=device)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('mark_step when exiting a profiler StepTrace region',
causes[0])
open(self.debug_file_name, 'w').close()

def test_dynamo(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)

def toy_program(t1):
return t1 + t1

compiled = torch.compile(toy_program, backend="openxla")
res = compiled(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 3)
self.assertIn('mark_step when dynamo processing input graphs', causes[0])
self.assertIn('mark_step when dynamo processing input graphs', causes[1])
self.assertIn('dynamo compiles FX graph to HLO', causes[2])
open(self.debug_file_name, 'w').close()

def test_parallel_loader(self):
device = xm.xla_device()

train_dataset_len = 100
batch_size = 10
train_loader = xu.SampleGenerator(
data=(torch.zeros(batch_size, 3, 128,
128), torch.zeros(batch_size, dtype=torch.int64)),
sample_count=train_dataset_len // 10)

train_device_loader = pl.MpDeviceLoader(
train_loader,
device,
loader_prefetch_size=8,
device_prefetch_size=4,
host_to_device_transfer_threads=1)

for step, (data, target) in enumerate(train_device_loader):
pass

with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), batch_size + 2)
for cause in causes:
self.assertIn('mark_step in parallel loader at step end', cause)
open(self.debug_file_name, 'w').close()

def test_print(self):
device = xm.xla_device()
t1 = torch.randn(2, 2, device=device)
print(t1)
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 1)
self.assertIn('user code trying to access tensor value', causes[0])
open(self.debug_file_name, 'w').close()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
22 changes: 16 additions & 6 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ static bool endsWith(const std::string& str, const std::string& suffix) {
void DebugUtil::analyze_graph_execution_python_frame() {
static bool is_master_process =
(runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0);
static std::string debug_file_name =
runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", "");
static std::string debug_output_prefix = "Execution Analysis: ";
// TODO: Make this configurable.
if (!is_master_process) {
Expand All @@ -231,16 +233,15 @@ void DebugUtil::analyze_graph_execution_python_frame() {
// python frame must be > 1
XLA_CHECK_GE(frames.size(), 1);
std::stringstream ss;
ss << debug_output_prefix
ss << "\n"
<< debug_output_prefix
<< "======================================================================"
"=========="
<< "\n";
ss << debug_output_prefix << "Execution Cause\n";
if (frames[0].function == "mark_step") {
if (frames.size() == 1) {
ss << debug_output_prefix << " user mark_step\n";
} else if (frames[1].function == "next" &&
endsWith(frames[1].file, "parallel_loader.py")) {
if (frames[1].function == "next" &&
endsWith(frames[1].file, "parallel_loader.py")) {
ss << debug_output_prefix
<< " mark_step in parallel loader at step end\n";
} else if (frames[1].function == "__exit__" &&
Expand All @@ -252,6 +253,8 @@ void DebugUtil::analyze_graph_execution_python_frame() {
endsWith(frames[1].file, "dynamo_bridge.py")) {
ss << debug_output_prefix
<< " mark_step when dynamo processing input graphs\n";
} else {
ss << debug_output_prefix << " user mark_step\n";
}
} else if (frames[0].function == "extract_graph_helper" &&
endsWith(frames[0].file, "dynamo_bridge.py")) {
Expand Down Expand Up @@ -281,7 +284,14 @@ void DebugUtil::analyze_graph_execution_python_frame() {

// TODO(JackCaoG): print more information about the graph that is about to get
// executed.
cerr << ss.str();
if (debug_file_name == "") {
// print to stderr by default
cerr << ss.str();
} else {
std::ofstream outFile;
outFile.open(debug_file_name, std::ios_base::app);
outFile << ss.rdbuf();
}
}

} // namespace torch_xla

0 comments on commit 6dd0623

Please sign in to comment.