Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the index_select operation for dim=0 #1113

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import logging
import random

import click
import fbgemm_gpu
import numpy as np
import torch

logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -69,5 +71,54 @@ def device(
logging.info(f"expand_into_jagged_permute {time} sec {num_bytes / time / 1e9} GB/s")


@cli.command()
@click.option("--row-size", default=25600)
@click.option("--batch-size", default=4096)
@click.option("--unique-batch-size", default=1024)
@click.option("--input-precision", type=str, default="fp32")
def batch_reuse_index_select_device(
row_size: int, batch_size: int, unique_batch_size: int, input_precision: str
) -> None:
# A function for generating indices in batch_reuse
# pyre-fixme[11]: Annotation `array` is not defined as a type.
def gen_inverse_index(curr_size: int, final_size: int) -> np.array:
inverse_index = list(range(curr_size))
np_arr = np.array(inverse_index)
for _ in range(final_size - curr_size):
inverse_index.append(np.random.randint(0, curr_size))
np_arr = np.array(inverse_index)
np.random.shuffle(np_arr)
return np_arr

dtype = torch.float
if input_precision == "fp32":
dtype = torch.float
elif input_precision == "fp16":
dtype = torch.half
else:
raise RuntimeError(f"Does not support data type {input_precision}")

indices = torch.cuda.IntTensor(gen_inverse_index(unique_batch_size, batch_size))

input = torch.rand(unique_batch_size, row_size, dtype=dtype, device="cuda")
input.requires_grad = True
num_bytes = 2 * batch_size * row_size * input.element_size()
time, output = benchmark_torch_function(
torch.ops.fbgemm.index_select_dim0, (input, indices, 0, unique_batch_size)
)
logging.info(
f"index_select_dim0 forward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s"
)

grad = torch.rand_like(output, dtype=dtype, device="cuda")
num_bytes = (input.numel() + output.numel()) * input.element_size()
time, _ = benchmark_torch_function(
functools.partial(output.backward, retain_graph=True), (grad,)
)
logging.info(
f"index_select_dim0 backward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s"
)


if __name__ == "__main__":
cli()
15 changes: 15 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,4 +590,19 @@ at::Tensor pack_segments_backward_cuda(
int64_t total_length,
int64_t max_length);

at::Tensor index_select_with_sorted_indices_cuda(
const at::Tensor& input,
const at::Tensor& sorted_indices,
const at::Tensor& orig_indices,
const int consecutive_range_start,
const int consecutive_range_length);

at::Tensor index_add_with_unique_indices_cuda(
const at::Tensor& grad_output,
const at::Tensor& sorted_indices,
const at::Tensor& orig_indices,
std::vector<int64_t>& input_shape,
const int consecutive_range_start,
const int consecutive_range_length);

} // namespace fbgemm_gpu
297 changes: 297 additions & 0 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2514,4 +2514,301 @@ Tensor pack_segments_backward_cuda(
return unpacked_tensor;
}

constexpr int MAX_ELEMENTS_PER_THREAD = 4;

template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
__global__
__launch_bounds__(kMaxThreads) void index_select_2d_with_sorted_indices_kernel(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits> input,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
sorted_indices,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
orig_indices,
at::PackedTensorAccessor32<scalar_t, 2> output) {
const int N = sorted_indices.size(0);
const int input_size = input.size(0);
const int D = input.size(1);
CUDA_KERNEL_ASSERT(output.size(0) == N)

for (int row = blockIdx.x; row < N; row += gridDim.x) {
const index_t src_idx = sorted_indices[row];
const int64_t dst_idx = orig_indices[row];
CUDA_KERNEL_ASSERT(src_idx < input_size)
int col;
for (col = threadIdx.x * UNROLL_FACTOR;
col < D / UNROLL_FACTOR * UNROLL_FACTOR;
col += blockDim.x * UNROLL_FACTOR) {
#pragma unroll
for (int i = 0; i < UNROLL_FACTOR; i++) {
output[dst_idx][col + i] = __ldg(&input[src_idx][col + i]);
}
}
for (; col < D; ++col) {
output[dst_idx][col] = __ldg(&input[src_idx][col]);
}
}
}

template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
__global__
__launch_bounds__(kMaxThreads) void index_add_2d_with_unique_indices_kernel(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
out_grad,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
unique_indices,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
orig_indices,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
at::PackedTensorAccessor32<scalar_t, 2> in_deduped_grad,
const int stride_D,
const int rounded_D,
const int remaining_D,
const bool consecutive_indices,
const int consecutive_range_start) {
const int start_offset = blockIdx.x == 0 ? 0 : offsets[blockIdx.x - 1];
const int end_offset = offsets[blockIdx.x];
index_t dst_idx = consecutive_indices ? blockIdx.x + consecutive_range_start
: unique_indices[blockIdx.x];
const bool has_remainder = blockIdx.y == blockDim.y - 1 && remaining_D > 0 &&
threadIdx.x < remaining_D;

// Buffer for storing temporary results
scalar_t sum[MAX_ELEMENTS_PER_THREAD];
for (int i = 0; i < MAX_ELEMENTS_PER_THREAD; i++) {
sum[i] = 0;
}

scalar_t sum_remainder = 0;

// Each thread block processes max of stride_D elements
int start_D = (blockIdx.y * stride_D) + (threadIdx.x * UNROLL_FACTOR);

// For each row
for (int row = start_offset; row < end_offset; row++) {
int64_t src_idx = orig_indices[row];
int col, i;
for (col = start_D, i = 0; col < start_D + stride_D && col < rounded_D;
col += blockDim.x * UNROLL_FACTOR, i += UNROLL_FACTOR) {
#pragma unroll
for (int j = 0; j < UNROLL_FACTOR; j++) {
sum[i + j] += __ldg(&out_grad[src_idx][col + j]);
}
}
if (has_remainder) {
sum_remainder += __ldg(&out_grad[src_idx][rounded_D + threadIdx.x]);
}
} // for each row

// Write results to global memory
int col, i;
for (col = start_D, i = 0; col < start_D + stride_D && col < rounded_D;
col += blockDim.x * UNROLL_FACTOR, i += UNROLL_FACTOR) {
#pragma unroll
for (int j = 0; j < UNROLL_FACTOR; j++) {
in_deduped_grad[dst_idx][col + j] = sum[i + j];
}
}
if (has_remainder) {
in_deduped_grad[dst_idx][rounded_D + threadIdx.x] += sum_remainder;
}
}

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void compute_frequency_sequence_kernel(
index_t* input,
int64_t* output,
const int size) {
const int i = blockDim.x * blockIdx.x + threadIdx.x;

if (i >= size) {
return;
}
// Atomic could become a bottleneck if frequencies are very skew
atomicAdd(&output[input[i]], 1);
}

void compute_frequency_sequence(
const Tensor& input,
Tensor& output,
const int size) {
output = at::zeros({size}, input.options().dtype(at::kLong));

AT_DISPATCH_INDEX_TYPES(
input.scalar_type(), "compute_frequency_sequence_kernel_1", [&] {
compute_frequency_sequence_kernel<index_t>
<<<cuda_calc_xblock_count(input.numel(), kWarpSize),
kWarpSize,
0,
at::cuda::getCurrentCUDAStream()>>>(
input.data_ptr<index_t>(), output.data_ptr<int64_t>(), size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}

template <
typename scalar_t,
int ndim,
template <typename U> class PtrTraits = at::DefaultPtrTraits>
at::PackedTensorAccessor32<scalar_t, ndim, PtrTraits>
dummy_packed_accessor32() {
std::array<int64_t, ndim> zeros{};
return {nullptr, zeros.data(), zeros.data()};
}

Tensor index_select_with_sorted_indices_cuda(
const Tensor& input,
const Tensor& sorted_indices,
const Tensor& orig_indices,
const int consecutive_range_start,
const int consecutive_range_length) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(input.get_device());

const int N = sorted_indices.size(0);
auto output_shape = input.sizes().vec();
output_shape[0] = N;

if (input.numel() == 0 || N == 0) {
return at::empty(output_shape, input.options());
}

Tensor input_reshaped = input.reshape({input.size(0), -1});
const int D = input_reshaped.size(1);

Tensor output = at::empty({N, D}, input_reshaped.options());

const int UNROLL_FACTOR = 2;

AT_DISPATCH_INDEX_TYPES(
sorted_indices.scalar_type(), "index_add_2d_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input_reshaped.scalar_type(), "index_add_2d_kernel_2", [&] {
index_select_2d_with_sorted_indices_kernel<
index_t,
scalar_t,
UNROLL_FACTOR><<<
cuda_calc_xblock_count(N, 1),
std::min(div_round_up(D, UNROLL_FACTOR), kMaxThreads),
0,
at::cuda::getCurrentCUDAStream()>>>(
input_reshaped
.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
sorted_indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
orig_indices
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
output.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});

return output.reshape(output_shape);
}

Tensor index_add_with_unique_indices_cuda(
const Tensor& grad_output,
const Tensor& sorted_indices,
const Tensor& orig_indices,
std::vector<int64_t>& input_shape,
const int consecutive_range_start,
const int consecutive_range_length) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_output.get_device());

const int N = grad_output.size(0);

if (grad_output.numel() == 0) {
return at::zeros(input_shape, grad_output.options());
}

const Tensor grad_output_reshaped = grad_output.reshape({N, -1});
const int D = grad_output_reshaped.size(1);

TORCH_CHECK(sorted_indices.size(0) == N);

Tensor input_grad = at::zeros({input_shape[0], D}, grad_output.options());
bool consecutive_indices =
consecutive_range_start >= 0 && consecutive_range_length > 0;

AT_DISPATCH_INDEX_TYPES(
sorted_indices.scalar_type(), "index_add_2d_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "index_add_2d_kernel_2", [&] {
// UNROLL_FACTOR is determined based on the empirical study
const int UNROLL_FACTOR = std::is_same<scalar_t, float>() ? 4 : 2;
const int rounded_D = D / UNROLL_FACTOR * UNROLL_FACTOR;
const int remaining_D = D - rounded_D;
int block_size =
std::min(div_round_up(D, UNROLL_FACTOR), kMaxThreads);
block_size = std::max(remaining_D, block_size);
// Number of elements per block
const int stride_D = MAX_ELEMENTS_PER_THREAD * block_size;

int num_unique_indices;
Tensor unique_indices, offsets;
if (consecutive_indices) {
TORCH_CHECK(
consecutive_range_start < input_shape[0] &&
consecutive_range_start + consecutive_range_length - 1 <
input_shape[0]);

// Since indices are selected from consecutive range, we can
// infer the number of unique indices from
// consecutive_range_length
num_unique_indices = consecutive_range_length;
compute_frequency_sequence(
sorted_indices, offsets, num_unique_indices);
offsets = offsets.cumsum(0);
} else {
Tensor unique_count;
// Unique consecutive does D->H transfer internally
// (enforcing synchronization between host and device)
std::tie(unique_indices, std::ignore, unique_count) =
at::unique_consecutive(sorted_indices, false, true, 0);

// This does D->H transfer
num_unique_indices = unique_indices.numel();
offsets = unique_count.cumsum(0);
}

const dim3 grid_size(
cuda_calc_xblock_count(num_unique_indices, 1),
(D + stride_D - 1) / stride_D,
1);

index_add_2d_with_unique_indices_kernel<
index_t,
scalar_t,
UNROLL_FACTOR><<<
grid_size,
block_size,
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_reshaped
.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
consecutive_indices ? dummy_packed_accessor32<
index_t,
1,
at::RestrictPtrTraits>()
: unique_indices.packed_accessor32<
index_t,
1,
at::RestrictPtrTraits>(),
orig_indices
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
offsets
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
input_grad.packed_accessor32<scalar_t, 2>(),
stride_D, // Pass constants as kernel args
rounded_D,
remaining_D,
consecutive_indices,
consecutive_range_start);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return input_grad.reshape(input_shape);
}

} // namespace fbgemm_gpu
Loading