From 04a8bffdf7b1d6e30616561de1734373375cfef5 Mon Sep 17 00:00:00 2001 From: vporpo Date: Tue, 8 Oct 2024 16:18:57 -0700 Subject: [PATCH] [SandboxVec][DAG] Build actual dependencies (#111094) This patch implements actual dependencies checking using BatchAA. This adds memory dep edges between MemDGNodes. --- llvm/include/llvm/SandboxIR/Utils.h | 8 + .../SandboxVectorizer/DependencyGraph.h | 42 ++- .../Vectorize/SandboxVectorizer/Interval.h | 12 +- .../SandboxVectorizer/DependencyGraph.cpp | 133 +++++++- .../SandboxVectorizer/DependencyGraphTest.cpp | 314 +++++++++++++++++- 5 files changed, 485 insertions(+), 24 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/Utils.h b/llvm/include/llvm/SandboxIR/Utils.h index e4156c6af9a2208..4ff4509b7086c50 100644 --- a/llvm/include/llvm/SandboxIR/Utils.h +++ b/llvm/include/llvm/SandboxIR/Utils.h @@ -12,6 +12,7 @@ #ifndef LLVM_SANDBOXIR_UTILS_H #define LLVM_SANDBOXIR_UTILS_H +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -99,6 +100,13 @@ class Utils { return false; return *Diff > 0; } + + /// Equivalent to BatchAA::getModRefInfo(). + static ModRefInfo + aliasAnalysisGetModRefInfo(BatchAAResults &BatchAA, const Instruction *I, + const std::optional &OptLoc) { + return BatchAA.getModRefInfo(cast(I->Val), OptLoc); + } }; } // namespace llvm::sandboxir diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h index 7bc920537faf418..ab49c3aa27143c0 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h @@ -24,6 +24,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/SandboxIR/Instruction.h" #include "llvm/SandboxIR/IntrinsicInst.h" #include "llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h" @@ -47,6 +48,7 @@ class DGNode { // TODO: Use a PointerIntPair for SubclassID and I. /// For isa/dyn_cast etc. DGNodeID SubclassID; + // TODO: Move MemPreds to MemDGNode. /// Memory predecessors. DenseSet MemPreds; @@ -86,13 +88,20 @@ class DGNode { (!(II = dyn_cast(I)) || isMemIntrinsic(II)); } + /// \Returns true if \p I is fence like. It excludes non-mem intrinsics. + static bool isFenceLike(Instruction *I) { + IntrinsicInst *II; + return I->isFenceLike() && + (!(II = dyn_cast(I)) || isMemIntrinsic(II)); + } + /// \Returns true if \p I is a memory dependency candidate instruction. static bool isMemDepNodeCandidate(Instruction *I) { AllocaInst *Alloca; return isMemDepCandidate(I) || ((Alloca = dyn_cast(I)) && Alloca->isUsedWithInAlloca()) || - isStackSaveOrRestoreIntrinsic(I); + isStackSaveOrRestoreIntrinsic(I) || isFenceLike(I); } Instruction *getInstruction() const { return I; } @@ -159,8 +168,37 @@ class DependencyGraph { /// The DAG spans across all instructions in this interval. Interval DAGInterval; + std::unique_ptr BatchAA; + + enum class DependencyType { + RAW, ///> Read After Write + WAW, ///> Write After Write + RAR, ///> Read After Read + WAR, ///> Write After Read + CTRL, ///> Control-related dependencies, like with PHIs/Terminators + OTHER, ///> Currently used for stack related instrs + NONE, ///> No memory/other dependency + }; + /// \Returns the dependency type depending on whether instructions may + /// read/write memory or whether they are some specific opcode-related + /// restrictions. + /// Note: It does not check whether a memory dependency is actually correct, + /// as it won't call AA. Therefore it returns the worst-case dep type. + static DependencyType getRoughDepType(Instruction *FromI, Instruction *ToI); + + // TODO: Implement AABudget. + /// \Returns true if there is a memory/other dependency \p SrcI->DstI. + bool alias(Instruction *SrcI, Instruction *DstI, DependencyType DepType); + + bool hasDep(sandboxir::Instruction *SrcI, sandboxir::Instruction *DstI); + + /// Go through all mem nodes in \p SrcScanRange and try to add dependencies to + /// \p DstN. + void scanAndAddDeps(DGNode &DstN, const Interval &SrcScanRange); + public: - DependencyGraph() {} + DependencyGraph(AAResults &AA) + : BatchAA(std::make_unique(AA)) {} DGNode *getNode(Instruction *I) const { auto It = InstrToNodeMap.find(I); diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h index b05294d70a3e0c4..e0c581f1d50b406 100644 --- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h +++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Interval.h @@ -58,7 +58,7 @@ template class IntervalIterator { } IntervalIterator &operator--() { // `I` is nullptr for end() when To is the BB terminator. - I = I != nullptr ? I->getPrevNode() : R.To; + I = I != nullptr ? I->getPrevNode() : R.bottom(); return *this; } IntervalIterator operator--(int) { @@ -110,14 +110,16 @@ template class Interval { T *bottom() const { return To; } using iterator = IntervalIterator; - using const_iterator = IntervalIterator; iterator begin() { return iterator(From, *this); } iterator end() { return iterator(To != nullptr ? To->getNextNode() : nullptr, *this); } - const_iterator begin() const { return const_iterator(From, *this); } - const_iterator end() const { - return const_iterator(To != nullptr ? To->getNextNode() : nullptr, *this); + iterator begin() const { + return iterator(From, const_cast(*this)); + } + iterator end() const { + return iterator(To != nullptr ? To->getNextNode() : nullptr, + const_cast(*this)); } /// Equality. bool operator==(const Interval &Other) const { diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp index 10da07c24940dde..845fadefc9bf03c 100644 --- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp +++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp @@ -8,8 +8,9 @@ #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/SandboxIR/Utils.h" -using namespace llvm::sandboxir; +namespace llvm::sandboxir { #ifndef NDEBUG void DGNode::print(raw_ostream &OS, bool PrintDeps) const { @@ -50,19 +51,119 @@ MemDGNodeIntervalBuilder::make(const Interval &Instrs, cast(DAG.getNode(MemBotI))); } +DependencyGraph::DependencyType +DependencyGraph::getRoughDepType(Instruction *FromI, Instruction *ToI) { + // TODO: Perhaps compile-time improvement by skipping if neither is mem? + if (FromI->mayWriteToMemory()) { + if (ToI->mayReadFromMemory()) + return DependencyType::RAW; + if (ToI->mayWriteToMemory()) + return DependencyType::WAW; + } else if (FromI->mayReadFromMemory()) { + if (ToI->mayWriteToMemory()) + return DependencyType::WAR; + if (ToI->mayReadFromMemory()) + return DependencyType::RAR; + } + if (isa(FromI) || isa(ToI)) + return DependencyType::CTRL; + if (ToI->isTerminator()) + return DependencyType::CTRL; + if (DGNode::isStackSaveOrRestoreIntrinsic(FromI) || + DGNode::isStackSaveOrRestoreIntrinsic(ToI)) + return DependencyType::OTHER; + return DependencyType::NONE; +} + +static bool isOrdered(Instruction *I) { + auto IsOrdered = [](Instruction *I) { + if (auto *LI = dyn_cast(I)) + return !LI->isUnordered(); + if (auto *SI = dyn_cast(I)) + return !SI->isUnordered(); + if (DGNode::isFenceLike(I)) + return true; + return false; + }; + bool Is = IsOrdered(I); + assert((!Is || DGNode::isMemDepCandidate(I)) && + "An ordered instruction must be a MemDepCandidate!"); + return Is; +} + +bool DependencyGraph::alias(Instruction *SrcI, Instruction *DstI, + DependencyType DepType) { + std::optional DstLocOpt = + Utils::memoryLocationGetOrNone(DstI); + if (!DstLocOpt) + return true; + // Check aliasing. + assert((SrcI->mayReadFromMemory() || SrcI->mayWriteToMemory()) && + "Expected a mem instr"); + // TODO: Check AABudget + ModRefInfo SrcModRef = + isOrdered(SrcI) + ? ModRefInfo::Mod + : Utils::aliasAnalysisGetModRefInfo(*BatchAA, SrcI, *DstLocOpt); + switch (DepType) { + case DependencyType::RAW: + case DependencyType::WAW: + return isModSet(SrcModRef); + case DependencyType::WAR: + return isRefSet(SrcModRef); + default: + llvm_unreachable("Expected only RAW, WAW and WAR!"); + } +} + +bool DependencyGraph::hasDep(Instruction *SrcI, Instruction *DstI) { + DependencyType RoughDepType = getRoughDepType(SrcI, DstI); + switch (RoughDepType) { + case DependencyType::RAR: + return false; + case DependencyType::RAW: + case DependencyType::WAW: + case DependencyType::WAR: + return alias(SrcI, DstI, RoughDepType); + case DependencyType::CTRL: + // Adding actual dep edges from PHIs/to terminator would just create too + // many edges, which would be bad for compile-time. + // So we ignore them in the DAG formation but handle them in the + // scheduler, while sorting the ready list. + return false; + case DependencyType::OTHER: + return true; + case DependencyType::NONE: + return false; + } +} + +void DependencyGraph::scanAndAddDeps(DGNode &DstN, + const Interval &SrcScanRange) { + assert(isa(DstN) && + "DstN is the mem dep destination, so it must be mem"); + Instruction *DstI = DstN.getInstruction(); + // Walk up the instruction chain from ScanRange bottom to top, looking for + // memory instrs that may alias. + for (MemDGNode &SrcN : reverse(SrcScanRange)) { + Instruction *SrcI = SrcN.getInstruction(); + if (hasDep(SrcI, DstI)) + DstN.addMemPred(&SrcN); + } +} + Interval DependencyGraph::extend(ArrayRef Instrs) { if (Instrs.empty()) return {}; - // TODO: For now create a chain of dependencies. - Interval Interval(Instrs); - auto *TopI = Interval.top(); - auto *BotI = Interval.bottom(); - DGNode *LastN = getOrCreateNode(TopI); + + Interval InstrInterval(Instrs); + + DGNode *LastN = getOrCreateNode(InstrInterval.top()); + // Create DGNodes for all instrs in Interval to avoid future Instruction to + // DGNode lookups. MemDGNode *LastMemN = dyn_cast(LastN); - for (Instruction *I = TopI->getNextNode(), *E = BotI->getNextNode(); I != E; - I = I->getNextNode()) { - auto *N = getOrCreateNode(I); - N->addMemPred(LastMemN); + for (Instruction &I : drop_begin(InstrInterval)) { + auto *N = getOrCreateNode(&I); // Build the Mem node chain. if (auto *MemN = dyn_cast(N)) { MemN->setPrevNode(LastMemN); @@ -70,9 +171,15 @@ Interval DependencyGraph::extend(ArrayRef Instrs) { LastMemN->setNextNode(MemN); LastMemN = MemN; } - LastN = N; } - return Interval; + // Create the dependencies. + auto DstRange = MemDGNodeIntervalBuilder::make(InstrInterval, *this); + for (MemDGNode &DstN : drop_begin(DstRange)) { + auto SrcRange = Interval(DstRange.top(), DstN.getPrevNode()); + scanAndAddDeps(DstN, SrcRange); + } + + return InstrInterval; } #ifndef NDEBUG @@ -95,3 +202,5 @@ void DependencyGraph::dump() const { dbgs() << "\n"; } #endif // NDEBUG + +} // namespace llvm::sandboxir diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp index fb8d3780684f808..e2f16919a5cddd3 100644 --- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp +++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp @@ -7,7 +7,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/SandboxIR/Context.h" #include "llvm/SandboxIR/Function.h" #include "llvm/SandboxIR/Instruction.h" @@ -20,6 +26,10 @@ using namespace llvm; struct DependencyGraphTest : public testing::Test { LLVMContext C; std::unique_ptr M; + std::unique_ptr AC; + std::unique_ptr DT; + std::unique_ptr BAA; + std::unique_ptr AA; void parseIR(LLVMContext &C, const char *IR) { SMDiagnostic Err; @@ -27,6 +37,24 @@ struct DependencyGraphTest : public testing::Test { if (!M) Err.print("DependencyGraphTest", errs()); } + + AAResults &getAA(llvm::Function &LLVMF) { + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + AA = std::make_unique(TLI); + AC = std::make_unique(LLVMF); + DT = std::make_unique(LLVMF); + BAA = std::make_unique(M->getDataLayout(), LLVMF, TLI, *AC, + DT.get()); + AA->addAAResult(*BAA); + return *AA; + } + /// \Returns true if there is a dependency: SrcN->DstN. + bool dependency(sandboxir::DGNode *SrcN, sandboxir::DGNode *DstN) { + const auto &Preds = DstN->memPreds(); + auto It = find(Preds, SrcN); + return It != Preds.end(); + } }; TEST_F(DependencyGraphTest, isStackSaveOrRestoreIntrinsic) { @@ -151,6 +179,7 @@ define void @foo(i8 %v1, ptr %ptr) { )IR"); llvm::Function *LLVMF = &*M->getFunction("foo"); sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); auto *BB = &*F->begin(); auto It = BB->begin(); @@ -165,7 +194,7 @@ define void @foo(i8 %v1, ptr %ptr) { auto *Call = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG; + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); DAG.extend({&*BB->begin(), BB->getTerminator()}); EXPECT_TRUE(isa(DAG.getNode(Store))); EXPECT_TRUE(isa(DAG.getNode(Load))); @@ -195,7 +224,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S0 = cast(&*It++); auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG; + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()}); // Check extend(). EXPECT_EQ(Span.top(), &*BB->begin()); @@ -214,7 +243,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { // Check memPreds(). EXPECT_TRUE(N0->memPreds().empty()); EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0)); - EXPECT_THAT(N2->memPreds(), testing::ElementsAre(N1)); + EXPECT_TRUE(N2->memPreds().empty()); } TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) { @@ -236,7 +265,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S1 = cast(&*It++); [[maybe_unused]] auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG; + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto *S0N = cast(DAG.getNode(S0)); @@ -270,7 +299,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { auto *S1 = cast(&*It++); auto *Ret = cast(&*It++); - sandboxir::DependencyGraph DAG; + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); DAG.extend({&*BB->begin(), BB->getTerminator()}); auto *S0N = cast(DAG.getNode(S0)); @@ -313,3 +342,278 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) { getPtrVec(sandboxir::MemDGNodeIntervalBuilder::make({Add0, Add0}, DAG)), testing::ElementsAre()); } + +TEST_F(DependencyGraphTest, AliasingStores) { + parseIR(C, R"IR( +define void @foo(ptr %ptr, i8 %v0, i8 %v1) { + store i8 %v0, ptr %ptr + store i8 %v1, ptr %ptr + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()}); + auto It = BB->begin(); + auto *Store0N = DAG.getNode(cast(&*It++)); + auto *Store1N = DAG.getNode(cast(&*It++)); + auto *RetN = DAG.getNode(cast(&*It++)); + EXPECT_TRUE(Store0N->memPreds().empty()); + EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N)); + EXPECT_TRUE(RetN->memPreds().empty()); +} + +TEST_F(DependencyGraphTest, NonAliasingStores) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) { + store i8 %v0, ptr %ptr0 + store i8 %v1, ptr %ptr1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()}); + auto It = BB->begin(); + auto *Store0N = DAG.getNode(cast(&*It++)); + auto *Store1N = DAG.getNode(cast(&*It++)); + auto *RetN = DAG.getNode(cast(&*It++)); + // We expect no dependencies because the stores don't alias. + EXPECT_TRUE(Store0N->memPreds().empty()); + EXPECT_TRUE(Store1N->memPreds().empty()); + EXPECT_TRUE(RetN->memPreds().empty()); +} + +TEST_F(DependencyGraphTest, VolatileLoads) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) { + %ld0 = load volatile i8, ptr %ptr0 + %ld1 = load volatile i8, ptr %ptr1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()}); + auto It = BB->begin(); + auto *Ld0N = DAG.getNode(cast(&*It++)); + auto *Ld1N = DAG.getNode(cast(&*It++)); + auto *RetN = DAG.getNode(cast(&*It++)); + EXPECT_TRUE(Ld0N->memPreds().empty()); + EXPECT_THAT(Ld1N->memPreds(), testing::ElementsAre(Ld0N)); + EXPECT_TRUE(RetN->memPreds().empty()); +} + +TEST_F(DependencyGraphTest, VolatileSotres) { + parseIR(C, R"IR( +define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) { + store volatile i8 %v, ptr %ptr0 + store volatile i8 %v, ptr %ptr1 + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()}); + auto It = BB->begin(); + auto *Store0N = DAG.getNode(cast(&*It++)); + auto *Store1N = DAG.getNode(cast(&*It++)); + auto *RetN = DAG.getNode(cast(&*It++)); + EXPECT_TRUE(Store0N->memPreds().empty()); + EXPECT_THAT(Store1N->memPreds(), testing::ElementsAre(Store0N)); + EXPECT_TRUE(RetN->memPreds().empty()); +} + +TEST_F(DependencyGraphTest, Call) { + parseIR(C, R"IR( +declare void @bar1() +declare void @bar2() +define void @foo(float %v1, float %v2) { + call void @bar1() + %add = fadd float %v1, %v2 + call void @bar2() + ret void +} +)IR"); + Function *LLVMF = M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); + + auto It = BB->begin(); + auto *Call1N = DAG.getNode(&*It++); + auto *AddN = DAG.getNode(&*It++); + auto *Call2N = DAG.getNode(&*It++); + + EXPECT_THAT(Call1N->memPreds(), testing::ElementsAre()); + EXPECT_THAT(AddN->memPreds(), testing::ElementsAre()); + EXPECT_THAT(Call2N->memPreds(), testing::ElementsAre(Call1N)); +} + +// Check that there is a dependency: stacksave -> alloca -> stackrestore. +TEST_F(DependencyGraphTest, StackSaveRestoreInAlloca) { + parseIR(C, R"IR( +declare ptr @llvm.stacksave() +declare void @llvm.stackrestore(ptr %ptr) + +define void @foo() { + %stack0 = call ptr @llvm.stacksave() ; Should depend on store + %alloca0 = alloca inalloca i8 ; Should depend on stacksave + call void @llvm.stackrestore(ptr %stack0) ; Should depend transiently on %alloca0 + ret void +} +)IR"); + Function *LLVMF = M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); + + auto It = BB->begin(); + auto *StackSaveN = DAG.getNode(&*It++); + auto *AllocaN = DAG.getNode(&*It++); + auto *StackRestoreN = DAG.getNode(&*It++); + + EXPECT_TRUE(dependency(AllocaN, StackRestoreN)); + EXPECT_TRUE(dependency(StackSaveN, AllocaN)); +} + +// Checks that stacksave and stackrestore depend on other mem instrs. +TEST_F(DependencyGraphTest, StackSaveRestoreDependOnOtherMem) { + parseIR(C, R"IR( +declare ptr @llvm.stacksave() +declare void @llvm.stackrestore(ptr %ptr) + +define void @foo(i8 %v0, i8 %v1, ptr %ptr) { + store volatile i8 %v0, ptr %ptr, align 4 + %stack0 = call ptr @llvm.stacksave() ; Should depend on store + call void @llvm.stackrestore(ptr %stack0) ; Should depend on stacksave + store volatile i8 %v1, ptr %ptr, align 4 ; Should depend on stackrestore + ret void +} +)IR"); + Function *LLVMF = M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); + + auto It = BB->begin(); + auto *Store0N = DAG.getNode(&*It++); + auto *StackSaveN = DAG.getNode(&*It++); + auto *StackRestoreN = DAG.getNode(&*It++); + auto *Store1N = DAG.getNode(&*It++); + + EXPECT_TRUE(dependency(Store0N, StackSaveN)); + EXPECT_TRUE(dependency(StackSaveN, StackRestoreN)); + EXPECT_TRUE(dependency(StackRestoreN, Store1N)); +} + +// Make sure there is a dependency between a stackrestore and an alloca. +TEST_F(DependencyGraphTest, StackRestoreAndInAlloca) { + parseIR(C, R"IR( +declare void @llvm.stackrestore(ptr %ptr) + +define void @foo(ptr %ptr) { + call void @llvm.stackrestore(ptr %ptr) + %alloca0 = alloca inalloca i8 ; Should depend on stackrestore + ret void +} +)IR"); + Function *LLVMF = M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); + + auto It = BB->begin(); + auto *StackRestoreN = DAG.getNode(&*It++); + auto *AllocaN = DAG.getNode(&*It++); + + EXPECT_TRUE(dependency(StackRestoreN, AllocaN)); +} + +// Make sure there is a dependency between the alloca and stacksave +TEST_F(DependencyGraphTest, StackSaveAndInAlloca) { + parseIR(C, R"IR( +declare ptr @llvm.stacksave() + +define void @foo(ptr %ptr) { + %alloca0 = alloca inalloca i8 ; Should depend on stackrestore + %stack0 = call ptr @llvm.stacksave() ; Should depend on alloca0 + ret void +} +)IR"); + Function *LLVMF = M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); + + auto It = BB->begin(); + auto *AllocaN = DAG.getNode(&*It++); + auto *StackSaveN = DAG.getNode(&*It++); + + EXPECT_TRUE(dependency(AllocaN, StackSaveN)); +} + +// A non-InAlloca in a stacksave-stackrestore region does not need extra +// dependencies. +TEST_F(DependencyGraphTest, StackSaveRestoreNoInAlloca) { + parseIR(C, R"IR( +declare ptr @llvm.stacksave() +declare void @llvm.stackrestore(ptr %ptr) +declare void @use(ptr %ptr) + +define void @foo() { + %stack = call ptr @llvm.stacksave() + %alloca1 = alloca i8 ; No dependency + call void @llvm.stackrestore(ptr %stack) + ret void +} +)IR"); + Function *LLVMF = M->getFunction("foo"); + + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + auto *BB = &*F->begin(); + + sandboxir::DependencyGraph DAG(getAA(*LLVMF)); + DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()}); + + auto It = BB->begin(); + auto *StackSaveN = DAG.getNode(&*It++); + auto *AllocaN = DAG.getNode(&*It++); + auto *StackRestoreN = DAG.getNode(&*It++); + + EXPECT_FALSE(dependency(StackSaveN, AllocaN)); + EXPECT_FALSE(dependency(AllocaN, StackRestoreN)); +}