Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inverse_permute argument to permute_1D_sparse_data #1403

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ permute_1D_sparse_data_cuda(
const at::Tensor& indices,
const c10::optional<at::Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum);

at::Tensor invert_permute_cuda(const at::Tensor& permute);
#endif

/// @ingroup sparse-data-cuda
Expand Down
33 changes: 32 additions & 1 deletion fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
lengths_contig.data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permute_contig.data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand Down Expand Up @@ -674,6 +674,37 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
return {permuted_lengths, permuted_indices, permuted_weights};
}

template <typename index_t>
__global__ __launch_bounds__(kMaxThreads) void invert_permute_kernel(
int32_t permute_size,
const index_t* __restrict__ permute,
index_t* __restrict__ inversed_permute) {
CUDA_KERNEL_LOOP(i, permute_size) {
inversed_permute[permute[i]] = i;
}
}

Tensor invert_permute_cuda(const Tensor& permute) {
TENSOR_ON_CUDA_GPU(permute);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(permute.get_device());
const auto permute_contig = permute.contiguous();
const auto permute_size = permute.numel();
Tensor inversed_permute = at::empty_like(permute);

constexpr int32_t threads_1 = kMaxThreads;
const auto blocks_1 = cuda_calc_xblock_count(permute_size, threads_1);
AT_DISPATCH_INDEX_TYPES(permute.scalar_type(), "invert_permute_kernel", [&] {
invert_permute_kernel<index_t>
<<<blocks_1, threads_1, 0, at::cuda::getCurrentCUDAStream()>>>(
permute_size,
permute_contig.data_ptr<index_t>(),
inversed_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return inversed_permute;
}

// Kernel for generate 1D data permute from dimension permute index.
// Used for permutation of sparse features.
template <typename index_t, typename offsets_t>
Expand Down
108 changes: 70 additions & 38 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,43 +114,6 @@ void _permute_2D_indices_weights_kernel_cpu(
}); // parallel_for T * B
}

// specialization for variable B and T,
// the permute here maps to all items in length.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
void _permute_1D_indices_weights_kernel_cpu(
const offsets_t* const __restrict__ input_offsets,
const indices_t* const __restrict__ indices,
const weights_t* const __restrict__ weights,
const int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
const offsets_t* const __restrict__ permuted_lengths,
const offsets_t* const __restrict__ output_offsets,
indices_t* const __restrict__ permuted_indices,
weights_t* const __restrict__ permuted_weights) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
offsets_t permuted_length = permuted_lengths[tb];
const offsets_t input_start = input_offsets[permute[tb]];
const offsets_t output_start = output_offsets[tb];
for (const auto i : c10::irange(permuted_length)) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}); // parallel_for T x B, different B across T
}

template <typename index_t>
void _permute_2D_lengths_cpu_kernel(
const int32_t T,
Expand Down Expand Up @@ -624,6 +587,43 @@ void _permute_1D_lengths_cpu_kernel(
});
}

// specialization for variable B and T,
// the permute here maps to all items in length.
template <
bool has_weight,
typename offsets_t,
typename indices_t,
typename weights_t>
void _permute_1D_indices_weights_kernel_cpu(
const offsets_t* const __restrict__ input_offsets,
const indices_t* const __restrict__ indices,
const weights_t* const __restrict__ weights,
const int64_t permuted_lengths_size,
const int32_t* const __restrict__ permute,
const offsets_t* const __restrict__ permuted_lengths,
const offsets_t* const __restrict__ output_offsets,
indices_t* const __restrict__ permuted_indices,
weights_t* const __restrict__ permuted_weights) {
at::parallel_for(
0,
permuted_lengths_size,
FALSE_SHARING_PAD,
[&](int64_t tb_begin, int64_t tb_end) {
for (int tb = tb_begin; tb < std::min(tb_end, permuted_lengths_size);
++tb) {
offsets_t permuted_length = permuted_lengths[tb];
const offsets_t input_start = input_offsets[permute[tb]];
const offsets_t output_start = output_offsets[tb];
for (const auto i : c10::irange(permuted_length)) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
}
}
}
}); // parallel_for T x B, different B across T
}

std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
const Tensor& permute,
const Tensor& lengths,
Expand Down Expand Up @@ -656,7 +656,7 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
_permute_1D_lengths_cpu_kernel(
lengths_contig->data_ptr<index_t>(),
permuted_lengths_size,
permute.data_ptr<int32_t>(),
permute_contig->data_ptr<int32_t>(),
permuted_lengths.data_ptr<index_t>());
}); // for each scalar_t

Expand Down Expand Up @@ -779,6 +779,36 @@ Tensor expand_into_jagged_permute_cpu(
return output_permute;
}

template <typename index_t>
void _invert_permute_cpu_kernel(
const int64_t permute_size,
const index_t* const __restrict__ permute,
index_t* const __restrict__ inversed_permute) {
at::parallel_for(
0, permute_size, FALSE_SHARING_PAD, [&](int64_t t_begin, int64_t t_end) {
for (int t = t_begin; t < std::min(t_end, permute_size); ++t) {
inversed_permute[permute[t]] = t;
}
});
}

Tensor invert_permute_cpu(const Tensor& permute) {
TENSOR_ON_CPU(permute);
const auto permute_contig = permute.expect_contiguous();
const auto permute_size = permute.numel();
Tensor inversed_permute = at::empty_like(permute);

AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "invert_permute_cpu_kernel", [&] {
_invert_permute_cpu_kernel<index_t>(
permute_size,
permute_contig->data_ptr<index_t>(),
inversed_permute.data_ptr<index_t>());
}); // for each scalar_t

return inversed_permute;
}

std::tuple<
Tensor,
Tensor,
Expand Down Expand Up @@ -2370,6 +2400,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def("invert_permute(Tensor permute) -> Tensor");
m.def(
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor");
m.def(
Expand Down Expand Up @@ -2442,6 +2473,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cpu);
DISPATCH_TO_CPU(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cpu);
DISPATCH_TO_CPU("invert_permute", fbgemm_gpu::invert_permute_cpu);
DISPATCH_TO_CPU(
"expand_into_jagged_permute", fbgemm_gpu::expand_into_jagged_permute_cpu);
DISPATCH_TO_CPU(
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cuda);
DISPATCH_TO_CUDA(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cuda);
DISPATCH_TO_CUDA("invert_permute", fbgemm_gpu::invert_permute_cuda);
DISPATCH_TO_CUDA(
"expand_into_jagged_permute",
fbgemm_gpu::expand_into_jagged_permute_cuda);
Expand Down
27 changes: 27 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,33 @@ def test_permute_indices(
else:
assert permuted_weights_gpu is None

# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
permute_size=st.integers(min_value=30, max_value=1000),
long_index=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_invert_permute(
self,
permute_size: int,
long_index: bool,
) -> None:
index_dtype = torch.int64 if long_index else torch.int32
permute_list = list(range(permute_size))
random.shuffle(permute_list)
inversed_permute_list = [0] * len(permute_list)
for i in range(permute_size):
inversed_permute_list[permute_list[i]] = i
permute = torch.IntTensor(permute_list).type(index_dtype)
inverse_permute_ref = torch.IntTensor(inversed_permute_list).type(index_dtype)

inverse_permute_cpu = torch.ops.fbgemm.invert_permute(permute)
torch.testing.assert_close(inverse_permute_cpu, inverse_permute_ref)

if gpu_available:
inverse_permute_gpu = torch.ops.fbgemm.invert_permute(permute.cuda())
torch.testing.assert_close(inverse_permute_gpu.cpu(), inverse_permute_cpu)

# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
B=st.integers(min_value=1, max_value=20),
Expand Down