diff --git a/src/llvm-demote-float16.cpp b/src/llvm-demote-float16.cpp index 740055730fb90..7eb591fd07d94 100644 --- a/src/llvm-demote-float16.cpp +++ b/src/llvm-demote-float16.cpp @@ -1,8 +1,9 @@ // This file is a part of Julia. License is MIT: https://julialang.org/license -// This pass finds floating-point operations on 16-bit (half precision) values, and replaces -// them by equivalent operations on 32-bit (single precision) values surrounded by a fpext -// and fptrunc. This ensures that the exact semantics of IEEE floating-point are preserved. +// This pass finds floating-point operations on 16-bit values (half precision and bfloat), +// and replaces them by equivalent operations on 32-bit (single precision) values surrounded +// by a fpext and fptrunc. This ensures that the exact semantics of IEEE floating-point are +// preserved. // // Without this pass, back-ends that do not natively support half-precision (e.g. x86_64) // similarly pattern-match half-precision operations with single-precision equivalents, but @@ -71,10 +72,17 @@ static bool have_fp16(Function &caller, const Triple &TT) { return false; } +static bool have_bf16(Function &caller, const Triple &TT) { + // TODO + return false; +} + static bool demoteFloat16(Function &F) { auto TT = Triple(F.getParent()->getTargetTriple()); - if (have_fp16(F, TT)) + auto has_fp16 = have_fp16(F, TT); + auto has_bf16 = have_bf16(F, TT); + if (has_fp16 && has_bf16) return false; auto &ctx = F.getContext(); @@ -82,14 +90,17 @@ static bool demoteFloat16(Function &F) SmallVector erase; for (auto &BB : F) { for (auto &I : BB) { - // extend Float16 operands to Float32 + // check whether there's any 16-bit floating point operands to extend bool Float16 = I.getType()->getScalarType()->isHalfTy(); - for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) { + bool BFloat16 = I.getType()->getScalarType()->isBFloatTy(); + for (size_t i = 0; !BFloat16 && !Float16 && i < I.getNumOperands(); i++) { Value *Op = I.getOperand(i); - if (Op->getType()->getScalarType()->isHalfTy()) + if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) Float16 = true; + else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) + BFloat16 = true; } - if (!Float16) + if (!Float16 && !BFloat16) continue; switch (I.getOpcode()) { @@ -113,11 +124,16 @@ static bool demoteFloat16(Function &F) IRBuilder<> builder(&I); - // extend Float16 operands to Float32 + // extend 16-bit floating point operands SmallVector Operands(I.getNumOperands()); for (size_t i = 0; i < I.getNumOperands(); i++) { Value *Op = I.getOperand(i); - if (Op->getType()->getScalarType()->isHalfTy()) { + if (!has_fp16 && Op->getType()->getScalarType()->isHalfTy()) { + // extend Float16 to Float32 + ++TotalExt; + Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32)); + } else if (!has_bf16 && Op->getType()->getScalarType()->isBFloatTy()) { + // extend BFloat16 to Float32 ++TotalExt; Op = builder.CreateFPExt(Op, Op->getType()->getWithNewType(T_float32)); } @@ -125,7 +141,7 @@ static bool demoteFloat16(Function &F) } // recreate the instruction if any operands changed, - // truncating the result back to Float16 + // truncating the result back to the original type Value *NewI; ++TotalChanged; switch (I.getOpcode()) {