From d8bbfa6e6cfb608e49f49169c1383252522d20f7 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Fri, 28 Jul 2023 16:35:04 -0700 Subject: [PATCH] Add batch_index_select_dim0 (w/ TBE backend) (#1897) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1897 This diff introduces `batch_index_select_dim0` using the `SplitTBE` implementation (it shares the same code generator as TBE). The new operator is designed to address limitations of `group_index_select_dim0`. Both operators are designed to operate multiple inputs. However, `batch_index_select_dim0` requires all inputs to be contiguous in memory, while `batch_index_select_dim0` can operate on inputs with a discrete memory layout. Implementation-wise, they are different. We plan to merge their backends in the future. Since `batch_index_select_dim0` is backed by TBE, it inherits TBE limitations including: - The column sizes must be a multiple of 4 and not exceed 1024. Moreover, the underlying buffer of the inputs tensor must be 16-byte aligned. This is because the TBE kernel uses a vector load/store which requires the buffer to be 16-byte aligned. The kernel will raise an error if this assumption is violated. - Due to the 16-byte aligned enforcement, during the backward pass, if the output gradient is not 16-byte aligned, the operator will copy the output gradient into a new 16-byte aligned buffer. This can be expensive if the output gradient size is large. Usage: ``` # This target might change in the future torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") ... output = torch.ops.fbgemm.batch_index_select_dim0( inputs, # Tensor - 1D tensor (concatenated flatten inputs) indices, # Tensor - 1D tensor (concatenated indices) input_num_indices, # List[int] input_rows, # List[int] input_columns, # List[int] ) ``` Differential Revision: D46084590 fbshipit-source-id: 6dd9d2bc6cf86832b301a855337a28fed76d748a --- fbgemm_gpu/CMakeLists.txt | 17 +- fbgemm_gpu/bench/sparse_ops_benchmark.py | 114 +++++++ .../batch_index_select_dim0_cpu_host.cpp | 237 +++++++++++++ .../codegen/batch_index_select_dim0_host.cpp | 312 ++++++++++++++++++ .../embedding_backward_code_generator.py | 63 +++- .../embedding_backward_split_grad_template.cu | 6 +- ...embedding_backward_split_host_template.cpp | 6 +- ..._backward_split_indice_weights_template.cu | 8 +- ...ding_backward_split_kernel_cta_template.cu | 62 +++- ...ing_backward_split_kernel_warp_template.cu | 63 +++- .../embedding_backward_split_template.cu | 218 ++++++++---- ...rward_split_kernel_nobag_small_template.cu | 87 ++++- ...embedding_forward_split_kernel_template.cu | 114 +++++-- .../embedding_forward_split_template.cu | 191 +++++++++-- .../include/fbgemm_gpu/embedding_common.h | 2 +- .../fbgemm_gpu/split_embeddings_utils.cuh | 7 +- fbgemm_gpu/src/split_embeddings_utils.cpp | 8 +- fbgemm_gpu/src/split_embeddings_utils.cu | 113 ++++++- fbgemm_gpu/test/sparse_ops_test.py | 154 +++++++++ 19 files changed, 1586 insertions(+), 196 deletions(-) create mode 100644 fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp create mode 100644 fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt index effaeb2f5f..f60c633de0 100644 --- a/fbgemm_gpu/CMakeLists.txt +++ b/fbgemm_gpu/CMakeLists.txt @@ -217,9 +217,16 @@ set(gen_gpu_kernel_source_files "gen_embedding_forward_split_unweighted_codegen_cuda.cu" "gen_embedding_backward_dense_indice_weights_codegen_cuda.cu" "gen_embedding_backward_split_indice_weights_codegen_cuda.cu" - "gen_embedding_backward_split_grad.cu" "gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu" - "gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu") + "gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu" + "gen_batch_index_select_dim0_forward_codegen_cuda.cu" + "gen_batch_index_select_dim0_forward_kernel.cu" + "gen_batch_index_select_dim0_forward_kernel_small.cu" + "gen_batch_index_select_dim0_backward_codegen_cuda.cu" + "gen_batch_index_select_dim0_backward_kernel_cta.cu" + "gen_batch_index_select_dim0_backward_kernel_warp.cu" + "gen_embedding_backward_split_grad.cu" +) if(NOT USE_ROCM) list(APPEND gen_gpu_kernel_source_files @@ -559,7 +566,8 @@ set(fbgemm_gpu_sources_static_cpu src/quantize_ops/quantize_ops_meta.cpp src/sparse_ops/sparse_ops_cpu.cpp src/sparse_ops/sparse_ops_meta.cpp - src/embedding_inplace_update_cpu.cpp) + src/embedding_inplace_update_cpu.cpp + codegen/batch_index_select_dim0_cpu_host.cpp) if(NOT FBGEMM_CPU_ONLY) list(APPEND fbgemm_gpu_sources_static_cpu @@ -576,7 +584,8 @@ if(NOT FBGEMM_CPU_ONLY) src/split_table_batched_embeddings.cpp src/metric_ops_host.cpp src/embedding_inplace_update_gpu.cpp - src/input_combine_gpu.cpp) + src/input_combine_gpu.cpp + codegen/batch_index_select_dim0_host.cpp) if(NVML_LIB_PATH) message(STATUS "Found NVML_LIB_PATH: ${NVML_LIB_PATH}") diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index f97adc0ba8..4111585333 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -4,15 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import functools import logging import random +from typing import List import click import fbgemm_gpu import numpy as np import torch +from torch.profiler import profile + logging.basicConfig(level=logging.DEBUG) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. @@ -26,6 +30,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") @click.group() @@ -352,5 +357,114 @@ def asynchronous_complete_cumsum_2d_ref(lengths: torch.Tensor) -> torch.Tensor: logging.info(f"fbgemm_gpu time: {time:.5f} sec") +@cli.command() +@click.option("--num-inputs", default=1024) +@click.option("--rows", default=100) +@click.option("--columns", default=128) +@click.option("--num-indices", default=2048) +@click.option("--timeline", is_flag=True, default=False) +def index_select_bench( + num_inputs: int, rows: int, columns: int, num_indices: int, timeline: bool +) -> None: + input_rows = [rows] * num_inputs + input_columns = [columns] * num_inputs + input_num_indices = [num_indices] * num_inputs + inputs = [ + torch.rand(rows, cols, dtype=torch.float, device="cuda") + for rows, cols in zip(input_rows, input_columns) + ] + for i in range(len(inputs)): + inputs[i].requires_grad = True + indices = [ + torch.randint(low=0, high=rows, size=(num,), dtype=torch.long, device="cuda") + for num, rows in zip(input_num_indices, input_rows) + ] + + concat_inputs = torch.concat([input.flatten().clone().detach() for input in inputs]) + concat_inputs.requires_grad = True + concat_indices = torch.concat(indices) + + gis_inputs = [input.clone().detach() for input in inputs] + for i in range(len(gis_inputs)): + gis_inputs[i].requires_grad = True + + def index_select_fwd_ref( + inputs: List[torch.Tensor], indices: List[torch.Tensor] + ) -> List[torch.Tensor]: + outputs = [] + for input, index in zip(inputs, indices): + outputs.append(torch.index_select(input, 0, index)) + return outputs + + def index_select_bwd_ref( + outputs: List[torch.Tensor], grads: List[torch.Tensor] + ) -> None: + for output, grad in zip(outputs, grads): + output.backward(grad, retain_graph=True) + + bench_kwargs = {"num_warmups": 10, "iters": 10 if timeline else 100} + profile_ctx = profile if timeline else contextlib.nullcontext + + with profile_ctx() as prof: + time_pyt, out_pyt = benchmark_torch_function( + index_select_fwd_ref, + (inputs, indices), + **bench_kwargs, + ) + + time_bis, out_bis = benchmark_torch_function( + torch.ops.fbgemm.batch_index_select_dim0, + ( + concat_inputs, + concat_indices, + input_num_indices, + input_rows, + input_columns, + ), + **bench_kwargs, + ) + + time_gis, out_gis = benchmark_torch_function( + torch.ops.fbgemm.group_index_select_dim0, + (gis_inputs, indices), + **bench_kwargs, + ) + + if timeline: + prof.export_chrome_trace("index_select_fwd_trace.json") + + grads = [torch.rand_like(out) for out in out_pyt] + concat_grads = torch.concat([grad.flatten() for grad in grads]) + concat_out_gis = torch.concat([out.flatten() for out in out_gis]) + + with profile_ctx() as prof: + time_bwd_pyt, _ = benchmark_torch_function( + index_select_bwd_ref, + (out_pyt, grads), + **bench_kwargs, + ) + + time_bwd_bis, _ = benchmark_torch_function( + functools.partial(out_bis.backward, retain_graph=True), + (concat_grads,), + **bench_kwargs, + ) + + time_bwd_gis, _ = benchmark_torch_function( + functools.partial(concat_out_gis.backward, retain_graph=True), + (concat_grads,), + **bench_kwargs, + ) + + if timeline: + prof.export_chrome_trace("index_select_bwd_trace.json") + + logging.info( + f"torch.index_select forward {time_pyt * 1e6:.2f} us, backward {time_bwd_pyt * 1e6:.2f} us\n" + f"torch.ops.fbgemm.batch_index_select forward {time_bis * 1e6:.2f} us, backward {time_bwd_bis * 1e6:.2f} us\n" + f"torch.ops.fbgemm.group_index_select_dim0 forward {time_gis * 1e6:.2f} us, backward {time_bwd_gis * 1e6:.2f} us" + ) + + if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp b/fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp new file mode 100644 index 0000000000..9d73b2de7b --- /dev/null +++ b/fbgemm_gpu/codegen/batch_index_select_dim0_cpu_host.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +class BatchIndexSelectDim0CPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& inputs, + const Tensor& indices, + const std::vector& input_num_indices, + const std::vector& input_rows, + const std::vector& input_columns, + const bool permute_output_dim_0_1) { + const int64_t num_inputs = input_num_indices.size(); + ctx->save_for_backward({indices}); + + ctx->saved_data["input_numel"] = inputs.numel(); + ctx->saved_data["input_num_indices"] = input_num_indices; + ctx->saved_data["input_rows"] = input_rows; + ctx->saved_data["input_columns"] = input_columns; + ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + + // Early exit + if (inputs.numel() == 0) { + return {at::empty({0}, inputs.options())}; + } + + // Compute section sizes for splitting tensors + std::vector input_numels; + std::vector indices_numels; + input_numels.reserve(num_inputs); + indices_numels.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + input_numels.push_back(input_rows[i] * input_columns[i]); + indices_numels.push_back(input_num_indices[i]); + } + + ctx->saved_data["indices_numels"] = indices_numels; + + // Split tensors into vectors + const auto inputs_ = at::split_with_sizes(inputs, input_numels, 0); + const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); + + std::vector outputs; + outputs.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + const auto input = inputs_[i].view({input_rows[i], input_columns[i]}); + const auto index = indices_[i]; + const auto output = at::index_select(input, 0, index); + if (permute_output_dim_0_1) { + outputs.push_back(output); + } else { + outputs.push_back(output.flatten()); + } + } + + // permute_output_dim_0_1 = true shape: (batch_size, num_inputs, cols) + // permute_output_dim_0_1 = false shape: (num_inputs, batch_size cols) + return {at::concat(outputs, permute_output_dim_0_1 ? 1 : 0).flatten()}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + using torch::autograd::Variable; + const auto grad_output = grad_outputs[0]; + const auto input_numel = ctx->saved_data["input_numel"].toInt(); + + // Early exit + if (input_numel == 0) { + return { + at::empty({0}, grad_output.options()), + Variable(), // indices + Variable(), // input_num_indices + Variable(), // input_rows + Variable(), // input_columns + Variable() // permute_output_dim_0_1 + }; + } + + const auto saved = ctx->get_saved_variables(); + auto indices = *std::begin(saved); + + const auto input_num_indices = + ctx->saved_data["input_num_indices"].toIntVector(); + const auto input_rows = ctx->saved_data["input_rows"].toIntVector(); + const auto input_cols = ctx->saved_data["input_columns"].toIntVector(); + const auto permute_output_dim_0_1 = + ctx->saved_data["permute_output_dim_0_1"].toBool(); + const auto indices_numels = ctx->saved_data["indices_numels"].toIntVector(); + + const int64_t num_inputs = input_num_indices.size(); + + std::vector grads; + if (permute_output_dim_0_1) { + grads = at::split_with_sizes( + grad_output.view({input_num_indices[0], -1}), input_cols, 1); + } else { + std::vector grad_numels; + grad_numels.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + grad_numels.push_back(input_num_indices[i] * input_cols[i]); + } + grads = at::split_with_sizes(grad_output, grad_numels, 0); + } + + const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); + + std::vector grad_inputs; + grad_inputs.reserve(num_inputs); + int64_t indices_offset = 0; + for (auto i = 0; i < num_inputs; i++) { + const auto num_indices = input_num_indices[i]; + const auto grad_input = + at::zeros({input_rows[i], input_cols[i]}, grad_output.options()); + indices_offset += num_indices; + const auto grad = + permute_output_dim_0_1 ? grads[i] : grads[i].view({num_indices, -1}); + grad_inputs.push_back( + at::index_add(grad_input, 0, indices_[i], grad).flatten()); + } + + return { + at::concat(grad_inputs, 0), + Variable(), // indices + Variable(), // input_num_indices + Variable(), // input_rows + Variable(), // input_columns + Variable() // permute_output_dim_0_1 + }; + } +}; + +Tensor batch_index_select_dim0_cpu( + Tensor inputs, + Tensor indices, + std::vector input_num_indices, + std::vector input_rows, + std::vector input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + const int64_t num_inputs = input_num_indices.size(); + TORCH_CHECK( + num_inputs == static_cast(input_rows.size()), + "[batch_index_select_dim0] input_rows must have the same length as " + "input_num_indices."); + TORCH_CHECK( + num_inputs == static_cast(input_columns.size()), + "[batch_index_select_dim0] input_columns must have the same length as " + "input_num_indices."); + + TORCH_CHECK( + reinterpret_cast(inputs.data_ptr()) % 16 == 0, + "Currently batch_index_select only supports 16-byte align input tensors"); + + const auto int_opts = torch::TensorOptions().dtype(torch::kInt64); + const auto num_cols = + torch::from_blob(input_columns.data(), {num_inputs}, int_opts); + const auto max_col = num_inputs > 0 ? num_cols.max().item() : 0; + const auto input_num_rows = + torch::from_blob(input_rows.data(), {num_inputs}, int_opts); + const auto output_num_rows = + torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts); + + if (num_inputs > 0) { + TORCH_CHECK( + torch::all(torch::gt(num_cols, 0)).item(), + "[batch_index_select_dim0] All input_columns must be the same."); + TORCH_CHECK( + torch::all(torch::gt(input_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_rows must be the same."); + if (permute_output_dim_0_1) { + // All output rows must be the same + TORCH_CHECK(input_num_indices[0] > 0); + TORCH_CHECK( + torch::all(torch::eq(output_num_rows, input_num_indices[0])) + .item(), + "[batch_index_select_dim0] All input_num_indices must be the same if " + "permute_output_dim_0_1 is true."); + } else { + TORCH_CHECK( + torch::all(torch::gt(output_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_num_indices must be greater than zero."); + } + } + + return BatchIndexSelectDim0CPUOp::apply( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1)[0]; +} + +// Deprecated for fb namespace! Please use fbgemm namespace instead! +TORCH_LIBRARY_FRAGMENT(fb, m) { + m.def( + "batch_index_select_dim0(" + " Tensor inputs," + " Tensor indices," + " int[] input_num_indices," + " int[] input_rows," + " int[] input_columns," + " bool permute_output_dim_0_1=False) -> Tensor"); + DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "batch_index_select_dim0(" + " Tensor inputs," + " Tensor indices," + " int[] input_num_indices," + " int[] input_rows," + " int[] input_columns," + " bool permute_output_dim_0_1=False) -> Tensor"); + DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); +} diff --git a/fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp new file mode 100644 index 0000000000..79966ace6d --- /dev/null +++ b/fbgemm_gpu/codegen/batch_index_select_dim0_host.cpp @@ -0,0 +1,312 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/sparse_ops.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +using Tensor = at::Tensor; +using namespace fbgemm_gpu; + +Tensor batch_index_select_dim0_codegen_forward_cuda( + Tensor dev_weights, + Tensor weights_offsets, + Tensor D_offsets, + int64_t max_D, + Tensor indices, + int64_t output_dtype, + const Tensor& output_offsets, + const Tensor& total_L_offsets, + const int64_t output_size, + const int32_t fixed_L_per_warp, + const int32_t num_warps_per_feature, + const bool permute_output_dim_0_1); + +Tensor batch_index_select_dim0_codegen_backward_cuda( + Tensor grad_output, + Tensor dev_weights, + Tensor weights_offsets, + Tensor D_offsets, + int64_t max_D, + Tensor hash_size_cumsum, + int64_t total_hash_size_bits, + Tensor indices, + int64_t max_segment_length_per_warp, + const Tensor& grad_offsets, + const Tensor& total_L_offsets, + const int32_t fixed_L_per_warp, + const int32_t num_warps_per_feature, + const bool permute_output_dim_0_1); + +class BatchIndexSelectDim0GPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const int64_t output_dtype, + const Tensor& dev_weights, + const Tensor& weights_offsets, + const Tensor& hash_size_cumsum, + const int64_t total_hash_size_bits, + const Tensor& indices, + const Tensor& D_offsets, + const int64_t max_D, + const Tensor& output_offsets, + const Tensor& total_L_offsets, + const int64_t output_size, + const int64_t fixed_L_per_warp, + const int64_t num_warps_per_feature, + const bool permute_output_dim_0_1) { + ctx->save_for_backward( + {dev_weights, + weights_offsets, + hash_size_cumsum, + indices, + D_offsets, + output_offsets, + total_L_offsets}); + + ctx->saved_data["max_D"] = max_D; + ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; + ctx->saved_data["fixed_L_per_warp"] = fixed_L_per_warp; + ctx->saved_data["num_warps_per_feature"] = num_warps_per_feature; + ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + + // Early exit + if (dev_weights.numel() == 0) { + return {at::empty({0}, dev_weights.options())}; + } + + return {batch_index_select_dim0_codegen_forward_cuda( + dev_weights, + weights_offsets, + D_offsets, + max_D, + indices, + output_dtype, + output_offsets, + total_L_offsets, + output_size, + fixed_L_per_warp, + num_warps_per_feature, + permute_output_dim_0_1)}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + const auto saved = ctx->get_saved_variables(); + auto savedItr = std::begin(saved); + auto dev_weights = *savedItr++; + auto weights_offsets = *savedItr++; + auto hash_size_cumsum = *savedItr++; + auto indices = *savedItr++; + auto D_offsets = *savedItr++; + auto grad_offsets = *savedItr++; + auto total_L_offsets = *savedItr++; + + const auto max_D = ctx->saved_data["max_D"].toInt(); + const auto total_hash_size_bits = + ctx->saved_data["total_hash_size_bits"].toInt(); + const auto fixed_L_per_warp = ctx->saved_data["fixed_L_per_warp"].toInt(); + const auto num_warps_per_feature = + ctx->saved_data["num_warps_per_feature"].toInt(); + const auto permute_output_dim_0_1 = + ctx->saved_data["permute_output_dim_0_1"].toBool(); + + using torch::autograd::Variable; + + Tensor grad_dev_weights; + if (dev_weights.numel() == 0) { + grad_dev_weights = at::empty({0}, dev_weights.options()); + } else { + TORCH_CHECK_EQ(grad_outputs.size(), 1); + + constexpr int32_t max_segment_length_per_warp = 32; + + auto grad_output = grad_outputs[0]; + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { + grad_output = at::empty_like(grad_output).copy_(grad_output); + } + + grad_dev_weights = batch_index_select_dim0_codegen_backward_cuda( + grad_output, + dev_weights, + weights_offsets, + D_offsets, + max_D, + hash_size_cumsum, + total_hash_size_bits, + indices, + max_segment_length_per_warp, + grad_offsets, + total_L_offsets, + fixed_L_per_warp, + num_warps_per_feature, + permute_output_dim_0_1); + } + + return { + Variable(), // output_dtype + grad_dev_weights, // grad_dev_weights + Variable(), // weights_offsets + Variable(), // hash_size_cumsum + Variable(), // total_hash_size_bits + Variable(), // indices + Variable(), // D_offsets + Variable(), // max_D + Variable(), // output_offsets + Variable(), // total_L_offsets + Variable(), // output_size + Variable(), // fixed_L_per_warp + Variable(), // num_warps_per_feature + Variable(), // permute_output_dim_0_1 + }; + } +}; + +Tensor batch_index_select_dim0_gpu( + Tensor inputs, + Tensor indices, + std::vector input_num_indices, + std::vector input_rows, + std::vector input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + // From the empirical study, this value provides the best perf + constexpr int64_t ROWS_PER_WARP = 1; + const int64_t num_inputs = input_num_indices.size(); + TORCH_CHECK( + num_inputs == static_cast(input_rows.size()), + "[batch_index_select_dim0] input_rows must have the same length as " + "input_num_indices."); + TORCH_CHECK( + num_inputs == static_cast(input_columns.size()), + "[batch_index_select_dim0] input_columns must have the same length as " + "input_num_indices."); + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(inputs, indices); + + TORCH_CHECK( + reinterpret_cast(inputs.data_ptr()) % 16 == 0, + "Currently batch_index_select only supports 16-byte align input tensors"); + + const auto int_opts = torch::TensorOptions().dtype(torch::kInt64); + const auto num_cols = + torch::from_blob(input_columns.data(), {num_inputs}, int_opts); + const auto max_col = num_inputs > 0 ? num_cols.max().item() : 0; + const auto input_num_rows = + torch::from_blob(input_rows.data(), {num_inputs}, int_opts); + const auto output_num_rows = + torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts); + + if (num_inputs > 0) { + TORCH_CHECK( + torch::all(torch::gt(num_cols, 0)).item(), + "[batch_index_select_dim0] All input_columns must be the same."); + TORCH_CHECK( + torch::all(torch::gt(input_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_rows must be the same."); + if (permute_output_dim_0_1) { + // All output rows must be the same + TORCH_CHECK(input_num_indices[0] > 0); + TORCH_CHECK( + torch::all(torch::eq(output_num_rows, input_num_indices[0])) + .item(), + "[batch_index_select_dim0] All input_num_indices must be the same if " + "permute_output_dim_0_1 is true."); + } else { + TORCH_CHECK( + torch::all(torch::gt(output_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_num_indices must be greater than zero."); + } + } + + const auto max_output_num_rows = + num_inputs > 0 ? output_num_rows.max().item() : 0; + + const auto input_numels = input_num_rows * num_cols; + const auto output_numels = + permute_output_dim_0_1 ? Tensor() : (output_num_rows * num_cols); + + // Takes ~1.2 ms for num_inputs = 1024 on CPU + auto D_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(num_cols).to(torch::kInt32); + auto input_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_numels); + auto input_row_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_num_rows); + auto total_L_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_num_rows); + int64_t total_hash_size_bits = + std::log2(static_cast(input_row_offsets[-1].item())) + 1; + input_offsets = torch::narrow(input_offsets, 0, 0, input_offsets.numel() - 1); + + const int64_t num_warps_per_input = + (max_output_num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + + // Transfer helper tensors to GPU + const auto device = inputs.device(); + constexpr bool non_blocking = true; + D_offsets = D_offsets.to(device, non_blocking); + input_offsets = input_offsets.to(device, non_blocking); + input_row_offsets = input_row_offsets.to(device, non_blocking); + total_L_offsets = total_L_offsets.to(device, non_blocking); + + int64_t output_size; + Tensor output_offsets; + if (permute_output_dim_0_1) { + // output_offsets is not required because the output tensor is not jagged + output_offsets = at::empty({0}, inputs.options().dtype(at::kLong)); + output_size = num_inputs > 0 + ? (input_num_indices[0] * D_offsets[-1].item()) + : 0; + } else { + output_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_numels); + output_size = output_offsets[-1].item(); + output_offsets = output_offsets.to(device, non_blocking); + } + + const auto sparse_type = fbgemm_gpu::getSparseType(inputs.scalar_type()); + TORCH_CHECK( + sparse_type == SparseType::FP32 || sparse_type == SparseType::FP16, + "batch_index_select_dim0 supports only either float or half") + + // Call TBE + return BatchIndexSelectDim0GPUOp::apply( + static_cast(fbgemm_gpu::getSparseType(inputs.scalar_type())), + inputs, + input_offsets, + input_row_offsets, + total_hash_size_bits, + indices, + D_offsets, + max_col, + output_offsets, + total_L_offsets, + output_size, + ROWS_PER_WARP, // fixed_L_per_warp + num_warps_per_input, + permute_output_dim_0_1)[0]; +} + +// Deprecated for fb namespace! Please use fbgemm namespace instead! +TORCH_LIBRARY_FRAGMENT(fb, m) { + DISPATCH_TO_CUDA("batch_index_select_dim0", batch_index_select_dim0_gpu); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + DISPATCH_TO_CUDA("batch_index_select_dim0", batch_index_select_dim0_gpu); +} diff --git a/fbgemm_gpu/codegen/embedding_backward_code_generator.py b/fbgemm_gpu/codegen/embedding_backward_code_generator.py index 8bab9d50f2..4256a6e02d 100644 --- a/fbgemm_gpu/codegen/embedding_backward_code_generator.py +++ b/fbgemm_gpu/codegen/embedding_backward_code_generator.py @@ -41,7 +41,11 @@ def generate_backward_embedding_cuda( write( filename, template.render( - weighted=weighted, nobag=nobag, vbe=vbe, **kwargs + weighted=weighted, + nobag=nobag, + vbe=vbe, + is_index_select=False, + **kwargs, ), ) print(f"[Backward Split] [{optimizer}]: {filename}") @@ -214,7 +218,11 @@ def generate_forward_embedding_cuda( write( filename, template.render( - dense=dense, weighted=weighted, nobag=nobag, vbe=vbe + dense=dense, + weighted=weighted, + nobag=nobag, + vbe=vbe, + is_index_select=False, ), ) print(f"[Forward Split]: {filename}") @@ -255,10 +263,57 @@ def forward_split() -> None: for dense in [True, False]: wdesc = f"{ 'dense' if dense else 'split' }" filename = f"gen_embedding_forward_{wdesc}_unweighted_nobag_kernel_small.cu" - write(filename, template.render(dense=dense)) + write(filename, template.render(dense=dense, is_index_select=False)) print(f"[Forward Split]: {filename}") +# TODO: Separate this function into another codegen script +def index_select() -> None: + kwargs = make_args([(FLOAT, "unused")]) + kwargs["args"] = kwargs["cuda"] + for templ_file, gen_file in [ + ( + "embedding_forward_split_template.cu", + "gen_batch_index_select_dim0_forward_codegen_cuda.cu", + ), + ( + "embedding_forward_split_kernel_template.cu", + "gen_batch_index_select_dim0_forward_kernel.cu", + ), + ( + "embedding_forward_split_kernel_nobag_small_template.cu", + "gen_batch_index_select_dim0_forward_kernel_small.cu", + ), + ( + "embedding_backward_split_template.cu", + "gen_batch_index_select_dim0_backward_codegen_cuda.cu", + ), + ( + "embedding_backward_split_kernel_cta_template.cu", + "gen_batch_index_select_dim0_backward_kernel_cta.cu", + ), + ( + "embedding_backward_split_kernel_warp_template.cu", + "gen_batch_index_select_dim0_backward_kernel_warp.cu", + ), + ]: + template = env.get_template(templ_file) + write( + gen_file, + template.render( + weighted=False, + dense=True, + vbe=False, + nobag=True, + is_index_select=True, + **kwargs, + ), + ) + + template = env.get_template("embedding_backward_split_grad_template.cu") + write("gen_embedding_backward_split_grad.cu", template.render()) + + def forward_quantized() -> None: @dataclass class template_instance_params: @@ -463,6 +518,8 @@ def emb_codegen( generate(**(approx_sgd())) generate(**(none_optimizer())) + # Generate index_select ops using TBE backend + index_select() gen__init__py() diff --git a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu index ff5658b986..5367f292b7 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_grad_template.cu @@ -68,7 +68,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( const pta::PackedTensorAccessor32 D_offsets, const pta::PackedTensorAccessor32 offsets, {% if vbe %} - const pta::PackedTensorAccessor32 grad_offsets, + const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, const int32_t info_B_num_bits, const uint32_t info_B_mask @@ -102,7 +102,7 @@ __global__ __launch_bounds__(kMaxThreads) void grad_mean{{ vbe_desc }}_kernel( int32_t L = indices_end - indices_start; {% if vbe %} - const auto grad_offset = grad_offsets[b_t]; + const auto grad_offset = row_grad_offsets[b_t]; const auto grad_outer_offset = 0; {% else %} const auto grad_offset = D_start; @@ -141,7 +141,7 @@ void grad_mean{{ vbe_desc }}_kernel const pta::PackedTensorAccessor32 D_offsets, const pta::PackedTensorAccessor32 offsets, {% if vbe %} - const pta::PackedTensorAccessor32 grad_offsets, + const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, const int32_t info_B_num_bits, const uint32_t info_B_mask diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp index ee6333475d..b60b380285 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp @@ -285,7 +285,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio lxu_cache_locations, {% if vbe %} vbe_metadata.B_offsets, - vbe_metadata.output_offsets, + vbe_metadata.row_output_offsets, vbe_metadata.b_t_map, {% endif %} {{ args.split_saved_tensors | join(", ") }} @@ -410,7 +410,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio auto lxu_cache_locations = *savedItr++; {% if vbe %} auto B_offsets = *savedItr++; - auto vbe_output_offsets = *savedItr++; + auto vbe_row_output_offsets = *savedItr++; auto vbe_b_t_map = *savedItr++; {% endif %} @@ -470,7 +470,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio {% if vbe %} struct VBEMetadata vbe_metadata = { .B_offsets = B_offsets, - .output_offsets = vbe_output_offsets, + .row_output_offsets = vbe_row_output_offsets, .b_t_map = vbe_b_t_map, }; {% endif %} diff --git a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu index 04db67ad76..1e147ad9a8 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu @@ -43,7 +43,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void pta::PackedTensorAccessor32 feature_requires_grad, // [T], pta::PackedTensorAccessor32, 1, at::RestrictPtrTraits> grad_indice_weights, {%- if vbe %} - const pta::PackedTensorAccessor32 grad_offsets, + const pta::PackedTensorAccessor32 row_grad_offsets, const pta::PackedTensorAccessor32 b_t_map, const int32_t info_B_num_bits, const uint32_t info_B_mask @@ -100,7 +100,7 @@ __global__ __launch_bounds__(kForwardMaxThreads) void {%- endif %} {%- if vbe %} - const grad_t* grad_output_ = &grad_output[0][grad_offsets[b_t]]; + const grad_t* grad_output_ = &grad_output[0][row_grad_offsets[b_t]]; {%- else %} const grad_t* grad_output_ = &grad_output[b][D_start]; {%- endif %} @@ -229,7 +229,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights lxu_cache_locations, {%- endif %} {%- if vbe %} - vbe_metadata.output_offsets, + vbe_metadata.row_output_offsets, vbe_metadata.b_t_map, {%- endif %} grad_output @@ -300,7 +300,7 @@ Tensor {{ "dense" if dense else "split" }}_embedding_codegen_grad_indice_weights MAKE_PTA_WITH_NAME(func_name, feature_requires_grad, int32_t, 1, 32), MAKE_PTA_ACC_WITH_NAME(func_name, grad_indice_weights, grad_t, 1, 32), {%- if vbe %} - MAKE_PTA_WITH_NAME(func_name, vbe_metadata.output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, vbe_metadata.row_output_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, vbe_metadata.b_t_map, int32_t, 1, 32), info_B_num_bits, info_B_mask diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu index 546e1a95b7..ce0a5db17c 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_cta_template.cu @@ -30,8 +30,12 @@ template < size_t kMaxVecsPerThread, int32_t kThreadGroupSize > __global__ __launch_bounds__(kMaxThreads) void +{%- if is_index_select %} +batch_index_select_dim0_codegen_backward_kernel_cta_per_row( +{%- else %} split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1( - const pta::PackedTensorAccessor64 grad_output, +{%- endif %} + const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -41,7 +45,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- endif %} // if optimizer != "none" const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -73,7 +77,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} // if not dense and optimizer != "none" {%- if not nobag and vbe %} const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, {%- endif %} {%- if not nobag %} const int32_t info_B_num_bits, @@ -84,7 +88,13 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ pta::PackedTensorAccessor32 grad_accum_counter, const int32_t max_segment_length_per_cta, const bool use_deterministic_algorithms, - {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}) { + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = ((1L << kThreadGroupSize) - 1) << @@ -136,8 +146,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} int64_t hash_size = hash_size_cumsum[t_0]; - {%- if not nobag %} - int32_t D = D_offsets[t_0 + 1] - D_offsets[t_0]; + {%- if not nobag or is_index_select %} + const int32_t D_start_t0 = D_offsets[t_0]; + // D can be hoisted here because D is the same if features share the + // same table, but D_start is different + const int32_t D = D_offsets[t_0 + 1] - D_start_t0; + {%- if is_index_select %} + // grad_offset can be hoisted here for batch_index_select because it + // does not allow multiple features to share a single embedding table + const auto grad_offset = permute_output_dim_0_1 ? D_start_t0 : grad_offsets[t_0]; + const auto grad_stride = permute_output_dim_0_1 ? D_offsets[T] : D; + {%- endif %} {%- endif %} int64_t idx = linear_index - hash_size; @@ -152,7 +171,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ const auto b = b_t & info_B_mask; const auto t = b_t >> info_B_num_bits; {%- if vbe %} - const auto grad_offset = output_offsets[B_offsets[t] + b]; + const auto grad_offset = row_output_offsets[B_offsets[t] + b]; {%- else %} // if vbe int32_t D_start = sl_j < sl_end ? D_offsets[t] : 0; {%- endif %} // if vbe @@ -183,13 +202,16 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ ++i) { int32_t d = (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH; Vec4T> grad_out_vec( - {%- if nobag %} + {%- if nobag and is_index_select %} + // grad_output is 1d + &grad_output[grad_offset + l_j * grad_stride + d] + {%- elif nobag %} &grad_output[l_j][d] {%- elif vbe %} &grad_output[0][grad_offset_j + d] {%- else %} &grad_output[b_j][0] + D_start_j + d - {%- endif %} + {%- endif %} // if nobag ); {%- if weighted %} @@ -399,15 +421,19 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ */ {%- macro template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) %} -template __global__ __launch_bounds__(kMaxThreads) -void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1 +template __global__ __launch_bounds__(kMaxThreads) void +{%- if is_index_select %} +batch_index_select_dim0_codegen_backward_kernel_cta_per_row +{%- else %} +split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1 +{%- endif %} < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, {{ kMaxVecsPerThread }}, {{ kThreadGroupSize }} > ( - const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, {%- if optimizer != "none" %} pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, {%- if not dense %} @@ -418,7 +444,7 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize {%- endif %} {%- endif %} // if optimizer != "none" const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -451,7 +477,7 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize {%- endif %} // if not dense and optimizer != "none" {%- if not nobag and vbe %} const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, {%- endif %} {%- if not nobag %} const int32_t info_B_num_bits, @@ -462,7 +488,13 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize pta::PackedTensorAccessor32 grad_accum_counter, const int32_t max_segment_length_per_cta, const bool use_deterministic_algorithms, - {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}); + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }} + {%- endif %} +); {%- endmacro %} {%- macro bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) %} diff --git a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu index b6d0908528..ad81b31b1b 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_kernel_warp_template.cu @@ -30,8 +30,12 @@ template < size_t kMaxVecsPerThread, int32_t kThreadGroupSize > __global__ __launch_bounds__(kBackwardMaxThreads) void +{%- if is_index_select %} +batch_index_select_dim0_codegen_backward_kernel_warp_per_row( +{%- else %} split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1( - const pta::PackedTensorAccessor64 grad_output, +{%- endif %} + const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -41,7 +45,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -73,14 +77,19 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} // if not dense and optimizer != "none" {%- if not nobag and vbe %} const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, {%- endif %} {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, {%- endif %} - {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}) { - + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +) { {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; {%- else %} @@ -123,8 +132,17 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} int64_t hash_size = hash_size_cumsum[t_0]; - {%- if not nobag %} - int32_t D = D_offsets[t_0 + 1] - D_offsets[t_0]; + {%- if not nobag or is_index_select %} + const auto D_start_t0 = D_offsets[t_0]; + // D can be hoisted here because D is the same if features share the + // same table, but D_start is different + const int32_t D = D_offsets[t_0 + 1] - D_start_t0; + {%- if is_index_select %} + // grad_offset can be hoisted here for batch_index_select because it + // does not allow multiple features to share a single embedding table + const auto grad_offset = permute_output_dim_0_1 ? D_start_t0 : grad_offsets[t_0]; + const auto grad_stride = permute_output_dim_0_1 ? D_offsets[T] : D; + {%- endif %} {%- endif %} int64_t idx = linear_index - hash_size; @@ -139,8 +157,8 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ const auto b = b_t & info_B_mask; const auto t = b_t >> info_B_num_bits; {%- if vbe %} - const auto grad_offset = output_offsets[B_offsets[t] + b]; - {%- else %} // if vbe + const auto grad_offset = row_output_offsets[B_offsets[t] + b]; + {% else %} int32_t D_start = sl_j < sl_end ? D_offsets[t] : 0; {%- endif %} // if vbe {%- else %} // if not nobag @@ -171,7 +189,10 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ ++i) { int32_t d = (i * kThreadGroupSize + threadIdx.x) * VEC_WIDTH; Vec4T> grad_out_vec( - {%- if nobag %} + {%- if nobag and is_index_select %} + // grad_output is 1d + &grad_output[grad_offset + l_j * grad_stride + d] + {%- elif nobag %} &grad_output[l_j][d] {%- elif vbe %} &grad_output[0][grad_offset_j + d] @@ -243,15 +264,19 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ */ {%- macro template_instantiation(emb_type, grad_type, cache_type, kMaxVecsPerThread, kThreadGroupSize) %} -template __global__ __launch_bounds__(kBackwardMaxThreads) -void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1 +template __global__ __launch_bounds__(kBackwardMaxThreads) void +{%- if is_index_select %} +batch_index_select_dim0_codegen_backward_kernel_warp_per_row +{%- else %} +split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1 +{%- endif %} < {{ emb_type }}, {{ grad_type }}, {{ cache_type }}, {{ kMaxVecsPerThread }}, {{ kThreadGroupSize }} > ( - const pta::PackedTensorAccessor64<{{ grad_type }}, 2, at::RestrictPtrTraits> grad_output, + const pta::PackedTensorAccessor64<{{ grad_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> grad_output, {%- if optimizer != "none" %} pta::PackedTensorAccessor64<{{ emb_type }}, 1, at::RestrictPtrTraits> dev_weights, {%- if not dense %} @@ -261,7 +286,7 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize {%- endif %} {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -293,13 +318,19 @@ void split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimize {%- endif %} // if not dense and optimizer != "none" {%- if not nobag and vbe %} const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, {%- endif %} {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, {%- endif %} - {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }}); + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args_no_defaults | replace_pta_namespace() | join(",\n ") | replace("cache_t", cache_type) }} + {%- endif %} +); {%- endmacro %} {%- macro bulk_template_instantiations(kMaxVecsPerThread, kThreadGroupSize) %} diff --git a/fbgemm_gpu/codegen/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/embedding_backward_split_template.cu index 7345be06ee..c75d8809ca 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_backward_split_template.cu @@ -8,7 +8,8 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} -{%- set vbe_desc = "_vbe" if vbe else "" %} +{%- set vdesc = "_vbe" if vbe else "" %} +{%- set ndesc = "_nobag" if nobag else "" %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/fbgemm_tensor_accessor.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" @@ -27,8 +28,12 @@ template < size_t kMaxVecsPerThread, int32_t kThreadGroupSize> __global__ __launch_bounds__(kMaxThreads) void -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1( - const pta::PackedTensorAccessor64 grad_output, +{% if is_index_select %} +batch_index_select_dim0_codegen_backward_kernel_cta_per_row( +{% else %} +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_cta_per_row_1( +{% endif %} + const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -38,7 +43,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- endif %} // if optimizer != "none" const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -70,7 +75,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} // if not dense and optimizer != "none" {%- if vbe %} const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, {%- endif %} {%- if not nobag %} const int32_t info_B_num_bits, @@ -81,7 +86,14 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ pta::PackedTensorAccessor32 grad_accum_counter, const int32_t max_segment_length_per_cta, const bool use_deterministic_algorithms, - {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}); + {%- if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {%- endif %} +); + template < typename emb_t, @@ -90,8 +102,12 @@ template < size_t kMaxVecsPerThread, int32_t kThreadGroupSize = kWarpSize> __global__ __launch_bounds__(kBackwardMaxThreads) void -split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1( - const pta::PackedTensorAccessor64 grad_output, +{% if is_index_select %} +batch_index_select_dim0_codegen_backward_kernel_warp_per_row( +{% else %} +split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_kernel_warp_per_row_1( +{% endif %} + const pta::PackedTensorAccessor64 grad_output, {%- if optimizer != "none" %} pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} @@ -101,7 +117,7 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -133,13 +149,19 @@ split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_ {%- endif %} // if not dense and optimizer != "none" {%- if vbe %} const pta::PackedTensorAccessor32 B_offsets, - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, {%- endif %} {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}); + {% endif %} + {% if is_index_select %} + const at::PackedTensorAccessor32 grad_offsets, + const bool permute_output_dim_0_1 + {% else %} + {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} + {% endif %} +); __global__ __launch_bounds__(kMaxThreads) void split_embedding_backward_codegen_find_long_segments( @@ -157,7 +179,7 @@ split_embedding_backward_codegen_find_long_segments( template __global__ __launch_bounds__(kMaxThreads) void -grad_mean{{ vbe_desc }}_kernel( +grad_mean{{ vdesc }}_kernel( pta::PackedTensorAccessor64 grad_output_mean, const pta::PackedTensorAccessor64 grad_output, const pta::PackedTensorAccessor32 D_offsets, @@ -247,13 +269,17 @@ grad_mean{{ vbe_desc }}_kernel( //////////////////////////////////////////////////////////////////////////////// {%- set func_name0 = "split_embedding{}_backward_codegen_{}_{}_exact{}_cuda".format( - "_nobag" if nobag else "", + ndesc, optimizer, wdesc, - vbe_desc) + vdesc) %} -Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vbe_desc }}_cuda( +{% if is_index_select %} +Tensor batch_index_select_dim0_codegen_backward_cuda( +{% else %} +Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_exact{{ vdesc }}_cuda( +{% endif %} Tensor grad_output, Tensor dev_weights, {%- if not dense %} @@ -262,7 +288,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi Tensor weights_placements, {%- endif %} Tensor weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} Tensor D_offsets, int64_t max_D, {%- else %} @@ -271,7 +297,9 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi Tensor hash_size_cumsum, int64_t total_hash_size_bits, Tensor indices, + {% if not is_index_select %} Tensor offsets, + {%- endif %} {%- if not nobag %} int64_t pooling_mode, {%- endif %} @@ -281,7 +309,9 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- if not dense %} Tensor lxu_cache_locations, {%- endif %} + {%- if not is_index_select %} int64_t unused_, + {% endif %} int64_t max_segment_length_per_warp, {%- if not dense %} {%- if optimizer != "none" %} @@ -293,7 +323,13 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- if vbe %} const VBEMetadata& vbe_metadata, {%- endif %} - {%- if optimizer != "none" %} + {% if is_index_select %} + const Tensor& grad_offsets, + const Tensor& total_L_offsets, + const int32_t fixed_L_per_warp, + const int32_t num_warps_per_feature, + const bool permute_output_dim_0_1 + {%- elif optimizer != "none" %} {{ args.split_function_args | join(", ") }} {%- else %} // This is acutally passed via args.split_function_args but explicitly list @@ -301,7 +337,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi int64_t total_hash_size, int64_t total_unique_indices {%- endif %} - ) { +) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( {%- if optimizer != "none" %} @@ -314,16 +350,18 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- endif %} {%- if vbe %} vbe_metadata.B_offsets, - vbe_metadata.output_offsets, + vbe_metadata.row_output_offsets, vbe_metadata.b_t_map, {%- endif %} weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} D_offsets, {%- endif %} hash_size_cumsum, indices, + {%- if not is_index_select %} offsets, + {%- endif %} {%- if weighted %} indice_weights, {%- endif %} @@ -332,13 +370,24 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- endif %} grad_output); + {%- if is_index_select %} + if (!permute_output_dim_0_1) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + grad_offsets, + dev_weights + ); + } + {%- endif %} + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); - {%- if nobag %} + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} + {% if not is_index_select %} TORCH_CHECK(max_D <= {{ max_embedding_dim }}); + {%- endif %} {%- if optimizer == "none" %} // grad_dev_weights has emb_t type @@ -375,14 +424,18 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi TORCH_CHECK(T > 0); // offsets = [B x T + 1] + {% if is_index_select %} + const auto total_B = num_warps_per_feature * T; + {% else %} const auto total_B = offsets.size(0) - 1; + {% endif %} TORCH_CHECK(total_B > 0); auto BT_block_size = kMaxThreads / kWarpSize; TORCH_CHECK(BT_block_size * kWarpSize <= kMaxThreads); {%- if vbe %} TORCH_CHECK(vbe_metadata.B_offsets.numel() == T + 1); - TORCH_CHECK(vbe_metadata.output_offsets.numel() == total_B); + TORCH_CHECK(vbe_metadata.row_output_offsets.numel() == total_B); TORCH_CHECK(vbe_metadata.b_t_map.numel() == total_B); {%- endif %} @@ -429,12 +482,21 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi hash_size_cumsum, total_hash_size_bits, indices, - offsets, + {{ "offsets" if not is_index_select else "Tensor()" }}, {{ "true" if nobag else "false" }}, {{ "c10::optional(vbe_metadata.b_t_map)" if vbe else "c10::optional()" }}, info_B_num_bits, info_B_mask, - total_unique_indices); + total_unique_indices, + {% if is_index_select %} + true, // is_index_select + c10::optional(total_L_offsets), + fixed_L_per_warp, + num_warps_per_feature + {% else %} + false // is_index_select + {% endif %} + ); {%- if not dense %} auto lxu_cache_locations_sorted = at::empty_like(lxu_cache_locations); @@ -522,7 +584,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi grad_output = grad_output.reshape({1, -1}); {%- endif %} - auto grad_output_accessor = MAKE_PTA_WITH_NAME("{{ func_name0 }}.1", grad_output, grad_t, 2, 64); + auto grad_output_accessor = MAKE_PTA_WITH_NAME("{{ func_name0 }}.1", grad_output, grad_t, {{ "1" if is_index_select else "2" }}, 64); {%- if not nobag %} Tensor grad_output_mean; @@ -531,10 +593,10 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- if not dense or not vbe %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name1 = "grad_mean{{ vbe_desc }}_kernel"; + const auto func_name1 = "grad_mean{{ vdesc }}_kernel"; #endif - grad_mean{{ vbe_desc }}_kernel<<< + grad_mean{{ vdesc }}_kernel<<< div_round_up(total_B, kMaxThreads / kWarpSize), dim3(kWarpSize, kMaxThreads / kWarpSize), 0, @@ -545,7 +607,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi MAKE_PTA_WITH_NAME(func_name1, D_offsets, int32_t, 1, 32), MAKE_PTA_WITH_NAME(func_name1, offsets, int64_t, 1, 32), {%- if vbe %} - MAKE_PTA_WITH_NAME(func_name1, vbe_metadata.output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, vbe_metadata.row_output_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name1, vbe_metadata.b_t_map, int32_t, 1, 32), info_B_num_bits, info_B_mask @@ -644,30 +706,39 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi // must use dynamic shared memory (rather than statically sized // arrays) and require an explicit opt-in using cudaFuncSetAttribute()". + {%- set cta_kernel = + "batch_index_select_dim0_codegen_backward_kernel_cta_per_row" + if is_index_select else + "split_embedding{}_backward_codegen_{}_{}{}_kernel_cta_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} + + const auto backward_cta_per_row_kernel = + {{ cta_kernel }} + ; + #ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize>, + backward_cta_per_row_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB -#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); +#endif #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name3 = "split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1"; + const auto func_name3 = "{{ cta_kernel }}"; #endif // dividing by kMaxThreads is a heuristic to avoid num of blocks far exceeding num_long_run_ids[0] - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_cta_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize> + backward_cta_per_row_kernel <<) * 4 * kWarpSize * @@ -685,7 +756,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- endif %} {%- endif %} // if optimizer != "none" MAKE_PTA_WITH_NAME(func_name3, weights_offsets, int64_t, 1, 32), - {%- if not nobag %} + {%- if not nobag or is_index_select %} MAKE_PTA_WITH_NAME(func_name3, D_offsets, int32_t, 1, 32), {%- else %} D, @@ -717,7 +788,7 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi {%- endif %} // if not dense and optimizer != "none" {%- if vbe %} MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.B_offsets, int32_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name3, vbe_metadata.row_output_offsets, int64_t, 1, 32), {%- endif %} {%- if not nobag %} info_B_num_bits, @@ -728,13 +799,38 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi MAKE_PTA_WITH_NAME(func_name3, grad_accum_counter, int32_t, 1, 32), max_segment_length_per_cta, use_deterministic_algorithms, - {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name3") | join(",\n ") }}); + {%- if is_index_select %} + grad_offsets.packed_accessor32(), + permute_output_dim_0_1 + {%- else %} + {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name3") | join(",\n ") }} + {%- endif %} + ); C10_CUDA_KERNEL_LAUNCH_CHECK(); grid_size = std::min( div_round_up(total_unique_indices, kBackwardMaxThreads / kThreadGroupSize), get_max_thread_blocks_()); + {%- set warp_kernel = + "batch_index_select_dim0_codegen_backward_kernel_warp_per_row" + if is_index_select else + "split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( + ndesc, + optimizer, + wdesc, + vdesc, + ) + %} + + const auto backward_warp_per_row_kernel = + {{ warp_kernel }} + ; + // Shared memory is not needed for non uint8_t weights size_t shmem_bytes = 0; if (std::is_same::value) { @@ -742,28 +838,18 @@ Tensor split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimi at::acc_type) * 4 * kWarpSize * kMaxVecsPerThread; #ifndef __HIP_PLATFORM_HCC__ cudaFuncSetAttribute( - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize>, + backward_warp_per_row_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB + C10_CUDA_KERNEL_LAUNCH_CHECK(); #endif } #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name4 = "split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1"; + const auto func_name4 = "{{ warp_kernel }}"; #endif - C10_CUDA_KERNEL_LAUNCH_CHECK(); - split_embedding{{ "_nobag" if nobag else "" }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vbe_desc }}_kernel_warp_per_row_1< - emb_t, - grad_t, - cache_t, - kMaxVecsPerThread, - kThreadGroupSize> + backward_warp_per_row_kernel <<(), + permute_output_dim_0_1 + {% else %} + {{ args.split_kernel_arg_constructors | make_pta_acc_format("func_name4") | join(",\n ") }} + {% endif %} + ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return; }); diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu index 6c2a4d84c2..d7164f02e4 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_nobag_small_template.cu @@ -28,7 +28,11 @@ template < size_t kThreadGroupSize > __launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_small_kernel( +{%- else %} {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +{%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} const pta::PackedTensorAccessor64 uvm_weights, @@ -36,29 +40,72 @@ __launch_bounds__(kForwardMaxThreads) __global__ void const pta::PackedTensorAccessor32 weights_placements, {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} int64_t D, + {%- endif %} FixedDivisor fd_B, const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} const pta::PackedTensorAccessor32 offsets, + {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, {%- endif %} - pta::PackedTensorAccessor64 output // [B][total_D], + {%- if is_index_select %} + const at::PackedTensorAccessor32 output_offsets, + const at::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} // if dense + // If 2D, shape is [B][total_D] + pta::PackedTensorAccessor64 output ) { int32_t T = weights_offsets.size(0); int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; + {%- if not is_index_select %} if (b_t >= offsets.size(0) - 1) { return; } + {%- endif %} int32_t t; int32_t b; fd_B.DivMod(b_t, &t, &b); - int64_t weights_offset = weights_offsets[t]; + {%- if is_index_select %} + index_t indices_start; + int32_t L; + int32_t L_start; + if (t >= T) { + return; + } + const auto total_L_start = total_L_offsets[t]; + const auto total_L = total_L_offsets[t + 1] - total_L_start; + L_start = b * fixed_L_per_warp; + if (L_start >= total_L) { + return; + } + indices_start = total_L_start + L_start; + L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp : (total_L - L_start); + {%- else %} index_t indices_start = offsets[b_t]; - index_t indices_end = offsets[b_t + 1]; - int32_t L = indices_end - indices_start; + int32_t L = offsets[b_t + 1] - indices_start; + {%- endif %} + + {%- if is_index_select %} + const int32_t D_start = D_offsets[t]; + const int32_t D_end = D_offsets[t + 1]; + const int32_t D = D_end - D_start; + + // Check D in the kernel to avoid iterating through the list on host + CUDA_KERNEL_ASSERT(D % 4 == 0 && "The column size must be multiple of 4"); + const auto output_offset = permute_output_dim_0_1 ? D_start : output_offsets[t]; + const auto output_stride = permute_output_dim_0_1 ? D_offsets[T] : D; + {%- endif %} // dense + + int64_t weights_offset = weights_offsets[t]; const emb_t* __restrict__ weights; {%- if not dense %} const auto placement = static_cast(weights_placements[t]); @@ -88,7 +135,11 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- endif %} for (auto j = group_start; j < group_end && l_start + j < L; ++j) { int64_t idx_j = shfl_sync(idx, j); + {%- if is_index_select %} + int64_t output_j = L_start + l_start + j; + {%- else %} int64_t output_j = indices_start + l_start + j; + {%- endif %} {%- if not dense %} int32_t cache_idx_j = shfl_sync(cache_idx, j); {%- endif %} @@ -125,7 +176,13 @@ __launch_bounds__(kForwardMaxThreads) __global__ void } {%- else %} Vec4T weight = weight_row_emb.load(d, qparams_emb); + {%- if is_index_select %} + // output is 1D (because the stride can be irregular) + weight.store(&output[output_offset + output_j * output_stride + d]); + {%- else %} + // output is 2D weight.store(&output[output_j][d]); + {%- endif %} {%- endif %} } } @@ -144,8 +201,12 @@ __launch_bounds__(kForwardMaxThreads) __global__ void {%- for kEmbeddingSize in [4, 8, 16, 32] %} {%- set index_type = 'int64_t' %} -template __launch_bounds__(kForwardMaxThreads) __global__ -void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel +template __launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_small_kernel +{%- else %} +{{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweighted_small_kernel +{%- endif %} < {{ emb_type }}, {{ cache_type }}, @@ -160,14 +221,26 @@ void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forward_unweigh const pta::PackedTensorAccessor32 weights_placements, {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} int64_t D, + {%- endif %} FixedDivisor fd_B, const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> indices, + {%- if not is_index_select %} const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> offsets, + {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, {%- endif %} - pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output); + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} + pta::PackedTensorAccessor64<{{ output_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> output); {%- endfor %} {%- endfor %} diff --git a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu index f36b7f6ae8..7f93a786ed 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu @@ -97,7 +97,13 @@ using namespace fbgemm_gpu; if (d < D) { // Since there is no pooling, simply copy the weights to output const auto weights_slice = weights_row.load(d, qparams); + {%- if is_index_select %} + // output is 1D (because the stride can be irregular) + weights_slice.store(&output[output_offset + output_j * output_stride + d]); + {%- else %} + // output is 2D weights_slice.store(&output[output_j][d]); + {%- endif %} } } {%- endif %} @@ -116,8 +122,12 @@ template < size_t kMaxVecsPerThread, {%- endif %} size_t kThreadGroupSize > -__launch_bounds__(kForwardMaxThreads) __global__ -void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel( +__launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_kernel( +{%- else %} +{{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel( +{%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} const pta::PackedTensorAccessor64 uvm_weights, @@ -125,13 +135,13 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } const pta::PackedTensorAccessor32 weights_placements, {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, - {%- endif %} + {%- endif %} // if nobag {%- if vbe %} - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, const pta::PackedTensorAccessor32 b_t_map, const int32_t info_B_num_bits, const uint32_t info_B_mask, @@ -139,7 +149,9 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } FixedDivisor fd_B, {%- endif %} const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} const pta::PackedTensorAccessor32 offsets, + {%- endif %} {%- if not nobag %} int64_t pooling_mode, {%- endif %} @@ -149,7 +161,14 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, {%- endif %} - pta::PackedTensorAccessor64 output // [B][total_D] + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} + // If 2D, shape is [B][total_D] + pta::PackedTensorAccessor64 output ) { // shfl_sync_mask is implicitly used by SHFL_SYNC @@ -166,9 +185,11 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } // Determine the linearized warp ID, and exit early if needed int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; + {%- if not is_index_select %} if (b_t >= offsets.size(0) - 1) { return; } + {%- endif %} // Determine the Table and Training Example IDs int32_t t; // Table ID @@ -181,10 +202,48 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } fd_B.DivMod(b_t, &t, &b); {%- endif %} + // Get total number of tables + int32_t T = weights_offsets.size(0); + + {%- if is_index_select %} + index_t indices_start; + int32_t L; + int32_t L_start; + if (t >= T) { + return; + } + const auto total_L_start = total_L_offsets[t]; + const auto total_L = total_L_offsets[t + 1] - total_L_start; + L_start = b * fixed_L_per_warp; + if (L_start >= total_L) { + return; + } + indices_start = total_L_start + L_start; + L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp : (total_L - L_start); + {%- else %} + // Determine the number of indices (pooling factor) to look up within the bag + index_t indices_start = offsets[b_t]; + int32_t L = offsets[b_t + 1] - indices_start; + {%- endif %} + + // Get the offsets of the embedding dimensions of the tables and determine D + {%- if not nobag or is_index_select %} + const auto D_start = D_offsets[t]; + const auto D_end = D_offsets[t + 1]; + const auto D = D_end - D_start; + {%- endif %} + + {%- if is_index_select %} + // Check D in the kernel to avoid iterating through the list on host + CUDA_KERNEL_ASSERT(D % 4 == 0 && "The column size must be multiple of 4"); + const auto output_offset = permute_output_dim_0_1 ? D_start : output_offsets[t]; + const auto output_stride = permute_output_dim_0_1 ? D_offsets[T] : D; + {%- endif %} + // From the Table ID, fetch its weight tensor offset, locate that position // in the input weights tensor, and set the weights table pointer - const emb_t* __restrict__ weights; int64_t weights_offset = weights_offsets[t]; + const emb_t* __restrict__ weights; {%- if not dense %} const auto placement = static_cast(weights_placements[t]); if (placement == PlacementType::DEVICE) { @@ -196,21 +255,6 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } weights = &dev_weights[weights_offset]; {%- endif %} - // Get total number of tables - int32_t T = weights_offsets.size(0); - - // Determine the number of indices (pooling factor) to look up within the bag - index_t indices_start = offsets[b_t]; - index_t indices_end = offsets[b_t + 1]; - int32_t L = indices_end - indices_start; - - // Get the offsets of the embedding dimensions of the tables and determine D - {%- if not nobag %} - int32_t D_start = D_offsets[t]; - int32_t D_end = D_offsets[t + 1]; - int32_t D = D_end - D_start; - {%- endif %} - // D is computed in the bag case or provided as function arg in the nobag case // (nobag only supports the case where the embedding dimensions are the same for all tables) int32_t D_emb = D; @@ -251,7 +295,9 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } // Load index from thread j in the group int64_t idx_j = SHFL_SYNC(idx, j); - {%- if nobag %} + {%- if is_index_select %} + int64_t output_j = L_start + l_start + j; + {%- elif nobag %} int64_t output_j = indices_start + l_start + j; {%- endif %} @@ -300,7 +346,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } // If weight type is FP32/16 if constexpr (!std::is_same_v) { {%- if vbe %} - output_t* output_ = &output[0][output_offsets[b_t]]; + output_t* output_ = &output[0][row_output_offsets[b_t]]; {%- else %} output_t* output_ = &output[b][D_start]; {%- endif %} @@ -367,8 +413,12 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } */ {%- macro template_instantiation(emb_type, cache_type, output_type, use_cache, kMaxVecsPerThread, kThreadGroupSize) %} -template __launch_bounds__(kForwardMaxThreads) __global__ -void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel +template __launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_kernel +{%- else %} +{{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}{{ vbe_desc }}_kernel +{%- endif %} < {{ emb_type }}, {{ cache_type }}, @@ -389,7 +439,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } const pta::PackedTensorAccessor32 weights_placements, {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, @@ -403,7 +453,9 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } FixedDivisor fd_B, {%- endif %} const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} const pta::PackedTensorAccessor32 offsets, + {%- endif %} {%- if not nobag %} int64_t pooling_mode, {%- endif %} @@ -413,7 +465,13 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" } {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, {%- endif %} - pta::PackedTensorAccessor64<{{ output_type }}, 2, at::RestrictPtrTraits> output); + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} + pta::PackedTensorAccessor64<{{ output_type }}, {{ "1" if is_index_select else "2" }}, at::RestrictPtrTraits> output); {%- endmacro %} {%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} diff --git a/fbgemm_gpu/codegen/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/embedding_forward_split_template.cu index 7c3324b51c..5d3f3b7a8e 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/embedding_forward_split_template.cu @@ -22,7 +22,6 @@ using Tensor = at::Tensor; using namespace fbgemm_gpu; - //////////////////////////////////////////////////////////////////////////////// // External Function Declarations //////////////////////////////////////////////////////////////////////////////// @@ -35,8 +34,12 @@ template < typename index_t, size_t kThreadGroupSize > -__launch_bounds__(kForwardMaxThreads) -__global__ void {{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +__launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_small_kernel( +{%- else %} +{{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel( +{%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} const pta::PackedTensorAccessor64 uvm_weights, @@ -44,14 +47,26 @@ __global__ void {{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_ker const pta::PackedTensorAccessor32 weights_placements, {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, + {%- if is_index_select %} + const pta::PackedTensorAccessor32 D_offsets, + {%- else %} int64_t D, + {%- endif %} FixedDivisor fd_B, const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} const pta::PackedTensorAccessor32 offsets, + {%- endif %} {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, {%- endif %} - pta::PackedTensorAccessor64 output // [B][total_D], + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} + pta::PackedTensorAccessor64 output ); {%- endif %} @@ -90,9 +105,9 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( {%- for nobag in [True, False] %} -{%- if not nobag or (not weighted and not vbe) %} {%- set ndesc = "_nobag" if nobag else "" %} - +{%- if (not nobag or (not weighted and not vbe)) and (nobag or (not is_index_select)) %} +{%- set has_experimental = (not dense and not nobag and not vbe and not is_index_select) %} template < typename emb_t, typename cache_t, @@ -106,8 +121,12 @@ template < {%- endif %} size_t kThreadGroupSize = kWarpSize > -__launch_bounds__(kForwardMaxThreads) -__global__ void {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel( +__launch_bounds__(kForwardMaxThreads) __global__ void +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_kernel( +{%- else %} +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_kernel( +{%- endif %} const pta::PackedTensorAccessor64 dev_weights, {%- if not dense %} const pta::PackedTensorAccessor64 uvm_weights, @@ -115,13 +134,13 @@ __global__ void {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ v const pta::PackedTensorAccessor32 weights_placements, {%- endif %} const pta::PackedTensorAccessor32 weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} const pta::PackedTensorAccessor32 D_offsets, {%- else %} int64_t D, {%- endif %} {%- if vbe %} - const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 row_output_offsets, const pta::PackedTensorAccessor32 b_t_map, const int32_t info_B_num_bits, const uint32_t info_B_mask, @@ -129,7 +148,9 @@ __global__ void {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ v FixedDivisor fd_B, {%- endif %} const pta::PackedTensorAccessor32 indices, + {%- if not is_index_select %} const pta::PackedTensorAccessor32 offsets, + {%- endif %} {%- if not nobag %} int64_t pooling_mode, {%- endif %} @@ -139,7 +160,13 @@ __global__ void {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ v {%- if not dense %} const pta::PackedTensorAccessor32 lxu_cache_locations, {%- endif %} - pta::PackedTensorAccessor64 output // [B][total_D] + {%- if is_index_select %} + const pta::PackedTensorAccessor32 output_offsets, + const pta::PackedTensorAccessor32 total_L_offsets, + const int32_t fixed_L_per_warp, + const bool permute_output_dim_0_1, + {%- endif %} // if dense + pta::PackedTensorAccessor64 output ); {%- endif %} @@ -224,11 +251,16 @@ __global__ void {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ v //////////////////////////////////////////////////////////////////////////////// {%- for nobag in [True, False] %} -{%- if not nobag or (not weighted and not vbe) %} {%- set ndesc = "_nobag" if nobag else "" %} -{%- set has_experimental = (not dense and not nobag and not vbe) %} - -Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_cuda( +{%- if (not nobag or (not weighted and not vbe)) and (nobag or (not is_index_select)) %} +{%- set has_experimental = (not dense and not nobag and not vbe and not is_index_select) %} + +Tensor +{%- if is_index_select %} +batch_index_select_dim0_codegen_forward_cuda( +{%- else %} +{{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_cuda( +{%- endif %} Tensor dev_weights, {%- if not dense %} Tensor uvm_weights, @@ -236,15 +268,21 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c Tensor weights_placements, {%- endif %} Tensor weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} Tensor D_offsets, - int64_t total_D, - int64_t max_D, {%- else %} int64_t D, {%- endif %} + {%- if not nobag %} + int64_t total_D, + {%- endif %} + {%- if not nobag or is_index_select %} + int64_t max_D, + {% endif %} Tensor indices, + {%- if not is_index_select %} Tensor offsets, + {%- endif %} {%- if not nobag %} int64_t pooling_mode, {%- endif %} @@ -255,12 +293,21 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c Tensor lxu_cache_locations, {%- endif %} int64_t output_dtype, + {%- if is_index_select %} + const Tensor& output_offsets, + const Tensor& total_L_offsets, + const int64_t output_size, + const int32_t fixed_L_per_warp, + const int32_t num_warps_per_feature, + const bool permute_output_dim_0_1 + {%- else %} {%- if vbe %} const VBEMetadata& vbe_metadata, const int32_t info_B_num_bits, const uint32_t info_B_mask, {%- endif %} bool is_experimental + {%- endif %} ) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( {%- if not dense %} @@ -269,11 +316,13 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c weights_placements, {%- endif %} weights_offsets, - {%- if not nobag %} + {%- if not nobag or is_index_select %} D_offsets, {%- endif %} indices, + {%- if not is_index_select %} offsets, + {%- endif %} {%- if weighted %} indice_weights, {%- endif %} @@ -281,12 +330,24 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c lxu_cache_locations, {%- endif %} {%- if vbe %} - vbe_metadata.output_offsets, + vbe_metadata.row_output_offsets, vbe_metadata.b_t_map, {%- endif %} + {%- if is_index_select %} + total_L_offsets, + {%- endif %} dev_weights ); + {%- if is_index_select %} + if (!permute_output_dim_0_1) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + output_offsets, + dev_weights + ); + } + {%- endif %} + at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(dev_weights.get_device()); @@ -298,19 +359,26 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c {%- endif %} TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] + {%- if is_index_select %} + const auto total_B = num_warps_per_feature * T; + const int32_t B = num_warps_per_feature; + {%- else %} const auto total_B = offsets.size(0) - 1; - const int32_t B = (total_B) / T; + const int32_t B = total_B / T; + {%- endif %} TORCH_CHECK_GE(B, 0); + {%- if not nobag or is_index_select %} {%- if not nobag %} TORCH_CHECK_GT(total_D, 0); TORCH_CHECK_EQ(total_D % 4, 0); + {%- endif %} TORCH_CHECK_LE(max_D, {{ max_embedding_dim }}); - {%- else %} + {%- elif not is_index_select %} TORCH_CHECK_GT(D, 0); TORCH_CHECK_EQ(D % 4, 0); {%- endif %} {%- if vbe %} - TORCH_CHECK(vbe_metadata.output_offsets.numel() == total_B); + TORCH_CHECK(vbe_metadata.row_output_offsets.numel() == total_B); TORCH_CHECK(vbe_metadata.b_t_map.numel() == total_B); TORCH_CHECK(vbe_metadata.output_size >= 0); {%- endif %} @@ -318,13 +386,31 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c Tensor output; {%- if nobag %} SparseType o_dtype = static_cast(output_dtype); + {%- if is_index_select %} + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || + o_dtype == SparseType::BF16); + + TORCH_CHECK(fixed_L_per_warp > 0); + TORCH_CHECK(num_warps_per_feature > 0); + if (!permute_output_dim_0_1) { + TORCH_CHECK(output_size >= 0); + TORCH_CHECK(output_offsets.numel() > 0); + } + + // If permute_output_dim_0_1 is true, output shape is (batch_size * total_D) + // Else, output shape is (output_size) + output = at::empty({output_size}, dev_weights.options().dtype(getScalarType(o_dtype))); + {%- else %} TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8); + int64_t adjusted_D = D; if (o_dtype == SparseType::INT8) { adjusted_D += T * kINT8QparamsBytes; } + output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype))); + {%- endif %} {%- else %} SparseType o_dtype = static_cast(output_dtype); TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || @@ -385,11 +471,16 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c {#-/* Sequence TBE Case (nobag=True) ****************************************************/#} {%- if nobag %} - DISPATCH_OPTIMAL_NOBAG_FORWARD_KERNEL(D, [&] { + DISPATCH_OPTIMAL_NOBAG_FORWARD_KERNEL({{ "D" if not is_index_select else "max_D" }}, [&] { + {%- set nobag_small_kernel = + "batch_index_select_dim0_codegen_forward_small_kernel" + if is_index_select else + "{}_embedding_nobag_codegen_forward_unweighted_small_kernel".format(ddesc) + %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "{{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel"; + const auto func_name = "{{ nobag_small_kernel }}"; #endif - {{ ddesc }}_embedding_nobag_codegen_forward_unweighted_small_kernel + {{ nobag_small_kernel }} <<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -403,14 +494,26 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32), {%- endif %} MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), + {%- if is_index_select %} + MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), + {%- else %} D, + {%- endif %} FixedDivisor(B), MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), + {%- if not is_index_select %} MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), + {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), {%- endif %} - MAKE_PTA_WITH_NAME(func_name, output, output_t, 2, 64) + {%- if is_index_select %} + MAKE_PTA_WITH_NAME(func_name, output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, total_L_offsets, int64_t, 1, 32), + fixed_L_per_warp, + permute_output_dim_0_1, + {%- endif %} + MAKE_PTA_WITH_NAME(func_name, output, output_t, {{ "1" if is_index_select else "2" }}, 64) ); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -418,15 +521,21 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c }); DISPATCH_KERNEL_FOR_CACHE_CASE(use_lxu_cache, [&] { + {%- set nobag_kernel = + "batch_index_select_dim0_codegen_forward_kernel" + if is_index_select else + "{}_embedding_nobag_codegen_forward_unweighted_kernel".format(ddesc) + %} #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "{{ ddesc }}_embedding_nobag_codegen_forward_unweighted_kernel"; + const auto func_name = "{{ nobag_kernel }}"; #endif - {%- if dense %} - {{ ddesc }}_embedding_nobag_codegen_forward_unweighted_kernel - {%- else %} - {{ ddesc }}_embedding_nobag_codegen_forward_unweighted_kernel - {%- endif %} + {{ nobag_kernel }} + {%- if dense or is_index_select %} + + {%- else %} + + {%- endif %} <<< div_round_up(total_B, kForwardMaxThreads / kWarpSize), dim3(kWarpSize, kForwardMaxThreads / kWarpSize), @@ -440,14 +549,26 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c MAKE_PTA_WITH_NAME(func_name, weights_placements, int32_t, 1, 32), {%- endif %} MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), + {%- if is_index_select %} + MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), + {%- else %} D, + {%- endif %} FixedDivisor(B), MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), + {%- if not is_index_select %} MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32), + {%- endif %} {%- if not dense %} MAKE_PTA_WITH_NAME(func_name, lxu_cache_locations, int32_t, 1, 32), {%- endif %} - MAKE_PTA_WITH_NAME(func_name, output, output_t, 2, 64) + {%- if is_index_select %} + MAKE_PTA_WITH_NAME(func_name, output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, total_L_offsets, int64_t, 1, 32), + fixed_L_per_warp, + permute_output_dim_0_1, + {%- endif %} + MAKE_PTA_WITH_NAME(func_name, output, output_t, {{ "1" if is_index_select else "2" }}, 64) ); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -483,7 +604,7 @@ Tensor {{ ddesc }}_embedding{{ ndesc }}_codegen_forward_{{ wdesc }}{{ vdesc }}_c MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32), {%- if vbe %} - MAKE_PTA_WITH_NAME(func_name, vbe_metadata.output_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, vbe_metadata.row_output_offsets, int64_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, vbe_metadata.b_t_map, int32_t, 1, 32), info_B_num_bits, info_B_mask, diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h index 7770557053..44b2ee7f91 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h @@ -86,7 +86,7 @@ struct VBEMetadata { at::Tensor B_offsets; // torch.int at::Tensor output_offsets_feature_rank; // torch.long at::Tensor B_offsets_rank_per_feature; // torch.int - at::Tensor output_offsets; // torch.long + at::Tensor row_output_offsets; // torch.long at::Tensor b_t_map; // torch.int int32_t max_B_feature_rank; int64_t output_size; diff --git a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh index f85e63586c..67ae6fe53a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh @@ -43,7 +43,12 @@ transpose_embedding_input( const c10::optional& vbe_b_t_map = c10::optional(), const int64_t info_B_num_bits = 26, const int64_t info_B_mask = 0x2FFFFFF, - const int64_t total_unique_indices = -1); + const int64_t total_unique_indices = -1, + const bool is_index_select = false, + const c10::optional& total_L_offsets = + c10::optional(), + const int64_t fixed_L_per_warp = 0, + const int64_t num_warps_per_feature = 0); std::tuple get_infos_metadata(at::Tensor unused, int64_t B, int64_t T); diff --git a/fbgemm_gpu/src/split_embeddings_utils.cpp b/fbgemm_gpu/src/split_embeddings_utils.cpp index 873c670c01..f5e066f390 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cpp +++ b/fbgemm_gpu/src/split_embeddings_utils.cpp @@ -22,8 +22,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "bool nobag=False, " "Tensor? vbe_b_t_map=None, " "int info_B_num_bits=26, " - "int info_B_mask=0x2FFFFFF," - "int total_unique_indices=-1) " + "int info_B_mask=0x2FFFFFF, " + "int total_unique_indices=-1, " + "bool is_index_select=False, " + "Tensor? total_L_offsets=None, " + "int fixed_L_per_warp=0, " + "int num_warps_per_feature=0) " "-> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("get_infos_metadata(Tensor unused, int B, int T) -> (int, int)"); DISPATCH_TO_CUDA("transpose_embedding_input", transpose_embedding_input); diff --git a/fbgemm_gpu/src/split_embeddings_utils.cu b/fbgemm_gpu/src/split_embeddings_utils.cu index b0e33bb41e..f546e89880 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cu +++ b/fbgemm_gpu/src/split_embeddings_utils.cu @@ -138,6 +138,62 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel( } } +template +__global__ +__launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel( + const at::PackedTensorAccessor32 + hash_size_cumsum, + const at::PackedTensorAccessor32 indices, + const at::PackedTensorAccessor32 + total_L_offsets, + at::PackedTensorAccessor32 infos, + at::PackedTensorAccessor32 + linear_indices, + FixedDivisor fd, + int32_t fixed_L_per_warp) { + const int32_t T = hash_size_cumsum.size(0) - 1; + auto b_t = blockIdx.x * blockDim.x + threadIdx.x; + int32_t b; + int32_t t; + + fd.DivMod(b_t, &t, &b); + + const int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize; + + index_t hash_offset = -1; + index_t indices_start = -1; + int32_t L = 0; + int32_t L_start = 0; + if (t < T) { + const auto total_L_start = total_L_offsets[t]; + const auto total_L = total_L_offsets[t + 1] - total_L_start; + L_start = b * fixed_L_per_warp; + if (L_start < total_L) { + hash_offset = hash_size_cumsum[t]; + indices_start = total_L_start + L_start; + L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp + : (total_L - L_start); + } + } + + // Compile-time conditional + for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) { + const index_t indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j); + const auto t_warp = fbgemm_gpu::shfl_sync(t, j); + const auto L_warp = fbgemm_gpu::shfl_sync(L, j); + const auto L_start_warp = fbgemm_gpu::shfl_sync(L_start, j); + const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j); + for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) { + const index_t idx = __ldg(&indices[indices_start_warp + i]); + // l is the relative l in the feature (i.e., the first l in the feature + // is 0) + const int64_t l_t = (L_start_warp + i) * T + t_warp; + infos[indices_start_warp + i] = l_t; + linear_indices[indices_start_warp + i] = hash_offset_warp + idx; + } + } +} + DLL_PUBLIC std::tuple< Tensor /*linear_indices*/, Tensor /*linear_indices_sorted*/, @@ -155,16 +211,30 @@ transpose_embedding_input( const c10::optional& vbe_b_t_map, const int64_t info_B_num_bits, const int64_t info_B_mask, - const int64_t total_unique_indices) { + const int64_t total_unique_indices, + const bool is_index_select, + const c10::optional& total_L_offsets, + const int64_t fixed_L_per_warp, + const int64_t num_warps_per_feature) { const bool vbe = vbe_b_t_map.has_value(); TORCH_CHECK(nobag || !vbe || info_B_num_bits > 0); TORCH_CHECK(!vbe || info_B_mask > 0); + TORCH_CHECK( + !is_index_select || (fixed_L_per_warp > 0 && num_warps_per_feature > 0)); - const auto total_B = offsets.size(0) - 1; const auto T = hash_size_cumsum.size(0) - 1; + const auto total_B = + !is_index_select ? (offsets.size(0) - 1) : (num_warps_per_feature * T); + + TORCH_CHECK( + !is_index_select || + (total_L_offsets.has_value() && + total_L_offsets.value().numel() == T + 1)); auto infos = at::empty_like( - indices, indices.options().dtype(nobag ? at::kLong : at::kInt)); + indices, + indices.options().dtype( + (nobag || is_index_select) ? at::kLong : at::kInt)); auto infos_sorted = at::empty_like(infos); auto linear_indices = at::empty_like(indices); auto linear_indices_sorted = at::empty_like(indices); @@ -203,10 +273,31 @@ transpose_embedding_input( using info_t = index_t; AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "transpose_embedding_input2", [&] { - if (!nobag) { - INVOKE_LINEARIZE_INDEX_KERNEL(int32_t, false); + if (!is_index_select) { + if (!nobag) { + INVOKE_LINEARIZE_INDEX_KERNEL(int32_t, false); + } else { + INVOKE_LINEARIZE_INDEX_KERNEL(int64_t, true); + } } else { - INVOKE_LINEARIZE_INDEX_KERNEL(int64_t, true); + // index_select is a special case of TBE (dense, nobag, with + // fixed_L_per_warp) + linearize_index_index_select_kernel<<< + div_round_up(total_B, kMaxThreads), + kMaxThreads, + 0, + at::cuda::getCurrentCUDAStream()>>>( + hash_size_cumsum + .packed_accessor32(), + indices.packed_accessor32(), + total_L_offsets.value() + .packed_accessor32(), + infos.packed_accessor32(), + linear_indices + .packed_accessor32(), + FixedDivisor(total_B / T), + fixed_L_per_warp); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } { size_t temp_storage_bytes = 0; @@ -386,7 +477,7 @@ DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t); __global__ __launch_bounds__(kMaxThreads) void populate_vbe_metadata_foreach_sample_inplace_kernel( at::PackedTensorAccessor32 - output_offsets, + row_output_offsets, at::PackedTensorAccessor32 b_t_map, const at::PackedTensorAccessor32 B_offsets, @@ -430,7 +521,7 @@ __launch_bounds__(kMaxThreads) void populate_vbe_metadata_foreach_sample_inplace // Update b_t b_t = B_start_t + B_start_r_t + b; const auto D_ = nobag ? D : D_offsets[t + 1] - D_offsets[t]; - output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; + row_output_offsets[b_t] = output_offsets_feature_rank[r * T + t] + b * D_; // Relative sample ID in the table const auto b_ = B_start_r_t + b; @@ -492,7 +583,7 @@ DLL_PUBLIC void populate_vbe_metadata_foreach_sample_inplace( at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(vbe_metadata.B_offsets.get_device()); - Tensor output_offsets = + Tensor row_output_offsets = at::empty({total_B}, vbe_metadata.output_offsets_feature_rank.options()); Tensor b_t_map = at::empty({total_B}, vbe_metadata.B_offsets.options()); @@ -504,7 +595,7 @@ DLL_PUBLIC void populate_vbe_metadata_foreach_sample_inplace( kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - output_offsets.packed_accessor32(), + row_output_offsets.packed_accessor32(), b_t_map.packed_accessor32(), vbe_metadata.B_offsets .packed_accessor32(), @@ -520,6 +611,6 @@ DLL_PUBLIC void populate_vbe_metadata_foreach_sample_inplace( info_B_num_bits); C10_CUDA_KERNEL_LAUNCH_CHECK(); - vbe_metadata.output_offsets = std::move(output_offsets); + vbe_metadata.row_output_offsets = std::move(row_output_offsets); vbe_metadata.b_t_map = std::move(b_t_map); } diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 22dcfe333d..8d390dad0a 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -8,7 +8,9 @@ # pyre-ignore-all-errors[56] import contextlib +import functools import itertools +import logging import random import unittest from itertools import accumulate @@ -28,6 +30,7 @@ except Exception: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable, skipIfRocm @@ -1960,6 +1963,157 @@ def test_bottom_unique_k_per_row( all_indices_deduped_ref = torch.as_tensor(all_indices[:, :, :L]) torch.testing.assert_close(all_indices_deduped, all_indices_deduped_ref) + @given( + num_inputs=st.integers(0, 100), + max_input_rows=st.integers(2, 32), + max_cols_factor=st.integers(2, 256), + max_output_rows=st.integers(2, 32), + permute_output_dim_0_1=st.booleans(), + dtype=st.sampled_from([torch.float, torch.half]), + use_cpu=st.booleans() if gpu_available else st.just(True), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_batch_index_select_dim0( + self, + num_inputs: int, + max_input_rows: int, + max_cols_factor: int, + max_output_rows: int, + permute_output_dim_0_1: bool, + dtype: torch.dtype, + use_cpu: bool, + ) -> None: + device = "cpu" if use_cpu else "cuda" + input_rows = torch.randint( + low=1, high=max_input_rows, size=(num_inputs,) + ).tolist() + input_columns = ( + torch.randint(low=1, high=max_cols_factor, size=(num_inputs,)) * 4 + ).tolist() + if permute_output_dim_0_1: + # All num_indices must be the same if permute_output_dim_0_1 is + # True + num_indices = torch.randint(low=1, high=max_output_rows, size=(1,)).item() + input_num_indices = [num_indices] * num_inputs + else: + input_num_indices = torch.randint( + low=1, high=max_output_rows, size=(num_inputs,) + ).tolist() + + def validate( + test_list: List[torch.Tensor], + ref_list: List[torch.Tensor], + rows: List[int], + val_fn: Callable[[torch.Tensor, torch.Tensor], bool], + name: str, + ) -> None: + test_passed_all = True + error_msg = "" + for i, (test, ref) in enumerate(zip(test_list, ref_list)): + test = test.float() + ref = ref.float() + test_passed = val_fn(test, ref) + test_passed_all = test_passed & test_passed_all + if not test_passed: + test = test.reshape(rows[i], -1) + ref = ref.reshape(rows[i], -1) + for r in range(rows[i]): + test_row = test[r] + ref_row = ref[r] + if not val_fn(test_row, ref_row): + error_msg += f"ERROR: {name} {i} row {r} are different, test {test_row}, ref {ref_row}\n" + assert test_passed_all, error_msg + logging.info(f"{name} test passed") + + if num_inputs == 0: + inputs = [torch.empty(0, dtype=dtype, device=device)] + indices = [torch.empty(0, dtype=torch.long, device=device)] + else: + inputs = [ + torch.rand(rows, cols, dtype=dtype, device=device) + for rows, cols in zip(input_rows, input_columns) + ] + indices = [ + torch.randint( + low=0, high=rows, size=(num,), dtype=torch.long, device=device + ) + for num, rows in zip(input_num_indices, input_rows) + ] + + for i in range(len(inputs)): + inputs[i].requires_grad = True + + output_ref = [ + input.index_select(dim=0, index=index).flatten() + for input, index in zip(inputs, indices) + ] + + concat_inputs = torch.concat( + [input.flatten().clone().detach() for input in inputs] + ) + concat_indices = torch.concat(indices) + + concat_inputs.requires_grad = True + + output_test = torch.ops.fbgemm.batch_index_select_dim0( + concat_inputs, + concat_indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1, + ) + + if permute_output_dim_0_1 and num_inputs > 0: + output_list = output_test.view(input_num_indices[0], -1).split( + input_columns, + dim=1, + ) + output_list = [out.flatten() for out in output_list] + else: + output_list = output_test.split( + [rows * cols for rows, cols in zip(input_num_indices, input_columns)] + ) + + validate(output_list, output_ref, input_num_indices, torch.equal, "output") + + if num_inputs == 0: + grads = [torch.empty(0, dtype=dtype, device=device)] + else: + grads = [torch.rand_like(output) for output in output_ref] + for out_ref, grad in zip(output_ref, grads): + out_ref.backward(grad) + + if permute_output_dim_0_1 and num_inputs > 0: + concat_grads = torch.concat( + [grad.view(input_num_indices[0], -1) for grad in grads], dim=1 + ).flatten() + else: + concat_grads = torch.concat(grads) + + assert concat_grads.shape == output_test.shape + output_test.backward(concat_grads) + + assert concat_inputs.grad is not None + grad_list = concat_inputs.grad.split( + [rows * cols for rows, cols in zip(input_rows, input_columns)] + ) + + grad_ref = [] + for input in inputs: + assert input.grad is not None + grad_ref.append(input.grad.flatten()) + + tol = 1.0e-4 if dtype == torch.float else 1.0e-2 + + validate( + grad_list, + grad_ref, + input_rows, + functools.partial(torch.allclose, atol=tol, rtol=tol), + "grad", + ) + if __name__ == "__main__": unittest.main()