-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bootstrap python hooks for collecting DB information for op tests (#1296
) This commit adds new ttrt API `construct_op_stats_json` that can be used as a common method for all frontends to generate op test sweep data. The format isn't fully established, but this is just serving as the boilerplate to begin wiring up all of the pieces with frontends. Example FE usage: import ttrt.binary @op_test_harness("llama7b") def test_softmax(): return torch.softmax def op_test_harness(model): def wrapper(test_fn): flatbuffer_binary = compile(test_fn) stats = ttrt.binary.stats.construct_op_stats_json( "forge-fe", model, flatbuffer_binary ) aggregate_or_write_to_file(stats) return wrapper Example Op Data: { "op_name": "ttnn.softmax", "framework_op_name": "", "dialect_op_name": "", "ttir_op_name": "", "inputs": [ { "shape": [ 1, 10 ], "data_type": "Float32", "memory_space": "DeviceL1", "memory_layout": "Interleaved", "core_range_set": [ { "loc": { "y": 0, "x": 0 }, "size": { "y": 7, "x": 8 } } ] } ], "outputs": [ { "shape": [ 1, 10 ], "data_type": "Float32", "memory_space": "DeviceDRAM", "memory_layout": "Interleaved", "core_range_set": [ { "loc": { "y": 0, "x": 0 }, "size": { "y": 7, "x": 8 } } ] } ], "attributes": {} },
- Loading branch information
Showing
3 changed files
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
load_system_desc_from_path, | ||
Flatbuffer, | ||
) | ||
from . import stats | ||
|
||
import json | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from ._C import ( | ||
Binary, | ||
) | ||
import json | ||
import re | ||
|
||
|
||
def as_dict(bin: Binary): | ||
return json.loads(bin.as_json()) | ||
|
||
|
||
def _parse_tensor_ref(ref): | ||
if type(ref) == list: | ||
tensors = (_parse_tensor_ref(r) for r in ref) | ||
return [t for l in tensors for t in l] | ||
return [ | ||
{ | ||
"shape": ref["desc"]["shape"], | ||
"data_type": ref["desc"]["layout"]["memory_desc"]["data_type"], | ||
"memory_space": ref["desc"]["layout"]["memory_desc"]["memory_space"], | ||
"memory_layout": ref["desc"]["layout"]["memory_desc"]["memory_layout"], | ||
"core_range_set": ref["desc"]["layout"]["core_range_set"], | ||
} | ||
] | ||
|
||
|
||
def _parse_inputs_outputs(operation): | ||
inputs = [] | ||
outputs = [] | ||
if operation["type_type"] == "GetDeviceOp": | ||
return inputs, outputs | ||
for k, v in operation["type"].items(): | ||
if k.startswith("in"): | ||
inputs.extend(_parse_tensor_ref(v)) | ||
elif k.startswith("out"): | ||
outputs.extend(_parse_tensor_ref(v)) | ||
return inputs, outputs | ||
|
||
|
||
def _parse_attributes(operation): | ||
attributes = {} | ||
return attributes | ||
|
||
|
||
def collect_op_stats(bin: Binary): | ||
assert bin.file_identifier == "TTNN", "Only supports TTNN binary files" | ||
d = as_dict(bin) | ||
program_index = 0 | ||
operations = [] | ||
|
||
pattern = re.compile(r"(?<!^)(?=[A-Z])") | ||
|
||
def to_ttnn_name(name): | ||
return "ttnn." + pattern.sub("_", name).lower().strip("_op") | ||
|
||
for operation in d["programs"][program_index]["operations"]: | ||
inputs, outputs = _parse_inputs_outputs(operation) | ||
operations.append( | ||
{ | ||
"op_name": to_ttnn_name(operation["type_type"]), | ||
"framework_op_name": "", | ||
"dialect_op_name": "", | ||
"ttir_op_name": "", | ||
"inputs": inputs, | ||
"outputs": outputs, | ||
"attributes": _parse_attributes(operation), | ||
} | ||
) | ||
|
||
return operations | ||
|
||
|
||
def construct_op_stats_json(frontend: str, model: str, bin: Binary): | ||
op_stats = collect_op_stats(bin) | ||
return json.dumps( | ||
{ | ||
"frontend": frontend, | ||
"model": model, | ||
"operations": op_stats, | ||
}, | ||
indent=4, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters