From e4f9554f841b8c636f76db9c2399299adbc3a0e5 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 13 Oct 2020 15:32:35 +0200 Subject: [PATCH] Explicitly demote Float16 operations to Float32. --- src/Makefile | 2 +- src/aotcompile.cpp | 14 ++++- src/jitlayers.cpp | 1 + src/jitlayers.h | 2 + src/llvm-demote-float16.cpp | 122 ++++++++++++++++++++++++++++++++++++ 5 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 src/llvm-demote-float16.cpp diff --git a/src/Makefile b/src/Makefile index 835c2bf60b55a9..00d7d3fc0de8bd 100644 --- a/src/Makefile +++ b/src/Makefile @@ -56,7 +56,7 @@ RUNTIME_SRCS += jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering \ llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \ llvm-multiversioning llvm-alloc-opt cgmemmgr llvm-api llvm-remove-addrspaces \ - llvm-remove-ni llvm-julia-licm + llvm-remove-ni llvm-julia-licm llvm-demote-float16 FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir) LLVM_LIBS := all ifeq ($(USE_POLLY),1) diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp index b0539786315b56..f49535ad9a7681 100644 --- a/src/aotcompile.cpp +++ b/src/aotcompile.cpp @@ -508,8 +508,10 @@ void jl_dump_native(void *native_code, if (unopt_bc_fname) PM.add(createBitcodeWriterPass(unopt_bc_OS)); - if (bc_fname || obj_fname || asm_fname) + if (bc_fname || obj_fname || asm_fname) { addOptimizationPasses(&PM, jl_options.opt_level, true, true); + addMachinePasses(&PM, TM.get()); + } if (bc_fname) PM.add(createBitcodeWriterPass(bc_OS)); if (obj_fname) @@ -604,6 +606,14 @@ void addTargetPasses(legacy::PassManagerBase *PM, TargetMachine *TM) PM->add(createTargetTransformInfoWrapperPass(TM->getTargetIRAnalysis())); } + +void addMachinePasses(legacy::PassManagerBase *PM, TargetMachine *TM) +{ + // TODO: don't do this on CPUs that natively support Float16 + PM->add(createDemoteFloat16Pass()); +} + + // this defines the set of optimization passes defined for Julia at various optimization levels. // it assumes that the TLI and TTI wrapper passes have already been added. void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level, @@ -809,6 +819,7 @@ class JuliaPipeline : public Pass { TPMAdapter Adapter(TPM); addTargetPasses(&Adapter, jl_TargetMachine); addOptimizationPasses(&Adapter, OptLevel); + addMachinePasses(&Adapter, jl_TargetMachine); } JuliaPipeline() : Pass(PT_PassManager, ID) {} Pass *createPrinterPass(raw_ostream &O, const std::string &Banner) const override { @@ -846,6 +857,7 @@ void *jl_get_llvmf_defn(jl_method_instance_t *mi, size_t world, char getwrapper, PM = new legacy::PassManager(); addTargetPasses(PM, jl_TargetMachine); addOptimizationPasses(PM, jl_options.opt_level); + addMachinePasses(PM, jl_TargetMachine); } // get the source code for this function diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index 6658a26e92d52e..3481db683a95c7 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -523,6 +523,7 @@ static void addPassesForOptLevel(legacy::PassManager &PM, TargetMachine &TM, raw { addTargetPasses(&PM, &TM); addOptimizationPasses(&PM, optlevel); + addMachinePasses(&PM, &TM); if (TM.addPassesToEmitMC(PM, Ctx, ObjStream)) llvm_unreachable("Target does not support MC emission."); } diff --git a/src/jitlayers.h b/src/jitlayers.h index 8dd45c1f939f52..10f371f610cb30 100644 --- a/src/jitlayers.h +++ b/src/jitlayers.h @@ -24,6 +24,7 @@ extern bool imaging_mode; void addTargetPasses(legacy::PassManagerBase *PM, TargetMachine *TM); void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level, bool lower_intrinsics=true, bool dump_native=false); +void addMachinePasses(legacy::PassManagerBase *PM, TargetMachine *TM); void jl_finalize_module(std::unique_ptr m); void jl_merge_module(Module *dest, std::unique_ptr src); Module *jl_create_llvm_module(StringRef name); @@ -241,6 +242,7 @@ Pass *createRemoveNIPass(); Pass *createJuliaLICMPass(); Pass *createMultiVersioningPass(); Pass *createAllocOptPass(); +Pass *createDemoteFloat16Pass(); // Whether the Function is an llvm or julia intrinsic. static inline bool isIntrinsicFunction(Function *F) { diff --git a/src/llvm-demote-float16.cpp b/src/llvm-demote-float16.cpp new file mode 100644 index 00000000000000..a5f0a37cedf35e --- /dev/null +++ b/src/llvm-demote-float16.cpp @@ -0,0 +1,122 @@ +// This file is a part of Julia. License is MIT: https://julialang.org/license + +#include "llvm-version.h" + +#define DEBUG_TYPE "demote_float16" + +#include "support/dtypes.h" + +#include +#include +#include +#include + +using namespace llvm; + +namespace { + +struct DemoteFloat16Pass : public FunctionPass { + static char ID; + DemoteFloat16Pass() : FunctionPass(ID){}; + +private: + bool runOnFunction(Function &F) override; +}; + +bool DemoteFloat16Pass::runOnFunction(Function &F) +{ + auto &ctx = F.getContext(); + auto T_float16 = Type::getHalfTy(ctx); + auto T_float32 = Type::getFloatTy(ctx); + + SmallVector erase; + for (auto &BB : F) { + for (auto &I : BB) { + switch (I.getOpcode()) { + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + break; + default: + continue; + } + + IRBuilder<> builder(&I); + + // extend Float16 operands to Float32 + bool OperandsChanged = false; + SmallVector Operands(I.getNumOperands()); + for (size_t i = 0; i < I.getNumOperands(); i++) { + Value *Op = I.getOperand(i); + if (Op->getType() == T_float16) { + Op = builder.CreateFPExt(Op, T_float32); + OperandsChanged = true; + } + Operands[i] = (Op); + } + + // recreate the instruction if any operands changed, + // truncating the result back to Float16 + if (OperandsChanged) { + Value *NewI; + switch (I.getOpcode()) { + case Instruction::FAdd: + assert(Operands.size() == 2); + NewI = builder.CreateFAddFMF(Operands[0], Operands[1], &I); + break; + case Instruction::FSub: + assert(Operands.size() == 2); + NewI = builder.CreateFSubFMF(Operands[0], Operands[1], &I); + break; + case Instruction::FMul: + assert(Operands.size() == 2); + NewI = builder.CreateFMulFMF(Operands[0], Operands[1], &I); + break; + case Instruction::FDiv: + assert(Operands.size() == 2); + NewI = builder.CreateFDivFMF(Operands[0], Operands[1], &I); + break; + case Instruction::FRem: + assert(Operands.size() == 2); + NewI = builder.CreateFRemFMF(Operands[0], Operands[1], &I); + break; + default: + abort(); + } + ((Instruction *)NewI)->copyMetadata(I); + if (NewI->getType() != I.getType()) + NewI = builder.CreateFPTrunc(NewI, I.getType()); + I.replaceAllUsesWith(NewI); + erase.push_back(&I); + } + } + } + + if (erase.size() > 0) { + for (auto V : erase) + V->eraseFromParent(); + return true; + } + else + return false; +} + +char DemoteFloat16Pass::ID = 0; +static RegisterPass + Y("DemoteFloat16", + "Demote Float16 operations to Float32 equivalents.", + false, + false); +} + +Pass *createDemoteFloat16Pass() +{ + return new DemoteFloat16Pass(); +} + +extern "C" JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass(LLVMPassManagerRef PM) +{ + unwrap(PM)->add(createDemoteFloat16Pass()); +}