From 6de87ac52e616cc21cf1e2effc7ec45f2883dfb8 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Tue, 13 Dec 2022 14:03:12 -0800 Subject: [PATCH] Add BF16 output support for inference TBE (#1503) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1503 Follow up from D41835847 (https://github.com/pytorch/FBGEMM/commit/39a423e4ad1a04b77fea81c7d09c3e6f8984fae9) BF16 will work on sm80+ (A100+) device capacity. Differential Revision: D41865889 fbshipit-source-id: e8be57d601a43a0f7d6b680b1e0fb7d23d2f1aa9 --- .../include/fbgemm_gpu/dispatch_macros.h | 28 +++++++++++++++++++ .../include/fbgemm_gpu/fbgemm_cuda_utils.cuh | 8 +++--- .../split_table_batched_embeddings_test.py | 6 ++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h b/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h index 94f4ecb37f..f6551d20e9 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h +++ b/fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h @@ -103,6 +103,32 @@ return __VA_ARGS__(); \ } +#if !( \ + defined(USE_ROCM) || \ + ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) + +#define DISPATCH_OUTPUT_TYPES(OUTPUT_TYPE, NAME, ...) \ + [&] { \ + const auto& output_type = OUTPUT_TYPE; \ + at::ScalarType _output_t = ::detail::scalar_type(output_type); \ + switch (_output_t) { \ + PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Half, at::Half, __VA_ARGS__) \ + PRIVATE_CASE_TYPE_OUTPUT2( \ + at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \ + PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \ + PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + default: \ + AT_ERROR( \ + #NAME, \ + " not implemented for output_t '", \ + toString(_output_t), \ + "'"); \ + } \ + }() + +#else + #define DISPATCH_OUTPUT_TYPES(OUTPUT_TYPE, NAME, ...) \ [&] { \ const auto& output_type = OUTPUT_TYPE; \ @@ -120,6 +146,8 @@ } \ }() +#endif + #define PRIVATE_CASE_TYPE_CACHE_EMB( \ grad_enum_type, _cache_t, _emb_t, grad_cxx_type, NAME, ...) \ case grad_enum_type: { \ diff --git a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh index f826598af7..fcba8f3784 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh @@ -1477,10 +1477,10 @@ DEVICE_INLINE __nv_bfloat162 to_bfloat16_2(float2 v) { __nv_bfloat162 raw; __nv_bfloat16 x; __nv_bfloat16 y; - } tmp; - tmp.x = __float2bfloat16_rn(v.x); - tmp.y = __float2bfloat16_rn(v.y); - return tmp.raw; + } t; + t.x = __float2bfloat16_rn(v.x); + t.y = __float2bfloat16_rn(v.y); + return t.raw; #endif } diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 096f9f74ca..671b0635da 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -831,6 +831,7 @@ def test_forward_fused_pooled_emb_quant( output_dtype=st.sampled_from( [ SparseType.FP16, + SparseType.BF16, SparseType.INT8, # SparseType.INT4, ] @@ -856,6 +857,11 @@ def test_nbit_forward_fused_pooled_emb_quant( D_alignment = max(D_alignment, output_dtype.align_size()) D = round_up(D, D_alignment) + if ( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 0) + ) and output_dtype == SparseType.BF16: + output_dtype = SparseType.FP16 + Ds = [ round_up( np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)),