Skip to content

Commit

Permalink
Extend HLO metadata to include class hierarchy information (#5715)
Browse files Browse the repository at this point in the history
* Add python binding to allow custom op_name metadata for lowered HLO
  • Loading branch information
mrnikwaws authored and bhavya01 committed Apr 22, 2024
1 parent ba8b902 commit 1318547
Show file tree
Hide file tree
Showing 10 changed files with 363 additions and 10 deletions.
4 changes: 2 additions & 2 deletions test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ fi
for name in "${test_names[@]}"; do
echo "Running $name cpp test..."
if [ "$LOGFILE" != "" ]; then
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1000 ${FILTER:+"$FILTER"} 2> $LOGFILE
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"} 2> $LOGFILE
else
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1000 ${FILTER:+"$FILTER"}
bazel $BAZEL_VERB $EXTRA_FLAGS //torch_xla/csrc/runtime:all //test/cpp:${name} --test_timeout 1200 ${FILTER:+"$FILTER"}
fi
done

213 changes: 213 additions & 0 deletions test/custom_debug_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import torch
import torch_xla

import inspect
from collections import defaultdict

from torch.utils._python_dispatch import TorchDispatchMode

class_count = defaultdict(int)
instance_count = dict()


def GetInstancePlaceHolder(class_type, obj):
global class_count
global instance_count

if (class_type, id(obj)) not in instance_count:
class_count[class_type] += 1
instance_count[(class_type, id(obj))] = class_count[class_type]

place_holder = instance_count[(class_type, id(obj))]

return f".{place_holder}"


def CheckIgnored(key):
ignored_list = ("self", "_bootstrap", "_fix_up_module",
"_get_supported_file_loaders", "_setup", "_buffers",
"_parameters", "_non_persistent_buffers_set")

return (key.startswith("__") and key.endswith("__")) or key in ignored_list


def Prefix(prefix, val):
if len(prefix) > 0:
return f"{prefix}.{val}"
else:
return f"{val}"


def ReverseSearchBreadthFirst(container, obj, debug=False):
if container is None:
return False

queue = []
visited = set()
nested_name = ""
max_depth = 5
queue.append((0, nested_name, container))

while len(queue):
depth, prefix, candidate = queue.pop(0)

if depth > max_depth or id(candidate) in visited:
continue

visited.add(id(candidate))

if isinstance(candidate, dict):
for k, v in candidate.items():
if not isinstance(k, str):
if debug:
print(f"Found non string key {k}")
break
if CheckIgnored(k):
continue
nested_name = Prefix(prefix, k)
if v is obj:
if debug:
print(f"Found {nested_name}")
return True, nested_name
elif debug:
print(f"Miss {nested_name}")
if id(v) not in visited and depth < max_depth:
queue.append((depth + 1, nested_name, v))
elif isinstance(candidate, (list, tuple)):
for i, v in enumerate(candidate):
nested_name = Prefix(prefix, i)
if v is obj:
if debug:
print(f"Found {nested_name}")
return True, nested_name
elif debug:
print(f"Miss {nested_name}")
if id(v) not in visited and depth < max_depth:
queue.append((depth + 1, nested_name, v))
elif hasattr(candidate, "__class__"):
# Ignore class wich overrides __getattr__ and
# generates error
if type(candidate).__name__ == "_ClassNamespace":
continue
for att in ("_modules", "__dict__"):
if hasattr(candidate, att):
v = getattr(candidate, att)
if id(v) not in visited and depth < max_depth:
queue.append((depth + 1, nested_name, v))
else:
print("No action")

return False, None


def FindMemberVariable(frame, obj):
parent_frame = frame.f_back
found = False
variable_name = None

for lframe in inspect.getouterframes(parent_frame):
if lframe.frame.f_code.co_nlocals <= 0:
continue
self_name = lframe.frame.f_code.co_varnames[0]
parent_obj = lframe.frame.f_locals[self_name]
found, variable_name = ReverseSearchBreadthFirst(parent_obj, obj)
if found:
break

return found, variable_name


def FindLocalVariable(frame, obj):
found = False
variable_name = None

for lframe in inspect.getouterframes(frame.f_back):
found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals, obj)
if found:
break

return found, variable_name


def GetClassNameAndObjFromFrame(frame):
class_obj_str = ""
if frame.f_code.co_argcount == 0:
return class_obj_str

likely_obj_name = frame.f_code.co_varnames[0]

obj = frame.f_locals[likely_obj_name]

if not hasattr(obj, "__class__") or likely_obj_name != "self":
return class_obj_str

name = type(obj).__name__
variable_name = None
found = False

found, variable_name = FindMemberVariable(frame, obj)

if not found:
found, variable_name = FindLocalVariable(frame, obj)

if not found:
variable_name = GetInstancePlaceHolder(name, obj)

name = name + "[" + variable_name + "]"

return name


def CleanNames(names):
last_name = ""
output = []
for name in names:
if name != last_name:
output.append(name)
last_name = name

# Drop the last scope which is the scope name add op_name lowerings
return output[:-1]


def GetAllObjectAndClassNames(frame):
names = []
while frame is not None:
name = GetClassNameAndObjFromFrame(frame)
if len(name) > 0:
names.append(name)
frame = frame.f_back

names.reverse()

names = CleanNames(names)

output = "/".join(names)

if len(output) > 0:
output += "/"

return output


class CustomOpNameLowering(TorchDispatchMode):

def __init__(self):
super().__init__()

def __enter__(self):
self._old_ir_debug = torch_xla._XLAC._get_ir_debug()
torch_xla._XLAC._set_ir_debug(True)
return super().__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
torch_xla._XLAC._set_ir_debug(self._old_ir_debug)
super().__exit__(exc_type, exc_val, exc_tb)

def __torch_dispatch__(self, func, types, args=(), kwargs={}):
res = func(*args, **kwargs)
if 'xla' in str(res.device):
frame = inspect.currentframe()
prefix = GetAllObjectAndClassNames(frame)
torch_xla._XLAC._set_xla_custom_op_name(res, prefix)
return res
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ function run_xla_op_tests1 {
run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_pt_xla_debug "$CDIR/test_pt_xla_debug.py"
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_hlo_metadata.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/pjrt/test_runtime.py"
run_test "$CDIR/pjrt/test_runtime_gpu.py"
Expand Down
90 changes: 90 additions & 0 deletions test/test_hlo_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import sys

# Normal imports section starts here.
import torch
import torch_xla
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import unittest
import json
from custom_debug_lowering import CustomOpNameLowering


class TestHloMetaData(unittest.TestCase):

def setUp(self):
torch.manual_seed(42)
self.pre_test_tensor_type = torch.get_default_dtype()
self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug()
torch.set_default_tensor_type(torch.FloatTensor)
torch_xla._XLAC._set_ir_debug(True)
super(TestHloMetaData, self).setUp()

def tearDown(self):
super(TestHloMetaData, self).tearDown()
torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug)

def test_metadata(self):
layer1 = torch.nn.Linear(4, 4)
nl1 = torch.nn.ReLU()
layer2 = torch.nn.Linear(4, 2)
nl2 = torch.nn.Tanh()
model = torch.nn.Sequential(layer1, nl1, layer2, nl2)

with CustomOpNameLowering():
model = model.to(device=xm.xla_device())
inp = torch.rand(4, 4, device=xm.xla_device())
out = model(inp)

ctx = torch_xla._XLAC.lowering.LoweringContext()
ctx.build([out])
hlo_text = ctx.hlo_json()

# Strings to match in the lowering
bingo = {
"torch/_ops.py": False,
#"torch/nn/modules/linear.py": False,
#"torch/nn/modules/activation.py": False,
#"torch/nn/functional.py": False,
"Sequential[model]/Linear[0]": False,
"Sequential[model]/ReLU[1]": False,
"Sequential[model]/Linear[2]": False,
"Sequential[model]/Tanh[3]": False,
"aten__addmm": False,
"aten__relu": False,
"aten__tanh": False,
"aten__permute": False
}

non_zero_metadata = False

local_json = json.loads(hlo_text)
assert "computations" in local_json
for c in local_json["computations"]:
if "instructions" in c:
i = c["instructions"]
for op in i:
if 'metadata' in op:
meta = op["metadata"]
print(meta)
if len(meta) > 0:
non_zero_metadata = True
for km, vm in meta.items():
for k in bingo.keys():
if isinstance(vm, str) and k in vm:
bingo[k] = True

assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?"

for k, v in bingo.items():
assert v, f"Keyword {k} was not found as expected in HLO metadata for simple test"

print("All required metadata symbols matched")


if __name__ == '__main__':
test = unittest.main(exit=False)
if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
print(met.metrics_report())
sys.exit(0 if test.result.wasSuccessful() else 1)
18 changes: 18 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,13 @@ class PyLoweringContext {
return result;
}

std::string GetHloJsonText() {
const xla::HloModuleProto& proto = computation.proto();
std::string result;
google::protobuf::util::MessageToJsonString(proto, &result);
return result;
}

private:
LoweringContext lowering_ctx;
xla::XlaComputation computation;
Expand Down Expand Up @@ -895,6 +902,7 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("build", &PyLoweringContext::Build)
.def("hlo", &PyLoweringContext::GetHlo)
.def("hlo_text", &PyLoweringContext::GetHloText)
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId);
Expand Down Expand Up @@ -1910,6 +1918,16 @@ void InitXlaModuleBindings(py::module m) {
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_set_xla_custom_op_name",
[](const at::Tensor& input, const std::string& op_name) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->SetCustomOpName(op_name);
});
m.def("_get_xla_custom_op_name",
[](const at::Tensor& input) -> const std::string& {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->GetCustomOpName();
});
m.def("_get_all_reduce_token",
[](const std::string& device_str) -> const torch::lazy::Value& {
auto device = GetDeviceOrCurrent(device_str);
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,8 @@ void XlaNode::UpdateShardingHash() {
}
}

void XlaNode::SetCustomOpName(const std::string& op_name) {
custom_op_name_ = op_name;
}

} // namespace torch_xla
5 changes: 5 additions & 0 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class XlaNode : public torch::lazy::Node {
return unbounded_dynamic_dims_;
}

void SetCustomOpName(const std::string& op_name);
const std::string& custom_op_name() const { return custom_op_name_; }

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;

Expand All @@ -167,6 +170,8 @@ class XlaNode : public torch::lazy::Node {

// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;

std::string custom_op_name_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand Down
Loading

0 comments on commit 1318547

Please sign in to comment.