Skip to content

Commit

Permalink
Extend Float16 demote pass to BFloat16.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 2, 2023
1 parent d893ef1 commit 6022476
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

// This pass finds floating-point operations on 16-bit (half precision) values, and replaces
// them by equivalent operations on 32-bit (single precision) values surrounded by a fpext
// and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved.
// This pass finds floating-point operations on 16-bit values (half precision and bfloat),
// and replaces them by equivalent operations on 32-bit (single precision) values surrounded
// by a fpext and fptrunc. This ensures that the exact semantics of IEEE floating-point are
// preserved.
//
// Without this pass, back-ends that do not natively support half-precision (e.g. x86_64)
// similarly pattern-match half-precision operations with single-precision equivalents, but
Expand Down Expand Up @@ -71,25 +72,35 @@ static bool have_fp16(Function &caller, const Triple &TT) {
return false;
}

static bool have_bf16(Function &caller, const Triple &TT) {
// TODO
return false;
}

static bool demoteFloat16(Function &F)
{
auto TT = Triple(F.getParent()->getTargetTriple());
if (have_fp16(F, TT))
auto has_fp16 = have_fp16(F, TT);
auto has_bf16 = have_bf16(F, TT);
if (has_fp16 && has_bf16)
return false;

auto &ctx = F.getContext();
auto T_float32 = Type::getFloatTy(ctx);
SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
// check whether there's any 16-bit floating point operands to extend
bool Float16 = I.getType()->getScalarType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
bool BFloat16 = I.getType()->getScalarType()->isBFloatTy();
for (size_t i = 0; !BFloat16 && !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy())
if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy())
Float16 = true;
else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy())
BFloat16 = true;
}
if (!Float16)
if (!Float16 && !BFloat16)
continue;

switch (I.getOpcode()) {
Expand All @@ -113,19 +124,24 @@ static bool demoteFloat16(Function &F)

IRBuilder<> builder(&I);

// extend Float16 operands to Float32
// extend 16-bit floating point operands
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->getScalarType()->isHalfTy()) {
if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) {
// extend Float16 to Float32
++TotalExt;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
} else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) {
// extend BFloat16 to Float32
++TotalExt;
Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32));
}
Operands[i] = Op;
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
// truncating the result back to the original type
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
Expand Down

0 comments on commit 6022476

Please sign in to comment.