Skip to content

Commit

Permalink
pack_segments support fp16/bf16 (#1708)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1708

As titled.

Differential Revision: D44562662

fbshipit-source-id: 9fa25e002aff3ad43132b5aa657640512a991127
  • Loading branch information
brad-mengchi authored and facebook-github-bot committed Apr 19, 2023
1 parent 138685b commit 46bd800
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
12 changes: 8 additions & 4 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2415,8 +2415,10 @@ Tensor pack_segments_forward_cuda(
TENSOR_NDIM_EQUALS(lengths, 1);
TORCH_CHECK(
t_in.dtype() == at::ScalarType::Float ||
t_in.dtype() == at::ScalarType::Double,
"t_in must be of type float or double");
t_in.dtype() == at::ScalarType::Double ||
t_in.dtype() == at::ScalarType::Half ||
t_in.dtype() == at::ScalarType::BFloat16,
"t_in must be of type float or double or half or bfloat16");
TORCH_CHECK(max_length > 0, "max_length must be a positive number");
at::cuda::OptionalCUDAGuard device_guard;
Expand Down Expand Up @@ -2518,8 +2520,10 @@ Tensor pack_segments_backward_cuda(
"LENGTHS and DATA must match in dimension 0");
TORCH_CHECK(
data.dtype() == at::ScalarType::Float ||
data.dtype() == at::ScalarType::Double,
"data must be of type float or double");
data.dtype() == at::ScalarType::Double ||
data.dtype() == at::ScalarType::Half ||
data.dtype() == at::ScalarType::BFloat16,
"data must be of type float or double or half or bfloat16");
TORCH_CHECK(
max_length == data.size(1),
"max_length should be equal to the second dimension of the packed segments");
Expand Down
24 changes: 16 additions & 8 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2202,8 +2202,9 @@ Tensor pack_segments_forward_cpu(
TENSOR_NDIM_EQUALS(lengths, 1);
TORCH_CHECK(
t_in.dtype() == at::ScalarType::Float ||
t_in.dtype() == at::ScalarType::Double,
"t_in must be of type float or double");
t_in.dtype() == at::ScalarType::Double ||
t_in.dtype() == at::ScalarType::Half,
"t_in must be of type float or double or half");
TORCH_CHECK(max_length > 0, "max_length must be a positive number");

const auto t_in_cont = t_in.expect_contiguous();
Expand All @@ -2224,8 +2225,11 @@ Tensor pack_segments_forward_cpu(
return; // Return empty output (with the proper shape)
}

AT_DISPATCH_FLOATING_TYPES(
t_in_cont->scalar_type(), "pack_segments_cpu-packing", ([&]() {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
t_in_cont->scalar_type(),
"pack_segments_cpu-packing",
([&]() {
const auto sizes =
t_in_cont->sizes().slice(1, t_in_cont->sizes().size() - 1);
const auto block_size = c10::multiply_integers(sizes);
Expand Down Expand Up @@ -2269,8 +2273,9 @@ Tensor pack_segments_backward_cpu(
"LENGTHS and DATA must match in dimension 0");
TORCH_CHECK(
data.dtype() == at::ScalarType::Float ||
data.dtype() == at::ScalarType::Double,
"data must be of type float or double");
data.dtype() == at::ScalarType::Double ||
data.dtype() == at::ScalarType::Half,
"data must be of type float or double or half");
TORCH_CHECK(
max_length == data.sizes()[1],
"max_length should be equal to the second dimension of the packed segments");
Expand All @@ -2292,8 +2297,11 @@ Tensor pack_segments_backward_cpu(
return;
}

AT_DISPATCH_FLOATING_TYPES(
data.scalar_type(), "unpack_segments_cpu-unpacking", ([&]() {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
data.scalar_type(),
"unpack_segments_cpu-unpacking",
([&]() {
const auto sizes = data.sizes().slice(2, data.sizes().size() - 2);
const auto block_size = c10::multiply_integers(sizes);
const auto block_bytesize = data.itemsize() * block_size;
Expand Down
30 changes: 26 additions & 4 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,12 @@ def _pack_segments_ref(
k=st.integers(2, 10),
batch_size=st.integers(1, 30),
divisions=st.integers(1, 10),
dtype=st.sampled_from(
[
torch.float,
torch.half,
]
),
)
@settings(deadline=None)
def test_pack_segments(
Expand All @@ -1443,9 +1449,10 @@ def test_pack_segments(
k: int,
batch_size: int,
divisions: int,
dtype: torch.dtype,
) -> None:
input_raw = np.random.rand(batch_size, n, k)
input_data = torch.tensor(input_raw, dtype=torch.float32, requires_grad=True)
input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True)
lengths = torch.tensor(
get_n_rand_num_summing_to_k(divisions, batch_size), dtype=torch.int
)
Expand All @@ -1457,14 +1464,15 @@ def test_pack_segments(

packed_ref = self._pack_segments_ref(lengths, input_raw)

# pyre-fixme[6]: For 2nd param expected `Tensor` but got `ndarray`.
packed_ref = torch.Tensor(packed_ref).to(dtype)
print(f"packed_tensor {packed_tensor}\npacked_ref {packed_ref}")
self.assertTrue(torch.equal(packed_tensor, packed_ref))

grad_cpu = torch.tensor(
np.random.uniform(low=0.01, high=0.5, size=packed_ref.shape).astype(
np.float32
)
)
).to(dtype)
# CPU backward
packed_tensor.backward(grad_cpu)

Expand All @@ -1486,6 +1494,12 @@ def test_pack_segments(
batch_size=st.integers(1, 30),
divisions=st.integers(1, 10),
max_length=st.integers(1, 20),
dtype=st.sampled_from(
[
torch.float,
torch.half,
]
),
)
@settings(deadline=None)
def test_pack_segments_smaller_max_len(
Expand All @@ -1495,8 +1509,9 @@ def test_pack_segments_smaller_max_len(
batch_size: int,
divisions: int,
max_length: int,
dtype: torch.dtype,
) -> None:
input_data = torch.tensor(np.random.rand(batch_size, n, k), dtype=torch.float32)
input_data = torch.tensor(np.random.rand(batch_size, n, k), dtype=dtype)
lengths = torch.tensor(
get_n_rand_num_summing_to_k(divisions, batch_size), dtype=torch.int
)
Expand Down Expand Up @@ -1530,6 +1545,12 @@ def test_pack_segments_smaller_max_len(
k=st.integers(2, 10),
batch_size=st.integers(1, 30),
divisions=st.integers(1, 10),
dtype=st.sampled_from(
[
torch.float,
torch.half,
]
),
)
@settings(deadline=None)
def test_pack_segments_meta_backend(
Expand All @@ -1538,6 +1559,7 @@ def test_pack_segments_meta_backend(
k: int,
batch_size: int,
divisions: int,
dtype: torch.dtype,
) -> None:
input_raw = np.random.rand(batch_size, n, k)
input_data = torch.tensor(
Expand Down

0 comments on commit 46bd800

Please sign in to comment.