Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fbgemm dequantize to bf16 #2241

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Tensor _float_to_fusednbitrowwise_gpu_t(
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);
// think unsigned as we use 0, 255

FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(), "_float_to_fusednbitrowwise_cuda_kernel", [&] {
_float_to_fusednbitrowwise_cuda_kernel<scalar_t>
<<<num_blocks,
Expand Down Expand Up @@ -201,11 +201,11 @@ DLL_PUBLIC at::Tensor _half_to_fusednbitrowwise_gpu(
///
/// @return A new tensor with values from the input tensor converted to
/// fused N-bit rowwise.
DLL_PUBLIC Tensor _float_or_half_to_fusednbitrowwise_gpu(
DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
const Tensor& input,
const int64_t bit_rate) {
Tensor output;
FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(),
"float_or_half_to_fusednbitrowwise_cuda_kernel",
[&] {
Expand Down Expand Up @@ -240,10 +240,16 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else { // T = at::Half
} else if constexpr (std::is_same_v<output_t, at::Half>) {
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
{nrows, output_columns}, // 2 = sizeof(half)
input.options().dtype(at::kHalf));
} else if constexpr (std::is_same_v<output_t, at::BFloat16>) {
output = at::empty(
{nrows, output_columns}, // 2 = sizeof(bfloat16)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false, "Unsupported output dtype");
}

if (nrows == 0 || output_columns == 0) {
Expand All @@ -258,7 +264,7 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);

FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] {
_fusednbitrowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
Expand Down Expand Up @@ -304,19 +310,19 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(

/// @ingroup quantize-ops-cuda
/// Converts a tensor of fused N-bit rowwise values into a tensor of `float` or
/// `at::Half` values.
/// `at::Half` or `at::Bf16` values.
///
/// @param input A tensor of fused N-bit rowwise values
/// @param bit_rate
/// @param output_dtype The target floating point type, specified as integer
/// representation of `SparseType` enum
///
/// @return A new tensor with values from the input tensor converted to `float`
/// or `at::Half`, depending on `output_dtype`.
/// or `at::Half` or `at::Bf16`, depending on `output_dtype`.
///
/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32` or
/// `SparseType::FP16`).
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
/// `SparseType::FP16` or `SparseType::BF16`).
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
const at::Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype) {
Expand All @@ -330,6 +336,9 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
case SparseType::FP16:
output = _fusednbitrowwise_to_float_gpu_t<at::Half>(input, bit_rate);
break;
case SparseType::BF16:
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16>(input, bit_rate);
break;
default:
TORCH_CHECK(false);
}
Expand All @@ -350,7 +359,7 @@ FBGEMM_OP_DISPATCH(
FBGEMM_OP_DISPATCH(
CUDA,
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
fbgemm_gpu::_float_or_half_to_fusednbitrowwise_gpu);
fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBHalfToFloat",
Expand All @@ -362,4 +371,4 @@ FBGEMM_OP_DISPATCH(
FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
fbgemm_gpu::_fusednbitrowwise_to_float_or_half_gpu);
fbgemm_gpu::_fusednbitrowwise_to_single_or_half_precision_gpu);
152 changes: 99 additions & 53 deletions fbgemm_gpu/test/quantize_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,81 +532,101 @@ def test_quantize_op(
@given(
nrows=st.integers(min_value=0, max_value=100),
ncols=st.integers(min_value=0, max_value=100),
bit_rate=st.sampled_from([2, 4]),
is_output_half=st.booleans(),
test_float_or_half_op=st.booleans(),
bit_rate=st.sampled_from([2, 4, 8]),
output_dtype=st.sampled_from(
[SparseType.FP16, SparseType.FP32, SparseType.BF16]
),
test_generic_op=st.booleans(),
test_cuda=st.booleans(),
)
@settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much])
def test_quantize_and_dequantize_op(
self,
nrows: int,
ncols: int,
bit_rate: int,
is_output_half: bool,
test_float_or_half_op: bool,
output_dtype: SparseType,
test_generic_op: bool,
test_cuda: bool,
) -> None:
assert 8 % bit_rate == 0
num_elem_per_byte = 8 // bit_rate
input_data = torch.rand(nrows, ncols).float()
if is_output_half:
if output_dtype == SparseType.FP16:
input_data = input_data.half()
elif output_dtype == SparseType.BF16:
input_data = input_data.bfloat16()

assume(ncols % (2 * num_elem_per_byte) == 0)

if test_float_or_half_op:
quantized_data = (
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
quantized_data,
bit_rate,
output_dtype=1 if is_output_half else 0,
)
)
else:
if not is_output_half:
if not test_cuda:
# cpu path does not support bf16
if output_dtype == SparseType.BF16:
return
if test_generic_op:
quantized_data = (
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat(
quantized_data, bit_rate
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
quantized_data,
bit_rate,
output_dtype.as_int(),
)
)
else:
quantized_data = torch.ops.fbgemm.HalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToHalf(
quantized_data, bit_rate
if output_dtype == SparseType.FP32:
quantized_data = (
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat(
quantized_data, bit_rate
)
)
elif output_dtype == SparseType.FP16:
quantized_data = (
torch.ops.fbgemm.HalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToHalf(
quantized_data, bit_rate
)
)
else:
raise NotImplementedError(
f"Unsupported output dtype {output_dtype} for cpu ops"
)
if nrows == 0 or ncols == 0:
assert dequantized_data.numel() == 0
return
if output_dtype == SparseType.FP32:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
)
if nrows == 0 or ncols == 0:
assert dequantized_data.numel() == 0
return
if not is_output_half:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
)
else:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
elif output_dtype == SparseType.FP16:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
).half()
else:
raise NotImplementedError(
f"Unsupported output dtype {output_dtype} for cpu ops"
)
).half()
torch.testing.assert_close(dequantized_data, reference)
torch.testing.assert_close(dequantized_data, reference)

if gpu_available:
if test_cuda and gpu_available:
input_data_gpu = input_data.cuda()
if test_float_or_half_op:
if test_generic_op:
quantized_data_gpu = (
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
input_data_gpu, bit_rate
Expand All @@ -616,11 +636,14 @@ def test_quantize_and_dequantize_op(
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
quantized_data_gpu,
bit_rate,
output_dtype=1 if is_output_half else 0,
output_dtype.as_int(),
)
)
else:
if not is_output_half:
# legacy path does not support bf16
if SparseType.BF16 == output_dtype:
return
if output_dtype == SparseType.FP32:
quantized_data_gpu = (
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
input_data_gpu, bit_rate
Expand All @@ -631,7 +654,7 @@ def test_quantize_and_dequantize_op(
quantized_data_gpu, bit_rate
)
)
else:
elif output_dtype == SparseType.FP16:
quantized_data_gpu = (
torch.ops.fbgemm.HalfToFusedNBitRowwiseQuantizedSBHalf(
input_data_gpu, bit_rate
Expand All @@ -642,10 +665,33 @@ def test_quantize_and_dequantize_op(
quantized_data_gpu, bit_rate
)
)
if nrows == 0 or ncols == 0:
assert dequantized_data_gpu.numel() == 0
return
# compare quantized data
torch.testing.assert_close(
dequantized_data_gpu.cpu().float(), dequantized_data.float()
)
if output_dtype == SparseType.FP32:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
)
elif output_dtype == SparseType.FP16:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
).half()
elif output_dtype == SparseType.BF16:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
).bfloat16()
else:
raise NotImplementedError(
f"Unsupported output dtype for gpu ops {output_dtype}"
)
torch.testing.assert_close(dequantized_data_gpu.cpu(), reference)

@unittest.skipIf(no_long_tests, "Slow test, requires buck build to run.") # noqa
def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None:
Expand Down
Loading