Skip to content

Commit

Permalink
Add invert_permute (#1403)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1403

Variable batch size EC requires permute_1D backward. Add inverse_permute so we don't need to generate backward recat.

Reviewed By: YazhiGao

Differential Revision:
D40521813

LaMa Project: L1138451

fbshipit-source-id: 2c28a3af86492b4db823c1f2cbfe2d52121d7470
  • Loading branch information
xing-liu authored and facebook-github-bot committed Oct 24, 2022
1 parent afb21a6 commit 9dc70f2
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 39 deletions.
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ permute_1D_sparse_data_cuda(
const at::Tensor& indices,
const c10::optional<at::Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum);

at::Tensor invert_permute_cuda(const at::Tensor& permute);
#endif

/// @ingroup sparse-data-cuda
Expand Down
33 changes: 32 additions & 1 deletion fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
lengths_contig.data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permute_contig.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down Expand Up @@ -674,6 +674,37 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
return {permuted_lengths, permuted_indices, permuted_weights};
}
template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void invert_permute_kernel(
int32_t permute_size,
const index_t* __restrict__ permute,
index_t* __restrict__ inversed_permute) {
CUDA_KERNEL_LOOP(i, permute_size) {
inversed_permute[permute[i]] = i;
}
}
Tensor invert_permute_cuda(const Tensor& permute) {
TENSOR_ON_CUDA_GPU(permute);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(permute.get_device());
const auto permute_contig = permute.contiguous();
const auto permute_size = permute.numel();
Tensor inversed_permute = at::empty_like(permute);
constexpr int32_t threads_1 = kMaxThreads;
const auto blocks_1 = cuda_calc_xblock_count(permute_size, threads_1);
AT_DISPATCH_INDEX_TYPES(permute.scalar_type(), "invert_permute_kernel", [&] {
invert_permute_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
permute_size,
permute_contig.data_ptr<index_t>(),
inversed_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return inversed_permute;
}
// Kernel for generate 1D data permute from dimension permute index.
// Used for permutation of sparse features.
template <typename index_t, typename offsets_t>
Expand Down
108 changes: 70 additions & 38 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,43 +114,6 @@ void _permute_2D_indices_weights_kernel_cpu(
}); // parallel_for T * B
}

// specialization for variable B and T,
// the permute here maps to all items in length.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
void _permute_1D_indices_weights_kernel_cpu(
const offsets_t* const __restrict__ input_offsets,
const indices_t* const __restrict__ indices,
const weights_t* const __restrict__ weights,
const int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
const offsets_t* const __restrict__ permuted_lengths,
const offsets_t* const __restrict__ output_offsets,
indices_t* const __restrict__ permuted_indices,
weights_t* const __restrict__ permuted_weights) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
offsets_t permuted_length = permuted_lengths[tb];
const offsets_t input_start = input_offsets[permute[tb]];
const offsets_t output_start = output_offsets[tb];
for (const auto i : c10::irange(permuted_length)) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}); // parallel_for T x B, different B across T
}

template <typename index_t>
void _permute_2D_lengths_cpu_kernel(
const int32_t T,
Expand Down Expand Up @@ -624,6 +587,43 @@ void _permute_1D_lengths_cpu_kernel(
});
}

// specialization for variable B and T,
// the permute here maps to all items in length.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
void _permute_1D_indices_weights_kernel_cpu(
const offsets_t* const __restrict__ input_offsets,
const indices_t* const __restrict__ indices,
const weights_t* const __restrict__ weights,
const int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
const offsets_t* const __restrict__ permuted_lengths,
const offsets_t* const __restrict__ output_offsets,
indices_t* const __restrict__ permuted_indices,
weights_t* const __restrict__ permuted_weights) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
offsets_t permuted_length = permuted_lengths[tb];
const offsets_t input_start = input_offsets[permute[tb]];
const offsets_t output_start = output_offsets[tb];
for (const auto i : c10::irange(permuted_length)) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}); // parallel_for T x B, different B across T
}

std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
const Tensor& permute,
const Tensor& lengths,
Expand Down Expand Up @@ -656,7 +656,7 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
_permute_1D_lengths_cpu_kernel(
lengths_contig->data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permute_contig->data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
}); // for each scalar_t

Expand Down Expand Up @@ -779,6 +779,36 @@ Tensor expand_into_jagged_permute_cpu(
return output_permute;
}

template <typename index_t>
void _invert_permute_cpu_kernel(
const int64_t permute_size,
const index_t* const __restrict__ permute,
index_t* const __restrict__ inversed_permute) {
at::parallel_for(
0, permute_size, FALSE_SHARING_PAD, [&](int64_t t_begin, int64_t t_end) {
for (int t = t_begin; t < std::min(t_end, permute_size); ++t) {
inversed_permute[permute[t]] = t;
}
});
}

Tensor invert_permute_cpu(const Tensor& permute) {
TENSOR_ON_CPU(permute);
const auto permute_contig = permute.expect_contiguous();
const auto permute_size = permute.numel();
Tensor inversed_permute = at::empty_like(permute);

AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "invert_permute_cpu_kernel", [&] {
_invert_permute_cpu_kernel<index_t>(
permute_size,
permute_contig->data_ptr<index_t>(),
inversed_permute.data_ptr<index_t>());
}); // for each scalar_t

return inversed_permute;
}

std::tuple<
Tensor,
Tensor,
Expand Down Expand Up @@ -2370,6 +2400,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def("invert_permute(Tensor permute) -> Tensor");
m.def(
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor");
m.def(
Expand Down Expand Up @@ -2442,6 +2473,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cpu);
DISPATCH_TO_CPU(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cpu);
DISPATCH_TO_CPU("invert_permute", fbgemm_gpu::invert_permute_cpu);
DISPATCH_TO_CPU(
"expand_into_jagged_permute", fbgemm_gpu::expand_into_jagged_permute_cpu);
DISPATCH_TO_CPU(
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cuda);
DISPATCH_TO_CUDA(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cuda);
DISPATCH_TO_CUDA("invert_permute", fbgemm_gpu::invert_permute_cuda);
DISPATCH_TO_CUDA(
"expand_into_jagged_permute",
fbgemm_gpu::expand_into_jagged_permute_cuda);
Expand Down
27 changes: 27 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,33 @@ def test_permute_indices(
else:
assert permuted_weights_gpu is None

# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
permute_size=st.integers(min_value=30, max_value=1000),
long_index=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_invert_permute(
self,
permute_size: int,
long_index: bool,
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
permute_list = list(range(permute_size))
random.shuffle(permute_list)
inversed_permute_list = [0] * len(permute_list)
for i in range(permute_size):
inversed_permute_list[permute_list[i]] = i
permute = torch.IntTensor(permute_list).type(index_dtype)
inverse_permute_ref = torch.IntTensor(inversed_permute_list).type(index_dtype)

inverse_permute_cpu = torch.ops.fbgemm.invert_permute(permute)
torch.testing.assert_close(inverse_permute_cpu, inverse_permute_ref)

if gpu_available:
inverse_permute_gpu = torch.ops.fbgemm.invert_permute(permute.cuda())
torch.testing.assert_close(inverse_permute_gpu.cpu(), inverse_permute_cpu)

# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
B=st.integers(min_value=1, max_value=20),
Expand Down

0 comments on commit 9dc70f2

Please sign in to comment.