From 0d7c720e67a0213565f0e7c141c4ffa1b91fc5b9 Mon Sep 17 00:00:00 2001 From: Mircea Trofin Date: Tue, 20 Aug 2024 21:09:16 -0700 Subject: [PATCH] [ctx_prof] API to get the instrumentation of a BB --- llvm/include/llvm/Analysis/CtxProfAnalysis.h | 5 +++++ llvm/lib/Analysis/CtxProfAnalysis.cpp | 7 ++++++ .../Analysis/CtxProfAnalysisTest.cpp | 22 +++++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h index 23abcbe2c6e9d2e..0b4dd8ae3a0dc70 100644 --- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h +++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h @@ -95,7 +95,12 @@ class CtxProfAnalysis : public AnalysisInfoMixin { PGOContextualProfile run(Module &M, ModuleAnalysisManager &MAM); + /// Get the instruction instrumenting a callsite, or nullptr if that cannot be + /// found. static InstrProfCallsite *getCallsiteInstrumentation(CallBase &CB); + + /// Get the instruction instrumenting a BB, or nullptr if not present. + static InstrProfIncrementInst *getBBInstrumentation(BasicBlock &BB); }; class CtxProfAnalysisPrinterPass diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp index ceebb2cf06d235b..3fc1bc34afb97e8 100644 --- a/llvm/lib/Analysis/CtxProfAnalysis.cpp +++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp @@ -202,6 +202,13 @@ InstrProfCallsite *CtxProfAnalysis::getCallsiteInstrumentation(CallBase &CB) { return nullptr; } +InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) { + for (auto &I : BB) + if (auto *Incr = dyn_cast(&I)) + return Incr; + return nullptr; +} + static void preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles, function_ref Visitor) { diff --git a/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp index 5f9bf3ec540eb32..fbe3a6e45109cc9 100644 --- a/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp +++ b/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp @@ -132,4 +132,26 @@ TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) { EXPECT_EQ(IndIns, nullptr); } +TEST_F(CtxProfAnalysisTest, GetBBIDTest) { + ModulePassManager MPM; + MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF)); + EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved()); + auto *F = M->getFunction("foo"); + ASSERT_NE(F, nullptr); + std::map BBNameAndID; + + for (auto &BB : *F) { + auto *Ins = CtxProfAnalysis::getBBInstrumentation(BB); + if (Ins) + BBNameAndID[BB.getName().str()] = + static_cast(Ins->getIndex()->getZExtValue()); + else + BBNameAndID[BB.getName().str()] = -1; + } + + EXPECT_THAT(BBNameAndID, + testing::UnorderedElementsAre( + testing::Pair("", 0), testing::Pair("yes", 1), + testing::Pair("no", -1), testing::Pair("exit", -1))); +} } // namespace