Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT: Regularize readbacks for parameters/OSR-locals in physical promotion #87165

Merged
merged 5 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/coreclr/jit/block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ BasicBlock* BasicBlock::GetUniquePred(Compiler* compiler) const
//
// Return Value:
// The unique successor of a block, or nullptr if there is no unique successor.

//
BasicBlock* BasicBlock::GetUniqueSucc() const
{
if (bbJumpKind == BBJ_ALWAYS)
Expand Down
221 changes: 167 additions & 54 deletions src/coreclr/jit/promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,77 @@ GenTree* Promotion::CreateReadBack(Compiler* compiler, unsigned structLclNum, co
return store;
}

//------------------------------------------------------------------------
// StartBlock:
// Handle reaching the end of the currently started block by preparing
// internal state for upcoming basic blocks, and inserting any necessary
// readbacks.
//
// Parameters:
// block - The block
//
void ReplaceVisitor::StartBlock(BasicBlock* block)
{
m_currentBlock = block;

#ifdef DEBUG
// At the start of every block we expect all replacements to be in their
// local home.
for (AggregateInfo* agg : m_aggregates)
{
if (agg == nullptr)
{
continue;
}

for (Replacement& rep : agg->Replacements)
{
assert(!rep.NeedsReadBack);
assert(rep.NeedsWriteBack);
}
}
#endif

// OSR locals and parameters may need an initial read back, which we mark
// when we start the scratch BB.
if (!m_compiler->fgBBisScratch(block))
{
return;
}

for (AggregateInfo* agg : m_aggregates)
{
if (agg == nullptr)
{
continue;
}

LclVarDsc* dsc = m_compiler->lvaGetDesc(agg->LclNum);
if (!dsc->lvIsParam && !dsc->lvIsOSRLocal)
{
continue;
}

JITDUMP("Marking fields of %s V%02u as needing read-back in scratch " FMT_BB "\n",
dsc->lvIsParam ? "parameter" : "OSR-local", agg->LclNum, block->bbNum);

for (size_t i = 0; i < agg->Replacements.size(); i++)
{
Replacement& rep = agg->Replacements[i];
rep.NeedsWriteBack = false;
if (m_liveness->IsReplacementLiveIn(block, agg->LclNum, (unsigned)i))
{
rep.NeedsReadBack = true;
JITDUMP(" V%02u (%s) marked\n", rep.LclNum, rep.Description);
}
else
{
JITDUMP(" V%02u (%s) not marked (not live-in to scratch BB)\n", rep.LclNum, rep.Description);
}
}
}
}

//------------------------------------------------------------------------
// EndBlock:
// Handle reaching the end of the currently started block by preparing
Expand Down Expand Up @@ -1191,11 +1262,27 @@ void ReplaceVisitor::EndBlock()
}
else
{
// We only mark fields as requiring read-back if they are
// live at the point where the stack local was written, so
// at first glance we would not expect this case to ever
// happen. However, it is possible that the field is live
// because it has a future struct use, in which case we may
// not need to insert any readbacks anywhere. For example,
// consider:
//
// V03 = CALL() // V03 is a struct with promoted V03.[000..008)
// CALL(struct V03) // V03.[000.008) marked as live here
//
// While V03.[000.008) gets marked for readback at the
// assignment, no readback is necessary at the location of
// the call argument, and it may die after that.

JITDUMP("Skipping reading back dead replacement V%02u.[%03u..%03u) -> V%02u near the end of " FMT_BB
"\n",
agg->LclNum, rep.Offset, rep.Offset + genTypeSize(rep.AccessType), rep.LclNum,
m_currentBlock->bbNum);
}

rep.NeedsReadBack = false;
}

Expand All @@ -1206,6 +1293,18 @@ void ReplaceVisitor::EndBlock()
m_hasPendingReadBacks = false;
}

//------------------------------------------------------------------------
// PostOrderVisit:
// Visit a node in post-order and make necessary changes for promoted field
// uses.
//
// Parameters:
// use - The use edge
// user - The user
//
// Returns:
// Visitor result.
//
Compiler::fgWalkResult ReplaceVisitor::PostOrderVisit(GenTree** use, GenTree* user)
{
GenTree* tree = *use;
Expand Down Expand Up @@ -1300,16 +1399,13 @@ GenTree** ReplaceVisitor::InsertMidTreeReadBacksIfNecessary(GenTree** use)

for (Replacement& rep : agg->Replacements)
{
// TODO-CQ: We should ensure we do not mark dead fields as
// requiring readback. Currently it is handled by querying liveness
// as part of end-of-block readback insertion, but for these
// mid-tree readbacks we cannot query liveness information for
// arbitrary locals.
if (!rep.NeedsReadBack)
{
continue;
}

JITDUMP(" V%02.[%03u..%03u) -> V%02u\n", agg->LclNum, rep.Offset, genTypeSize(rep.AccessType), rep.LclNum);

rep.NeedsReadBack = false;
GenTree* readBack = Promotion::CreateReadBack(m_compiler, agg->LclNum, rep);
*use =
Expand Down Expand Up @@ -1369,10 +1465,7 @@ void ReplaceVisitor::LoadStoreAroundCall(GenTreeCall* call, GenTree* user)
GenTreeLclVarCommon* retBufLcl = retBufArg->GetNode()->AsLclVarCommon();
unsigned size = m_compiler->typGetObjLayout(call->gtRetClsHnd)->GetSize();

if (MarkForReadBack(retBufLcl->GetLclNum(), retBufLcl->GetLclOffs(), size))
{
JITDUMP("Retbuf has replacements that were marked for read back\n");
}
MarkForReadBack(retBufLcl, size DEBUGARG("used as retbuf"));
}
}

Expand Down Expand Up @@ -1486,9 +1579,9 @@ void ReplaceVisitor::ReplaceLocal(GenTree** use, GenTree* user)

Replacement& rep = replacements[index];
assert(accessType == rep.AccessType);
JITDUMP(" ..replaced with promoted lcl V%02u\n", rep.LclNum);

bool isDef = lcl->OperIsLocalStore();

if (isDef)
{
*use = m_compiler->gtNewStoreLclVarNode(rep.LclNum, lcl->Data());
Expand All @@ -1507,6 +1600,7 @@ void ReplaceVisitor::ReplaceLocal(GenTree** use, GenTree* user)
}
else if (rep.NeedsReadBack)
{
JITDUMP(" ..needs a read back\n");
*use = m_compiler->gtNewOperNode(GT_COMMA, (*use)->TypeGet(),
Promotion::CreateReadBack(m_compiler, lclNum, rep), *use);
rep.NeedsReadBack = false;
Expand Down Expand Up @@ -1537,6 +1631,8 @@ void ReplaceVisitor::ReplaceLocal(GenTree** use, GenTree* user)
m_compiler->lvaGetDesc(rep.LclNum)->lvRedefinedInEmbeddedStatement = true;
}

JITDUMP(" ..replaced with V%02u\n", rep.LclNum);

m_madeChanges = true;
}

Expand Down Expand Up @@ -1617,18 +1713,18 @@ void ReplaceVisitor::WriteBackBefore(GenTree** use, unsigned lcl, unsigned offs,
// back before their next use.
//
// Parameters:
// lcl - The struct local
// offs - The starting offset of the range in the struct local that needs to be read back from.
// size - The size of the range
// lcl - Local node. Its offset is the start of the range.
// size - The size of the range
jakobbotsch marked this conversation as resolved.
Show resolved Hide resolved
//
bool ReplaceVisitor::MarkForReadBack(unsigned lcl, unsigned offs, unsigned size)
void ReplaceVisitor::MarkForReadBack(GenTreeLclVarCommon* lcl, unsigned size DEBUGARG(const char* reason))
{
if (m_aggregates[lcl] == nullptr)
if (m_aggregates[lcl->GetLclNum()] == nullptr)
{
return false;
return;
}

jitstd::vector<Replacement>& replacements = m_aggregates[lcl]->Replacements;
unsigned offs = lcl->GetLclOffs();
jitstd::vector<Replacement>& replacements = m_aggregates[lcl->GetLclNum()]->Replacements;
size_t index = Promotion::BinarySearch<Replacement, &Replacement::Offset>(replacements, offs);

if ((ssize_t)index < 0)
Expand All @@ -1640,20 +1736,37 @@ bool ReplaceVisitor::MarkForReadBack(unsigned lcl, unsigned offs, unsigned size)
}
}

bool any = false;
unsigned end = offs + size;
while ((index < replacements.size()) && (replacements[index].Offset < end))
if ((index >= replacements.size()) || (replacements[index].Offset >= end))
{
// No overlap with any field.
return;
}

StructDeaths deaths = m_liveness->GetDeathsForStructLocal(lcl);
JITDUMP("Fields of [%06u] in range [%03u..%03u) need to be read back: %s\n", Compiler::dspTreeID(lcl), offs,
offs + size, reason);

do
{
any = true;
Replacement& rep = replacements[index];
assert(rep.Overlaps(offs, size));
rep.NeedsReadBack = true;
rep.NeedsWriteBack = false;
m_hasPendingReadBacks = true;
index++;
}

return any;
if (deaths.IsReplacementDying((unsigned)index))
{
JITDUMP(" V%02u (%s) not marked (is dying)\n", rep.LclNum, rep.Description);
}
else
{
rep.NeedsReadBack = true;
m_hasPendingReadBacks = true;
JITDUMP(" V%02u (%s) marked\n", rep.LclNum, rep.Description);
}

rep.NeedsWriteBack = false;

index++;
} while ((index < replacements.size()) && (replacements[index].Offset < end));
}

//------------------------------------------------------------------------
Expand Down Expand Up @@ -1766,17 +1879,43 @@ PhaseStatus Promotion::Run()
return PhaseStatus::MODIFIED_NOTHING;
}

// Check for parameters and OSR locals that need to be read back on entry
// to the function.
for (AggregateInfo* agg : aggregates)
{
if (agg == nullptr)
{
continue;
}

LclVarDsc* dsc = m_compiler->lvaGetDesc(agg->LclNum);
if (dsc->lvIsParam || dsc->lvIsOSRLocal)
{
// We will need an initial readback. We create the scratch BB ahead
// of time so that we get correct liveness and mark the
// parameters/OSR-locals as requiring read-back as part of
// ReplaceVisitor::StartBlock when we get to the scratch block.
m_compiler->fgEnsureFirstBBisScratch();
break;
}
}

// Compute liveness for the fields and remainders.
PromotionLiveness liveness(m_compiler, aggregates);
liveness.Run();

JITDUMP("Making replacements\n\n");

// Make all replacements we decided on.
ReplaceVisitor replacer(this, aggregates, &liveness);
for (BasicBlock* bb : m_compiler->Blocks())
{
replacer.StartBlock(bb);

JITDUMP("\nReplacing in ");
DBEXEC(m_compiler->verbose, bb->dspBlockHeader(m_compiler));
JITDUMP("\n");

for (Statement* stmt : bb->Statements())
{
DISPSTMT(stmt);
Expand All @@ -1795,8 +1934,7 @@ PhaseStatus Promotion::Run()
replacer.EndBlock();
}

// Insert initial IR to read arguments/OSR locals into replacement locals,
// and add necessary explicit zeroing.
// Add necessary explicit zeroing for some locals.
Statement* prevStmt = nullptr;
for (unsigned lclNum = 0; lclNum < numLocals; lclNum++)
{
Expand All @@ -1806,11 +1944,7 @@ PhaseStatus Promotion::Run()
}

LclVarDsc* dsc = m_compiler->lvaGetDesc(lclNum);
if (dsc->lvIsParam || dsc->lvIsOSRLocal)
{
InsertInitialReadBack(lclNum, aggregates[lclNum]->Replacements, &prevStmt);
}
else if (dsc->lvSuppressedZeroInit)
if (dsc->lvSuppressedZeroInit)
{
// We may have suppressed inserting an explicit zero init based on the
// assumption that the entire local will be zero inited in the prolog.
Expand Down Expand Up @@ -1856,27 +1990,6 @@ bool Promotion::IsCandidateForPhysicalPromotion(LclVarDsc* dsc)
return (dsc->TypeGet() == TYP_STRUCT) && !dsc->lvPromoted && !dsc->IsAddressExposed();
}

//------------------------------------------------------------------------
// Promotion::InsertInitialReadBack:
// Insert IR to initially read a struct local's value into its promoted field locals.
//
// Parameters:
// lclNum - The struct local
// replacements - Replacements for the struct local
// prevStmt - [in, out] Previous statement to insert after
//
void Promotion::InsertInitialReadBack(unsigned lclNum,
const jitstd::vector<Replacement>& replacements,
Statement** prevStmt)
{
for (unsigned i = 0; i < replacements.size(); i++)
{
const Replacement& rep = replacements[i];
GenTree* readBack = CreateReadBack(m_compiler, lclNum, rep);
InsertInitStatement(prevStmt, readBack);
}
}

//------------------------------------------------------------------------
// Promotion::ExplicitlyZeroInitReplacementLocals:
// Insert IR to zero out replacement locals if necessary.
Expand Down
10 changes: 3 additions & 7 deletions src/coreclr/jit/promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ class Promotion
static StructSegments SignificantSegments(Compiler* compiler,
ClassLayout* layout DEBUGARG(FixedBitVect** bitVectRepr = nullptr));

void InsertInitialReadBack(unsigned lclNum, const jitstd::vector<Replacement>& replacements, Statement** prevStmt);
void ExplicitlyZeroInitReplacementLocals(unsigned lclNum,
const jitstd::vector<Replacement>& replacements,
Statement** prevStmt);
Expand Down Expand Up @@ -226,6 +225,7 @@ class PromotionLiveness
}

void Run();
bool IsReplacementLiveIn(BasicBlock* bb, unsigned structLcl, unsigned replacement);
bool IsReplacementLiveOut(BasicBlock* bb, unsigned structLcl, unsigned replacement);
StructDeaths GetDeathsForStructLocal(GenTreeLclVarCommon* use);

Expand Down Expand Up @@ -271,11 +271,7 @@ class ReplaceVisitor : public GenTreeVisitor<ReplaceVisitor>
return m_madeChanges;
}

void StartBlock(BasicBlock* block)
{
m_currentBlock = block;
}

void StartBlock(BasicBlock* block);
void EndBlock();

void StartStatement()
Expand All @@ -292,7 +288,7 @@ class ReplaceVisitor : public GenTreeVisitor<ReplaceVisitor>
void ReplaceLocal(GenTree** use, GenTree* user);
void StoreBeforeReturn(GenTreeUnOp* ret);
void WriteBackBefore(GenTree** use, unsigned lcl, unsigned offs, unsigned size);
bool MarkForReadBack(unsigned lcl, unsigned offs, unsigned size);
void MarkForReadBack(GenTreeLclVarCommon* lcl, unsigned size DEBUGARG(const char* reason));

void HandleStore(GenTree** use, GenTree* user);
bool OverlappingReplacements(GenTreeLclVarCommon* lcl,
Expand Down
Loading