Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include Dynamo executation in the executation cause analysis #5758

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/test_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def toy_program(t1):
with open(self.debug_file_name, 'rb') as f:
lines = f.readlines()
causes = extract_execution_cause(lines)
self.assertEqual(len(causes), 3)
self.assertEqual(len(causes), 4)
self.assertIn('mark_step when dynamo processing input graphs', causes[0])
self.assertIn('mark_step when dynamo processing input graphs', causes[1])
self.assertIn('dynamo compiles FX graph to HLO', causes[2])
self.assertIn('dynamo is compiling a FX graph to HLO', causes[2])
self.assertIn('dynamo is executing a compiled program', causes[3])
open(self.debug_file_name, 'w').close()

def test_parallel_loader(self):
Expand Down
13 changes: 10 additions & 3 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ static bool endsWith(const std::string& str, const std::string& suffix) {
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

void DebugUtil::analyze_graph_execution_python_frame() {
void DebugUtil::analyze_graph_execution_python_frame(
bool from_dynamo_executation) {
static bool is_master_process =
(runtime::sys_util::GetEnvInt("PJRT_LOCAL_PROCESS_RANK", 0) == 0);
static std::string debug_file_name =
Expand All @@ -237,7 +238,13 @@ void DebugUtil::analyze_graph_execution_python_frame() {
"=========="
<< "\n";
ss << debug_output_prefix << "Execution Cause\n";
if (frames[0].function == "mark_step") {
if (from_dynamo_executation) {
// when executation is from dynamo compiled graph, the python stack will not
// show any dynamo related python file since frame is already replaced. We
// can either analyze the C++ call stack or rely on caller to pass a boolean
// variable.
ss << debug_output_prefix << " dynamo is executing a compiled program\n";
} else if (frames[0].function == "mark_step") {
if (frames[1].function == "next" &&
endsWith(frames[1].file, "parallel_loader.py")) {
ss << debug_output_prefix
Expand All @@ -256,7 +263,7 @@ void DebugUtil::analyze_graph_execution_python_frame() {
}
} else if (frames[0].function == "extract_graph_helper" &&
endsWith(frames[0].file, "dynamo_bridge.py")) {
ss << debug_output_prefix << " dynamo compiles FX graph to HLO\n";
ss << debug_output_prefix << " dynamo is compiling a FX graph to HLO\n";
} else {
// TODO(JackCaoG): be more specific about exeuction caused by printing
// tensor or fallback or some weird indexing.
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/debug_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class DebugUtil {

// warning, this function should only be called when a graph execution is
// about to happen.
static void analyze_graph_execution_python_frame();
static void analyze_graph_execution_python_frame(
bool from_dynamo_executation = false);
};

} // namespace torch_xla
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,10 @@ XLAGraphExecutor::ExecuteComputationWithBarrier(
tsl::profiler::TraceMe activity("ExecuteComputationWithBarrier",
tsl::profiler::TraceMeLevel::kInfo);
MaybeDumpGraph("dynamo", hash);
if (runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false)) {
DebugUtil::analyze_graph_execution_python_frame(
/*from_dynamo_executation=*/true);
}
auto cachedComputation =
XLAGraphExecutor::Get()->GetComputationCache()->Get(hash);
TF_VLOG(5) << "Cached computation (hash: " << torch::lazy::HashToString(hash)
Expand Down
Loading