Skip to content

Commit

Permalink
Bootstrap python hooks for collecting DB information for op tests (#1296
Browse files Browse the repository at this point in the history
)

This commit adds new ttrt API `construct_op_stats_json` that can be used
as a common method for all frontends to generate op test sweep data.
The format isn't fully established, but this is just serving as the
boilerplate to begin wiring up all of the pieces with frontends.

Example FE usage:

    import ttrt.binary

    @op_test_harness("llama7b")
    def test_softmax():
        return torch.softmax

    def op_test_harness(model):
        def wrapper(test_fn):
            flatbuffer_binary = compile(test_fn)
            stats = ttrt.binary.stats.construct_op_stats_json(
              "forge-fe",
              model,
              flatbuffer_binary
            )
            aggregate_or_write_to_file(stats)

        return wrapper

Example Op Data:

 {
    "op_name": "ttnn.softmax",
    "framework_op_name": "",
    "dialect_op_name": "",
    "ttir_op_name": "",
    "inputs": [
      {
        "shape": [ 1, 10 ],
        "data_type": "Float32",
        "memory_space": "DeviceL1",
        "memory_layout": "Interleaved",
        "core_range_set": [ { "loc": { "y": 0, "x": 0 }, "size": { "y": 7, "x": 8 } } ]
      }
    ],
    "outputs": [
      {
        "shape": [ 1, 10 ],
        "data_type": "Float32",
        "memory_space": "DeviceDRAM",
        "memory_layout": "Interleaved",
        "core_range_set": [ { "loc": { "y": 0, "x": 0 }, "size": { "y": 7, "x": 8 } } ]
      }
    ],
    "attributes": {}
  },
  • Loading branch information
nsmithtt authored Nov 19, 2024
1 parent f60d2f5 commit 79f0ab1
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
1 change: 1 addition & 0 deletions runtime/tools/python/ttrt/binary/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
load_system_desc_from_path,
Flatbuffer,
)
from . import stats

import json

Expand Down
86 changes: 86 additions & 0 deletions runtime/tools/python/ttrt/binary/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from ._C import (
Binary,
)
import json
import re


def as_dict(bin: Binary):
return json.loads(bin.as_json())


def _parse_tensor_ref(ref):
if type(ref) == list:
tensors = (_parse_tensor_ref(r) for r in ref)
return [t for l in tensors for t in l]
return [
{
"shape": ref["desc"]["shape"],
"data_type": ref["desc"]["layout"]["memory_desc"]["data_type"],
"memory_space": ref["desc"]["layout"]["memory_desc"]["memory_space"],
"memory_layout": ref["desc"]["layout"]["memory_desc"]["memory_layout"],
"core_range_set": ref["desc"]["layout"]["core_range_set"],
}
]


def _parse_inputs_outputs(operation):
inputs = []
outputs = []
if operation["type_type"] == "GetDeviceOp":
return inputs, outputs
for k, v in operation["type"].items():
if k.startswith("in"):
inputs.extend(_parse_tensor_ref(v))
elif k.startswith("out"):
outputs.extend(_parse_tensor_ref(v))
return inputs, outputs


def _parse_attributes(operation):
attributes = {}
return attributes


def collect_op_stats(bin: Binary):
assert bin.file_identifier == "TTNN", "Only supports TTNN binary files"
d = as_dict(bin)
program_index = 0
operations = []

pattern = re.compile(r"(?<!^)(?=[A-Z])")

def to_ttnn_name(name):
return "ttnn." + pattern.sub("_", name).lower().strip("_op")

for operation in d["programs"][program_index]["operations"]:
inputs, outputs = _parse_inputs_outputs(operation)
operations.append(
{
"op_name": to_ttnn_name(operation["type_type"]),
"framework_op_name": "",
"dialect_op_name": "",
"ttir_op_name": "",
"inputs": inputs,
"outputs": outputs,
"attributes": _parse_attributes(operation),
}
)

return operations


def construct_op_stats_json(frontend: str, model: str, bin: Binary):
op_stats = collect_op_stats(bin)
return json.dumps(
{
"frontend": frontend,
"model": model,
"operations": op_stats,
},
indent=4,
)
13 changes: 13 additions & 0 deletions runtime/tools/python/ttrt/common/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Read:
"cpp",
"inputs",
"outputs",
"op_stats",
]

@staticmethod
Expand Down Expand Up @@ -430,6 +431,18 @@ def outputs(self, binary):
except Exception as e:
raise Exception(f"failed to read outputs for binary={binary.file_path}")

def op_stats(self, binary):
try:
import ttrt.binary

op_stats = ttrt.binary.stats.collect_op_stats(binary.fbb)
self.logging.info(f"\n{json.dumps(op_stats, indent=2)}")

except Exception as e:
raise Exception(
f"failed to read operator_stats for binary={binary.file_path} with exception={str(e)}"
)

@staticmethod
def register_arg(name, type, default, choices, help):
Read.registered_args[name] = {
Expand Down

0 comments on commit 79f0ab1

Please sign in to comment.