Skip to content

Commit

Permalink
WIP: activity fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 13, 2024
1 parent 068ad9c commit a09f308
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 52 deletions.
106 changes: 56 additions & 50 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,12 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR,
std::unique_ptr<ActivityAnalyzer>(new ActivityAnalyzer(*this, UP));
UpHypothesis->ConstantInstructions.insert(I);
assert(directions & UP);
if (UpHypothesis->isInstructionInactiveFromOrigin(TR, I, false)) {

if (Value *invalidOrigin = UpHypothesis->isInstructionPossibleActiveFromOrigin(TR, I, false)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateInstIfInactiveValue[invalidOrigin].insert(I);
}
} else {
if (EnzymePrintActivity)
llvm::errs() << " constant instruction from origin "
"instruction "
Expand All @@ -979,15 +984,6 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR,
if (DownHypothesis)
insertConstantsFrom(TR, *DownHypothesis);
return true;
} else if (directions == 3) {
if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<BinaryOperator>(I)) {
for (auto &op : I->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateInstIfInactiveValue[op].insert(I);
}
}
}
}
}

Expand Down Expand Up @@ -1776,12 +1772,20 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
if (auto inst = dyn_cast<Instruction>(Val)) {
if (!inst->mayReadFromMemory() && !isa<AllocaInst>(Val)) {
if (directions == UP && !isa<PHINode>(inst)) {
if (isInstructionInactiveFromOrigin(TR, inst, true)) {
if (Value *invalidOrigin = isInstructionPossibleActiveFromOrigin(TR, inst, true)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValIfInactiveValue[invalidOrigin].insert(Val);
}
} else {
InsertConstantValue(TR, Val);
return true;
}
} else {
if (UpHypothesis->isInstructionInactiveFromOrigin(TR, inst, true)) {
if (auto invalidOrigin = UpHypothesis->isInstructionPossibleActiveFromOrigin(TR, inst, true)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValIfInactiveValue[invalidOrigin].insert(Val);
}
} else {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
Expand Down Expand Up @@ -1818,14 +1822,17 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
new ActivityAnalyzer(*this, directions));
Hypothesis->ActiveValues.insert(Val);
if (auto VI = dyn_cast<Instruction>(Val)) {
if (UpHypothesis->isInstructionInactiveFromOrigin(TR, VI, true)) {
Hypothesis->DeducingPointers.insert(Val);
if (EnzymePrintActivity)
llvm::errs() << " constant instruction hypothesis: " << *VI << "\n";
} else {
if (auto invalidOrigin = UpHypothesis->isInstructionPossibleActiveFromOrigin(TR, VI, true)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateInstIfInactiveValue[invalidOrigin].insert(Val);
}
if (EnzymePrintActivity)
llvm::errs() << " cannot show constant instruction hypothesis: "
<< *VI << "\n";
} else {
Hypothesis->DeducingPointers.insert(Val);
if (EnzymePrintActivity)
llvm::errs() << " constant instruction hypothesis: " << *VI << "\n";
}
}

Expand Down Expand Up @@ -2152,9 +2159,13 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
if (DeducingPointers.size() == 0)
UpHypothesis->insertConstantsFrom(TR, *Hypothesis);
assert(directions & UP);
bool ActiveUp =
!isa<Argument>(Val) &&
!UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true);

bool ActiveUp = !isa<Argument>(Val);
if (auto invalidOrigin = UpHypothesis->isInstructionPossibleActiveFromOrigin(TR, Val, true)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[invalidOrigin].insert(Val);
}
}

// Case b) can occur if:
// 1) this memory is used as part of an active return
Expand Down Expand Up @@ -2260,34 +2271,24 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
UpHypothesis =
std::unique_ptr<ActivityAnalyzer>(new ActivityAnalyzer(*this, UP));
if (directions == UP && !isa<PHINode>(Val)) {
if (isInstructionInactiveFromOrigin(TR, Val, true)) {
if (auto invalidOrigin = isInstructionPossibleActiveFromOrigin(TR, Val, true)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[invalidOrigin].insert(Val);
}
} else {
InsertConstantValue(TR, Val);
return true;
} else if (auto I = dyn_cast<Instruction>(Val)) {
if (directions == 3) {
for (auto &op : I->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[op].insert(I);
}
}
}
}
} else {
UpHypothesis->ConstantValues.insert(Val);
if (UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true)) {
if (auto invalidOrigin = UpHypothesis->isInstructionPossibleActiveFromOrigin(TR, Val, true)) {
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[invalidOrigin].insert(Val);
}
} else {
insertConstantsFrom(TR, *UpHypothesis);
InsertConstantValue(TR, Val);
return true;
} else if (auto I = dyn_cast<Instruction>(Val)) {
if (directions == 3) {
for (auto &op : I->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[op].insert(I);
}
}
}
}
}
}
Expand Down Expand Up @@ -2332,7 +2333,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
}

/// Is the instruction guaranteed to be inactive because of its operands
bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
Value ActivityAnalyzer::isInstructionPossibleActiveFromOrigin(TypeResults const &TR,
llvm::Value *val,
bool considerValue) {
// Must be an analyzer only searching up
Expand All @@ -2345,7 +2346,7 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
llvm::errs() << "unknown pointer source: " << *val << "\n";
assert(0 && "unknown pointer source");
llvm_unreachable("unknown pointer source");
return false;
return nullptr;
}

Instruction *inst = cast<Instruction>(val);
Expand All @@ -2359,27 +2360,32 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
if (EnzymePrintActivity)
llvm::errs() << " constant instruction from known cpuid instruction "
<< *inst << "\n";
return true;
return nullptr;
}
}
}

if (auto SI = dyn_cast<StoreInst>(inst)) {
// if either src or dst is inactive, there cannot be a transfer of active
// values and thus the store is inactive
if (isConstantValue(TR, SI->getValueOperand()) ||
isConstantValue(TR, SI->getPointerOperand())) {
if (EnzymePrintActivity)
llvm::errs() << " constant instruction as store operand is inactive "
for (auto V : {SI->getValueOperand(), SI->getPointerOperand()}) {
if (isConstantValue(TR, V)) {
if (EnzymePrintActivity)
llvm::errs() << " constant instruction as store operand (" << *V << ") is inactive "
<< *inst << "\n";
return true;
return nullptr;
}
}
// TODO to be more precise for the recompute analysis we should return both operands as
// either would allow the analysis to be strengthened
return SI->getValueOperand();
}

if (!considerValue) {
if (auto IEI = dyn_cast<InsertElementInst>(inst)) {
if ((!TR.anyFloat(IEI->getOperand(0)) ||
isConstantValue(TR, IEI->getOperand(0))) &&
for (auto V : {IEI->getOperand(0), IEI->getOperand(1)}) {
if ((!TR.anyFloat(V) ||
isConstantValue(TR, V)) &&
(!TR.anyFloat(IEI->getOperand(1)) ||
isConstantValue(TR, IEI->getOperand(1)))) {
if (EnzymePrintActivity)
Expand Down
9 changes: 7 additions & 2 deletions enzyme/Enzyme/ActivityAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ class ActivityAnalyzer {
void insertConstantsFrom(TypeResults const &TR,
ActivityAnalyzer &Hypothesis) {
for (auto I : Hypothesis.ConstantInstructions) {
if (EnzymePrintActivity)
llvm::errs() << " inserting constant instruction " << *I << " " << " due to hypothesis\n";
InsertConstantInstruction(TR, I);
}
for (auto V : Hypothesis.ConstantValues) {
if (EnzymePrintActivity)
llvm::errs() << " inserting constant value " << *V << " " << " due to hypothesis\n";
InsertConstantValue(TR, V);
}
}
Expand Down Expand Up @@ -223,10 +227,11 @@ class ActivityAnalyzer {
/// Is the use of value val as an argument of call CI known to be inactive
bool isFunctionArgumentConstant(llvm::CallInst *CI, llvm::Value *val);

/// Is the instruction guaranteed to be inactive because of its operands.
/// Is the instruction guaranteed to be inactive because of its operands, return
/// null. Otherwise return the value which causes this assumption to break.
/// \p considerValue specifies that we ask whether the returned value, rather
/// than the instruction itself is active.
bool isInstructionInactiveFromOrigin(TypeResults const &TR, llvm::Value *val,
Value* isInstructionPossibleActiveFromOrigin(TypeResults const &TR, llvm::Value *val,
bool considerValue);

public:
Expand Down
30 changes: 30 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/mallocuse.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,early-cse,sroa,instsimplify,%simplifycfg,adce)" -enzyme-preopt=false -opaque-pointers -S | FileCheck %s; fi

declare ptr @__enzyme_virtualreverse(...)

declare ptr @malloc(i64)

define void @my_model.fullgrad1() {
%z = call ptr (...) @__enzyme_virtualreverse(ptr nonnull @_take)
ret void
}

define double @_take(ptr %a0, i1 %a1) {
%a3 = tail call ptr @malloc(i64 10)
%a4 = tail call ptr @malloc(i64 10)
%a5 = ptrtoint ptr %a4 to i64
%a6 = or i64 %a5, 1
%a7 = inttoptr i64 %a6 to ptr
%a8 = load double, ptr %a7, align 8
store double %a8, ptr %a0, align 8
br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next

.lr.ph1.peel.next: ; preds = %2
%.pre = load double, ptr %a4, align 8
ret double %.pre

.lr.ph: ; preds = %.lr.ph, %2
%a9 = load double, ptr %a3, align 4
store double %a9, ptr %a4, align 8
br label %.lr.ph
}

0 comments on commit a09f308

Please sign in to comment.