Skip to content

Commit

Permalink
[Flang][MLIR][OpenMP] Fix num_teams, num_threads, thread_limit loweri…
Browse files Browse the repository at this point in the history
…ng (#132)

* [Flang][MLIR][OpenMP] Fix num_teams, num_threads, thread_limit lowering

This patch fixes lowering for the num_teams, num_threads and thread_limit
clauses when inside of a target region and compiling for the host device.

The current approach requires these to be attached to the parent MLIR
omp.target operation. However, some incorrect checks based on the
`evalHasSiblings()` helper function would result in these clauses being
attached to the `omp.teams` or `omp.parallel` operation instead, triggering
a verifier error.

In this patch, these checks are updated to stop breaking when lowering
combined `target teams [X]` constructs. Also, the `genTeamsClauses()` function
is fixed to avoid processing num_teams and thread_limit twice, which probably
resulted from a recent merge.

Related op verifier checks are relaxed due to the difficulty of checking for the
proper conditions to be met. It is left to the MLIR creator to decide where
these clauses are attached and processed.
  • Loading branch information
skatrak committed Aug 9, 2024
1 parent b7fbf3a commit 4f5f002
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 71 deletions.
111 changes: 64 additions & 47 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ using namespace Fortran::lower::omp;
// Code generation helper functions
//===----------------------------------------------------------------------===//

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
auto checkSiblings = [&eval](const lower::pft::EvaluationList &siblings) {
for (auto &sibling : siblings)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
};

return eval.parent.visit(common::visitors{
[&](const lower::pft::Program &parent) {
return parent.getUnits().size() + parent.getCommonBlocks().size() > 1;
},
[&](const lower::pft::Evaluation &parent) {
return checkSiblings(*parent.evaluationList);
},
[&](const auto &parent) {
return checkSiblings(parent.evaluationList);
}});
}

static mlir::omp::TargetOp findParentTargetOp(mlir::OpBuilder &builder) {
mlir::Operation *parentOp = builder.getBlock()->getParentOp();
if (!parentOp)
Expand Down Expand Up @@ -92,6 +113,38 @@ static void genNestedEvaluations(lower::AbstractConverter &converter,
converter.genEval(e);
}

static bool mustEvalTeamsThreadsOutsideTarget(lower::pft::Evaluation &eval,
mlir::omp::TargetOp targetOp) {
if (!targetOp)
return false;

auto offloadModOp = llvm::cast<mlir::omp::OffloadModuleInterface>(
*targetOp->getParentOfType<mlir::ModuleOp>());
if (offloadModOp.getIsTargetDevice())
return false;

auto dir = Fortran::common::visit(
common::visitors{
[&](const parser::OpenMPBlockConstruct &c) {
return std::get<parser::OmpBlockDirective>(
std::get<parser::OmpBeginBlockDirective>(c.t).t)
.v;
},
[&](const parser::OpenMPLoopConstruct &c) {
return std::get<parser::OmpLoopDirective>(
std::get<parser::OmpBeginLoopDirective>(c.t).t)
.v;
},
[&](const auto &) {
llvm_unreachable("Unexpected OpenMP construct");
return llvm::omp::OMPD_unknown;
},
},
eval.get<parser::OpenMPConstruct>().u);

return llvm::omp::allTargetSet.test(dir) || !evalHasSiblings(eval);
}

//===----------------------------------------------------------------------===//
// HostClausesInsertionGuard
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -425,27 +478,6 @@ createAndSetPrivatizedLoopVar(lower::AbstractConverter &converter,
return storeOp;
}

static bool evalHasSiblings(lower::pft::Evaluation &eval) {
return eval.parent.visit(common::visitors{
[&](const lower::pft::Program &parent) {
return parent.getUnits().size() + parent.getCommonBlocks().size() > 1;
},
[&](const lower::pft::Evaluation &parent) {
for (auto &sibling : *parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
},
[&](const auto &parent) {
for (auto &sibling : parent.evaluationList)
if (&sibling != &eval && !sibling.isEndStmt())
return true;

return false;
}});
}

// This helper function implements the functionality of "promoting"
// non-CPTR arguments of use_device_ptr to use_device_addr
// arguments (automagic conversion of use_device_ptr ->
Expand Down Expand Up @@ -1562,8 +1594,6 @@ genTeamsClauses(lower::AbstractConverter &converter,
cp.processAllocate(clauseOps);
cp.processDefault();
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
// TODO Support delayed privatization.

// Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside
Expand Down Expand Up @@ -2304,17 +2334,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::iterator item) {
lower::StatementContext stmtCtx;

auto offloadModOp = llvm::cast<mlir::omp::OffloadModuleInterface>(
converter.getModuleOp().getOperation());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool mustEvalOutsideTarget = targetOp && !offloadModOp.getIsTargetDevice();
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

mlir::omp::TeamsOperands clauseOps;
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
mustEvalOutsideTarget, clauseOps, numTeamsClauseOps,
evalOutsideTarget, clauseOps, numTeamsClauseOps,
threadLimitClauseOps);

auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
Expand All @@ -2324,15 +2352,15 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);

if (numTeamsClauseOps.numTeamsUpper) {
if (mustEvalOutsideTarget)
if (evalOutsideTarget)
targetOp.getNumTeamsUpperMutable().assign(
numTeamsClauseOps.numTeamsUpper);
else
teamsOp.getNumTeamsUpperMutable().assign(numTeamsClauseOps.numTeamsUpper);
}

if (threadLimitClauseOps.threadLimit) {
if (mustEvalOutsideTarget)
if (evalOutsideTarget)
targetOp.getTeamsThreadLimitMutable().assign(
threadLimitClauseOps.threadLimit);
else
Expand Down Expand Up @@ -2412,12 +2440,9 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
ConstructQueue::iterator item) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

mlir::omp::ParallelOperands parallelClauseOps;
mlir::omp::NumThreadsClauseOps numThreadsClauseOps;
Expand Down Expand Up @@ -2476,12 +2501,9 @@ static void genCompositeDistributeParallelDo(
ConstructQueue::iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
Expand All @@ -2493,9 +2515,8 @@ static void genCompositeDistributeParallelDo(
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
/*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps,
numThreadsClauseOps, parallelReductionTypes,
parallelReductionSyms);
evalOutsideTarget, parallelClauseOps, numThreadsClauseOps,
parallelReductionTypes, parallelReductionSyms);

const auto &privateClauseOps = dsp.getPrivateClauseOps();
parallelClauseOps.privateVars = privateClauseOps.privateVars;
Expand Down Expand Up @@ -2551,12 +2572,9 @@ static void genCompositeDistributeParallelDoSimd(
ConstructQueue::iterator item, DataSharingProcessor &dsp) {
lower::StatementContext stmtCtx;

auto offloadModOp =
llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp());
mlir::omp::TargetOp targetOp =
findParentTargetOp(converter.getFirOpBuilder());
bool evalOutsideTarget =
targetOp && !offloadModOp.getIsTargetDevice() && !evalHasSiblings(eval);
bool evalOutsideTarget = mustEvalTeamsThreadsOutsideTarget(eval, targetOp);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
Expand All @@ -2568,9 +2586,8 @@ static void genCompositeDistributeParallelDoSimd(
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
/*evalOutsideTarget=*/evalOutsideTarget, parallelClauseOps,
numThreadsClauseOps, parallelReductionTypes,
parallelReductionSyms);
evalOutsideTarget, parallelClauseOps, numThreadsClauseOps,
parallelReductionTypes, parallelReductionSyms);

const auto &privateClauseOps = dsp.getPrivateClauseOps();
parallelClauseOps.privateVars = privateClauseOps.privateVars;
Expand Down
Loading

0 comments on commit 4f5f002

Please sign in to comment.