-
Notifications
You must be signed in to change notification settings - Fork 524
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add batch_index_select_dim0 (w/ TBE backend) (#1897)
Summary: Pull Request resolved: #1897 This diff introduces `batch_index_select_dim0` using the `SplitTBE` implementation (it shares the same code generator as TBE). The new operator is designed to address limitations of `group_index_select_dim0`. Both operators are designed to operate multiple inputs. However, `batch_index_select_dim0` requires all inputs to be contiguous in memory, while `batch_index_select_dim0` can operate on inputs with a discrete memory layout. Implementation-wise, they are different. We plan to merge their backends in the future. Since `batch_index_select_dim0` is backed by TBE, it inherits TBE limitations including: - The column sizes must be a multiple of 4 and not exceed 1024. Moreover, the underlying buffer of the inputs tensor must be 16-byte aligned. This is because the TBE kernel uses a vector load/store which requires the buffer to be 16-byte aligned. The kernel will raise an error if this assumption is violated. - Due to the 16-byte aligned enforcement, during the backward pass, if the output gradient is not 16-byte aligned, the operator will copy the output gradient into a new 16-byte aligned buffer. This can be expensive if the output gradient size is large. Usage: ``` # This target might change in the future torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") ... output = torch.ops.fbgemm.batch_index_select_dim0( inputs, # Tensor - 1D tensor (concatenated flatten inputs) indices, # Tensor - 1D tensor (concatenated indices) input_num_indices, # List[int] input_rows, # List[int] input_columns, # List[int] ) ``` Differential Revision: D46084590 fbshipit-source-id: 59f99f5c2bc5c5424205bd668a6c7777ecf53f7b
- Loading branch information
1 parent
3579b4d
commit 1140128
Showing
19 changed files
with
1,582 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
237 changes: 237 additions & 0 deletions
237
fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/TypeDefault.h> | ||
#include <ATen/core/op_registration/op_registration.h> | ||
#include <torch/script.h> | ||
|
||
#include "fbgemm_gpu/embedding_common.h" | ||
#include "fbgemm_gpu/sparse_ops.h" | ||
#include "fbgemm_gpu/sparse_ops_utils.h" | ||
|
||
using Tensor = at::Tensor; | ||
using namespace fbgemm_gpu; | ||
|
||
class BatchIndexSelectDim0CPUOp | ||
: public torch::autograd::Function<BatchIndexSelectDim0CPUOp> { | ||
public: | ||
static torch::autograd::variable_list forward( | ||
torch::autograd::AutogradContext* ctx, | ||
const Tensor& inputs, | ||
const Tensor& indices, | ||
const std::vector<int64_t>& input_num_indices, | ||
const std::vector<int64_t>& input_rows, | ||
const std::vector<int64_t>& input_columns, | ||
const bool permute_output_dim_0_1) { | ||
const int64_t num_inputs = input_num_indices.size(); | ||
ctx->save_for_backward({indices}); | ||
|
||
ctx->saved_data["input_numel"] = inputs.numel(); | ||
ctx->saved_data["input_num_indices"] = input_num_indices; | ||
ctx->saved_data["input_rows"] = input_rows; | ||
ctx->saved_data["input_columns"] = input_columns; | ||
ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; | ||
|
||
// Early exit | ||
if (inputs.numel() == 0) { | ||
return {at::empty({0}, inputs.options())}; | ||
} | ||
|
||
// Compute section sizes for splitting tensors | ||
std::vector<int64_t> input_numels; | ||
std::vector<int64_t> indices_numels; | ||
input_numels.reserve(num_inputs); | ||
indices_numels.reserve(num_inputs); | ||
for (auto i = 0; i < num_inputs; i++) { | ||
input_numels.push_back(input_rows[i] * input_columns[i]); | ||
indices_numels.push_back(input_num_indices[i]); | ||
} | ||
|
||
ctx->saved_data["indices_numels"] = indices_numels; | ||
|
||
// Split tensors into vectors | ||
const auto inputs_ = at::split_with_sizes(inputs, input_numels, 0); | ||
const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); | ||
|
||
std::vector<Tensor> outputs; | ||
outputs.reserve(num_inputs); | ||
for (auto i = 0; i < num_inputs; i++) { | ||
const auto input = inputs_[i].view({input_rows[i], input_columns[i]}); | ||
const auto index = indices_[i]; | ||
const auto output = at::index_select(input, 0, index); | ||
if (permute_output_dim_0_1) { | ||
outputs.push_back(output); | ||
} else { | ||
outputs.push_back(output.flatten()); | ||
} | ||
} | ||
|
||
// permute_output_dim_0_1 = true shape: (batch_size, num_inputs, cols) | ||
// permute_output_dim_0_1 = false shape: (num_inputs, batch_size cols) | ||
return {at::concat(outputs, permute_output_dim_0_1 ? 1 : 0).flatten()}; | ||
} | ||
|
||
static torch::autograd::variable_list backward( | ||
torch::autograd::AutogradContext* ctx, | ||
torch::autograd::variable_list grad_outputs) { | ||
using torch::autograd::Variable; | ||
const auto grad_output = grad_outputs[0]; | ||
const auto input_numel = ctx->saved_data["input_numel"].toInt(); | ||
|
||
// Early exit | ||
if (input_numel == 0) { | ||
return { | ||
at::empty({0}, grad_output.options()), | ||
Variable(), // indices | ||
Variable(), // input_num_indices | ||
Variable(), // input_rows | ||
Variable(), // input_columns | ||
Variable() // permute_output_dim_0_1 | ||
}; | ||
} | ||
|
||
const auto saved = ctx->get_saved_variables(); | ||
auto indices = *std::begin(saved); | ||
|
||
const auto input_num_indices = | ||
ctx->saved_data["input_num_indices"].toIntVector(); | ||
const auto input_rows = ctx->saved_data["input_rows"].toIntVector(); | ||
const auto input_cols = ctx->saved_data["input_columns"].toIntVector(); | ||
const auto permute_output_dim_0_1 = | ||
ctx->saved_data["permute_output_dim_0_1"].toBool(); | ||
const auto indices_numels = ctx->saved_data["indices_numels"].toIntVector(); | ||
|
||
const int64_t num_inputs = input_num_indices.size(); | ||
|
||
std::vector<Tensor> grads; | ||
if (permute_output_dim_0_1) { | ||
grads = at::split_with_sizes( | ||
grad_output.view({input_num_indices[0], -1}), input_cols, 1); | ||
} else { | ||
std::vector<int64_t> grad_numels; | ||
grad_numels.reserve(num_inputs); | ||
for (auto i = 0; i < num_inputs; i++) { | ||
grad_numels.push_back(input_num_indices[i] * input_cols[i]); | ||
} | ||
grads = at::split_with_sizes(grad_output, grad_numels, 0); | ||
} | ||
|
||
const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); | ||
|
||
std::vector<Tensor> grad_inputs; | ||
grad_inputs.reserve(num_inputs); | ||
int64_t indices_offset = 0; | ||
for (auto i = 0; i < num_inputs; i++) { | ||
const auto num_indices = input_num_indices[i]; | ||
const auto grad_input = | ||
at::zeros({input_rows[i], input_cols[i]}, grad_output.options()); | ||
indices_offset += num_indices; | ||
const auto grad = | ||
permute_output_dim_0_1 ? grads[i] : grads[i].view({num_indices, -1}); | ||
grad_inputs.push_back( | ||
at::index_add(grad_input, 0, indices_[i], grad).flatten()); | ||
} | ||
|
||
return { | ||
at::concat(grad_inputs, 0), | ||
Variable(), // indices | ||
Variable(), // input_num_indices | ||
Variable(), // input_rows | ||
Variable(), // input_columns | ||
Variable() // permute_output_dim_0_1 | ||
}; | ||
} | ||
}; | ||
|
||
Tensor batch_index_select_dim0_cpu( | ||
Tensor inputs, | ||
Tensor indices, | ||
std::vector<int64_t> input_num_indices, | ||
std::vector<int64_t> input_rows, | ||
std::vector<int64_t> input_columns, | ||
// Permute dim 0 and 1 of the output tensor | ||
const bool permute_output_dim_0_1) { | ||
const int64_t num_inputs = input_num_indices.size(); | ||
TORCH_CHECK( | ||
num_inputs == static_cast<int64_t>(input_rows.size()), | ||
"[batch_index_select_dim0] input_rows must have the same length as " | ||
"input_num_indices."); | ||
TORCH_CHECK( | ||
num_inputs == static_cast<int64_t>(input_columns.size()), | ||
"[batch_index_select_dim0] input_columns must have the same length as " | ||
"input_num_indices."); | ||
|
||
TORCH_CHECK( | ||
reinterpret_cast<uint64_t>(inputs.data_ptr()) % 16 == 0, | ||
"Currently batch_index_select only supports 16-byte align input tensors"); | ||
|
||
const auto int_opts = torch::TensorOptions().dtype(torch::kInt64); | ||
const auto num_cols = | ||
torch::from_blob(input_columns.data(), {num_inputs}, int_opts); | ||
const auto max_col = num_inputs > 0 ? num_cols.max().item<int64_t>() : 0; | ||
const auto input_num_rows = | ||
torch::from_blob(input_rows.data(), {num_inputs}, int_opts); | ||
const auto output_num_rows = | ||
torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts); | ||
|
||
if (num_inputs > 0) { | ||
TORCH_CHECK( | ||
torch::all(torch::gt(num_cols, 0)).item<bool>(), | ||
"[batch_index_select_dim0] All input_columns must be the same."); | ||
TORCH_CHECK( | ||
torch::all(torch::gt(input_num_rows, 0)).item<bool>(), | ||
"[batch_index_select_dim0] All input_rows must be the same."); | ||
if (permute_output_dim_0_1) { | ||
// All output rows must be the same | ||
TORCH_CHECK(input_num_indices[0] > 0); | ||
TORCH_CHECK( | ||
torch::all(torch::eq(output_num_rows, input_num_indices[0])) | ||
.item<bool>(), | ||
"[batch_index_select_dim0] All input_num_indices must be the same if " | ||
"permute_output_dim_0_1 is true."); | ||
} else { | ||
TORCH_CHECK( | ||
torch::all(torch::gt(output_num_rows, 0)).item<bool>(), | ||
"[batch_index_select_dim0] All input_num_indices must be greater than zero."); | ||
} | ||
} | ||
|
||
return BatchIndexSelectDim0CPUOp::apply( | ||
inputs, | ||
indices, | ||
input_num_indices, | ||
input_rows, | ||
input_columns, | ||
permute_output_dim_0_1)[0]; | ||
} | ||
|
||
// Deprecated for fb namespace! Please use fbgemm namespace instead! | ||
TORCH_LIBRARY_FRAGMENT(fb, m) { | ||
m.def( | ||
"batch_index_select_dim0(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" int[] input_num_indices," | ||
" int[] input_rows," | ||
" int[] input_columns," | ||
" bool permute_output_dim_0_1=False) -> Tensor"); | ||
DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); | ||
} | ||
|
||
TORCH_LIBRARY_FRAGMENT(fbgemm, m) { | ||
m.def( | ||
"batch_index_select_dim0(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" int[] input_num_indices," | ||
" int[] input_rows," | ||
" int[] input_columns," | ||
" bool permute_output_dim_0_1=False) -> Tensor"); | ||
DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); | ||
} |
Oops, something went wrong.