Skip to content

Commit

Permalink
Add batch_index_select_dim0 (w/ TBE backend) (#1897)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1897

This diff introduces `batch_index_select_dim0` using the `SplitTBE`
implementation (it shares the same code generator as TBE).  The new
operator is designed to address limitations of
`group_index_select_dim0`.  Both operators are designed to operate
multiple inputs.  However, `batch_index_select_dim0` requires all
inputs to be contiguous in memory, while `batch_index_select_dim0` can
operate on inputs with a discrete memory layout.  Implementation-wise,
they are different.  We plan to merge their backends in the future.

Since `batch_index_select_dim0` is backed by TBE, it inherits TBE
limitations including:
- The column sizes must be a multiple of 4 and not exceed 1024.
  Moreover, the underlying buffer of the inputs tensor must be 16-byte
  aligned.  This is because the TBE kernel uses a vector load/store
  which requires the buffer to be 16-byte aligned.  The kernel will
  raise an error if this assumption is violated.
- Due to the 16-byte aligned enforcement, during the backward pass, if
  the output gradient is not 16-byte aligned, the operator will copy
  the output gradient into a new 16-byte aligned buffer.  This can be
  expensive if the output gradient size is large.

Usage:

```
# This target might change in the future
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")

...

output = torch.ops.fbgemm.batch_index_select_dim0(
            inputs, # Tensor - 1D tensor (concatenated flatten inputs)
            indices, # Tensor - 1D tensor (concatenated indices)
            input_num_indices, # List[int]
            input_rows, # List[int]
            input_columns, # List[int]
         )
```

Differential Revision: D46084590

fbshipit-source-id: ef02e43cff10311b29bff3d351839ac9fde13ddf
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 28, 2023
1 parent 3579b4d commit 8967bf7
Show file tree
Hide file tree
Showing 18 changed files with 1,358 additions and 194 deletions.
11 changes: 9 additions & 2 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,16 @@ set(gen_gpu_kernel_source_files
"gen_embedding_forward_split_unweighted_codegen_cuda.cu"
"gen_embedding_backward_dense_indice_weights_codegen_cuda.cu"
"gen_embedding_backward_split_indice_weights_codegen_cuda.cu"
"gen_embedding_backward_split_grad.cu"
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu")
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu"
"gen_batch_index_select_dim0_forward_codegen_cuda.cu"
"gen_batch_index_select_dim0_forward_kernel.cu"
"gen_batch_index_select_dim0_forward_kernel_small.cu"
"gen_batch_index_select_dim0_backward_codegen_cuda.cu"
"gen_batch_index_select_dim0_backward_kernel_cta.cu"
"gen_batch_index_select_dim0_backward_kernel_warp.cu"
"gen_embedding_backward_split_grad.cu"
)

if(NOT USE_ROCM)
list(APPEND gen_gpu_kernel_source_files
Expand Down
114 changes: 114 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import functools
import logging
import random
from typing import List

import click
import fbgemm_gpu
import numpy as np
import torch

from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
Expand All @@ -26,6 +30,7 @@

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")


@click.group()
Expand Down Expand Up @@ -352,5 +357,114 @@ def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor:
logging.info(f"fbgemm_gpu time: {time:.5f} sec")


@cli.command()
@click.option("--num-inputs", default=1024)
@click.option("--rows", default=100)
@click.option("--columns", default=128)
@click.option("--num-indices", default=2048)
@click.option("--timeline", is_flag=True, default=False)
def index_select_bench(
num_inputs: int, rows: int, columns: int, num_indices: int, timeline: bool
) -> None:
input_rows = [rows] * num_inputs
input_columns = [columns] * num_inputs
input_num_indices = [num_indices] * num_inputs
inputs = [
torch.rand(rows, cols, dtype=torch.float, device="cuda")
for rows, cols in zip(input_rows, input_columns)
]
for i in range(len(inputs)):
inputs[i].requires_grad = True
indices = [
torch.randint(low=0, high=rows, size=(num,), dtype=torch.long, device="cuda")
for num, rows in zip(input_num_indices, input_rows)
]

concat_inputs = torch.concat([input.flatten().clone().detach() for input in inputs])
concat_inputs.requires_grad = True
concat_indices = torch.concat(indices)

gis_inputs = [input.clone().detach() for input in inputs]
for i in range(len(gis_inputs)):
gis_inputs[i].requires_grad = True

def index_select_fwd_ref(
inputs: List[torch.Tensor], indices: List[torch.Tensor]
) -> List[torch.Tensor]:
outputs = []
for input, index in zip(inputs, indices):
outputs.append(torch.index_select(input, 0, index))
return outputs

def index_select_bwd_ref(
outputs: List[torch.Tensor], grads: List[torch.Tensor]
) -> None:
for output, grad in zip(outputs, grads):
output.backward(grad, retain_graph=True)

bench_kwargs = {"num_warmups": 10, "iters": 10 if timeline else 100}
profile_ctx = profile if timeline else contextlib.nullcontext

with profile_ctx() as prof:
time_pyt, out_pyt = benchmark_torch_function(
index_select_fwd_ref,
(inputs, indices),
**bench_kwargs,
)

time_bis, out_bis = benchmark_torch_function(
torch.ops.fbgemm.batch_index_select_dim0,
(
concat_inputs,
concat_indices,
input_num_indices,
input_rows,
input_columns,
),
**bench_kwargs,
)

time_gis, out_gis = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(gis_inputs, indices),
**bench_kwargs,
)

if timeline:
prof.export_chrome_trace("index_select_fwd_trace.json")

grads = [torch.rand_like(out) for out in out_pyt]
concat_grads = torch.concat([grad.flatten() for grad in grads])
concat_out_gis = torch.concat([out.flatten() for out in out_gis])

with profile_ctx() as prof:
time_bwd_pyt, _ = benchmark_torch_function(
index_select_bwd_ref,
(out_pyt, grads),
**bench_kwargs,
)

time_bwd_bis, _ = benchmark_torch_function(
functools.partial(out_bis.backward, retain_graph=True),
(concat_grads,),
**bench_kwargs,
)

time_bwd_gis, _ = benchmark_torch_function(
functools.partial(concat_out_gis.backward, retain_graph=True),
(concat_grads,),
**bench_kwargs,
)

if timeline:
prof.export_chrome_trace("index_select_bwd_trace.json")

logging.info(
f"torch.index_select forward {time_pyt * 1e6:.2f} us, backward {time_bwd_pyt * 1e6:.2f} us\n"
f"torch.ops.fbgemm.batch_index_select forward {time_bis * 1e6:.2f} us, backward {time_bwd_bis * 1e6:.2f} us\n"
f"torch.ops.fbgemm.group_index_select_dim0 forward {time_gis * 1e6:.2f} us, backward {time_bwd_gis * 1e6:.2f} us"
)


if __name__ == "__main__":
cli()
Loading

0 comments on commit 8967bf7

Please sign in to comment.