Skip to content

Commit

Permalink
[ctx_prof] Add support for ICP
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrofin committed Aug 21, 2024
1 parent 0d7c720 commit 61e37e3
Show file tree
Hide file tree
Showing 8 changed files with 364 additions and 33 deletions.
18 changes: 14 additions & 4 deletions llvm/include/llvm/Analysis/CtxProfAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class PGOContextualProfile {
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex++;
}

using ConstVisitor = function_ref<void(const PGOCtxProfContext &)>;
using Visitor = function_ref<void(PGOCtxProfContext &)>;

void update(Visitor, const Function *F = nullptr);
void visit(ConstVisitor, const Function *F = nullptr) const;

const CtxProfFlatProfile flatten() const;

bool invalidate(Module &, const PreservedAnalyses &PA,
Expand Down Expand Up @@ -105,13 +111,18 @@ class CtxProfAnalysis : public AnalysisInfoMixin<CtxProfAnalysis> {

class CtxProfAnalysisPrinterPass
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
raw_ostream &OS;

public:
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS) : OS(OS) {}
enum class PrintMode { Everything, JSON };
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
PrintMode Mode = PrintMode::Everything)
: OS(OS), Mode(Mode) {}

PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }

private:
raw_ostream &OS;
const PrintMode Mode;
};

/// Assign a GUID to functions as metadata. GUID calculation takes linkage into
Expand All @@ -134,6 +145,5 @@ class AssignGUIDPass : public PassInfoMixin<AssignGUIDPass> {
// This should become GlobalValue::getGUID
static uint64_t getGUID(const Function &F);
};

} // namespace llvm
#endif // LLVM_ANALYSIS_CTXPROFANALYSIS_H
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,7 @@ class InstrProfCntrInstBase : public InstrProfInstBase {
ConstantInt *getNumCounters() const;
// The index of the counter that this instruction acts on.
ConstantInt *getIndex() const;
void setIndex(uint32_t Idx);
};

/// This represents the llvm.instrprof.cover intrinsic.
Expand Down Expand Up @@ -1569,6 +1570,7 @@ class InstrProfCallsite : public InstrProfCntrInstBase {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}
Value *getCallee() const;
void setCallee(Value *);
};

/// This represents the llvm.instrprof.timestamp intrinsic.
Expand Down
20 changes: 20 additions & 0 deletions llvm/include/llvm/ProfileData/PGOCtxProfReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,23 @@ class PGOCtxProfContext final {

GlobalValue::GUID guid() const { return GUID; }
const SmallVectorImpl<uint64_t> &counters() const { return Counters; }
SmallVectorImpl<uint64_t> &counters() { return Counters; }

uint64_t getEntrycount() const { return Counters[0]; }

const CallsiteMapTy &callsites() const { return Callsites; }
CallsiteMapTy &callsites() { return Callsites; }

void ingestContext(uint32_t CSId, PGOCtxProfContext &&Other) {
auto [Iter, _] = callsites().try_emplace(CSId, CallTargetMapTy());
Iter->second.emplace(Other.guid(), std::move(Other));
}

void growCounters(uint32_t Size) {
if (Size >= Counters.size())
Counters.resize(Size);
}

bool hasCallsite(uint32_t I) const {
return Callsites.find(I) != Callsites.end();
}
Expand All @@ -68,6 +82,12 @@ class PGOCtxProfContext final {
assert(hasCallsite(I) && "Callsite not found");
return Callsites.find(I)->second;
}

CallTargetMapTy &callsite(uint32_t I) {
assert(hasCallsite(I) && "Callsite not found");
return Callsites.find(I)->second;
}

void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const;
};

Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H

#include "llvm/Analysis/CtxProfAnalysis.h"
namespace llvm {
template <typename T> class ArrayRef;
class Constant;
Expand Down Expand Up @@ -56,6 +57,9 @@ CallBase &promoteCall(CallBase &CB, Function *Callee,
CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
MDNode *BranchWeights = nullptr);

CallBase *promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
PGOContextualProfile &CtxProf);

/// This is similar to `promoteCallWithIfThenElse` except that the condition to
/// promote a virtual call is that \p VPtr is the same as any of \p
/// AddressPoints.
Expand Down
79 changes: 50 additions & 29 deletions llvm/lib/Analysis/CtxProfAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,22 @@ PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
return PreservedAnalyses::all();
}

OS << "Function Info:\n";
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
OS << Guid << " : " << FuncInfo.Name
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
if (Mode == PrintMode::Everything) {
OS << "Function Info:\n";
for (const auto &[Guid, FuncInfo] : C.FuncInfo)
OS << Guid << " : " << FuncInfo.Name
<< ". MaxCounterID: " << FuncInfo.NextCounterIndex
<< ". MaxCallsiteID: " << FuncInfo.NextCallsiteIndex << "\n";
}

const auto JSONed = ::llvm::json::toJSON(C.profiles());

OS << "\nCurrent Profile:\n";
if (Mode == PrintMode::Everything)
OS << "\nCurrent Profile:\n";
OS << formatv("{0:2}", JSONed);
if (Mode == PrintMode::JSON)
return PreservedAnalyses::all();

OS << "\n";
OS << "\nFlat Profile:\n";
auto Flat = C.flatten();
Expand All @@ -209,34 +215,49 @@ InstrProfIncrementInst *CtxProfAnalysis::getBBInstrumentation(BasicBlock &BB) {
return nullptr;
}

static void
preorderVisit(const PGOCtxProfContext::CallTargetMapTy &Profiles,
function_ref<void(const PGOCtxProfContext &)> Visitor) {
std::function<void(const PGOCtxProfContext &)> Traverser =
[&](const auto &Ctx) {
Visitor(Ctx);
for (const auto &[_, SubCtxSet] : Ctx.callsites())
for (const auto &[__, Subctx] : SubCtxSet)
Traverser(Subctx);
};
for (const auto &[_, P] : Profiles)
template <class ProfilesTy, class ProfTy>
static void preorderVisit(ProfilesTy &Profiles,
function_ref<void(ProfTy &)> Visitor,
GlobalValue::GUID Match = 0) {
std::function<void(ProfTy &)> Traverser = [&](auto &Ctx) {
if (!Match || Ctx.guid() == Match)
Visitor(Ctx);
for (auto &[_, SubCtxSet] : Ctx.callsites())
for (auto &[__, Subctx] : SubCtxSet)
Traverser(Subctx);
};
for (auto &[_, P] : Profiles)
Traverser(P);
}

void PGOContextualProfile::update(Visitor V, const Function *F) {
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
preorderVisit<PGOCtxProfContext::CallTargetMapTy, PGOCtxProfContext>(
*Profiles, V, G);
}

void PGOContextualProfile::visit(ConstVisitor V, const Function *F) const {
GlobalValue::GUID G = F ? getDefinedFunctionGUID(*F) : 0U;
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
const PGOCtxProfContext>(*Profiles, V, G);
}

const CtxProfFlatProfile PGOContextualProfile::flatten() const {
assert(Profiles.has_value());
CtxProfFlatProfile Flat;
preorderVisit(*Profiles, [&](const PGOCtxProfContext &Ctx) {
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
if (Ins) {
llvm::append_range(It->second, Ctx.counters());
return;
}
assert(It->second.size() == Ctx.counters().size() &&
"All contexts corresponding to a function should have the exact "
"same number of counters.");
for (size_t I = 0, E = It->second.size(); I < E; ++I)
It->second[I] += Ctx.counters()[I];
});
preorderVisit<const PGOCtxProfContext::CallTargetMapTy,
const PGOCtxProfContext>(
*Profiles, [&](const PGOCtxProfContext &Ctx) {
auto [It, Ins] = Flat.insert({Ctx.guid(), {}});
if (Ins) {
llvm::append_range(It->second, Ctx.counters());
return;
}
assert(It->second.size() == Ctx.counters().size() &&
"All contexts corresponding to a function should have the exact "
"same number of counters.");
for (size_t I = 0, E = It->second.size(); I < E; ++I)
It->second[I] += Ctx.counters()[I];
});
return Flat;
}
10 changes: 10 additions & 0 deletions llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ ConstantInt *InstrProfCntrInstBase::getIndex() const {
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(3)));
}

void InstrProfCntrInstBase::setIndex(uint32_t Idx) {
assert(isa<InstrProfCntrInstBase>(this));
setArgOperand(3, ConstantInt::get(Type::getInt32Ty(getContext()), Idx));
}

Value *InstrProfIncrementInst::getStep() const {
if (InstrProfIncrementInstStep::classof(this)) {
return const_cast<Value *>(getArgOperand(4));
Expand All @@ -300,6 +305,11 @@ Value *InstrProfCallsite::getCallee() const {
return nullptr;
}

void InstrProfCallsite::setCallee(Value *V) {
assert(isa<InstrProfCallsite>(this));
setArgOperand(4, V);
}

std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
unsigned NumOperands = arg_size();
Metadata *MD = nullptr;
Expand Down
86 changes: 86 additions & 0 deletions llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@

#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/CtxProfAnalysis.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/AttributeMask.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/ProfileData/PGOCtxProfReader.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"

using namespace llvm;
Expand Down Expand Up @@ -572,6 +575,89 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
return promoteCall(NewInst, Callee);
}

CallBase *llvm::promoteCallWithIfThenElse(CallBase &CB, Function &Callee,
PGOContextualProfile &CtxProf) {
assert(CB.isIndirectCall());
if (!CtxProf.isFunctionKnown(Callee))
return nullptr;
auto &Caller = *CB.getParent()->getParent();
auto *CSInstr = CtxProfAnalysis::getCallsiteInstrumentation(CB);
if (!CSInstr)
return nullptr;
const auto CSIndex = CSInstr->getIndex()->getZExtValue();

CallBase &DirectCall = promoteCall(
versionCallSite(CB, &Callee, /*BranchWeights=*/nullptr), &Callee);
CSInstr->moveBefore(&CB);
const auto NewCSID = CtxProf.allocateNextCallsiteIndex(Caller);
auto *NewCSInstr = cast<InstrProfCallsite>(CSInstr->clone());
NewCSInstr->setIndex(NewCSID);
NewCSInstr->setCallee(&Callee);
NewCSInstr->insertBefore(&DirectCall);
auto &DirectBB = *DirectCall.getParent();
auto &IndirectBB = *CB.getParent();

assert((CtxProfAnalysis::getBBInstrumentation(IndirectBB) == nullptr) &&
"The ICP direct BB is new, it shouldn't have instrumentation");
assert((CtxProfAnalysis::getBBInstrumentation(DirectBB) == nullptr) &&
"The ICP indirect BB is new, it shouldn't have instrumentation");

// Make the 2 new BBs have counters.
const uint32_t DirectID = CtxProf.allocateNextCounterIndex(Caller);
const uint32_t IndirectID = CtxProf.allocateNextCounterIndex(Caller);
const uint32_t NewCountersSize = IndirectID + 1;
auto *EntryBBIns =
CtxProfAnalysis::getBBInstrumentation(Caller.getEntryBlock());
auto *DirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
DirectBBIns->setIndex(DirectID);
DirectBBIns->insertInto(&DirectBB, DirectBB.getFirstInsertionPt());

auto *IndirectBBIns = cast<InstrProfCntrInstBase>(EntryBBIns->clone());
IndirectBBIns->setIndex(IndirectID);
IndirectBBIns->insertInto(&IndirectBB, IndirectBB.getFirstInsertionPt());

const GlobalValue::GUID CalleeGUID = AssignGUIDPass::getGUID(Callee);

auto ProfileUpdater = [&](PGOCtxProfContext &Ctx) {
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
assert(NewCountersSize - 2 == Ctx.counters().size());
// Regardless what next, all the ctx-es belonging to a function must have
// the same size counters.
Ctx.growCounters(NewCountersSize);

// Maybe in this context, the indirect callsite wasn't observed at all
if (!Ctx.hasCallsite(CSIndex))
return;
auto &CSData = Ctx.callsite(CSIndex);
auto It = CSData.find(CalleeGUID);

// Maybe we did notice the indirect callsite, but to other targets.
if (It == CSData.end())
return;

assert(CalleeGUID == It->second.guid());

uint32_t DirectCount = It->second.getEntrycount();
uint32_t TotalCount = 0;
for (const auto &[_, V] : CSData)
TotalCount += V.getEntrycount();
assert(TotalCount >= DirectCount);
uint32_t IndirectCount = TotalCount - DirectCount;
// The ICP's effect is as-if the direct BB would have been taken DirectCount
// times, and the indirect BB, IndirectCount times
Ctx.counters()[DirectID] = DirectCount;
Ctx.counters()[IndirectID] = IndirectCount;

// This particular indirect target needs to be moved to this caller under
// the newly-allocated callsite index.
assert(Ctx.callsites().count(NewCSID) == 0);
Ctx.ingestContext(NewCSID, std::move(It->second));
CSData.erase(CalleeGUID);
};
CtxProf.update(ProfileUpdater, &Caller);
return &DirectCall;
}

CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr,
Function *Callee,
ArrayRef<Constant *> AddressPoints,
Expand Down
Loading

0 comments on commit 61e37e3

Please sign in to comment.