diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h index 0b4dd8ae3a0dc70..10aef6f6067b6f0 100644 --- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h +++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h @@ -73,6 +73,12 @@ class PGOContextualProfile { return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++; } + using ConstVisitor = function_ref; + using Visitor = function_ref; + + void update(Visitor, const Function *F = nullptr); + void visit(ConstVisitor, const Function *F = nullptr) const; + const CtxProfFlatProfile flatten() const; bool invalidate(Module &, const PreservedAnalyses &PA, @@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin { class CtxProfAnalysisPrinterPass : public PassInfoMixin { - raw_ostream &OS; - public: - explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {} + enum class PrintMode { Everything, JSON }; + explicit CtxProfAnalysisPrinterPass(raw_ostream &OS, + PrintMode Mode = PrintMode::Everything) + : OS(OS), Mode(Mode) {} PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); static bool isRequired() { return true; } + +private: + raw_ostream &OS; + const PrintMode Mode; }; /// Assign a GUID to functions as metadata. GUID calculation takes linkage into diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h index b45c89cadb0fde6..71a96e0671c2f1b 100644 --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -1535,6 +1535,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase { ConstantInt *getNumCounters() const; // The index of the counter that this instruction acts on. ConstantInt *getIndex() const; + void setIndex(uint32_t Idx); }; /// This represents the llvm.instrprof.cover intrinsic. @@ -1585,6 +1586,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase { return isa(V) && classof(cast(V)); } Value *getCallee() const; + void setCallee(Value *Callee); }; /// This represents the llvm.instrprof.timestamp intrinsic. diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h index 190deaeeacd085f..f7f88966f7573f9 100644 --- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h +++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h @@ -57,9 +57,25 @@ class PGOCtxProfContext final { GlobalValue::GUID guid() const { return GUID; } const SmallVectorImpl &counters() const { return Counters; } + SmallVectorImpl &counters() { return Counters; } + + uint64_t getEntrycount() const { + assert(!Counters.empty() && + "Functions are expected to have at their entry BB instrumented, so " + "there should always be at least 1 counter."); + return Counters[0]; + } + const CallsiteMapTy &callsites() const { return Callsites; } CallsiteMapTy &callsites() { return Callsites; } + void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) { + auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy()); + Iter->second.emplace(Other.guid(), std::move(Other)); + } + + void resizeCounters(uint32_t Size) { Counters.resize(Size); } + bool hasCallsite(uint32_t I) const { return Callsites.find(I) != Callsites.end(); } @@ -68,6 +84,12 @@ class PGOCtxProfContext final { assert(hasCallsite(I) && "Callsite not found"); return Callsites.find(I)->second; } + + CallTargetMapTy &callsite(uint32_t I) { + assert(hasCallsite(I) && "Callsite not found"); + return Callsites.find(I)->second; + } + void getContainedGuids(DenseSet &Guids) const; }; diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h index 385831f457038d4..58af26f31417b00 100644 --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -14,6 +14,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H +#include "llvm/Analysis/CtxProfAnalysis.h" namespace llvm { template class ArrayRef; class Constant; @@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee, CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee, MDNode *BranchWeights = nullptr); +CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee, + PGOContextualProfile &CtxProf); + /// This is similar to `promoteCallWithIfThenElse` except that the condition to /// promote a virtual call is that \p VPtr is the same as any of \p /// AddressPoints. diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp index 3fc1bc34afb97e8..2cd3f2114397e5b 100644 --- a/llvm/lib/Analysis/CtxProfAnalysis.cpp +++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp @@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M, return PreservedAnalyses::all(); } - OS << "Function Info:\n"; - for (const auto &[Guid, FuncInfo] : C.FuncInfo) - OS << Guid << " : " << FuncInfo.Name - << ". MaxCounterID: " << FuncInfo.NextCounterIndex - << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n"; + if (Mode == PrintMode::Everything) { + OS << "Function Info:\n"; + for (const auto &[Guid, FuncInfo] : C.FuncInfo) + OS << Guid << " : " << FuncInfo.Name + << ". MaxCounterID: " << FuncInfo.NextCounterIndex + << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n"; + } const auto JSONed = ::llvm::json::toJSON(C.profiles()); - OS << "\nCurrent Profile:\n"; + if (Mode == PrintMode::Everything) + OS << "\nCurrent Profile:\n"; OS << formatv("{0:2}", JSONed); + if (Mode == PrintMode::JSON) + return PreservedAnalyses::all(); + OS << "\n"; OS << "\nFlat Profile:\n"; auto Flat = C.flatten(); @@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) { return nullptr; } -static void -preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles, - function_ref Visitor) { - std::function Traverser = - [&](const auto &Ctx) { - Visitor(Ctx); - for (const auto &[_, SubCtxSet] : Ctx.callsites()) - for (const auto &[__, Subctx] : SubCtxSet) - Traverser(Subctx); - }; - for (const auto &[_, P] : Profiles) +template +static void preorderVisit(ProfilesTy &Profiles, + function_ref Visitor, + GlobalValue::GUID Match = 0) { + std::function Traverser = [&](auto &Ctx) { + if (!Match || Ctx.guid() == Match) + Visitor(Ctx); + for (auto &[_, SubCtxSet] : Ctx.callsites()) + for (auto &[__, Subctx] : SubCtxSet) + Traverser(Subctx); + }; + for (auto &[_, P] : Profiles) Traverser(P); } +void PGOContextualProfile::update(Visitor V, const Function *F) { + GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U; + preorderVisit( + *Profiles, V, G); +} + +void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const { + GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U; + preorderVisit(*Profiles, V, G); +} + const CtxProfFlatProfile PGOContextualProfile::flatten() const { assert(Profiles.has_value()); CtxProfFlatProfile Flat; - preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) { - auto [It, Ins] = Flat.insert({Ctx.guid(), {}}); - if (Ins) { - llvm::append_range(It->second, Ctx.counters()); - return; - } - assert(It->second.size() == Ctx.counters().size() && - "All contexts corresponding to a function should have the exact " - "same number of counters."); - for (size_t I = 0, E = It->second.size(); I < E; ++I) - It->second[I] += Ctx.counters()[I]; - }); + preorderVisit( + *Profiles, [&](const PGOCtxProfContext &Ctx) { + auto [It, Ins] = Flat.insert({Ctx.guid(), {}}); + if (Ins) { + llvm::append_range(It->second, Ctx.counters()); + return; + } + assert(It->second.size() == Ctx.counters().size() && + "All contexts corresponding to a function should have the exact " + "same number of counters."); + for (size_t I = 0, E = It->second.size(); I < E; ++I) + It->second[I] += Ctx.counters()[I]; + }); return Flat; } diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp index db3b0196f66fd69..0eadd0f980c15b3 100644 --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const { return cast(const_cast(getArgOperand(3))); } +void InstrProfCntrInstBase::setIndex(uint32_t Idx) { + assert(isa(this)); + setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx)); +} + Value *InstrProfIncrementInst::getStep() const { if (InstrProfIncrementInstStep::classof(this)) { return const_cast(getArgOperand(4)); @@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const { return nullptr; } +void InstrProfCallsite::setCallee(Value *V) { + assert(isa(this)); + setArgOperand(4, V); +} + std::optional ConstrainedFPIntrinsic::getRoundingMode() const { unsigned NumOperands = arg_size(); Metadata *MD = nullptr; diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 90dc727cde16d77..71e888c1970f95f 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -13,13 +13,16 @@ #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constant.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/ProfileData/PGOCtxProfReader.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -572,6 +575,89 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee, return promoteCall(NewInst, Callee); } +CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee, + PGOContextualProfile &CtxProf) { + assert(CB.isIndirectCall()); + if (!CtxProf.isFunctionKnown(Callee)) + return nullptr; + auto &Caller = *CB.getParent()->getParent(); + auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB); + if (!CSInstr) + return nullptr; + const auto CSIndex = CSInstr->getIndex()->getZExtValue(); + + CallBase &DirectCall = promoteCall( + versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee); + CSInstr->moveBefore(&CB); + const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller); + auto *NewCSInstr = cast(CSInstr->clone()); + NewCSInstr->setIndex(NewCSID); + NewCSInstr->setCallee(&Callee); + NewCSInstr->insertBefore(&DirectCall); + auto &DirectBB = *DirectCall.getParent(); + auto &IndirectBB = *CB.getParent(); + + assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) && + "The ICP direct BB is new, it shouldn't have instrumentation"); + assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) && + "The ICP indirect BB is new, it shouldn't have instrumentation"); + + // Make the 2 new BBs have counters. + const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller); + const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller); + const uint32_t NewCountersSize = IndirectID + 1; + auto *EntryBBIns = + CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock()); + auto *DirectBBIns = cast(EntryBBIns->clone()); + DirectBBIns->setIndex(DirectID); + DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt()); + + auto *IndirectBBIns = cast(EntryBBIns->clone()); + IndirectBBIns->setIndex(IndirectID); + IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt()); + + const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee); + + auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) { + assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller)); + assert(NewCountersSize - 2 == Ctx.counters().size()); + // Regardless what next, all the ctx-es belonging to a function must have + // the same size counters. + Ctx.resizeCounters(NewCountersSize); + + // Maybe in this context, the indirect callsite wasn't observed at all + if (!Ctx.hasCallsite(CSIndex)) + return; + auto &CSData = Ctx.callsite(CSIndex); + auto It = CSData.find(CalleeGUID); + + // Maybe we did notice the indirect callsite, but to other targets. + if (It == CSData.end()) + return; + + assert(CalleeGUID == It->second.guid()); + + uint32_t DirectCount = It->second.getEntrycount(); + uint32_t TotalCount = 0; + for (const auto &[_, V] : CSData) + TotalCount += V.getEntrycount(); + assert(TotalCount >= DirectCount); + uint32_t IndirectCount = TotalCount - DirectCount; + // The ICP's effect is as-if the direct BB would have been taken DirectCount + // times, and the indirect BB, IndirectCount times + Ctx.counters()[DirectID] = DirectCount; + Ctx.counters()[IndirectID] = IndirectCount; + + // This particular indirect target needs to be moved to this caller under + // the newly-allocated callsite index. + assert(Ctx.callsites().count(NewCSID) == 0); + Ctx.ingestContext(NewCSID, std::move(It->second)); + CSData.erase(CalleeGUID); + }; + CtxProf.update(ProfileUpdater, &Caller); + return &DirectCall; +} + CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, Function *Callee, ArrayRef AddressPoints, diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp index 2d457eb3b678aac..aff603de2a2bd55 100644 --- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -14,7 +15,12 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/ProfileData/PGOCtxProfReader.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/SupportHelpers.h" #include "gtest/gtest.h" using namespace llvm; @@ -456,3 +462,175 @@ declare void @_ZN5Base35func3Ev(ptr) // 1 call instruction from the entry block. EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4); } + +using namespace llvm::ctx_profile; + +class ContextManager final { + std::vector> Nodes; + std::map Roots; + +public: + ContextNode *createNode(GUID Guid, uint32_t NrCounters, uint32_t NrCallsites, + ContextNode *Next = nullptr) { + auto AllocSize = ContextNode::getAllocSize(NrCounters, NrCallsites); + auto *Mem = Nodes.emplace_back(std::make_unique(AllocSize)).get(); + std::memset(Mem, 0, AllocSize); + auto *Ret = new (Mem) ContextNode(Guid, NrCounters, NrCallsites, Next); + return Ret; + } +}; + +TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +define i32 @testfunc1(ptr %d) !guid !0 { + call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0) + call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr %d) + %call = call i32 %d() + ret i32 %call +} + +define i32 @f1() !guid !1 { + call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0) + ret i32 2 +} + +define i32 @f2() !guid !2 { + call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0) + call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @f4) + %r = call i32 @f4() + ret i32 %r +} + +define i32 @testfunc2(ptr %p) !guid !4 { + call void @llvm.instrprof.increment(ptr null, i64 0, i32 1, i32 0) + call void @llvm.instrprof.callsite(ptr null, i64 0, i32 1, i32 0, ptr @testfunc1) + %r = call i32 @testfunc1(ptr %p) + ret i32 %r +} + +declare i32 @f3() + +define i32 @f4() !guid !3 { + ret i32 3 +} + +!0 = !{i64 1000} +!1 = !{i64 1001} +!2 = !{i64 1002} +!3 = !{i64 1004} +!4 = !{i64 1005} +)IR"); + + // Synthesize a profile. The profile is nonsensical, but the goal is to check + // that new BBs are created with IDs and the right counter values. + ContextManager Mgr; + auto BuildTree = [&](const std::vector &CalleeEntrycounts) { + auto *Entry = Mgr.createNode(1000, 1, 1); + // Set the entrycount to 1 so it's not 0. We don't care about it, really, + // for this test but we generally assume it's not 0. + Entry->counters()[0] = 1; + auto *F1 = Mgr.createNode(1001, 1, 0); + auto *F2 = Mgr.createNode(1002, 1, 1, F1); + auto *F3 = Mgr.createNode(1003, 1, 0, F2); + auto *F4 = Mgr.createNode(1004, 1, 0); + + F1->counters()[0] = CalleeEntrycounts[0]; + F2->counters()[0] = CalleeEntrycounts[1]; + F3->counters()[0] = CalleeEntrycounts[2]; + F4->counters()[0] = CalleeEntrycounts[3]; + F2->subContexts()[0] = F4; + Entry->subContexts()[0] = F3; // which chains F2 and F1 + return Entry; + }; + // We'll be interested in f2. the entry counts for it are: 11 in the first + // context; and 102 in the second. + // The total number of times the indirect callsite is exercised is: + // 10+11+12 = 35 in the first case; and 101+102+103 = 306 in the + // second. + // This means that the direct/indirect call counters will be: 11/22 in the + // first case and 102/204 in the second. Meaning, the "Counters" for the + // GUID=1002 context will look like [1, 11, 22] and [1, 102, 204], + // respectivelly (the first "1" being the entrycount which we set to 1 above) + auto *Entry1 = BuildTree({10, 11, 12, 13}); + auto *SubTree2 = BuildTree({101, 102, 103, 104}); + auto *Entry2 = Mgr.createNode(1005, 1, 1); + Entry2->counters()[0] = 2; + Entry2->subContexts()[0] = SubTree2; + + llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique*/ true); + { + std::error_code EC; + raw_fd_stream Out(ProfileFile.path(), EC); + ASSERT_FALSE(EC); + { + PGOCtxProfileWriter Writer(Out); + Writer.write(*Entry1); + Writer.write(*Entry2); + } + } + + ModuleAnalysisManager MAM; + MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); }); + MAM.registerPass([&]() { return PassInstrumentationAnalysis(); }); + auto &CtxProf = MAM.getResult(*M); + auto *Caller = M->getFunction("testfunc1"); + ASSERT_TRUE(!!Caller); + auto *Callee = M->getFunction("f2"); + ASSERT_TRUE(!!Callee); + auto *IndirectCS = [&]() -> CallBase * { + for (auto &BB : *Caller) + for (auto &I : BB) + if (auto *CB = dyn_cast(&I); CB && CB->isIndirectCall()) + return CB; + return nullptr; + }(); + ASSERT_TRUE(!!IndirectCS); + promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf); + + std::string Str; + raw_string_ostream OS(Str); + CtxProfAnalysisPrinterPass Printer( + OS, CtxProfAnalysisPrinterPass::PrintMode::JSON); + Printer.run(*M, MAM); + const char *Expected = R"( + [ + { + "Guid": 1000, + "Counters": [1, 11, 22], + "Callsites": [ + [{ "Guid": 1001, + "Counters": [10]}, + { "Guid": 1003, + "Counters": [12] + }], + [{ "Guid": 1002, + "Counters": [11], + "Callsites": [ + [{ "Guid": 1004, + "Counters": [13] }]]}]] + }, + { + "Guid": 1005, + "Counters": [2], + "Callsites": [ + [{ "Guid": 1000, + "Counters": [1, 102, 204], + "Callsites": [ + [{ "Guid": 1001, + "Counters": [101]}, + { "Guid": 1003, + "Counters": [103]}], + [{ "Guid": 1002, + "Counters": [102], + "Callsites": [ + [{ "Guid": 1004, + "Counters": [104]}]]}]]}]]} +])"; + auto ExpectedJSON = json::parse(Expected); + ASSERT_TRUE(!!ExpectedJSON); + auto ProducedJSON = json::parse(Str); + ASSERT_TRUE(!!ProducedJSON); + EXPECT_EQ(*ProducedJSON, *ExpectedJSON); +}