Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][LLVM] Add ability to turn on fast math flags #9223

Merged
merged 16 commits into from
Oct 19, 2021
25 changes: 23 additions & 2 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm,
this->InitTarget(tm);
}

void CodeGenLLVM::SetFastMathFlag(llvm::FastMathFlags fmf) { builder_->setFastMathFlags(fmf); }

void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
module_->setTargetTriple(tm->getTargetTriple().str());
module_->setDataLayout(tm->createDataLayout());
Expand Down Expand Up @@ -343,7 +345,26 @@ void CodeGenLLVM::Optimize() {

// place optimization pass
llvm::PassManagerBuilder builder;
builder.OptLevel = 3;

// Use the same opt-level as specified in TargetMachine for running passes
llvm::CodeGenOpt::Level opt_level = target_machine_->getOptLevel();

switch (opt_level) {
case llvm::CodeGenOpt::Level::None:
builder.OptLevel = 0;
break;
case llvm::CodeGenOpt::Level::Less:
builder.OptLevel = 1;
break;

case llvm::CodeGenOpt::Level::Default:
builder.OptLevel = 2;
masahi marked this conversation as resolved.
Show resolved Hide resolved
break;

default:
// CodeGenOpt::Level::Aggressive
builder.OptLevel = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A related comment: This code path should hit by default, otherwise this change would introduce regression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see fair enough, I think OptLevel 3 should not be the default but let me get some data to support this first. Changed to make OptLevel 3 the default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

#if TVM_LLVM_VERSION >= 50
builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0, false);
Expand Down Expand Up @@ -410,7 +431,7 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
} else {
return etype;
}
}
} // namespace codegen

llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
if (auto* ptr = type.as<PrimTypeNode>()) {
Expand Down
7 changes: 7 additions & 0 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*/
virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx,
bool system_lib, bool dynamic_lookup, bool target_c_runtime);

/*!
* \brief Turn on fast math flags for floating point operations.
* \param fmf FastMathFlags to use for code generation.
*/
void SetFastMathFlag(llvm::FastMathFlags fmf);

/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
Expand Down
20 changes: 18 additions & 2 deletions src/target/llvm/llvm_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::stri
#if TVM_LLVM_VERSION < 50
opt.LessPreciseFPMADOption = true;
#endif
// In clang, these are fed from LangOpts which describe language specific features
// TODO(AndrewZhaoLuo): figure out how these relate to fast math flags
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
Expand Down Expand Up @@ -139,8 +141,22 @@ std::unique_ptr<llvm::TargetMachine> GetLLVMTargetMachine(const Target& target,
ICHECK(allow_null) << err << " target_triple=" << target_triple;
return nullptr;
}
llvm::TargetMachine* tm =
llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_);

Integer llvm_opt_level = target->GetAttr<Integer>("opt-level").value_or(Integer(3));
llvm::CodeGenOpt::Level llvm_opt;
if (llvm_opt_level <= 0) {
llvm_opt = llvm::CodeGenOpt::None;
} else if (llvm_opt_level == 1) {
llvm_opt = llvm::CodeGenOpt::Less;
} else if (llvm_opt_level == 2) {
llvm_opt = llvm::CodeGenOpt::Default;
} else {
// llvm_opt_level >= 3
llvm_opt = llvm::CodeGenOpt::Aggressive;
}

llvm::TargetMachine* tm = llvm_target->createTargetMachine(
target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_, llvm::CodeModel::Small, llvm_opt);
return std::unique_ptr<llvm::TargetMachine>(tm);
}

Expand Down
47 changes: 46 additions & 1 deletion src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,53 @@ class LLVMModuleNode final : public runtime::ModuleNode {
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib, target_c_runtime);

cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
// See https://llvm.org/docs/LangRef.html#fast-math-flags for details
Bool fast_math_all = target->GetAttr<Bool>("fast-math").value_or(Bool(false));
Bool fast_math_nnan = target->GetAttr<Bool>("fast-math-nnan").value_or(Bool(false));
Bool fast_math_ninf = target->GetAttr<Bool>("fast-math-ninf").value_or(Bool(false));
Bool fast_math_nsz = target->GetAttr<Bool>("fast-math-nsz").value_or(Bool(false));
Bool fast_math_arcp = target->GetAttr<Bool>("fast-math-arcp").value_or(Bool(false));

llvm::FastMathFlags fmf;
if (fast_math_all) {
#if TVM_LLVM_VERSION >= 60
fmf.setFast();
#else
fmf.setUnsafeAlgebra();
#endif
}

if (fast_math_nnan) {
fmf.setNoNaNs();
}
if (fast_math_ninf) {
fmf.setNoInfs();
}
if (fast_math_nsz) {
fmf.setNoSignedZeros();
}
if (fast_math_arcp) {
fmf.setAllowReciprocal();
}

#if TVM_LLVM_VERSION >= 60
Bool fast_math_contract = target->GetAttr<Bool>("fast-math-contract").value_or(Bool(false));
Bool fast_math_afn = target->GetAttr<Bool>("fast-math-afn").value_or(Bool(false));
Bool fast_math_reassoc = target->GetAttr<Bool>("fast-math-reassoc").value_or(Bool(false));
if (fast_math_contract) {
fmf.setAllowContract();
}
if (fast_math_afn) {
fmf.setApproxFunc();
}
if (fast_math_reassoc) {
fmf.setAllowReassoc();
}
#endif

cg->SetFastMathFlag(fmf);

cg->AddFunctionsOrdered(funcs.begin(), funcs.end());
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
}
Expand Down
9 changes: 9 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
.add_attr_option<Bool>("link-params", Bool(false))
.add_attr_option<Bool>("unpacked-api")
.add_attr_option<String>("interface-api")
// Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags
.add_attr_option<Bool>("fast-math") // implies all the below
.add_attr_option<Bool>("fast-math-nnan")
.add_attr_option<Bool>("fast-math-ninf")
.add_attr_option<Bool>("fast-math-nsz")
.add_attr_option<Bool>("fast-math-arcp")
.add_attr_option<Bool>("fast-math-contract")
.add_attr_option<Bool>("fast-math-reassoc")
.add_attr_option<Integer>("opt-level")
.set_default_keys({"cpu"});

TVM_REGISTER_TARGET_KIND("c", kDLCPU)
Expand Down