Skip to content

Commit

Permalink
Explicitly demote Float16 operations to Float32.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Oct 14, 2020
1 parent 1e69b0a commit e4f9554
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ RUNTIME_SRCS += jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering \
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
llvm-multiversioning llvm-alloc-opt cgmemmgr llvm-api llvm-remove-addrspaces \
llvm-remove-ni llvm-julia-licm
llvm-remove-ni llvm-julia-licm llvm-demote-float16
FLAGS += -I$(shell $(LLVM_CONFIG_HOST) --includedir)
LLVM_LIBS := all
ifeq ($(USE_POLLY),1)
Expand Down
14 changes: 13 additions & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,10 @@ void jl_dump_native(void *native_code,

if (unopt_bc_fname)
PM.add(createBitcodeWriterPass(unopt_bc_OS));
if (bc_fname || obj_fname || asm_fname)
if (bc_fname || obj_fname || asm_fname) {
addOptimizationPasses(&PM, jl_options.opt_level, true, true);
addMachinePasses(&PM, TM.get());
}
if (bc_fname)
PM.add(createBitcodeWriterPass(bc_OS));
if (obj_fname)
Expand Down Expand Up @@ -604,6 +606,14 @@ void addTargetPasses(legacy::PassManagerBase *PM, TargetMachine *TM)
PM->add(createTargetTransformInfoWrapperPass(TM->getTargetIRAnalysis()));
}


void addMachinePasses(legacy::PassManagerBase *PM, TargetMachine *TM)
{
// TODO: don't do this on CPUs that natively support Float16
PM->add(createDemoteFloat16Pass());
}


// this defines the set of optimization passes defined for Julia at various optimization levels.
// it assumes that the TLI and TTI wrapper passes have already been added.
void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level,
Expand Down Expand Up @@ -809,6 +819,7 @@ class JuliaPipeline : public Pass {
TPMAdapter Adapter(TPM);
addTargetPasses(&Adapter, jl_TargetMachine);
addOptimizationPasses(&Adapter, OptLevel);
addMachinePasses(&Adapter, jl_TargetMachine);
}
JuliaPipeline() : Pass(PT_PassManager, ID) {}
Pass *createPrinterPass(raw_ostream &O, const std::string &Banner) const override {
Expand Down Expand Up @@ -846,6 +857,7 @@ void *jl_get_llvmf_defn(jl_method_instance_t *mi, size_t world, char getwrapper,
PM = new legacy::PassManager();
addTargetPasses(PM, jl_TargetMachine);
addOptimizationPasses(PM, jl_options.opt_level);
addMachinePasses(PM, jl_TargetMachine);
}

// get the source code for this function
Expand Down
1 change: 1 addition & 0 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ static void addPassesForOptLevel(legacy::PassManager &PM, TargetMachine &TM, raw
{
addTargetPasses(&PM, &TM);
addOptimizationPasses(&PM, optlevel);
addMachinePasses(&PM, &TM);
if (TM.addPassesToEmitMC(PM, Ctx, ObjStream))
llvm_unreachable("Target does not support MC emission.");
}
Expand Down
2 changes: 2 additions & 0 deletions src/jitlayers.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern bool imaging_mode;

void addTargetPasses(legacy::PassManagerBase *PM, TargetMachine *TM);
void addOptimizationPasses(legacy::PassManagerBase *PM, int opt_level, bool lower_intrinsics=true, bool dump_native=false);
void addMachinePasses(legacy::PassManagerBase *PM, TargetMachine *TM);
void jl_finalize_module(std::unique_ptr<Module> m);
void jl_merge_module(Module *dest, std::unique_ptr<Module> src);
Module *jl_create_llvm_module(StringRef name);
Expand Down Expand Up @@ -241,6 +242,7 @@ Pass *createRemoveNIPass();
Pass *createJuliaLICMPass();
Pass *createMultiVersioningPass();
Pass *createAllocOptPass();
Pass *createDemoteFloat16Pass();
// Whether the Function is an llvm or julia intrinsic.
static inline bool isIntrinsicFunction(Function *F)
{
Expand Down
122 changes: 122 additions & 0 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

#include "llvm-version.h"

#define DEBUG_TYPE "demote_float16"

#include "support/dtypes.h"

#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Debug.h>

using namespace llvm;

namespace {

struct DemoteFloat16Pass : public FunctionPass {
static char ID;
DemoteFloat16Pass() : FunctionPass(ID){};

private:
bool runOnFunction(Function &F) override;
};

bool DemoteFloat16Pass::runOnFunction(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
switch (I.getOpcode()) {
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul:
case Instruction::FDiv:
case Instruction::FRem:
break;
default:
continue;
}

IRBuilder<> builder(&I);

// extend Float16 operands to Float32
bool OperandsChanged = false;
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType() == T_float16) {
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
}
Operands[i] = (Op);
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
if (OperandsChanged) {
Value *NewI;
switch (I.getOpcode()) {
case Instruction::FAdd:
assert(Operands.size() == 2);
NewI = builder.CreateFAddFMF(Operands[0], Operands[1], &I);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
NewI = builder.CreateFSubFMF(Operands[0], Operands[1], &I);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
NewI = builder.CreateFMulFMF(Operands[0], Operands[1], &I);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
NewI = builder.CreateFDivFMF(Operands[0], Operands[1], &I);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
NewI = builder.CreateFRemFMF(Operands[0], Operands[1], &I);
break;
default:
abort();
}
((Instruction *)NewI)->copyMetadata(I);
if (NewI->getType() != I.getType())
NewI = builder.CreateFPTrunc(NewI, I.getType());
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}
}

if (erase.size() > 0) {
for (auto V : erase)
V->eraseFromParent();
return true;
}
else
return false;
}

char DemoteFloat16Pass::ID = 0;
static RegisterPass<DemoteFloat16Pass>
Y("DemoteFloat16",
"Demote Float16 operations to Float32 equivalents.",
false,
false);
}

Pass *createDemoteFloat16Pass()
{
return new DemoteFloat16Pass();
}

extern "C" JL_DLLEXPORT void LLVMExtraAddDemoteFloat16Pass(LLVMPassManagerRef PM)
{
unwrap(PM)->add(createDemoteFloat16Pass());
}

0 comments on commit e4f9554

Please sign in to comment.