Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jagged_sum operator for padded nested tensors to TritonBench #2305

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 60 additions & 41 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import itertools
import math
import random
from typing import Callable, Generator, List, Optional, Tuple
Expand All @@ -14,19 +15,21 @@
register_metric,
)

random.seed(16)
torch.manual_seed(16)
seed = 16
random.seed(seed)
torch.manual_seed(seed)

GIGABYTES_PER_BYTE = 1e-6
RANDOM_CHOICE_MARGIN = 0.3
ABSOLUTE_TOLERANCE = 1e-3


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--seqlen",
type=int,
default=100,
default=500,
help="Maximum sequence length on ragged dimension (integer)",
)
parser.add_argument(
Expand All @@ -40,6 +43,9 @@ def parse_op_args(args: List[str]):

class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
use_cuda_graphs = False # enables GPU/CPU sync (for methods like NestedTensor unbind)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.sizes = range(4, 10, 2)
Expand All @@ -58,6 +64,17 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor):
dtype=self.dtype,
)

@register_benchmark()
def torch_jagged_sum_pad(self, x: torch.Tensor):
return lambda: torch.sum(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[self.seqlen], # max length of ragged dimension
),
dim=1,
) # sum along ragged dimension (dim == 1)

def get_x_val(self, example_inputs):
return len(example_inputs[0])

Expand Down Expand Up @@ -90,50 +107,52 @@ def get_input_iter(self) -> Generator:
"""

B_vals, M_vals = self.get_x_vals()

for B in B_vals:
for M in M_vals:
tensors = []

# greater sparsity --> shorter sequence lengths on ragged dimension
seqlen_avg = math.floor(
self.seqlen * (1 - self.sparsity)
) # average sequence length across all tensors in nested tensor
seqlen_margin = math.floor(
self.seqlen * RANDOM_CHOICE_MARGIN
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity

for _ in range(B):
seqlen_randint = random.randint(
max(seqlen_avg - seqlen_margin, 1),
min(seqlen_avg + seqlen_margin, self.seqlen),
)
tensor_2d = torch.randn(
(seqlen_randint, M), device=self.device, dtype=self.dtype
)
tensors.append(tensor_2d)

nt = torch.nested.nested_tensor(
tensors,
layout=torch.jagged,
device=self.device,
dtype=self.dtype,
B_M_vals = itertools.product(B_vals, M_vals)

for B, M in B_M_vals:
tensors = []

# greater sparsity --> shorter sequence lengths on ragged dimension
seqlen_avg = math.floor(
self.seqlen * (1 - self.sparsity)
) # average sequence length across all tensors in nested tensor
seqlen_margin = math.floor(
self.seqlen * RANDOM_CHOICE_MARGIN
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity

for _ in range(B):
seqlen_randint = random.randint(
max(
seqlen_avg - seqlen_margin, 1
), # seqlen_randint must be at least 1
min(
seqlen_avg + seqlen_margin, self.seqlen
), # seqlen_randint must not exceed self.seqlen
)
tensor_2d = torch.randn(
(seqlen_randint, M), device=self.device, dtype=self.dtype
)
tensors.append(tensor_2d)

yield (nt,)
nt = torch.nested.nested_tensor(
tensors,
layout=torch.jagged,
device=self.device,
dtype=self.dtype,
)

@register_metric()
def B_M(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
return tuple([(ex.size(0), ex.size(2)) for ex in example_inputs])[
0
] # return (B, M) for each example input
yield (nt,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=ABSOLUTE_TOLERANCE)

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: example_inputs[0].element_size()
return (
example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
/ metrics.latency
* GIGABYTES_PER_BYTE
)
return list(map(gbps, metrics.latency if metrics.latency else [0]))
Loading