Skip to content

Commit

Permalink
Speed up and fix type analysis merges
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 18, 2023
1 parent 9a087e6 commit fd06730
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 200 deletions.
119 changes: 30 additions & 89 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,34 +436,8 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA,
analysis[Val] = analysis[CE->getOperand(0)];
return;
}
if (CE->getOpcode() == Instruction::GetElementPtr &&
llvm::all_of(CE->operand_values(),
[](Value *v) { return isa<ConstantInt>(v); })) {
auto g2 = cast<GetElementPtrInst>(CE->getAsInstruction());
APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
g2->accumulateConstantOffset(DL, ai);
// Using destructor rather than eraseFromParent
// as g2 has no parent
delete g2;

int off = (int)ai.getLimitedValue();

// TODO also allow negative offsets
if (off < 0) {
analysis[Val] = TypeTree(BaseType::Pointer).Only(-1, nullptr);
return;
}

getConstantAnalysis(CE->getOperand(0), TA, analysis);
auto gepData0 = analysis[CE->getOperand(0)].Data0();

TypeTree result =
gepData0
.ShiftIndices(DL, /*init offset*/ off, /*max size*/ -1,
/*new offset*/ 0)
.Only(-1, nullptr);
result.insert({-1}, BaseType::Pointer);
analysis[Val] = result;
if (CE->getOpcode() == Instruction::GetElementPtr) {
TA.visitGEPOperator(*cast<GEPOperator>(CE));
return;
}

Expand Down Expand Up @@ -678,8 +652,11 @@ void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) {

// Attempt to update the underlying analysis
bool LegalOr = true;
if (analysis.find(Val) == analysis.end() && isa<Constant>(Val))
getConstantAnalysis(cast<Constant>(Val), *this, analysis);
if (analysis.find(Val) == analysis.end() && isa<Constant>(Val)) {
if (!isa<ConstantExpr>(Val) ||
cast<ConstantExpr>(Val)->getOpcode() != Instruction::GetElementPtr)
getConstantAnalysis(cast<Constant>(Val), *this, analysis);
}

TypeTree prev = analysis[Val];

Expand Down Expand Up @@ -1200,53 +1177,8 @@ void TypeAnalyzer::visitConstantExpr(ConstantExpr &CE) {
updateAnalysis(CE.getOperand(0), getAnalysis(&CE), &CE);
return;
}
if (CE.getOpcode() == Instruction::GetElementPtr &&
llvm::all_of(CE.operand_values(),
[](Value *v) { return isa<ConstantInt>(v); })) {

auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
auto g2 = cast<GetElementPtrInst>(CE.getAsInstruction());
APInt ai(DL.getIndexSizeInBits(g2->getPointerAddressSpace()), 0);
g2->accumulateConstantOffset(DL, ai);
// Using destructor rather than eraseFromParent
// as g2 has no parent

int maxSize = -1;
if (cast<ConstantInt>(CE.getOperand(1))->getLimitedValue() == 0) {
maxSize = DL.getTypeAllocSizeInBits(g2->getResultElementType()) / 8;
}

delete g2;

int off = (int)ai.getLimitedValue();

// TODO also allow negative offsets
if (off < 0) {
if (direction & DOWN)
updateAnalysis(&CE, TypeTree(BaseType::Pointer).Only(-1, nullptr), &CE);
if (direction & UP)
updateAnalysis(CE.getOperand(0),
TypeTree(BaseType::Pointer).Only(-1, nullptr), &CE);
return;
}

if (direction & DOWN) {
auto gepData0 = getAnalysis(CE.getOperand(0)).Data0();
TypeTree result =
gepData0.ShiftIndices(DL, /*init offset*/ off,
/*max size*/ maxSize, /*newoffset*/ 0);
result.insert({}, BaseType::Pointer);
updateAnalysis(&CE, result.Only(-1, nullptr), &CE);
}
if (direction & UP) {
auto pointerData0 = getAnalysis(&CE).Data0();

TypeTree result =
pointerData0.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1,
/*new offset*/ off);
result.insert({}, BaseType::Pointer);
updateAnalysis(CE.getOperand(0), result.Only(-1, nullptr), &CE);
}
if (CE.getOpcode() == Instruction::GetElementPtr) {
visitGEPOperator(*cast<GEPOperator>(&CE));
return;
}
auto I = CE.getAsInstruction();
Expand Down Expand Up @@ -1375,14 +1307,20 @@ std::set<SmallVector<T, 4>> getSet(ArrayRef<std::set<T>> todo, size_t idx) {
}

void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
visitGEPOperator(*cast<GEPOperator>(&gep));
}

void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) {
auto inst = dyn_cast<Instruction>(&gep);
if (isa<UndefValue>(gep.getPointerOperand())) {
updateAnalysis(&gep, TypeTree(BaseType::Anything).Only(-1, &gep), &gep);
updateAnalysis(&gep, TypeTree(BaseType::Anything).Only(-1, inst), &gep);
return;
}
if (isa<ConstantPointerNull>(gep.getPointerOperand())) {
bool nonZero = false;
bool legal = true;
for (auto &ind : gep.indices()) {
for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
auto ind = I->get();
if (auto CI = dyn_cast<ConstantInt>(ind)) {
if (!CI->isZero()) {
nonZero = true;
Expand All @@ -1397,12 +1335,12 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
break;
}
if (legal && nonZero) {
updateAnalysis(&gep, TypeTree(BaseType::Integer).Only(-1, &gep), &gep);
updateAnalysis(&gep, TypeTree(BaseType::Integer).Only(-1, inst), &gep);
return;
}
}

if (gep.indices().begin() == gep.indices().end()) {
if (gep.idx_begin() == gep.idx_end()) {
if (direction & DOWN)
updateAnalysis(&gep, getAnalysis(gep.getPointerOperand()), &gep);
if (direction & UP)
Expand All @@ -1428,8 +1366,9 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
if (gep.isInBounds() || (!EnzymeStrictAliasing &&
pointerAnalysis.Inner0() == BaseType::Pointer &&
getAnalysis(&gep).Inner0() == BaseType::Pointer)) {
for (auto &ind : gep.indices()) {
updateAnalysis(ind, TypeTree(BaseType::Integer).Only(-1, &gep), &gep);
for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
auto ind = I->get();
updateAnalysis(ind, TypeTree(BaseType::Integer).Only(-1, inst), &gep);
}
}
}
Expand All @@ -1439,7 +1378,8 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
bool pointerPropagate = gep.isInBounds();
if (!pointerPropagate) {
bool allIntegral = true;
for (auto &ind : gep.indices()) {
for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
auto ind = I->get();
auto CT = getAnalysis(ind).Inner0();
if (CT != BaseType::Integer && CT != BaseType::Anything) {
allIntegral = false;
Expand Down Expand Up @@ -1468,16 +1408,17 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
}
}
updateAnalysis(&gep, keepMinus, &gep);
updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1, &gep),
updateAnalysis(&gep, TypeTree(pointerAnalysis.Inner0()).Only(-1, inst),
&gep);
}
if (direction & UP)
updateAnalysis(gep.getPointerOperand(),
TypeTree(getAnalysis(&gep).Inner0()).Only(-1, &gep), &gep);
TypeTree(getAnalysis(&gep).Inner0()).Only(-1, inst), &gep);

SmallVector<std::set<Value *>, 4> idnext;

for (auto &a : gep.indices()) {
for (auto I = gep.idx_begin(), E = gep.idx_end(); I != E; I++) {
auto a = I->get();
auto iset = fntypeinfo.knownIntegralValues(a, DT, intseen, SE);
std::set<Value *> vset;
for (auto i : iset) {
Expand Down Expand Up @@ -1546,9 +1487,9 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
seenIdx = true;
}
if (direction & DOWN)
updateAnalysis(&gep, downTree.Only(-1, &gep), &gep);
updateAnalysis(&gep, downTree.Only(-1, inst), &gep);
if (direction & UP)
updateAnalysis(gep.getPointerOperand(), upTree.Only(-1, &gep), &gep);
updateAnalysis(gep.getPointerOperand(), upTree.Only(-1, inst), &gep);
}

void TypeAnalyzer::visitPHINode(PHINode &phi) {
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class TypeAnalyzer : public llvm::InstVisitor<TypeAnalyzer> {

void visitGetElementPtrInst(llvm::GetElementPtrInst &gep);

void visitGEPOperator(llvm::GEPOperator &gep);

void visitPHINode(llvm::PHINode &phi);

void visitTruncInst(llvm::TruncInst &I);
Expand Down
Loading

0 comments on commit fd06730

Please sign in to comment.