-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathStrengthReduction.cpp
79 lines (71 loc) · 2.64 KB
/
StrengthReduction.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include "llvm/Pass.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstdint>
#include <llvm-15/llvm/IR/Constants.h>
#include <llvm-15/llvm/IR/DerivedTypes.h>
#include <llvm-15/llvm/IR/InstrTypes.h>
#include <llvm-15/llvm/IR/Operator.h>
#include <llvm-15/llvm/Support/MathExtras.h>
#include <llvm-15/llvm/Support/raw_ostream.h>
using namespace llvm;
// Page 240 of Hsu, 2021 helped me with this
namespace {
struct StrengthReductionPass : public PassInfoMixin<StrengthReductionPass> {
// Strength reduction patterns
void rewriteMul2Shl(BinaryOperator &mul) {
auto lhs = dyn_cast<ConstantInt>(mul.getOperand(0));
auto rhs = dyn_cast<ConstantInt>(mul.getOperand(1));
if (!((lhs != nullptr && isPowerOf2_64(lhs->getValue().getZExtValue())) ||
(rhs != nullptr && isPowerOf2_64(rhs->getValue().getZExtValue()))))
return;
auto shift_amt =
(lhs != nullptr && isPowerOf2_64(lhs->getValue().getZExtValue()))
? Log2_64(lhs->getValue().getZExtValue())
: Log2_64(rhs->getValue().getZExtValue());
auto operand =
(lhs != nullptr && isPowerOf2_64(lhs->getValue().getZExtValue()))
? mul.getOperand(1)
: mul.getOperand(0);
// Idk if I'm even handing signed and unsigned types right, just getting the
// pass to work
auto intType = mul.getType();
IRBuilder<> builder(&mul);
auto *shl =
builder.CreateShl(operand, ConstantInt::get(intType, shift_amt));
mul.replaceAllUsesWith(shl);
mul.eraseFromParent();
}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
SmallVector<BinaryOperator *> mul2Rewrite;
llvm::errs() << "let's go\n";
for (auto &F : M) {
for (auto &BB : F.getBasicBlockList()) {
for (auto &I : BB.getInstList()) {
if (auto *mul = dyn_cast<MulOperator>(&I)) {
mul2Rewrite.push_back(&cast<BinaryOperator>(I));
}
}
}
}
for (auto &mul : mul2Rewrite)
rewriteMul2Shl(*mul);
llvm::errs() << "done\n";
return PreservedAnalyses::all();
};
};
} // namespace
extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo
llvmGetPassPluginInfo() {
return {.APIVersion = LLVM_PLUGIN_API_VERSION,
.PluginName = "Strength Reduction Pass for CS6120",
.PluginVersion = "v0.1",
.RegisterPassBuilderCallbacks = [](PassBuilder &PB) {
PB.registerPipelineStartEPCallback(
[](ModulePassManager &MPM, OptimizationLevel Level) {
MPM.addPass(StrengthReductionPass());
});
}};
}