Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hexagon][testing] refactor benchmark-table code #11400

Merged
merged 1 commit into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 39 additions & 110 deletions tests/python/contrib/test_hexagon/benchmark_hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

import os
import os.path
import pathlib
import sys
import pytest
import numpy as np
import logging
import tempfile
cconvey marked this conversation as resolved.
Show resolved Hide resolved
import csv

import tvm.testing
from tvm import te
from tvm.contrib.hexagon.build import HexagonLauncherRPC
from .benchmark_util import BenchmarksTable

RPC_SERVER_PORT = 7070

Expand Down Expand Up @@ -58,112 +57,22 @@ def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC):
print("-" * 80)
print()

# TODO: We should move this into a separate test fixture, to make it easier to write
# additional benchmarking functions. We'd just need to generalize the assumptions regarding
# the particular fields being tracked as independent variables.
class benchmark_results_collection:
def __init__(self):
self.row_dicts_ = []

def num_failures(self):
num = 0
for d in self.row_dicts_:
if d["status"] == "FAIL":
num += 1
return num

def num_skips(self):
num = 0
for d in self.row_dicts_:
if d["status"] == "SKIP":
num += 1
return num

def record_success(
self, dtype, sched_type, mem_scope, num_vecs_per_tensor, benchmark_result
):
median_usec = benchmark_result.median * 1000000
min_usec = benchmark_result.min * 1000000
max_usec = benchmark_result.max * 1000000

self.row_dicts_.append(
{
"dtype": dtype,
"sched_type": sched_type,
"mem_scope": mem_scope,
"num_vecs_per_tensor": num_vecs_per_tensor,
"status": "OK",
"median(µsec)": f"{median_usec:.3}",
"min(µsec)": f"{min_usec:.3}",
"max(µsec)": f"{max_usec:.3}",
}
)

def record_failure(self, dtype, sched_type, mem_scope, num_vecs_per_tensor, error_text):
self.row_dicts_.append(
{
"dtype": dtype,
"sched_type": sched_type,
"mem_scope": mem_scope,
"num_vecs_per_tensor": num_vecs_per_tensor,
"status": "FAIL",
"comment": error_text,
}
)

def record_skip(self, dtype, sched_type, mem_scope, num_vecs_per_tensor, comment_text):
self.row_dicts_.append(
{
"dtype": dtype,
"sched_type": sched_type,
"mem_scope": mem_scope,
"num_vecs_per_tensor": num_vecs_per_tensor,
"status": "SKIP",
"comment": comment_text,
}
)

def dump(self, f):
csv.register_dialect(
"benchmarks",
delimiter="\t",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
)

fieldnames = [
"dtype",
"sched_type",
"mem_scope",
"num_vecs_per_tensor",
"status",
"median(µsec)",
"min(µsec)",
"max(µsec)",
"comment",
]

writer = csv.DictWriter(f, fieldnames, dialect="benchmarks", restval="")

writer.writeheader()
for d in self.row_dicts_:
writer.writerow(d)

br = benchmark_results_collection()
bt = BenchmarksTable()

# Create and benchmark a single primfunc.
# If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'br'.
# If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'bt'.
def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
version_name = f"dtype:{dtype}-schedtype:{sched_type}-memscope:{mem_scope}-numvecs:{num_vectors_per_tensor}"
print()
print(f"CONFIGURATION: {version_name}")

if num_vectors_per_tensor == 2048 and mem_scope == "global.vtcm":
br.record_skip(
dtype,
sched_type,
mem_scope,
num_vectors_per_tensor,
f"Expect to exceed VTCM budget.",
bt.record_skip(
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
num_vectors_per_tensor=num_vectors_per_tensor,
comments="Expect to exceed VTCM budget.",
)
return

Expand Down Expand Up @@ -255,25 +164,45 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
timer = mod.time_evaluator("elemwise_add", sess.device, number=10, repeat=1)
timing_result = timer(A_data, B_data, C_data)

print("TIMING RESULT: {}".format(timing_result))

# Verify that the computation actually happened, and produced the correct result.
result = C_data.numpy()
tvm.testing.assert_allclose(host_numpy_C_data_expected, result)

br.record_success(
dtype, sched_type, mem_scope, num_vectors_per_tensor, timing_result
bt.record_success(
timing_result,
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
num_vectors_per_tensor=num_vectors_per_tensor,
)

except Exception as err:
f.write("ERROR:\n")
f.write("{}\n".format(err))
br.record_failure(
dtype, sched_type, mem_scope, num_vectors_per_tensor, f"See {report_path}"
bt.record_fail(
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
num_vectors_per_tensor=num_vectors_per_tensor,
comments=f"See {report_path}",
)

# -----------------------------------------------------------------------------------------------

csv_column_order = [
"dtype",
"sched_type",
"mem_scope",
"num_vectors_per_tensor",
"row_status",
"timings_min_usecs",
"timings_max_usecs",
"timings_median_usecs",
"timings_mean_usecs",
"timings_stddev_usecs",
"comments",
]

# Hexagon v69 allows more dtypes, but we're sticking with v68 for now.
for dtype in [
"int8",
Expand All @@ -300,7 +229,7 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor)

# Report our progress.
br.dump(sys.stdout)
bt.print_csv(sys.stdout, csv_column_order)

print("-" * 80)
print(f"OUTPUT DIRECTORY: {host_output_dir}")
Expand All @@ -309,8 +238,8 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):

tabular_output_filename = os.path.join(host_output_dir, "benchmark-results.csv")
with open(tabular_output_filename, "w") as csv_file:
br.dump(csv_file)
bt.print_csv(csv_file, csv_column_order)
print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}")

if br.num_failures() > 0:
if bt.has_fail() > 0:
pytest.fail("At least one benchmark configuration failed", pytrace=False)
141 changes: 141 additions & 0 deletions tests/python/contrib/test_hexagon/benchmark_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import csv


class BenchmarksTable:
"""
Stores/reports the result of benchmark runs.

Each line item has a status: success, fail, or skip.

Each 'success' line item must include benchmark data,
in the form provided by TVM's `time_evaluator` mechanism.

Each line item may also specify values for any subset of
the columns provided to the table's construstor.
"""

BUILTIN_COLUMN_NAMES = set(
[
"row_status",
"timings_min_usecs",
"timings_max_usecs",
"timings_median_usecs",
"timings_mean_usecs",
"timings_stddev_usecs",
]
)

def __init__(self):
self._line_items = []

def validate_user_supplied_kwargs(self, kwarg_dict):
name_conflicts = set(kwarg_dict).intersection(self.BUILTIN_COLUMN_NAMES)

if name_conflicts:
name_list = ", ".join(name_conflicts)
raise Exception(f"Attempting to supply values for built-in column names: {name_list}")

def record_success(self, timings, **kwargs):
"""
`timings` : Assumed to have the structure and meaning of
the timing results provided by TVM's `time_evaluator`
mechanism.

`kwargs` : Optional values for any of the other columns
Lunderberg marked this conversation as resolved.
Show resolved Hide resolved
defined for this benchmark table.
"""
self.validate_user_supplied_kwargs(kwargs)
line_item = kwargs

line_item["row_status"] = "SUCCESS"

line_item["timings_min_usecs"] = timings.min * 1000000
line_item["timings_max_usecs"] = timings.max * 1000000
line_item["timings_median_usecs"] = timings.median * 1000000
line_item["timings_stddev_usecs"] = timings.std * 1000000
line_item["timings_mean_usecs"] = timings.mean * 1000000

self._line_items.append(line_item)

def record_skip(self, **kwargs):
self.validate_user_supplied_kwargs(kwargs)

line_item = dict(kwargs)
line_item["row_status"] = "SKIP"
self._line_items.append(line_item)

def record_fail(self, **kwargs):
self.validate_user_supplied_kwargs(kwargs)

line_item = dict(kwargs)
line_item["row_status"] = "FAIL"
self._line_items.append(line_item)

def has_fail(self):
"""
Returns True if the table contains at least one 'fail' line item,
otherwise returns False.
"""
return any(item["row_status"] == "FAIL" for item in self._line_items)

def print_csv(self, f, column_name_order, timing_decimal_places=3):
"""
Print the benchmark results as a csv.

`f` : The output stream.

`column_name_order`: an iterable sequence of column names, indicating the
left-to-right ordering of columns in the CSV output.

The CSV output will contain only those columns that are mentioned in
this list.

`timing_decimal_places`: for the numeric timing values, this is the
number of decimal places to provide in the printed output.
For example, a value of 3 is equivalent to the Python formatting string
`'{:.3f}'`
"""
writer = csv.DictWriter(
f, column_name_order, dialect="excel-tab", restval="", extrasaction="ignore"
)

writer.writeheader()

for line_item_dict in self._line_items:
# Use a copy of the line-item dictionary, because we might do some modifications
# for the sake of rendering...
csv_line_dict = dict(line_item_dict)

for col_name in [
"timings_min_usecs",
"timings_max_usecs",
"timings_median_usecs",
"timings_stddev_usecs",
"timings_mean_usecs",
]:
if col_name in csv_line_dict:
old_value = csv_line_dict[col_name]
assert isinstance(
old_value, float
), f"Formatting code assumes that column {col_name} is some col_nameind of float, but its actual type is {type(old_value)}"
str_value = f"{old_value:>0.{timing_decimal_places}f}"
csv_line_dict[col_name] = str_value

writer.writerow(csv_line_dict)