Skip to content

Commit

Permalink
fbgemm dequantize to bf16 (pytorch#2241)
Browse files Browse the repository at this point in the history
Summary:

enable new path for bf16 output

Differential Revision: D52403556
  • Loading branch information
jiayisuse authored and facebook-github-bot committed Jan 10, 2024
1 parent 005cd43 commit e3b5f6d
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 65 deletions.
33 changes: 21 additions & 12 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Tensor _float_to_fusednbitrowwise_gpu_t(
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);
// think unsigned as we use 0, 255

FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(), "_float_to_fusednbitrowwise_cuda_kernel", [&] {
_float_to_fusednbitrowwise_cuda_kernel<scalar_t>
<<<num_blocks,
Expand Down Expand Up @@ -201,11 +201,11 @@ DLL_PUBLIC at::Tensor _half_to_fusednbitrowwise_gpu(
///
/// @return A new tensor with values from the input tensor converted to
/// fused N-bit rowwise.
DLL_PUBLIC Tensor _float_or_half_to_fusednbitrowwise_gpu(
DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
const Tensor& input,
const int64_t bit_rate) {
Tensor output;
FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(),
"float_or_half_to_fusednbitrowwise_cuda_kernel",
[&] {
Expand Down Expand Up @@ -240,10 +240,16 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else { // T = at::Half
} else if constexpr (std::is_same_v<output_t, at::Half>) {
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
{nrows, output_columns}, // 2 = sizeof(half)
input.options().dtype(at::kHalf));
} else if constexpr (std::is_same_v<output_t, at::BFloat16>) {
output = at::empty(
{nrows, output_columns}, // 2 = sizeof(bfloat16)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false, "Unsupported output dtype");
}
if (nrows == 0 || output_columns == 0) {
Expand All @@ -258,7 +264,7 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] {
_fusednbitrowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
Expand Down Expand Up @@ -304,19 +310,19 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(
/// @ingroup quantize-ops-cuda
/// Converts a tensor of fused N-bit rowwise values into a tensor of `float` or
/// `at::Half` values.
/// `at::Half` or `at::Bf16` values.
///
/// @param input A tensor of fused N-bit rowwise values
/// @param bit_rate
/// @param output_dtype The target floating point type, specified as integer
/// representation of `SparseType` enum
///
/// @return A new tensor with values from the input tensor converted to `float`
/// or `at::Half`, depending on `output_dtype`.
/// or `at::Half` or `at::Bf16`, depending on `output_dtype`.
///
/// @throw c10::Error if `output_dtype` is not one of (`SparseType::FP32` or
/// `SparseType::FP16`).
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
/// `SparseType::FP16` or `SparseType::BF16`).
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
const at::Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype) {
Expand All @@ -330,6 +336,9 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_or_half_gpu(
case SparseType::FP16:
output = _fusednbitrowwise_to_float_gpu_t<at::Half>(input, bit_rate);
break;
case SparseType::BF16:
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16>(input, bit_rate);
break;
default:
TORCH_CHECK(false);
}
Expand All @@ -350,7 +359,7 @@ FBGEMM_OP_DISPATCH(
FBGEMM_OP_DISPATCH(
CUDA,
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf",
fbgemm_gpu::_float_or_half_to_fusednbitrowwise_gpu);
fbgemm_gpu::_single_or_half_precision_to_fusednbitrowwise_gpu);
FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBHalfToFloat",
Expand All @@ -362,4 +371,4 @@ FBGEMM_OP_DISPATCH(
FBGEMM_OP_DISPATCH(
CUDA,
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
fbgemm_gpu::_fusednbitrowwise_to_float_or_half_gpu);
fbgemm_gpu::_fusednbitrowwise_to_single_or_half_precision_gpu);
152 changes: 99 additions & 53 deletions fbgemm_gpu/test/quantize_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,81 +532,101 @@ def test_quantize_op(
@given(
nrows=st.integers(min_value=0, max_value=100),
ncols=st.integers(min_value=0, max_value=100),
bit_rate=st.sampled_from([2, 4]),
is_output_half=st.booleans(),
test_float_or_half_op=st.booleans(),
bit_rate=st.sampled_from([2, 4, 8]),
output_dtype=st.sampled_from(
[SparseType.FP16, SparseType.FP32, SparseType.BF16]
),
test_generic_op=st.booleans(),
test_cuda=st.booleans(),
)
@settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much])
def test_quantize_and_dequantize_op(
self,
nrows: int,
ncols: int,
bit_rate: int,
is_output_half: bool,
test_float_or_half_op: bool,
output_dtype: SparseType,
test_generic_op: bool,
test_cuda: bool,
) -> None:
assert 8 % bit_rate == 0
num_elem_per_byte = 8 // bit_rate
input_data = torch.rand(nrows, ncols).float()
if is_output_half:
if output_dtype == SparseType.FP16:
input_data = input_data.half()
elif output_dtype == SparseType.BF16:
input_data = input_data.bfloat16()

assume(ncols % (2 * num_elem_per_byte) == 0)

if test_float_or_half_op:
quantized_data = (
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
quantized_data,
bit_rate,
output_dtype=1 if is_output_half else 0,
)
)
else:
if not is_output_half:
if not test_cuda:
# cpu path does not support bf16
if output_dtype == SparseType.BF16:
return
if test_generic_op:
quantized_data = (
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat(
quantized_data, bit_rate
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
quantized_data,
bit_rate,
output_dtype.as_int(),
)
)
else:
quantized_data = torch.ops.fbgemm.HalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToHalf(
quantized_data, bit_rate
if output_dtype == SparseType.FP32:
quantized_data = (
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat(
quantized_data, bit_rate
)
)
elif output_dtype == SparseType.FP16:
quantized_data = (
torch.ops.fbgemm.HalfToFusedNBitRowwiseQuantizedSBHalf(
input_data, bit_rate
)
)
dequantized_data = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToHalf(
quantized_data, bit_rate
)
)
else:
raise NotImplementedError(
f"Unsupported output dtype {output_dtype} for cpu ops"
)
if nrows == 0 or ncols == 0:
assert dequantized_data.numel() == 0
return
if output_dtype == SparseType.FP32:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
)
if nrows == 0 or ncols == 0:
assert dequantized_data.numel() == 0
return
if not is_output_half:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
)
else:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
elif output_dtype == SparseType.FP16:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
).half()
else:
raise NotImplementedError(
f"Unsupported output dtype {output_dtype} for cpu ops"
)
).half()
torch.testing.assert_close(dequantized_data, reference)
torch.testing.assert_close(dequantized_data, reference)

if gpu_available:
if test_cuda and gpu_available:
input_data_gpu = input_data.cuda()
if test_float_or_half_op:
if test_generic_op:
quantized_data_gpu = (
torch.ops.fbgemm.FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
input_data_gpu, bit_rate
Expand All @@ -616,11 +636,14 @@ def test_quantize_and_dequantize_op(
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
quantized_data_gpu,
bit_rate,
output_dtype=1 if is_output_half else 0,
output_dtype.as_int(),
)
)
else:
if not is_output_half:
# legacy path does not support bf16
if SparseType.BF16 == output_dtype:
return
if output_dtype == SparseType.FP32:
quantized_data_gpu = (
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
input_data_gpu, bit_rate
Expand All @@ -631,7 +654,7 @@ def test_quantize_and_dequantize_op(
quantized_data_gpu, bit_rate
)
)
else:
elif output_dtype == SparseType.FP16:
quantized_data_gpu = (
torch.ops.fbgemm.HalfToFusedNBitRowwiseQuantizedSBHalf(
input_data_gpu, bit_rate
Expand All @@ -642,10 +665,33 @@ def test_quantize_and_dequantize_op(
quantized_data_gpu, bit_rate
)
)
if nrows == 0 or ncols == 0:
assert dequantized_data_gpu.numel() == 0
return
# compare quantized data
torch.testing.assert_close(
dequantized_data_gpu.cpu().float(), dequantized_data.float()
)
if output_dtype == SparseType.FP32:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
)
elif output_dtype == SparseType.FP16:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
).half()
elif output_dtype == SparseType.BF16:
reference = torch.from_numpy(
fused_rowwise_nbit_quantize_dequantize_reference(
input_data.float().numpy(), bit_rate
)
).bfloat16()
else:
raise NotImplementedError(
f"Unsupported output dtype for gpu ops {output_dtype}"
)
torch.testing.assert_close(dequantized_data_gpu.cpu(), reference)

@unittest.skipIf(no_long_tests, "Slow test, requires buck build to run.") # noqa
def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None:
Expand Down

0 comments on commit e3b5f6d

Please sign in to comment.