Skip to content

Commit

Permalink
Add masked_index_select and refactor masked_index_put (pytorch#2910)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2910

This diff adds `torch.ops.fbgemm.masked_index_select` (will be used in
subsequent diffs) and refactors `torch.ops.fbgemm.masked_index_put`.

- Add unit tests
- Add docstrings
- Add contiguous tensor checks

Differential Revision: D60362812
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 29, 2024
1 parent e588dfa commit b905f14
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ vec4_copy(uint8_t* dst, const uint8_t* src, const int32_t D) {
}
}

template <typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel(
template <typename scalar_t, bool is_index_put>
__global__ __launch_bounds__(kMaxThreads) void masked_index_kernel(
pta::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> self,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
indices,
Expand All @@ -67,20 +67,27 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel(
if (n >= count_) {
return;
}
// idx == -1 if it is conflict miss
const auto idx = indices[n];
if (idx < 0) {
return;
}
const auto D = self.size(1);
vec4_copy(&self[idx][0], &values[n][0], D);
const auto self_idx = is_index_put ? idx : n;
const auto values_idx = is_index_put ? n : idx;
vec4_copy(&self[self_idx][0], &values[values_idx][0], D);
}

Tensor masked_index_put_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count) {
template <bool is_index_put>
Tensor masked_index_impl(
const Tensor& self,
const Tensor& indices,
const Tensor& values,
const Tensor& count) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(self, indices, values, count);
TENSOR_CONTIGUOUS(self);
TENSOR_CONTIGUOUS(indices);
TENSOR_CONTIGUOUS(values);

CUDA_DEVICE_GUARD(self);

Expand All @@ -93,17 +100,18 @@ Tensor masked_index_put_cuda(

FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
self.scalar_type(),
"masked_index_put",
is_index_put ? "masked_index_put" : "masked_index_select",
[&] {
const int32_t tx = std::min<int32_t>(D / 4, kMaxThreads);
const dim3 threads(tx, kMaxThreads / tx);
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "masked_index_put_kernel";
const auto func_name = is_index_put ? "masked_index_put_kernel"
: "masked_index_select_kernel";
#endif
if (std::is_same_v<scalar_t, uint8_t>) {
TORCH_CHECK(D % 16 == 0, "D needs to be padded to be multiple of 16")
}
masked_index_put_kernel<scalar_t>
masked_index_kernel<scalar_t, is_index_put>
<<<div_round_up(N, kMaxThreads / tx),
dim3(tx, kMaxThreads / tx),
0,
Expand All @@ -119,6 +127,23 @@ Tensor masked_index_put_cuda(
return self;
}
Tensor masked_index_put_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count) {
return masked_index_impl</*is_index_put=*/true>(self, indices, values, count);
}
Tensor masked_index_select_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count) {
return masked_index_impl</*is_index_put=*/false>(
self, indices, values, count);
}
__global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_state,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,61 @@ ssd_cache_populate_actions_cuda(
int64_t prefetch_dist,
Tensor lru_state);

/// @ingroup embedding-ssd
///
/// @brief Similar to `torch.Tensor.index_put` but ignore `indices < 0`
///
/// `masked_index_put_cuda` only supports 2D input `values`. It puts
/// `count` rows in `values` into `self` using the row indices that
/// are >= 0 in `indices`.
///
/// ```python
/// # Equivalent PyTorch Python code
/// indices = indices[:count]
/// filter_ = indices >= 0
/// indices_ = indices[filter_]
/// self[indices_] = values[filter_.nonzero().flatten()]
/// ```
///
/// @param self The 2D output tensor (the tensor that is indexed)
/// @param indices The 1D index tensor
/// @param values The 2D input tensor
/// @param count The tensor that contains the length of `indices` to
/// process
///
/// @return The `self` tensor
Tensor
masked_index_put_cuda(Tensor self, Tensor indices, Tensor values, Tensor count);

/// @ingroup embedding-ssd
///
/// @brief Similar to `torch.index_select` but ignore `indices < 0`
///
/// `masked_index_select_cuda` only supports 2D input `values`. It
/// puts `count` rows that are specified in `indices` (where `indices`
/// >= 0) from `values` into `self`
///
/// ```python
/// # Equivalent PyTorch Python code
/// indices = indices[:count]
/// filter_ = indices >= 0
/// indices_ = indices[filter_]
/// self[filter_.nonzero().flatten()] = values[indices_]
/// ```
///
/// @param self The 2D output tensor
/// @param indices The 1D index tensor
/// @param values The 2D input tensor (the tensor that is indexed)
/// @param count The tensor that contains the length of `indices` to
/// process
///
/// @return The `self` tensor
Tensor masked_index_select_cuda(
Tensor self,
Tensor indices,
Tensor values,
Tensor count);

Tensor masked_index_put_byte_cuda(
Tensor self,
Tensor indices,
Expand Down Expand Up @@ -197,6 +249,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor count"
") -> Tensor");
DISPATCH_TO_CUDA("masked_index_put", masked_index_put_cuda);
m.def(
"masked_index_select("
" Tensor self, "
" Tensor indices, "
" Tensor values, "
" Tensor count"
") -> Tensor");
DISPATCH_TO_CUDA("masked_index_select", masked_index_select_cuda);
m.def(
"ssd_cache_populate_actions("
" Tensor linear_indices, "
Expand Down
156 changes: 156 additions & 0 deletions fbgemm_gpu/test/tbe/ssd/ssd_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest
from typing import Callable

import fbgemm_gpu.tbe.ssd # noqa F401
import hypothesis.strategies as st
import torch
from hypothesis import given, settings, Verbosity

from .. import common # noqa E402
from ..common import open_source

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, running_on_github
else:
from fbgemm_gpu.test.test_utils import gpu_unavailable, running_on_github


MAX_EXAMPLES = 20


@unittest.skipIf(*running_on_github)
@unittest.skipIf(*gpu_unavailable)
class SSDUtilsTest(unittest.TestCase):
def execute_masked_index_test(
self,
D: int,
max_index: int,
num_indices: int,
num_value_rows: int,
num_output_rows: int,
dtype: torch.dtype,
test_fn: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
],
is_index_put: bool,
) -> None:
"""
A helper function that generates inputs/outputs, runs
torch.ops.fbgemm.masked_index_* against the PyTorch counterpart, and
compares the output results"""
device = "cuda"

# Number of columns must be multiple of 4 (embedding requirement)
D = D * 4

# Generate indices
indices = torch.randint(
low=0, high=max_index, size=(num_indices,), dtype=torch.long, device=device
)

# Compute/set unique indices (indices have to be unique to avoid race
# condition)
indices_unique = indices.unique()
count_val = indices_unique.numel()
indices[:count_val] = indices_unique

# Permute unique indices
rand_pos = torch.randperm(indices_unique.numel(), device=device)
indices[:count_val] = indices[rand_pos]

# Set some indices to -1
indices[rand_pos[: max(count_val // 2, 1)]] = -1

# Generate count tensor
count = torch.as_tensor([count_val], dtype=torch.int, device=device)

# Generate values
values = torch.rand(num_value_rows, D, dtype=dtype, device=device)

# Allocate output and output_ref
output = torch.zeros(num_output_rows, D, dtype=dtype, device=device)
output_ref = torch.zeros(num_output_rows, D, dtype=dtype, device=device)

# Run test
output = test_fn(output, indices, values, count)

# Run reference
indices = indices[:count_val]
filter_ = indices >= 0
indices_ = indices[filter_]
filter_locs = filter_.nonzero().flatten()
if is_index_put:
output_ref[indices_] = values[filter_locs]
else:
output_ref[filter_locs] = values[indices_]

# Compare results
assert torch.equal(output_ref, output)

# pyre-ignore [56]
@given(
num_indices=st.integers(min_value=10, max_value=100),
D=st.integers(min_value=2, max_value=256),
num_output_rows=st.integers(min_value=10, max_value=100),
dtype=st.sampled_from([torch.float, torch.half]),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_masked_index_put(
self,
num_indices: int,
D: int,
num_output_rows: int,
dtype: torch.dtype,
) -> None:
"""
Test correctness of torch.ops.fbgemm.masked_index_put against PyTorch's
index_put
"""
self.execute_masked_index_test(
D=D,
max_index=num_output_rows,
num_indices=num_indices,
num_value_rows=num_indices,
num_output_rows=num_output_rows,
dtype=dtype,
test_fn=torch.ops.fbgemm.masked_index_put,
is_index_put=True,
)

# pyre-ignore [56]
@given(
num_indices=st.integers(min_value=10, max_value=100),
D=st.integers(min_value=2, max_value=256),
num_value_rows=st.integers(min_value=10, max_value=100),
dtype=st.sampled_from([torch.float, torch.half]),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_masked_index_select(
self,
num_indices: int,
D: int,
num_value_rows: int,
dtype: torch.dtype,
) -> None:
"""
Test correctness of torch.ops.fbgemm.masked_index_select aginst
PyTorch's index_select
"""
self.execute_masked_index_test(
D=D,
max_index=num_value_rows,
num_indices=num_indices,
num_value_rows=num_value_rows,
num_output_rows=num_indices,
dtype=dtype,
test_fn=torch.ops.fbgemm.masked_index_select,
is_index_put=False,
)

0 comments on commit b905f14

Please sign in to comment.