diff --git a/llvm/include/llvm/Analysis/CtxProfAnalysis.h b/llvm/include/llvm/Analysis/CtxProfAnalysis.h index 0b4dd8ae3a0dc7..10aef6f6067b6f 100644 --- a/llvm/include/llvm/Analysis/CtxProfAnalysis.h +++ b/llvm/include/llvm/Analysis/CtxProfAnalysis.h @@ -73,6 +73,12 @@ class PGOContextualProfile { return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++; } + using ConstVisitor = function_ref; + using Visitor = function_ref; + + void update(Visitor, const Function *F = nullptr); + void visit(ConstVisitor, const Function *F = nullptr) const; + const CtxProfFlatProfile flatten() const; bool invalidate(Module &, const PreservedAnalyses &PA, @@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin { class CtxProfAnalysisPrinterPass : public PassInfoMixin { - raw_ostream &OS; - public: - explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {} + enum class PrintMode { Everything, JSON }; + explicit CtxProfAnalysisPrinterPass(raw_ostream &OS, + PrintMode Mode = PrintMode::Everything) + : OS(OS), Mode(Mode) {} PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); static bool isRequired() { return true; } + +private: + raw_ostream &OS; + const PrintMode Mode; }; /// Assign a GUID to functions as metadata. GUID calculation takes linkage into diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h index 0c8c07654d2259..a9cb78483e7831 100644 --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -1597,6 +1597,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase { ConstantInt *getNumCounters() const; // The index of the counter that this instruction acts on. ConstantInt *getIndex() const; + void setIndex(uint32_t Idx); }; /// This represents the llvm.instrprof.cover intrinsic. @@ -1647,6 +1648,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase { return isa(V) && classof(cast(V)); } Value *getCallee() const; + void setCallee(Value *Callee); }; /// This represents the llvm.instrprof.timestamp intrinsic. diff --git a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h index 190deaeeacd085..f7f88966f7573f 100644 --- a/llvm/include/llvm/ProfileData/PGOCtxProfReader.h +++ b/llvm/include/llvm/ProfileData/PGOCtxProfReader.h @@ -57,9 +57,25 @@ class PGOCtxProfContext final { GlobalValue::GUID guid() const { return GUID; } const SmallVectorImpl &counters() const { return Counters; } + SmallVectorImpl &counters() { return Counters; } + + uint64_t getEntrycount() const { + assert(!Counters.empty() && + "Functions are expected to have at their entry BB instrumented, so " + "there should always be at least 1 counter."); + return Counters[0]; + } + const CallsiteMapTy &callsites() const { return Callsites; } CallsiteMapTy &callsites() { return Callsites; } + void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) { + auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy()); + Iter->second.emplace(Other.guid(), std::move(Other)); + } + + void resizeCounters(uint32_t Size) { Counters.resize(Size); } + bool hasCallsite(uint32_t I) const { return Callsites.find(I) != Callsites.end(); } @@ -68,6 +84,12 @@ class PGOCtxProfContext final { assert(hasCallsite(I) && "Callsite not found"); return Callsites.find(I)->second; } + + CallTargetMapTy &callsite(uint32_t I) { + assert(hasCallsite(I) && "Callsite not found"); + return Callsites.find(I)->second; + } + void getContainedGuids(DenseSet &Guids) const; }; diff --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h index 385831f457038d..58af26f31417b0 100644 --- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h +++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h @@ -14,6 +14,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H +#include "llvm/Analysis/CtxProfAnalysis.h" namespace llvm { template class ArrayRef; class Constant; @@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee, CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee, MDNode *BranchWeights = nullptr); +CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee, + PGOContextualProfile &CtxProf); + /// This is similar to `promoteCallWithIfThenElse` except that the condition to /// promote a virtual call is that \p VPtr is the same as any of \p /// AddressPoints. diff --git a/llvm/lib/Analysis/CtxProfAnalysis.cpp b/llvm/lib/Analysis/CtxProfAnalysis.cpp index e60eb10fa300c7..6f8455d5077066 100644 --- a/llvm/lib/Analysis/CtxProfAnalysis.cpp +++ b/llvm/lib/Analysis/CtxProfAnalysis.cpp @@ -176,16 +176,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M, return PreservedAnalyses::all(); } - OS << "Function Info:\n"; - for (const auto &[Guid, FuncInfo] : C.FuncInfo) - OS << Guid << " : " << FuncInfo.Name - << ". MaxCounterID: " << FuncInfo.NextCounterIndex - << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n"; + if (Mode == PrintMode::Everything) { + OS << "Function Info:\n"; + for (const auto &[Guid, FuncInfo] : C.FuncInfo) + OS << Guid << " : " << FuncInfo.Name + << ". MaxCounterID: " << FuncInfo.NextCounterIndex + << ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n"; + } const auto JSONed = ::llvm::json::toJSON(C.profiles()); - OS << "\nCurrent Profile:\n"; + if (Mode == PrintMode::Everything) + OS << "\nCurrent Profile:\n"; OS << formatv("{0:2}", JSONed); + if (Mode == PrintMode::JSON) + return PreservedAnalyses::all(); + OS << "\n"; OS << "\nFlat Profile:\n"; auto Flat = C.flatten(); @@ -212,34 +218,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) { return nullptr; } -static void -preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles, - function_ref Visitor) { - std::function Traverser = - [&](const auto &Ctx) { - Visitor(Ctx); - for (const auto &[_, SubCtxSet] : Ctx.callsites()) - for (const auto &[__, Subctx] : SubCtxSet) - Traverser(Subctx); - }; - for (const auto &[_, P] : Profiles) +template +static void preorderVisit(ProfilesTy &Profiles, + function_ref Visitor, + GlobalValue::GUID Match = 0) { + std::function Traverser = [&](auto &Ctx) { + if (!Match || Ctx.guid() == Match) + Visitor(Ctx); + for (auto &[_, SubCtxSet] : Ctx.callsites()) + for (auto &[__, Subctx] : SubCtxSet) + Traverser(Subctx); + }; + for (auto &[_, P] : Profiles) Traverser(P); } +void PGOContextualProfile::update(Visitor V, const Function *F) { + GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U; + preorderVisit( + *Profiles, V, G); +} + +void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const { + GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U; + preorderVisit(*Profiles, V, G); +} + const CtxProfFlatProfile PGOContextualProfile::flatten() const { assert(Profiles.has_value()); CtxProfFlatProfile Flat; - preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) { - auto [It, Ins] = Flat.insert({Ctx.guid(), {}}); - if (Ins) { - llvm::append_range(It->second, Ctx.counters()); - return; - } - assert(It->second.size() == Ctx.counters().size() && - "All contexts corresponding to a function should have the exact " - "same number of counters."); - for (size_t I = 0, E = It->second.size(); I < E; ++I) - It->second[I] += Ctx.counters()[I]; - }); + preorderVisit( + *Profiles, [&](const PGOCtxProfContext &Ctx) { + auto [It, Ins] = Flat.insert({Ctx.guid(), {}}); + if (Ins) { + llvm::append_range(It->second, Ctx.counters()); + return; + } + assert(It->second.size() == Ctx.counters().size() && + "All contexts corresponding to a function should have the exact " + "same number of counters."); + for (size_t I = 0, E = It->second.size(); I < E; ++I) + It->second[I] += Ctx.counters()[I]; + }); return Flat; } diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp index fdb3784675cef7..8e40189c7e8560 100644 --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -286,6 +286,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const { return cast(const_cast(getArgOperand(3))); } +void InstrProfCntrInstBase::setIndex(uint32_t Idx) { + assert(isa(this)); + setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx)); +} + Value *InstrProfIncrementInst::getStep() const { if (InstrProfIncrementInstStep::classof(this)) { return const_cast(getArgOperand(4)); @@ -301,6 +306,11 @@ Value *InstrProfCallsite::getCallee() const { return nullptr; } +void InstrProfCallsite::setCallee(Value *Callee) { + assert(isa(this)); + setArgOperand(4, Callee); +} + std::optional ConstrainedFPIntrinsic::getRoundingMode() const { unsigned NumOperands = arg_size(); Metadata *MD = nullptr; diff --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 90dc727cde16d7..5f872c352429c1 100644 --- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -12,14 +12,16 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/AttributeMask.h" #include "llvm/IR/Constant.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/ProfileData/PGOCtxProfReader.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -572,6 +574,88 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee, return promoteCall(NewInst, Callee); } +CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee, + PGOContextualProfile &CtxProf) { + assert(CB.isIndirectCall()); + if (!CtxProf.isFunctionKnown(Callee)) + return nullptr; + auto &Caller = *CB.getFunction(); + auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB); + if (!CSInstr) + return nullptr; + const uint64_t CSIndex = CSInstr->getIndex()->getZExtValue(); + + CallBase &DirectCall = promoteCall( + versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee); + CSInstr->moveBefore(&CB); + const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller); + auto *NewCSInstr = cast(CSInstr->clone()); + NewCSInstr->setIndex(NewCSID); + NewCSInstr->setCallee(&Callee); + NewCSInstr->insertBefore(&DirectCall); + auto &DirectBB = *DirectCall.getParent(); + auto &IndirectBB = *CB.getParent(); + + assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) && + "The ICP direct BB is new, it shouldn't have instrumentation"); + assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) && + "The ICP indirect BB is new, it shouldn't have instrumentation"); + + // Allocate counters for the new basic blocks. + const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller); + const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller); + auto *EntryBBIns = + CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock()); + auto *DirectBBIns = cast(EntryBBIns->clone()); + DirectBBIns->setIndex(DirectID); + DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt()); + + auto *IndirectBBIns = cast(EntryBBIns->clone()); + IndirectBBIns->setIndex(IndirectID); + IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt()); + + const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee); + const uint32_t NewCountersSize = IndirectID + 1; + + auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) { + assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller)); + assert(NewCountersSize - 2 == Ctx.counters().size()); + // All the ctx-es belonging to a function must have the same size counters. + Ctx.resizeCounters(NewCountersSize); + + // Maybe in this context, the indirect callsite wasn't observed at all + if (!Ctx.hasCallsite(CSIndex)) + return; + auto &CSData = Ctx.callsite(CSIndex); + auto It = CSData.find(CalleeGUID); + + // Maybe we did notice the indirect callsite, but to other targets. + if (It == CSData.end()) + return; + + assert(CalleeGUID == It->second.guid()); + + uint32_t DirectCount = It->second.getEntrycount(); + uint32_t TotalCount = 0; + for (const auto &[_, V] : CSData) + TotalCount += V.getEntrycount(); + assert(TotalCount >= DirectCount); + uint32_t IndirectCount = TotalCount - DirectCount; + // The ICP's effect is as-if the direct BB would have been taken DirectCount + // times, and the indirect BB, IndirectCount times + Ctx.counters()[DirectID] = DirectCount; + Ctx.counters()[IndirectID] = IndirectCount; + + // This particular indirect target needs to be moved to this caller under + // the newly-allocated callsite index. + assert(Ctx.callsites().count(NewCSID) == 0); + Ctx.ingestContext(NewCSID, std::move(It->second)); + CSData.erase(CalleeGUID); + }; + CtxProf.update(ProfileUpdater, &Caller); + return &DirectCall; +} + CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, Function *Callee, ArrayRef AddressPoints, diff --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp index 2d457eb3b678aa..36c64b9f333d7c 100644 --- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -14,7 +15,13 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/PassInstrumentation.h" +#include "llvm/ProfileData/PGOCtxProfReader.h" +#include "llvm/ProfileData/PGOCtxProfWriter.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/SupportHelpers.h" #include "gtest/gtest.h" using namespace llvm; @@ -456,3 +463,153 @@ declare void @_ZN5Base35func3Ev(ptr) // 1 call instruction from the entry block. EXPECT_EQ(F->front().size(), OrigEntryBBSize + 4); } + +TEST(CallPromotionUtilsTest, PromoteWithIcmpAndCtxProf) { + LLVMContext C; + std::unique_ptr M = parseIR(C, + R"IR( +define i32 @testfunc1(ptr %d) !guid !0 { + call void @llvm.instrprof.increment(ptr @testfunc1, i64 0, i32 1, i32 0) + call void @llvm.instrprof.callsite(ptr @testfunc1, i64 0, i32 1, i32 0, ptr %d) + %call = call i32 %d() + ret i32 %call +} + +define i32 @f1() !guid !1 { + call void @llvm.instrprof.increment(ptr @f1, i64 0, i32 1, i32 0) + ret i32 2 +} + +define i32 @f2() !guid !2 { + call void @llvm.instrprof.increment(ptr @f2, i64 0, i32 1, i32 0) + call void @llvm.instrprof.callsite(ptr @f2, i64 0, i32 1, i32 0, ptr @f4) + %r = call i32 @f4() + ret i32 %r +} + +define i32 @testfunc2(ptr %p) !guid !4 { + call void @llvm.instrprof.increment(ptr @testfunc2, i64 0, i32 1, i32 0) + call void @llvm.instrprof.callsite(ptr @testfunc2, i64 0, i32 1, i32 0, ptr @testfunc1) + %r = call i32 @testfunc1(ptr %p) + ret i32 %r +} + +declare i32 @f3() + +define i32 @f4() !guid !3 { + ret i32 3 +} + +!0 = !{i64 1000} +!1 = !{i64 1001} +!2 = !{i64 1002} +!3 = !{i64 1004} +!4 = !{i64 1005} +)IR"); + + const char *Profile = R"json( + [ + { + "Guid": 1000, + "Counters": [1], + "Callsites": [ + [{ "Guid": 1001, + "Counters": [10]}, + { "Guid": 1002, + "Counters": [11], + "Callsites": [[{"Guid": 1004, "Counters":[13]}]] + }, + { "Guid": 1003, + "Counters": [12] + }]] + }, + { + "Guid": 1005, + "Counters": [2], + "Callsites": [ + [{ "Guid": 1000, + "Counters": [1], + "Callsites": [ + [{ "Guid": 1001, + "Counters": [101]}, + { "Guid": 1002, + "Counters": [102], + "Callsites": [[{"Guid": 1004, "Counters":[104]}]] + }, + { "Guid": 1003, + "Counters": [103] + }]]}]]}] + )json"; + + llvm::unittest::TempFile ProfileFile("ctx_profile", "", "", /*Unique=*/true); + { + std::error_code EC; + raw_fd_stream Out(ProfileFile.path(), EC); + ASSERT_FALSE(EC); + // "False" means no error. + ASSERT_FALSE(llvm::createCtxProfFromJSON(Profile, Out)); + } + + ModuleAnalysisManager MAM; + MAM.registerPass([&]() { return CtxProfAnalysis(ProfileFile.path()); }); + MAM.registerPass([&]() { return PassInstrumentationAnalysis(); }); + auto &CtxProf = MAM.getResult(*M); + auto *Caller = M->getFunction("testfunc1"); + ASSERT_NE(Caller, nullptr); + auto *Callee = M->getFunction("f2"); + ASSERT_NE(Callee, nullptr); + auto *IndirectCS = [&]() -> CallBase * { + for (auto &BB : *Caller) + for (auto &I : BB) + if (auto *CB = dyn_cast(&I); CB && CB->isIndirectCall()) + return CB; + return nullptr; + }(); + ASSERT_NE(IndirectCS, nullptr); + promoteCallWithIfThenElse(*IndirectCS, *Callee, CtxProf); + + std::string Str; + raw_string_ostream OS(Str); + CtxProfAnalysisPrinterPass Printer( + OS, CtxProfAnalysisPrinterPass::PrintMode::JSON); + Printer.run(*M, MAM); + const char *Expected = R"json( + [ + { + "Guid": 1000, + "Counters": [1, 11, 22], + "Callsites": [ + [{ "Guid": 1001, + "Counters": [10]}, + { "Guid": 1003, + "Counters": [12] + }], + [{ "Guid": 1002, + "Counters": [11], + "Callsites": [ + [{ "Guid": 1004, + "Counters": [13] }]]}]] + }, + { + "Guid": 1005, + "Counters": [2], + "Callsites": [ + [{ "Guid": 1000, + "Counters": [1, 102, 204], + "Callsites": [ + [{ "Guid": 1001, + "Counters": [101]}, + { "Guid": 1003, + "Counters": [103]}], + [{ "Guid": 1002, + "Counters": [102], + "Callsites": [ + [{ "Guid": 1004, + "Counters": [104]}]]}]]}]]} +])json"; + auto ExpectedJSON = json::parse(Expected); + ASSERT_TRUE(!!ExpectedJSON); + auto ProducedJSON = json::parse(Str); + ASSERT_TRUE(!!ProducedJSON); + EXPECT_EQ(*ProducedJSON, *ExpectedJSON); +}