From a179a90bea90a13e9eb73722172c79fd7dcbbf5a Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Wed, 19 Oct 2022 16:18:34 -0700 Subject: [PATCH] Add inverse_permute argument to permute_1D_sparse_data Summary: Variable batch size EC requires permute_1D backward. Add inverse_permute so we don't need to generate backward recat. Reviewed By: YazhiGao Differential Revision: D40521813 LaMa Project: L1138451 fbshipit-source-id: cab0e79e3c5cc4cfe60653d541156354ee804d00 --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 6 ++- fbgemm_gpu/src/sparse_ops.cu | 50 +++++++++++++++---- fbgemm_gpu/src/sparse_ops_cpu.cpp | 58 ++++++++++++++++++---- 3 files changed, 91 insertions(+), 23 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 444a31f218..10591120c5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -88,7 +88,8 @@ permute_1D_sparse_data_cuda( const at::Tensor& lengths, const at::Tensor& indices, const c10::optional& weights, - const c10::optional& permuted_lengths_sum); + const c10::optional& permuted_lengths_sum, + const c10::optional& inverse_permute); #endif /// @ingroup sparse-data-cuda @@ -198,7 +199,8 @@ permute_1D_sparse_data_cpu( const at::Tensor& lengths, const at::Tensor& indices, const c10::optional& weights, - const c10::optional& permuted_lengths_sum); + const c10::optional& permuted_lengths_sum, + const c10::optional& inverse_permute); at::Tensor _float_to_fused8bitrowwise_gpu(const at::Tensor& input); at::Tensor _half_to_fused8bitrowwise_gpu(const at::Tensor& input); diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 694b4fe679..e201f49f75 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -507,6 +507,19 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_lengths_kernel( } } +// Kernel for permuting 1D lengths. Used for permutation of sparse features. +template +__global__ +__launch_bounds__(kMaxThreads) void inverse_permute_1D_lengths_kernel( + const index_t* __restrict__ lengths, + int32_t permuted_lengths_size, + const int32_t* __restrict__ permute, + index_t* __restrict__ permuted_lengths) { + CUDA_KERNEL_LOOP(i, permuted_lengths_size) { + permuted_lengths[permute[i]] = lengths[i]; + } +} + // Kernel for permuting the indices and weights. Used for permutation of sparse // data template < @@ -549,7 +562,8 @@ std::tuple> permute_1D_sparse_data_cuda( const Tensor& lengths, const Tensor& indices, const c10::optional& weights, - const c10::optional& permuted_lengths_sum) { + const c10::optional& permuted_lengths_sum, + const c10::optional& inverse_permute) { TENSOR_ON_CUDA_GPU(permute); TENSOR_ON_CUDA_GPU(lengths); TENSOR_ON_CUDA_GPU(indices); @@ -576,19 +590,33 @@ std::tuple> permute_1D_sparse_data_cuda( Tensor permuted_weights; permuted_lengths = at::empty({permuted_lengths_size}, lengths.options()); + bool use_inverse_permute = inverse_permute ? *inverse_permute : false; constexpr int32_t threads_1 = kMaxThreads; const auto blocks_1 = cuda_calc_xblock_count(permuted_lengths_size, threads_1); - AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "permute_1D_lengths_kernel", [&] { - permute_1D_lengths_kernel - <<>>( - lengths_contig.data_ptr(), - permuted_lengths_size, - permute.data_ptr(), - permuted_lengths.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + if (use_inverse_permute) { + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "permute_1D_lengths_kernel", [&] { + inverse_permute_1D_lengths_kernel + <<>>( + lengths_contig.data_ptr(), + permuted_lengths_size, + permute.data_ptr(), + permuted_lengths.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } else { + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "permute_1D_lengths_kernel", [&] { + permute_1D_lengths_kernel + <<>>( + lengths_contig.data_ptr(), + permuted_lengths_size, + permute.data_ptr(), + permuted_lengths.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } // convert lengths to offsets const auto input_offsets = asynchronous_exclusive_cumsum_gpu(lengths_contig); diff --git a/fbgemm_gpu/src/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops_cpu.cpp index 8cf7ad48f1..9829e2155d 100644 --- a/fbgemm_gpu/src/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops_cpu.cpp @@ -624,12 +624,37 @@ void _permute_1D_lengths_cpu_kernel( }); } +// specialization for variable B and T, +// the permute here maps to all items in length. +template +void _inverse_permute_1D_lengths_cpu_kernel( + const index_t* const __restrict__ lengths, + int64_t permuted_lengths_size, + const int32_t* const __restrict__ permute, + index_t* const __restrict__ permuted_lengths) { + at::parallel_for( + 0, + permuted_lengths_size, + FALSE_SHARING_PAD, + [&](int64_t tb_begin, int64_t tb_end) { + // Have a separate loop for summing up lengths + index_t current_output_offset = 0; + for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size); + ++tb) { + auto permuted_length = lengths[tb]; + permuted_lengths[permute[tb]] = permuted_length; + current_output_offset += permuted_length; + } + }); +} + std::tuple> permute_1D_sparse_data_cpu( const Tensor& permute, const Tensor& lengths, const Tensor& indices, const c10::optional& weights, - const c10::optional& permuted_lengths_sum) { + const c10::optional& permuted_lengths_sum, + const c10::optional& inverse_permute) { TENSOR_ON_CPU(permute); TENSOR_ON_CPU(lengths); TENSOR_ON_CPU(indices); @@ -651,14 +676,27 @@ std::tuple> permute_1D_sparse_data_cpu( std::vector output_offsets_per_thread_cumsum( (num_threads + 1) * FALSE_SHARING_PAD, 0); - AT_DISPATCH_INDEX_TYPES( - lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] { - _permute_1D_lengths_cpu_kernel( - lengths_contig->data_ptr(), - permuted_lengths_size, - permute.data_ptr(), - permuted_lengths.data_ptr()); - }); // for each scalar_t + bool use_inverse_permute = inverse_permute ? *inverse_permute : false; + + if (use_inverse_permute) { + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] { + _inverse_permute_1D_lengths_cpu_kernel( + lengths_contig->data_ptr(), + permuted_lengths_size, + permute.data_ptr(), + permuted_lengths.data_ptr()); + }); // for each scalar_t + } else { + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] { + _permute_1D_lengths_cpu_kernel( + lengths_contig->data_ptr(), + permuted_lengths_size, + permute.data_ptr(), + permuted_lengths.data_ptr()); + }); // for each scalar_t + } const auto input_offsets = asynchronous_exclusive_cumsum_cpu(lengths); const auto output_offsets = @@ -2369,7 +2407,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); m.def( - "permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)"); + "permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None, bool? inverse_permute=None) -> (Tensor, Tensor, Tensor?)"); m.def( "expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor"); m.def(