Skip to content

Commit

Permalink
Add inverse_permute argument to permute_1D_sparse_data
Browse files Browse the repository at this point in the history
Summary: Variable batch size EC requires permute_1D backward. Add inverse_permute so we don't need to generate backward recat.

Reviewed By: YazhiGao

Differential Revision:
D40521813

LaMa Project: L1138451

fbshipit-source-id: cab0e79e3c5cc4cfe60653d541156354ee804d00
  • Loading branch information
xing-liu authored and facebook-github-bot committed Oct 19, 2022
1 parent 79c8f8a commit a179a90
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 23 deletions.
6 changes: 4 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ permute_1D_sparse_data_cuda(
const at::Tensor& lengths,
const at::Tensor& indices,
const c10::optional<at::Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum);
const c10::optional<int64_t>& permuted_lengths_sum,
const c10::optional<bool>& inverse_permute);
#endif

/// @ingroup sparse-data-cuda
Expand Down Expand Up @@ -198,7 +199,8 @@ permute_1D_sparse_data_cpu(
const at::Tensor& lengths,
const at::Tensor& indices,
const c10::optional<at::Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum);
const c10::optional<int64_t>& permuted_lengths_sum,
const c10::optional<bool>& inverse_permute);

at::Tensor _float_to_fused8bitrowwise_gpu(const at::Tensor& input);
at::Tensor _half_to_fused8bitrowwise_gpu(const at::Tensor& input);
Expand Down
50 changes: 39 additions & 11 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,19 @@ __global__ __launch_bounds__(kMaxThreads) void permute_1D_lengths_kernel(
}
}
// Kernel for permuting 1D lengths. Used for permutation of sparse features.
template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void inverse_permute_1D_lengths_kernel(
const index_t* __restrict__ lengths,
int32_t permuted_lengths_size,
const int32_t* __restrict__ permute,
index_t* __restrict__ permuted_lengths) {
CUDA_KERNEL_LOOP(i, permuted_lengths_size) {
permuted_lengths[permute[i]] = lengths[i];
}
}
// Kernel for permuting the indices and weights. Used for permutation of sparse
// data
template <
Expand Down Expand Up @@ -549,7 +562,8 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum) {
const c10::optional<int64_t>& permuted_lengths_sum,
const c10::optional<bool>& inverse_permute) {
TENSOR_ON_CUDA_GPU(permute);
TENSOR_ON_CUDA_GPU(lengths);
TENSOR_ON_CUDA_GPU(indices);
Expand All @@ -576,19 +590,33 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
Tensor permuted_weights;
permuted_lengths = at::empty({permuted_lengths_size}, lengths.options());
bool use_inverse_permute = inverse_permute ? *inverse_permute : false;
constexpr int32_t threads_1 = kMaxThreads;
const auto blocks_1 =
cuda_calc_xblock_count(permuted_lengths_size, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_kernel", [&] {
permute_1D_lengths_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
lengths_contig.data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
if (use_inverse_permute) {
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_kernel", [&] {
inverse_permute_1D_lengths_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
lengths_contig.data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else {
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_kernel", [&] {
permute_1D_lengths_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
lengths_contig.data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
// convert lengths to offsets
const auto input_offsets = asynchronous_exclusive_cumsum_gpu(lengths_contig);
Expand Down
58 changes: 48 additions & 10 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,12 +624,37 @@ void _permute_1D_lengths_cpu_kernel(
});
}

// specialization for variable B and T,
// the permute here maps to all items in length.
template <typename index_t>
void _inverse_permute_1D_lengths_cpu_kernel(
const index_t* const __restrict__ lengths,
int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
index_t* const __restrict__ permuted_lengths) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
// Have a separate loop for summing up lengths
index_t current_output_offset = 0;
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
auto permuted_length = lengths[tb];
permuted_lengths[permute[tb]] = permuted_length;
current_output_offset += permuted_length;
}
});
}

std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum) {
const c10::optional<int64_t>& permuted_lengths_sum,
const c10::optional<bool>& inverse_permute) {
TENSOR_ON_CPU(permute);
TENSOR_ON_CPU(lengths);
TENSOR_ON_CPU(indices);
Expand All @@ -651,14 +676,27 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
std::vector<int64_t> output_offsets_per_thread_cumsum(
(num_threads + 1) * FALSE_SHARING_PAD, 0);

AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] {
_permute_1D_lengths_cpu_kernel(
lengths_contig->data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
}); // for each scalar_t
bool use_inverse_permute = inverse_permute ? *inverse_permute : false;

if (use_inverse_permute) {
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] {
_inverse_permute_1D_lengths_cpu_kernel<index_t>(
lengths_contig->data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
}); // for each scalar_t
} else {
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_1D_lengths_cpu_kernel", [&] {
_permute_1D_lengths_cpu_kernel<index_t>(
lengths_contig->data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
}); // for each scalar_t
}

const auto input_offsets = asynchronous_exclusive_cumsum_cpu(lengths);
const auto output_offsets =
Expand Down Expand Up @@ -2369,7 +2407,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None, bool? inverse_permute=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor");
m.def(
Expand Down

0 comments on commit a179a90

Please sign in to comment.