Skip to content

Commit

Permalink
Add batch_index_select_dim0 (w/ TBE backend) (#1897)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1897

This diff introduces `batch_index_select_dim0` using the `SplitTBE`
implementation (it shares the same code generator as TBE).  The new
operator is designed to address limitations of
`group_index_select_dim0`.  Both operators are designed to operate
multiple inputs.  However, `batch_index_select_dim0` requires all
inputs to be contiguous in memory, while `batch_index_select_dim0` can
operate on inputs with a discrete memory layout.  Implementation-wise,
they are different.  We plan to merge their backends in the future.

Since `batch_index_select_dim0` is backed by TBE, it inherits TBE
limitations including:
- The column sizes must be a multiple of 4 and not exceed 1024.
  Moreover, the underlying buffer of the inputs tensor must be 16-byte
  aligned.  This is because the TBE kernel uses a vector load/store
  which requires the buffer to be 16-byte aligned.  The kernel will
  raise an error if this assumption is violated.
- Due to the 16-byte aligned enforcement, during the backward pass, if
  the output gradient is not 16-byte aligned, the operator will copy
  the output gradient into a new 16-byte aligned buffer.  This can be
  expensive if the output gradient size is large.

Usage:

```
# This target might change in the future
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")

...

output = torch.ops.fbgemm.batch_index_select_dim0(
            inputs, # Tensor - 1D tensor (concatenated flatten inputs)
            indices, # Tensor - 1D tensor (concatenated indices)
            input_num_indices, # List[int]
            input_rows, # List[int]
            input_columns, # List[int]
         )
```

Differential Revision: D46084590

fbshipit-source-id: 59f99f5c2bc5c5424205bd668a6c7777ecf53f7b
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 28, 2023
1 parent 3579b4d commit 1140128
Show file tree
Hide file tree
Showing 19 changed files with 1,582 additions and 194 deletions.
11 changes: 9 additions & 2 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,16 @@ set(gen_gpu_kernel_source_files
"gen_embedding_forward_split_unweighted_codegen_cuda.cu"
"gen_embedding_backward_dense_indice_weights_codegen_cuda.cu"
"gen_embedding_backward_split_indice_weights_codegen_cuda.cu"
"gen_embedding_backward_split_grad.cu"
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu")
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu"
"gen_batch_index_select_dim0_forward_codegen_cuda.cu"
"gen_batch_index_select_dim0_forward_kernel.cu"
"gen_batch_index_select_dim0_forward_kernel_small.cu"
"gen_batch_index_select_dim0_backward_codegen_cuda.cu"
"gen_batch_index_select_dim0_backward_kernel_cta.cu"
"gen_batch_index_select_dim0_backward_kernel_warp.cu"
"gen_embedding_backward_split_grad.cu"
)

if(NOT USE_ROCM)
list(APPEND gen_gpu_kernel_source_files
Expand Down
114 changes: 114 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import functools
import logging
import random
from typing import List

import click
import fbgemm_gpu
import numpy as np
import torch

from torch.profiler import profile

logging.basicConfig(level=logging.DEBUG)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
Expand All @@ -26,6 +30,7 @@

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")


@click.group()
Expand Down Expand Up @@ -352,5 +357,114 @@ def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor:
logging.info(f"fbgemm_gpu time: {time:.5f} sec")


@cli.command()
@click.option("--num-inputs", default=1024)
@click.option("--rows", default=100)
@click.option("--columns", default=128)
@click.option("--num-indices", default=2048)
@click.option("--timeline", is_flag=True, default=False)
def index_select_bench(
num_inputs: int, rows: int, columns: int, num_indices: int, timeline: bool
) -> None:
input_rows = [rows] * num_inputs
input_columns = [columns] * num_inputs
input_num_indices = [num_indices] * num_inputs
inputs = [
torch.rand(rows, cols, dtype=torch.float, device="cuda")
for rows, cols in zip(input_rows, input_columns)
]
for i in range(len(inputs)):
inputs[i].requires_grad = True
indices = [
torch.randint(low=0, high=rows, size=(num,), dtype=torch.long, device="cuda")
for num, rows in zip(input_num_indices, input_rows)
]

concat_inputs = torch.concat([input.flatten().clone().detach() for input in inputs])
concat_inputs.requires_grad = True
concat_indices = torch.concat(indices)

gis_inputs = [input.clone().detach() for input in inputs]
for i in range(len(gis_inputs)):
gis_inputs[i].requires_grad = True

def index_select_fwd_ref(
inputs: List[torch.Tensor], indices: List[torch.Tensor]
) -> List[torch.Tensor]:
outputs = []
for input, index in zip(inputs, indices):
outputs.append(torch.index_select(input, 0, index))
return outputs

def index_select_bwd_ref(
outputs: List[torch.Tensor], grads: List[torch.Tensor]
) -> None:
for output, grad in zip(outputs, grads):
output.backward(grad, retain_graph=True)

bench_kwargs = {"num_warmups": 10, "iters": 10 if timeline else 100}
profile_ctx = profile if timeline else contextlib.nullcontext

with profile_ctx() as prof:
time_pyt, out_pyt = benchmark_torch_function(
index_select_fwd_ref,
(inputs, indices),
**bench_kwargs,
)

time_bis, out_bis = benchmark_torch_function(
torch.ops.fbgemm.batch_index_select_dim0,
(
concat_inputs,
concat_indices,
input_num_indices,
input_rows,
input_columns,
),
**bench_kwargs,
)

time_gis, out_gis = benchmark_torch_function(
torch.ops.fbgemm.group_index_select_dim0,
(gis_inputs, indices),
**bench_kwargs,
)

if timeline:
prof.export_chrome_trace("index_select_fwd_trace.json")

grads = [torch.rand_like(out) for out in out_pyt]
concat_grads = torch.concat([grad.flatten() for grad in grads])
concat_out_gis = torch.concat([out.flatten() for out in out_gis])

with profile_ctx() as prof:
time_bwd_pyt, _ = benchmark_torch_function(
index_select_bwd_ref,
(out_pyt, grads),
**bench_kwargs,
)

time_bwd_bis, _ = benchmark_torch_function(
functools.partial(out_bis.backward, retain_graph=True),
(concat_grads,),
**bench_kwargs,
)

time_bwd_gis, _ = benchmark_torch_function(
functools.partial(concat_out_gis.backward, retain_graph=True),
(concat_grads,),
**bench_kwargs,
)

if timeline:
prof.export_chrome_trace("index_select_bwd_trace.json")

logging.info(
f"torch.index_select forward {time_pyt * 1e6:.2f} us, backward {time_bwd_pyt * 1e6:.2f} us\n"
f"torch.ops.fbgemm.batch_index_select forward {time_bis * 1e6:.2f} us, backward {time_bwd_bis * 1e6:.2f} us\n"
f"torch.ops.fbgemm.group_index_select_dim0 forward {time_gis * 1e6:.2f} us, backward {time_bwd_gis * 1e6:.2f} us"
)


if __name__ == "__main__":
cli()
237 changes: 237 additions & 0 deletions fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>

#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

class BatchIndexSelectDim0CPUOp
: public torch::autograd::Function<BatchIndexSelectDim0CPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& inputs,
const Tensor& indices,
const std::vector<int64_t>& input_num_indices,
const std::vector<int64_t>& input_rows,
const std::vector<int64_t>& input_columns,
const bool permute_output_dim_0_1) {
const int64_t num_inputs = input_num_indices.size();
ctx->save_for_backward({indices});

ctx->saved_data["input_numel"] = inputs.numel();
ctx->saved_data["input_num_indices"] = input_num_indices;
ctx->saved_data["input_rows"] = input_rows;
ctx->saved_data["input_columns"] = input_columns;
ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1;

// Early exit
if (inputs.numel() == 0) {
return {at::empty({0}, inputs.options())};
}

// Compute section sizes for splitting tensors
std::vector<int64_t> input_numels;
std::vector<int64_t> indices_numels;
input_numels.reserve(num_inputs);
indices_numels.reserve(num_inputs);
for (auto i = 0; i < num_inputs; i++) {
input_numels.push_back(input_rows[i] * input_columns[i]);
indices_numels.push_back(input_num_indices[i]);
}

ctx->saved_data["indices_numels"] = indices_numels;

// Split tensors into vectors
const auto inputs_ = at::split_with_sizes(inputs, input_numels, 0);
const auto indices_ = at::split_with_sizes(indices, indices_numels, 0);

std::vector<Tensor> outputs;
outputs.reserve(num_inputs);
for (auto i = 0; i < num_inputs; i++) {
const auto input = inputs_[i].view({input_rows[i], input_columns[i]});
const auto index = indices_[i];
const auto output = at::index_select(input, 0, index);
if (permute_output_dim_0_1) {
outputs.push_back(output);
} else {
outputs.push_back(output.flatten());
}
}

// permute_output_dim_0_1 = true shape: (batch_size, num_inputs, cols)
// permute_output_dim_0_1 = false shape: (num_inputs, batch_size cols)
return {at::concat(outputs, permute_output_dim_0_1 ? 1 : 0).flatten()};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
using torch::autograd::Variable;
const auto grad_output = grad_outputs[0];
const auto input_numel = ctx->saved_data["input_numel"].toInt();

// Early exit
if (input_numel == 0) {
return {
at::empty({0}, grad_output.options()),
Variable(), // indices
Variable(), // input_num_indices
Variable(), // input_rows
Variable(), // input_columns
Variable() // permute_output_dim_0_1
};
}

const auto saved = ctx->get_saved_variables();
auto indices = *std::begin(saved);

const auto input_num_indices =
ctx->saved_data["input_num_indices"].toIntVector();
const auto input_rows = ctx->saved_data["input_rows"].toIntVector();
const auto input_cols = ctx->saved_data["input_columns"].toIntVector();
const auto permute_output_dim_0_1 =
ctx->saved_data["permute_output_dim_0_1"].toBool();
const auto indices_numels = ctx->saved_data["indices_numels"].toIntVector();

const int64_t num_inputs = input_num_indices.size();

std::vector<Tensor> grads;
if (permute_output_dim_0_1) {
grads = at::split_with_sizes(
grad_output.view({input_num_indices[0], -1}), input_cols, 1);
} else {
std::vector<int64_t> grad_numels;
grad_numels.reserve(num_inputs);
for (auto i = 0; i < num_inputs; i++) {
grad_numels.push_back(input_num_indices[i] * input_cols[i]);
}
grads = at::split_with_sizes(grad_output, grad_numels, 0);
}

const auto indices_ = at::split_with_sizes(indices, indices_numels, 0);

std::vector<Tensor> grad_inputs;
grad_inputs.reserve(num_inputs);
int64_t indices_offset = 0;
for (auto i = 0; i < num_inputs; i++) {
const auto num_indices = input_num_indices[i];
const auto grad_input =
at::zeros({input_rows[i], input_cols[i]}, grad_output.options());
indices_offset += num_indices;
const auto grad =
permute_output_dim_0_1 ? grads[i] : grads[i].view({num_indices, -1});
grad_inputs.push_back(
at::index_add(grad_input, 0, indices_[i], grad).flatten());
}

return {
at::concat(grad_inputs, 0),
Variable(), // indices
Variable(), // input_num_indices
Variable(), // input_rows
Variable(), // input_columns
Variable() // permute_output_dim_0_1
};
}
};

Tensor batch_index_select_dim0_cpu(
Tensor inputs,
Tensor indices,
std::vector<int64_t> input_num_indices,
std::vector<int64_t> input_rows,
std::vector<int64_t> input_columns,
// Permute dim 0 and 1 of the output tensor
const bool permute_output_dim_0_1) {
const int64_t num_inputs = input_num_indices.size();
TORCH_CHECK(
num_inputs == static_cast<int64_t>(input_rows.size()),
"[batch_index_select_dim0] input_rows must have the same length as "
"input_num_indices.");
TORCH_CHECK(
num_inputs == static_cast<int64_t>(input_columns.size()),
"[batch_index_select_dim0] input_columns must have the same length as "
"input_num_indices.");

TORCH_CHECK(
reinterpret_cast<uint64_t>(inputs.data_ptr()) % 16 == 0,
"Currently batch_index_select only supports 16-byte align input tensors");

const auto int_opts = torch::TensorOptions().dtype(torch::kInt64);
const auto num_cols =
torch::from_blob(input_columns.data(), {num_inputs}, int_opts);
const auto max_col = num_inputs > 0 ? num_cols.max().item<int64_t>() : 0;
const auto input_num_rows =
torch::from_blob(input_rows.data(), {num_inputs}, int_opts);
const auto output_num_rows =
torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts);

if (num_inputs > 0) {
TORCH_CHECK(
torch::all(torch::gt(num_cols, 0)).item<bool>(),
"[batch_index_select_dim0] All input_columns must be the same.");
TORCH_CHECK(
torch::all(torch::gt(input_num_rows, 0)).item<bool>(),
"[batch_index_select_dim0] All input_rows must be the same.");
if (permute_output_dim_0_1) {
// All output rows must be the same
TORCH_CHECK(input_num_indices[0] > 0);
TORCH_CHECK(
torch::all(torch::eq(output_num_rows, input_num_indices[0]))
.item<bool>(),
"[batch_index_select_dim0] All input_num_indices must be the same if "
"permute_output_dim_0_1 is true.");
} else {
TORCH_CHECK(
torch::all(torch::gt(output_num_rows, 0)).item<bool>(),
"[batch_index_select_dim0] All input_num_indices must be greater than zero.");
}
}

return BatchIndexSelectDim0CPUOp::apply(
inputs,
indices,
input_num_indices,
input_rows,
input_columns,
permute_output_dim_0_1)[0];
}

// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(
"batch_index_select_dim0("
" Tensor inputs,"
" Tensor indices,"
" int[] input_num_indices,"
" int[] input_rows,"
" int[] input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor");
DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"batch_index_select_dim0("
" Tensor inputs,"
" Tensor indices,"
" int[] input_num_indices,"
" int[] input_rows,"
" int[] input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor");
DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu);
}
Loading

0 comments on commit 1140128

Please sign in to comment.