From f985e9d110c778a2a901b3caf0281f0e6752dae8 Mon Sep 17 00:00:00 2001 From: Janani Sriram Date: Thu, 20 Jun 2024 14:46:24 -0700 Subject: [PATCH] Add unit tests on CPU for TritonBench features (#2323) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2323 Add unit tests that run on the CPU to verify the behavior of the following: - `x_only = True` for metric registration in [`register_metric()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=337) - custom `label` argument for benchmark registration in [`register_benchmark()`](https://www.internalfb.com/code/fbsource/[731f07681fbbb38750aee3b165137e39fa6cee50]/fbcode/pytorch/benchmark/torchbenchmark/util/triton_op.py?lines=316) Reviewed By: xuzhao9 Differential Revision: D58558868 --- torchbenchmark/operators/test_op/__init__.py | 1 + torchbenchmark/operators/test_op/operator.py | 44 ++++++++++++++++++++ torchbenchmark/util/triton_op.py | 1 + 3 files changed, 46 insertions(+) create mode 100644 torchbenchmark/operators/test_op/__init__.py create mode 100644 torchbenchmark/operators/test_op/operator.py diff --git a/torchbenchmark/operators/test_op/__init__.py b/torchbenchmark/operators/test_op/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/test_op/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/test_op/operator.py b/torchbenchmark/operators/test_op/operator.py new file mode 100644 index 0000000000..bf149aaeb8 --- /dev/null +++ b/torchbenchmark/operators/test_op/operator.py @@ -0,0 +1,44 @@ +from typing import Generator, List, Optional + +import torch + +from torchbenchmark.util.triton_op import ( + BenchmarkOperator, + BenchmarkOperatorMetrics, + register_benchmark, + register_metric, +) + + +class Operator(BenchmarkOperator): + + DEFAULT_METRICS = ["test_metric"] + + def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None): + super().__init__(mode=mode, device=device, extra_args=extra_args) + + @register_benchmark(label="new_op_label") + def test_op(self, x: torch.Tensor): + return lambda: x + + def get_x_val(self, example_inputs): + return example_inputs[0].shape + + def get_x_vals(self) -> List[int]: + return [2**n for n in [1, 2, 3]] + + def get_input_iter(self) -> Generator: + for x in self.get_x_vals(): + yield (torch.Tensor(torch.randn(x, device=self.device, dtype=self.dtype)),) + + @register_metric(x_only=True) + def test_metric( + self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics + ): + return [ex.shape[0] + 2 for ex in example_inputs] + + @register_metric() + def test_metric_per_benchmark( + self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics + ): + return [ex.shape[0] + 3 for ex in example_inputs] diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index fdcd14729a..1e3d1783f2 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -468,6 +468,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None) self._only = _split_params_by_comma(self.tb_args.only) self._input_id = self.tb_args.input_id self._num_inputs = self.tb_args.num_inputs + self.device = device # Run the post initialization def __post__init__(self):