Skip to content

Commit

Permalink
Replace cvt instructions with bitwise operations in s8->bf16 conversions
Browse files Browse the repository at this point in the history
Hopper has very low throughput of conversion instructions that cause this
operations to quickly become an ALU bottleneck. Restating it in terms of
bitwise ops and SIMD bf16 instructions increases the throughput significantly
and translates to meaningful speedups (e.g. 10% end-to-end on one matmul I was
looking at).
  • Loading branch information
apaszke committed Aug 22, 2024
1 parent 8c5e33c commit b53d56b
Showing 1 changed file with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,20 @@ static const Fp8ConversionDesc Fp32_to_Fp8E5M2 = {
"cvt.rn.satfinite.e5m2x2.f32 $0, $2, $1; \n", 32, 16, 2};

/* ----- Packed integer to BF16 ------ */
// Hopper has very low throughput of most conversions, so we rely on bit
// tricks instead of cvt instructions.
static const std::string S8_to_Bf16 =
"{ \n"
".reg .s8 s<4>; \n"
".reg .f32 f<4>; \n"
"mov.b32 {s0, s1, s2, s3}, $2; \n" // unpack
"cvt.rn.f32.s8 f0, s0; \n" // no s8->bf16 pre-Hopper
"cvt.rn.f32.s8 f1, s1; \n" // fi[0:15] is always 0
"cvt.rn.f32.s8 f2, s2; \n" //
"cvt.rn.f32.s8 f3, s3; \n" //
"prmt.b32 $0, f0, f1, 0x7632; \n" // f32->bf16 + pack
"prmt.b32 $1, f2, f3, 0x7632; \n" //
"{ \n"
".reg .b32 l<3>; \n"
".reg .b32 h<3>; \n"
"prmt.b32 l0, $2, 0x43, 0x4140; \n" // Unpack to shifted bf16.
"prmt.b32 h0, $2, 0x43, 0x4342; \n"
"and.b32 l1, l0, 0xff7fff7f; \n" // Zero the least exp bit.
"and.b32 h1, h0, 0xff7fff7f; \n"
"and.b32 l2, l0, 0xff80ff80; \n" // Zero the mantissa.
"and.b32 h2, h0, 0xff80ff80; \n"
"sub.bf16x2 $0, l1, l2; \n" // Subtract the offset.
"sub.bf16x2 $1, h1, h2; \n"
"}";

typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
Expand Down

0 comments on commit b53d56b

Please sign in to comment.