Skip to content

Commit

Permalink
Add BF16 output support for inference TBE (pytorch#1498)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1498

As title

Reviewed By: jiecaoyu

Differential Revision: D41835847

fbshipit-source-id: d9a8af08345d83c6f0912f9bc98022632f03764f
  • Loading branch information
jianyuh authored and facebook-github-bot committed Dec 13, 2022
1 parent fe0f7ab commit 39a423e
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
return;
}
static_assert(
std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, uint8_t>::value,
std::is_same<output_t, float>::value || std::is_same<output_t, at::BFloat16>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, uint8_t>::value,
"output_t can only be float or half or bytes now"
);

Expand Down Expand Up @@ -331,7 +331,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
}
{% else %}
const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value) {
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, at::BFloat16>::value) {
#pragma unroll MaxNum128BRows
for (uint32_t j = 0; j < MaxNum128BRows; ++j) {
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
Expand Down Expand Up @@ -388,7 +388,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
const uint32_t b = min(static_cast<uint32_t>(bb * OutputRowsPerThread + i), static_cast<uint32_t>(B - 1));
const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast<float>(1.0) / Ls[i]: static_cast<float>(1.0);

if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value) {
if (std::is_same<output_t, float>::value || std::is_same<output_t, at::Half>::value || std::is_same<output_t, at::BFloat16>::value) {
#pragma unroll MaxNum128BRows
for (uint32_t j = 0; j < MaxNum128BRows; ++j) {
const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
Expand Down Expand Up @@ -625,7 +625,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
Tensor output;
const int kINT8QparamsBytes = 8;
SparseType o_dtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8);
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);
{% if not nobag %}
int64_t total_adjusted_D = total_D;
if (o_dtype == SparseType::INT8) {
Expand Down
Loading

0 comments on commit 39a423e

Please sign in to comment.