-
Notifications
You must be signed in to change notification settings - Fork 480
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extend HLO metadata to include class hierarchy information (#5715)
* Add python binding to allow custom op_name metadata for lowered HLO
- Loading branch information
Showing
10 changed files
with
363 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
import torch | ||
import torch_xla | ||
|
||
import inspect | ||
from collections import defaultdict | ||
|
||
from torch.utils._python_dispatch import TorchDispatchMode | ||
|
||
class_count = defaultdict(int) | ||
instance_count = dict() | ||
|
||
|
||
def GetInstancePlaceHolder(class_type, obj): | ||
global class_count | ||
global instance_count | ||
|
||
if (class_type, id(obj)) not in instance_count: | ||
class_count[class_type] += 1 | ||
instance_count[(class_type, id(obj))] = class_count[class_type] | ||
|
||
place_holder = instance_count[(class_type, id(obj))] | ||
|
||
return f".{place_holder}" | ||
|
||
|
||
def CheckIgnored(key): | ||
ignored_list = ("self", "_bootstrap", "_fix_up_module", | ||
"_get_supported_file_loaders", "_setup", "_buffers", | ||
"_parameters", "_non_persistent_buffers_set") | ||
|
||
return (key.startswith("__") and key.endswith("__")) or key in ignored_list | ||
|
||
|
||
def Prefix(prefix, val): | ||
if len(prefix) > 0: | ||
return f"{prefix}.{val}" | ||
else: | ||
return f"{val}" | ||
|
||
|
||
def ReverseSearchBreadthFirst(container, obj, debug=False): | ||
if container is None: | ||
return False | ||
|
||
queue = [] | ||
visited = set() | ||
nested_name = "" | ||
max_depth = 5 | ||
queue.append((0, nested_name, container)) | ||
|
||
while len(queue): | ||
depth, prefix, candidate = queue.pop(0) | ||
|
||
if depth > max_depth or id(candidate) in visited: | ||
continue | ||
|
||
visited.add(id(candidate)) | ||
|
||
if isinstance(candidate, dict): | ||
for k, v in candidate.items(): | ||
if not isinstance(k, str): | ||
if debug: | ||
print(f"Found non string key {k}") | ||
break | ||
if CheckIgnored(k): | ||
continue | ||
nested_name = Prefix(prefix, k) | ||
if v is obj: | ||
if debug: | ||
print(f"Found {nested_name}") | ||
return True, nested_name | ||
elif debug: | ||
print(f"Miss {nested_name}") | ||
if id(v) not in visited and depth < max_depth: | ||
queue.append((depth + 1, nested_name, v)) | ||
elif isinstance(candidate, (list, tuple)): | ||
for i, v in enumerate(candidate): | ||
nested_name = Prefix(prefix, i) | ||
if v is obj: | ||
if debug: | ||
print(f"Found {nested_name}") | ||
return True, nested_name | ||
elif debug: | ||
print(f"Miss {nested_name}") | ||
if id(v) not in visited and depth < max_depth: | ||
queue.append((depth + 1, nested_name, v)) | ||
elif hasattr(candidate, "__class__"): | ||
# Ignore class wich overrides __getattr__ and | ||
# generates error | ||
if type(candidate).__name__ == "_ClassNamespace": | ||
continue | ||
for att in ("_modules", "__dict__"): | ||
if hasattr(candidate, att): | ||
v = getattr(candidate, att) | ||
if id(v) not in visited and depth < max_depth: | ||
queue.append((depth + 1, nested_name, v)) | ||
else: | ||
print("No action") | ||
|
||
return False, None | ||
|
||
|
||
def FindMemberVariable(frame, obj): | ||
parent_frame = frame.f_back | ||
found = False | ||
variable_name = None | ||
|
||
for lframe in inspect.getouterframes(parent_frame): | ||
if lframe.frame.f_code.co_nlocals <= 0: | ||
continue | ||
self_name = lframe.frame.f_code.co_varnames[0] | ||
parent_obj = lframe.frame.f_locals[self_name] | ||
found, variable_name = ReverseSearchBreadthFirst(parent_obj, obj) | ||
if found: | ||
break | ||
|
||
return found, variable_name | ||
|
||
|
||
def FindLocalVariable(frame, obj): | ||
found = False | ||
variable_name = None | ||
|
||
for lframe in inspect.getouterframes(frame.f_back): | ||
found, variable_name = ReverseSearchBreadthFirst(lframe.frame.f_locals, obj) | ||
if found: | ||
break | ||
|
||
return found, variable_name | ||
|
||
|
||
def GetClassNameAndObjFromFrame(frame): | ||
class_obj_str = "" | ||
if frame.f_code.co_argcount == 0: | ||
return class_obj_str | ||
|
||
likely_obj_name = frame.f_code.co_varnames[0] | ||
|
||
obj = frame.f_locals[likely_obj_name] | ||
|
||
if not hasattr(obj, "__class__") or likely_obj_name != "self": | ||
return class_obj_str | ||
|
||
name = type(obj).__name__ | ||
variable_name = None | ||
found = False | ||
|
||
found, variable_name = FindMemberVariable(frame, obj) | ||
|
||
if not found: | ||
found, variable_name = FindLocalVariable(frame, obj) | ||
|
||
if not found: | ||
variable_name = GetInstancePlaceHolder(name, obj) | ||
|
||
name = name + "[" + variable_name + "]" | ||
|
||
return name | ||
|
||
|
||
def CleanNames(names): | ||
last_name = "" | ||
output = [] | ||
for name in names: | ||
if name != last_name: | ||
output.append(name) | ||
last_name = name | ||
|
||
# Drop the last scope which is the scope name add op_name lowerings | ||
return output[:-1] | ||
|
||
|
||
def GetAllObjectAndClassNames(frame): | ||
names = [] | ||
while frame is not None: | ||
name = GetClassNameAndObjFromFrame(frame) | ||
if len(name) > 0: | ||
names.append(name) | ||
frame = frame.f_back | ||
|
||
names.reverse() | ||
|
||
names = CleanNames(names) | ||
|
||
output = "/".join(names) | ||
|
||
if len(output) > 0: | ||
output += "/" | ||
|
||
return output | ||
|
||
|
||
class CustomOpNameLowering(TorchDispatchMode): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def __enter__(self): | ||
self._old_ir_debug = torch_xla._XLAC._get_ir_debug() | ||
torch_xla._XLAC._set_ir_debug(True) | ||
return super().__enter__() | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
torch_xla._XLAC._set_ir_debug(self._old_ir_debug) | ||
super().__exit__(exc_type, exc_val, exc_tb) | ||
|
||
def __torch_dispatch__(self, func, types, args=(), kwargs={}): | ||
res = func(*args, **kwargs) | ||
if 'xla' in str(res.device): | ||
frame = inspect.currentframe() | ||
prefix = GetAllObjectAndClassNames(frame) | ||
torch_xla._XLAC._set_xla_custom_op_name(res, prefix) | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import sys | ||
|
||
# Normal imports section starts here. | ||
import torch | ||
import torch_xla | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.debug.metrics as met | ||
import unittest | ||
import json | ||
from custom_debug_lowering import CustomOpNameLowering | ||
|
||
|
||
class TestHloMetaData(unittest.TestCase): | ||
|
||
def setUp(self): | ||
torch.manual_seed(42) | ||
self.pre_test_tensor_type = torch.get_default_dtype() | ||
self.pre_test_ir_debug = torch_xla._XLAC._get_ir_debug() | ||
torch.set_default_tensor_type(torch.FloatTensor) | ||
torch_xla._XLAC._set_ir_debug(True) | ||
super(TestHloMetaData, self).setUp() | ||
|
||
def tearDown(self): | ||
super(TestHloMetaData, self).tearDown() | ||
torch_xla._XLAC._set_ir_debug(self.pre_test_ir_debug) | ||
|
||
def test_metadata(self): | ||
layer1 = torch.nn.Linear(4, 4) | ||
nl1 = torch.nn.ReLU() | ||
layer2 = torch.nn.Linear(4, 2) | ||
nl2 = torch.nn.Tanh() | ||
model = torch.nn.Sequential(layer1, nl1, layer2, nl2) | ||
|
||
with CustomOpNameLowering(): | ||
model = model.to(device=xm.xla_device()) | ||
inp = torch.rand(4, 4, device=xm.xla_device()) | ||
out = model(inp) | ||
|
||
ctx = torch_xla._XLAC.lowering.LoweringContext() | ||
ctx.build([out]) | ||
hlo_text = ctx.hlo_json() | ||
|
||
# Strings to match in the lowering | ||
bingo = { | ||
"torch/_ops.py": False, | ||
#"torch/nn/modules/linear.py": False, | ||
#"torch/nn/modules/activation.py": False, | ||
#"torch/nn/functional.py": False, | ||
"Sequential[model]/Linear[0]": False, | ||
"Sequential[model]/ReLU[1]": False, | ||
"Sequential[model]/Linear[2]": False, | ||
"Sequential[model]/Tanh[3]": False, | ||
"aten__addmm": False, | ||
"aten__relu": False, | ||
"aten__tanh": False, | ||
"aten__permute": False | ||
} | ||
|
||
non_zero_metadata = False | ||
|
||
local_json = json.loads(hlo_text) | ||
assert "computations" in local_json | ||
for c in local_json["computations"]: | ||
if "instructions" in c: | ||
i = c["instructions"] | ||
for op in i: | ||
if 'metadata' in op: | ||
meta = op["metadata"] | ||
print(meta) | ||
if len(meta) > 0: | ||
non_zero_metadata = True | ||
for km, vm in meta.items(): | ||
for k in bingo.keys(): | ||
if isinstance(vm, str) and k in vm: | ||
bingo[k] = True | ||
|
||
assert non_zero_metadata, "No metadata was lowered - an issue with turning on IR DEBUG?" | ||
|
||
for k, v in bingo.items(): | ||
assert v, f"Keyword {k} was not found as expected in HLO metadata for simple test" | ||
|
||
print("All required metadata symbols matched") | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main(exit=False) | ||
if xu.getenv_as('METRICS_DEBUG', bool, defval=False): | ||
print(met.metrics_report()) | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.