Skip to content

Commit

Permalink
Add BF16 output support for inference TBE (pytorch#1503)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1503

Follow up from D41835847 (pytorch@39a423e)

BF16 will work on sm80+ (A100+) device capacity.

Differential Revision: D41865889

fbshipit-source-id: e8be57d601a43a0f7d6b680b1e0fb7d23d2f1aa9
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 13, 2022
1 parent 39a423e commit 6de87ac
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
28 changes: 28 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/dispatch_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,32 @@
return __VA_ARGS__(); \
}

#if !( \
defined(USE_ROCM) || \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))

#define DISPATCH_OUTPUT_TYPES(OUTPUT_TYPE, NAME, ...) \
[&] { \
const auto& output_type = OUTPUT_TYPE; \
at::ScalarType _output_t = ::detail::scalar_type(output_type); \
switch (_output_t) { \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Half, at::Half, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2( \
at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \
PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
default: \
AT_ERROR( \
#NAME, \
" not implemented for output_t '", \
toString(_output_t), \
"'"); \
} \
}()

#else

#define DISPATCH_OUTPUT_TYPES(OUTPUT_TYPE, NAME, ...) \
[&] { \
const auto& output_type = OUTPUT_TYPE; \
Expand All @@ -120,6 +146,8 @@
} \
}()

#endif

#define PRIVATE_CASE_TYPE_CACHE_EMB( \
grad_enum_type, _cache_t, _emb_t, grad_cxx_type, NAME, ...) \
case grad_enum_type: { \
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1477,10 +1477,10 @@ DEVICE_INLINE __nv_bfloat162 to_bfloat16_2(float2 v) {
__nv_bfloat162 raw;
__nv_bfloat16 x;
__nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(v.x);
tmp.y = __float2bfloat16_rn(v.y);
return tmp.raw;
} t;
t.x = __float2bfloat16_rn(v.x);
t.y = __float2bfloat16_rn(v.y);
return t.raw;
#endif
}

Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def test_forward_fused_pooled_emb_quant(
output_dtype=st.sampled_from(
[
SparseType.FP16,
SparseType.BF16,
SparseType.INT8,
# SparseType.INT4,
]
Expand All @@ -856,6 +857,11 @@ def test_nbit_forward_fused_pooled_emb_quant(
D_alignment = max(D_alignment, output_dtype.align_size())
D = round_up(D, D_alignment)

if (
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 0)
) and output_dtype == SparseType.BF16:
output_dtype = SparseType.FP16

Ds = [
round_up(
np.random.randint(low=int(max(0.25 * D, 1)), high=int(1.0 * D)),
Expand Down

0 comments on commit 6de87ac

Please sign in to comment.