-
Notifications
You must be signed in to change notification settings - Fork 467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tooling to explain why a graph execution happens #5723
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
fa407d8
Initial commit for debugging tool
JackCaoG 6fb077c
minor format tweak
JackCaoG b85dc01
Only master process should print the execution frame info
JackCaoG e2748da
add execution cause
JackCaoG 659f1a0
handle dynamo and everything else
JackCaoG 6dd0623
add test
JackCaoG 277bef3
linter
JackCaoG 7c96974
add test to the script
JackCaoG File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import os | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.debug.profiler as xp | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.distributed.parallel_loader as pl | ||
import unittest | ||
|
||
|
||
def check_env_flag(name, default=''): | ||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] | ||
|
||
|
||
def extract_execution_cause(lines): | ||
causes = [] | ||
for i in range(len(lines)): | ||
if 'Execution Cause' in lines[i].decode(): | ||
causes.append(lines[i + 1].decode()) | ||
return causes | ||
|
||
|
||
class PtXLADebugTest(unittest.TestCase): | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
if not check_env_flag('PT_XLA_DEBUG'): | ||
assert False, "This test should be run with PT_XLA_DEBUG" | ||
cls.debug_file_name = os.getenv('PT_XLA_DEBUG_FILE') | ||
if not cls.debug_file_name: | ||
assert False, "This test should be run with PT_XLA_DEBUG_FILE" | ||
open(cls.debug_file_name, 'w').close() | ||
|
||
def test_user_mark_step(self): | ||
device = xm.xla_device() | ||
t1 = torch.randn(2, 2, device=device) | ||
xm.mark_step() | ||
with open(self.debug_file_name, 'rb') as f: | ||
lines = f.readlines() | ||
causes = extract_execution_cause(lines) | ||
self.assertEqual(len(causes), 1) | ||
self.assertIn('user mark_step', causes[0]) | ||
open(self.debug_file_name, 'w').close() | ||
|
||
def test_step_trace(self): | ||
device = xm.xla_device() | ||
with xp.StepTrace('train_pt_xla_debug'): | ||
t1 = torch.randn(2, 2, device=device) | ||
with open(self.debug_file_name, 'rb') as f: | ||
lines = f.readlines() | ||
causes = extract_execution_cause(lines) | ||
self.assertEqual(len(causes), 1) | ||
self.assertIn('mark_step when exiting a profiler StepTrace region', | ||
causes[0]) | ||
open(self.debug_file_name, 'w').close() | ||
|
||
def test_dynamo(self): | ||
device = xm.xla_device() | ||
t1 = torch.randn(2, 2, device=device) | ||
|
||
def toy_program(t1): | ||
return t1 + t1 | ||
|
||
compiled = torch.compile(toy_program, backend="openxla") | ||
res = compiled(t1) | ||
with open(self.debug_file_name, 'rb') as f: | ||
lines = f.readlines() | ||
causes = extract_execution_cause(lines) | ||
self.assertEqual(len(causes), 3) | ||
self.assertIn('mark_step when dynamo processing input graphs', causes[0]) | ||
self.assertIn('mark_step when dynamo processing input graphs', causes[1]) | ||
self.assertIn('dynamo compiles FX graph to HLO', causes[2]) | ||
open(self.debug_file_name, 'w').close() | ||
|
||
def test_parallel_loader(self): | ||
device = xm.xla_device() | ||
|
||
train_dataset_len = 100 | ||
batch_size = 10 | ||
train_loader = xu.SampleGenerator( | ||
data=(torch.zeros(batch_size, 3, 128, | ||
128), torch.zeros(batch_size, dtype=torch.int64)), | ||
sample_count=train_dataset_len // 10) | ||
|
||
train_device_loader = pl.MpDeviceLoader( | ||
train_loader, | ||
device, | ||
loader_prefetch_size=8, | ||
device_prefetch_size=4, | ||
host_to_device_transfer_threads=1) | ||
|
||
for step, (data, target) in enumerate(train_device_loader): | ||
pass | ||
|
||
with open(self.debug_file_name, 'rb') as f: | ||
lines = f.readlines() | ||
causes = extract_execution_cause(lines) | ||
self.assertEqual(len(causes), batch_size + 2) | ||
for cause in causes: | ||
self.assertIn('mark_step in parallel loader at step end', cause) | ||
open(self.debug_file_name, 'w').close() | ||
|
||
def test_print(self): | ||
device = xm.xla_device() | ||
t1 = torch.randn(2, 2, device=device) | ||
print(t1) | ||
with open(self.debug_file_name, 'rb') as f: | ||
lines = f.readlines() | ||
causes = extract_execution_cause(lines) | ||
self.assertEqual(len(causes), 1) | ||
self.assertIn('user code trying to access tensor value', causes[0]) | ||
open(self.debug_file_name, 'w').close() | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if it worths logging
frames[0].function;
for the cases other thanmark_step
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea I want to expand this a bit later to cover most of the common cases of the execution.