forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sum reduction operator to TritonBench (pytorch#2282)
Summary: Add a Triton reduction kernel for the `sum` operator where `dim=None` to TritonBench, following the [TritonBench guide](https://fb.workplace.com/notes/953949486404240). This implementation works for all matrices being reduced to a scalar value. To measure accuracy of Triton reduction kernel, add accuracy metric to sum kernel in TritonBench in order to test accuracy of Triton implementation against baseline PyTorch implementation, referencing [`torchbenchmark/operators/gemm/operator.py`](https://www.internalfb.com/code/fbsource/[767bb6faa353685b84f08a39f36fdcf6ca170c85]/fbcode/pytorch/benchmark/torchbenchmark/operators/gemm/operator.py?lines=236). Reset output registers per run of the Triton kernel for accurate Triton output. Referenced the existing [vector_add](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/vector_add/) and [grouped_gemm](https://www.internalfb.com/code/fbsource/fbcode/pytorch/benchmark/torchbenchmark/operators/grouped_gemm/) TritonBench operators as frameworks for implementation. See the [TritonBench Operator Coverage Tracker](https://docs.google.com/spreadsheets/d/1091POOPSPsUnlNVEKaz2X_DQXdIwFv-fGOH_g9by-Zo/edit#gid=0) for current operator coverage in TritonBench. Reviewed By: xuzhao9, davidberard98 Differential Revision: D58048782
- Loading branch information
1 parent
f7b4bcc
commit c331f9c
Showing
3 changed files
with
125 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .operator import Operator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def triton_sum_kernel_scalar( | ||
input_ptr, | ||
output_ptr, | ||
M, # number of elements | ||
BLOCK_SIZE_M: tl.constexpr, # number of elements per block | ||
): | ||
pid = tl.program_id(axis=0) # i-th block of input | ||
|
||
block_start = pid * BLOCK_SIZE_M | ||
# offsets have shape equal to input shape | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block | ||
|
||
# mask has shape equal to input shape | ||
mask = offsets < M # mask out offsets that are out of bounds for input | ||
|
||
# loaded pointers have shape equal to input shape | ||
x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape | ||
|
||
output = tl.sum(x) | ||
|
||
# output_offsets have shape equal to output shape | ||
output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,)) | ||
|
||
# stored pointers have shape equal to output shape | ||
tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import argparse | ||
from typing import Callable, Generator, List, Optional, Tuple | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
from torchbenchmark.util.triton_op import ( | ||
BenchmarkOperator, | ||
BenchmarkOperatorMetrics, | ||
register_benchmark, | ||
register_metric, | ||
) | ||
|
||
from .kernels import triton_sum_kernel_scalar | ||
|
||
|
||
class Operator(BenchmarkOperator): | ||
|
||
DEFAULT_METRICS = ["latency", "accuracy"] | ||
|
||
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None): | ||
super().__init__(mode=mode, device=device, extra_args=extra_args) | ||
self.sizes = range(1, 17) | ||
|
||
@register_benchmark() | ||
def triton_sum(self, x: torch.Tensor): | ||
x_1d = x.view(-1) | ||
M = x_1d.shape[0] | ||
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),) | ||
BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2 | ||
|
||
def _inner(): | ||
output = torch.zeros(1, device=x.device, dtype=x.dtype) | ||
|
||
triton_sum_kernel_scalar[grid]( | ||
x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
) | ||
|
||
return output | ||
|
||
return _inner | ||
|
||
@register_benchmark(baseline=True) | ||
def torch_sum(self, x: torch.Tensor): | ||
result = torch.sum(x) | ||
return lambda: result | ||
|
||
def get_x_val(self, example_inputs): | ||
return len(example_inputs[0]) | ||
|
||
def get_x_vals(self) -> List[int]: | ||
x_vals = [] | ||
|
||
x_vals.extend([2**n for n in self.sizes]) | ||
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0]) | ||
|
||
return x_vals | ||
|
||
def get_input_iter(self) -> Generator: | ||
# reduce to a scalar value | ||
for size in self.get_x_vals(): # 1D matrix | ||
input_1d = torch.randn(size, device=self.device, dtype=self.dtype) | ||
yield (input_1d, ) | ||
|
||
for size in self.get_x_vals(): # 2D matrix | ||
if size < pow(2, 8): # ensure we don't exceed floating point limitations | ||
input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype) | ||
yield (input_2d, ) | ||
|
||
for size in self.get_x_vals(): # 3D matrix | ||
if size < pow(2, 4): # ensure we don't exceed floating point limitations | ||
input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype) | ||
yield (input_2d, ) | ||
|
||
def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool: | ||
output = fn() | ||
baseline_output = baseline_fn() | ||
return torch.allclose(output, baseline_output, atol=1e-4) | ||
|
||
@register_metric(skip_baseline=True) | ||
def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics): | ||
return [ex.dim() for ex in example_inputs] | ||
|
||
@register_metric() | ||
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics): | ||
gbps = ( | ||
lambda ms: 3 | ||
* example_inputs[0].element_size() | ||
* example_inputs[0].numel() | ||
/ ms | ||
* 1e-6 | ||
) | ||
return list(map(gbps, metrics.latency if metrics.latency else [0])) |