Skip to content

Commit

Permalink
Revert "Remove llvm-muladd pass and move it's functionality to to llv…
Browse files Browse the repository at this point in the history
…m-simdloop (#55802)"

This reverts commit 69ed5fd.
  • Loading branch information
KristofferC committed Oct 24, 2024
1 parent 69ed5fd commit 71838b9
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 96 deletions.
12 changes: 12 additions & 0 deletions doc/src/devdocs/llvm-passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ This pass is used to verify Julia's invariants about LLVM IR. This includes thin

These passes are used to perform transformations on LLVM IR that LLVM will not perform itself, e.g. fast math flag propagation, escape analysis, and optimizations on Julia-specific internal functions. They use knowledge about Julia's semantics to perform these optimizations.

### CombineMulAdd

* Filename: `llvm-muladd.cpp`
* Class Name: `CombineMulAddPass`
* Opt Name: `function(CombineMulAdd)`

This pass serves to optimize the particular combination of a regular `fmul` with a fast `fadd` into a contract `fmul` with a fast `fadd`. This is later optimized by the backend to a [fused multiply-add](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add) instruction, which can provide significantly faster operations at the cost of more [unpredictable semantics](https://simonbyrne.github.io/notes/fastmath/).

!!! note

This optimization only occurs when the `fmul` has a single use, which is the fast `fadd`.

### AllocOpt

* Filename: `llvm-alloc-opt.cpp`
Expand Down
1 change: 1 addition & 0 deletions doc/src/devdocs/llvm.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The code for lowering Julia AST to LLVM IR or interpreting it directly is in dir
| `llvm-julia-licm.cpp` | Custom LLVM pass to hoist/sink Julia-specific intrinsics |
| `llvm-late-gc-lowering.cpp` | Custom LLVM pass to root GC-tracked values |
| `llvm-lower-handlers.cpp` | Custom LLVM pass to lower try-catch blocks |
| `llvm-muladd.cpp` | Custom LLVM pass for fast-match FMA |
| `llvm-multiversioning.cpp` | Custom LLVM pass to generate sysimg code on multiple architectures |
| `llvm-propagate-addrspaces.cpp` | Custom LLVM pass to canonicalize addrspaces |
| `llvm-ptls.cpp` | Custom LLVM pass to lower TLS operations |
Expand Down
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ RT_LLVMLINK :=
CG_LLVMLINK :=

ifeq ($(JULIACODEGEN),LLVM)
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop \
CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd \
llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \
llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \
llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \
Expand Down
117 changes: 117 additions & 0 deletions src/llvm-muladd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// This file is a part of Julia. License is MIT: https://julialang.org/license

#include "llvm-version.h"
#include "passes.h"

#include <llvm-c/Core.h>
#include <llvm-c/Types.h>

#include <llvm/ADT/Statistic.h>
#include <llvm/Analysis/OptimizationRemarkEmitter.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Operator.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Verifier.h>
#include <llvm/Pass.h>
#include <llvm/Support/Debug.h>

#include "julia.h"
#include "julia_assert.h"

#define DEBUG_TYPE "combine-muladd"
#undef DEBUG

using namespace llvm;
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");

#ifndef __clang_gcanalyzer__
#define REMARK(remark) ORE.emit(remark)
#else
#define REMARK(remark) (void) 0;
#endif

/**
* Combine
* ```
* %v0 = fmul ... %a, %b
* %v = fadd contract ... %v0, %c
* ```
* to
* `%v = call contract @llvm.fmuladd.<...>(... %a, ... %b, ... %c)`
* when `%v0` has no other use
*/

// Return true if we changed the mulOp
static bool checkCombine(Value *maybeMul, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
{
auto mulOp = dyn_cast<Instruction>(maybeMul);
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
return false;
if (!mulOp->hasOneUse()) {
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
});
return false;
}
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
auto fmf = mulOp->getFastMathFlags();
if (!fmf.allowContract()) {
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
<< "marked for fma " << ore::NV("fmul", mulOp);
});
++TotalContracted;
fmf.setAllowContract(true);
mulOp->copyFastMathFlags(fmf);
return true;
}
return false;
}

static bool combineMulAdd(Function &F) JL_NOTSAFEPOINT
{
OptimizationRemarkEmitter ORE(&F);
bool modified = false;
for (auto &BB: F) {
for (auto it = BB.begin(); it != BB.end();) {
auto &I = *it;
it++;
switch (I.getOpcode()) {
case Instruction::FAdd: {
if (!I.hasAllowContract())
continue;
modified |= checkCombine(I.getOperand(0), ORE) || checkCombine(I.getOperand(1), ORE);
break;
}
case Instruction::FSub: {
if (!I.hasAllowContract())
continue;
modified |= checkCombine(I.getOperand(0), ORE) || checkCombine(I.getOperand(1), ORE);
break;
}
default:
break;
}
}
}
#ifdef JL_VERIFY_PASSES
assert(!verifyLLVMIR(F));
#endif
return modified;
}

PreservedAnalyses CombineMulAddPass::run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT
{
if (combineMulAdd(F)) {
return PreservedAnalyses::allInSet<CFGAnalyses>();
}
return PreservedAnalyses::all();
}
66 changes: 0 additions & 66 deletions src/llvm-simdloop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction
STATISTIC(MaxChainLength, "Max length of reduction chain");
STATISTIC(AddChains, "Addition reduction chains");
STATISTIC(MulChains, "Multiply reduction chains");
STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");

#ifndef __clang_gcanalyzer__
#define REMARK(remark) ORE.emit(remark)
Expand All @@ -50,49 +49,6 @@ STATISTIC(TotalContracted, "Total number of multiplies marked for FMA");
#endif
namespace {

/**
* Combine
* ```
* %v0 = fmul ... %a, %b
* %v = fadd contract ... %v0, %c
* ```
* to
* %v0 = fmul contract ... %a, %b
* %v = fadd contract ... %v0, %c
* when `%v0` has no other use
*/

static bool checkCombine(Value *maybeMul, Loop &L, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT
{
auto mulOp = dyn_cast<Instruction>(maybeMul);
if (!mulOp || mulOp->getOpcode() != Instruction::FMul)
return false;
if (!L.contains(mulOp))
return false;
if (!mulOp->hasOneUse()) {
LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp)
<< "fmul had multiple uses " << ore::NV("fmul", mulOp);
});
return false;
}
// On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us.
auto fmf = mulOp->getFastMathFlags();
if (!fmf.allowContract()) {
LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n");
REMARK([&](){
return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp)
<< "marked for fma " << ore::NV("fmul", mulOp);
});
++TotalContracted;
fmf.setAllowContract(true);
mulOp->copyFastMathFlags(fmf);
return true;
}
return false;
}

static unsigned getReduceOpcode(Instruction *J, Instruction *operand) JL_NOTSAFEPOINT
{
switch (J->getOpcode()) {
Expand Down Expand Up @@ -194,28 +150,6 @@ static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop &L, OptimizationRe
});
(*K)->setHasAllowReassoc(true);
(*K)->setHasAllowContract(true);
switch ((*K)->getOpcode()) {
case Instruction::FAdd: {
if (!(*K)->hasAllowContract())
continue;
// (*K)->getOperand(0)->print(dbgs());
// (*K)->getOperand(1)->print(dbgs());
checkCombine((*K)->getOperand(0), L, ORE);
checkCombine((*K)->getOperand(1), L, ORE);
break;
}
case Instruction::FSub: {
if (!(*K)->hasAllowContract())
continue;
// (*K)->getOperand(0)->print(dbgs());
// (*K)->getOperand(1)->print(dbgs());
checkCombine((*K)->getOperand(0), L, ORE);
checkCombine((*K)->getOperand(1), L, ORE);
break;
}
default:
break;
}
if (SE)
SE->forgetValue(*K);
++length;
Expand Down
11 changes: 4 additions & 7 deletions src/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@ struct DemoteFloat16Pass : PassInfoMixin<DemoteFloat16Pass> {
static bool isRequired() { return true; }
};

struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
static bool isRequired() { return true; }
};

struct CombineMulAddPass : PassInfoMixin<CombineMulAddPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT {
// no-op
return PreservedAnalyses::all();
}
struct LateLowerGCPass : PassInfoMixin<LateLowerGCPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT;
static bool isRequired() { return true; }
};

struct AllocOptPass : PassInfoMixin<AllocOptPass> {
Expand Down
1 change: 1 addition & 0 deletions src/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ static void buildCleanupPipeline(ModulePassManager &MPM, PassBuilder *PB, Optimi
if (options.cleanup) {
if (O.getSpeedupLevel() >= 2) {
FunctionPassManager FPM;
JULIA_PASS(FPM.addPass(CombineMulAddPass()));
FPM.addPass(DivRemPairsPass());
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
Expand Down
21 changes: 0 additions & 21 deletions test/llvmpasses/julia-simdloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,6 @@ loopdone:
ret double %nextv
}

; CHECK-LABEL: @simd_test_sub4(
define double @simd_test_sub4(double *%a) {
top:
br label %loop
loop:
%i = phi i64 [0, %top], [%nexti, %loop]
%v = phi double [0.000000e+00, %top], [%nextv, %loop]
%aptr = getelementptr double, double *%a, i64 %i
%aval = load double, double *%aptr
%nextv2 = fmul double %aval, %aval
; CHECK: fmul contract double %aval, %aval
%nextv = fsub double %v, %nextv2
; CHECK: fsub reassoc contract double %v, %nextv2
%nexti = add i64 %i, 1
%done = icmp sgt i64 %nexti, 500
br i1 %done, label %loopdone, label %loop, !llvm.loop !0
loopdone:
ret double %nextv
}

; Tests if we correctly pass through other metadata
; CHECK-LABEL: @disabled(
define i32 @disabled(i32* noalias nocapture %a, i32* noalias nocapture readonly %b, i32 %N) {
Expand All @@ -104,7 +84,6 @@ for.end: ; preds = %for.body
ret i32 %1
}


!0 = distinct !{!0, !"julia.simdloop"}
!1 = distinct !{!1, !"julia.simdloop", !"julia.ivdep"}
!2 = distinct !{!2, !"julia.simdloop", !"julia.ivdep", !3}
Expand Down
64 changes: 64 additions & 0 deletions test/llvmpasses/muladd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
; This file is a part of Julia. License is MIT: https://julialang.org/license

; RUN: opt -enable-new-pm=1 --opaque-pointers=0 --load-pass-plugin=libjulia-codegen%shlibext -passes='CombineMulAdd' -S %s | FileCheck %s

; RUN: opt -enable-new-pm=1 --opaque-pointers=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='CombineMulAdd' -S %s | FileCheck %s


; CHECK-LABEL: @fast_muladd1
define double @fast_muladd1(double %a, double %b, double %c) {
top:
; CHECK: {{contract|fmuladd}}
%v1 = fmul double %a, %b
%v2 = fadd fast double %v1, %c
; CHECK: ret double
ret double %v2
}

; CHECK-LABEL: @fast_mulsub1
define double @fast_mulsub1(double %a, double %b, double %c) {
top:
; CHECK: {{contract|fmuladd}}
%v1 = fmul double %a, %b
%v2 = fsub fast double %v1, %c
; CHECK: ret double
ret double %v2
}

; CHECK-LABEL: @fast_mulsub_vec1
define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
top:
; CHECK: {{contract|fmuladd}}
%v1 = fmul <2 x double> %a, %b
%v2 = fsub fast <2 x double> %c, %v1
; CHECK: ret <2 x double>
ret <2 x double> %v2
}

; COM: Should not mark fmul as contract when multiple uses of fmul exist
; CHECK-LABEL: @slow_muladd1
define double @slow_muladd1(double %a, double %b, double %c) {
top:
; CHECK: %v1 = fmul double %a, %b
%v1 = fmul double %a, %b
; CHECK: %v2 = fadd fast double %v1, %c
%v2 = fadd fast double %v1, %c
; CHECK: %v3 = fadd fast double %v1, %b
%v3 = fadd fast double %v1, %b
; CHECK: %v4 = fadd fast double %v3, %v2
%v4 = fadd fast double %v3, %v2
; CHECK: ret double %v4
ret double %v4
}

; COM: Should not mark fadd->fadd fast as contract
; CHECK-LABEL: @slow_addadd1
define double @slow_addadd1(double %a, double %b, double %c) {
top:
; CHECK: %v1 = fadd double %a, %b
%v1 = fadd double %a, %b
; CHECK: %v2 = fadd fast double %v1, %c
%v2 = fadd fast double %v1, %c
; CHECK: ret double %v2
ret double %v2
}
2 changes: 1 addition & 1 deletion test/llvmpasses/parsing.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
; COM: NewPM-only test, tests for ability to parse Julia passes

; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='module(CPUFeatures,RemoveNI,JuliaMultiVersioning,RemoveJuliaAddrspaces,LowerPTLSPass,function(DemoteFloat16,LateLowerGCFrame,FinalLowerGC,AllocOpt,PropagateJuliaAddrspaces,LowerExcHandlers,GCInvariantVerifier,loop(LowerSIMDLoop,JuliaLICM),GCInvariantVerifier<strong>,GCInvariantVerifier<no-strong>),LowerPTLSPass<imaging>,LowerPTLSPass<no-imaging>,JuliaMultiVersioning<external>,JuliaMultiVersioning<no-external>)' -S %s -o /dev/null
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='module(CPUFeatures,RemoveNI,JuliaMultiVersioning,RemoveJuliaAddrspaces,LowerPTLSPass,function(DemoteFloat16,CombineMulAdd,LateLowerGCFrame,FinalLowerGC,AllocOpt,PropagateJuliaAddrspaces,LowerExcHandlers,GCInvariantVerifier,loop(LowerSIMDLoop,JuliaLICM),GCInvariantVerifier<strong>,GCInvariantVerifier<no-strong>),LowerPTLSPass<imaging>,LowerPTLSPass<no-imaging>,JuliaMultiVersioning<external>,JuliaMultiVersioning<no-external>)' -S %s -o /dev/null
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia<level=3;llvm_only>" -S %s -o /dev/null
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia<level=3;no_llvm_only>" -S %s -o /dev/null
; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia<level=3;no_enable_vector_pipeline>" -S %s -o /dev/null
Expand Down

0 comments on commit 71838b9

Please sign in to comment.