Skip to content

Commit

Permalink
[flang][OpenMP] Handle expressions in target ... do loop control (#107
Browse files Browse the repository at this point in the history
)

When emitting the ops required to compute the target loop's trip count,
flang might emit ops outside the target regions that operands defined
inside the region. This is fixed up by `HostClausesInsertionGuard`
already.

However, the current support only handles simple cases. If the loop
control contains more elaborate expressions, the fix up logic does not
handle it properly. This PR handles such cases.
  • Loading branch information
ergawy authored Jul 4, 2024
1 parent 0ef3716 commit 6bd0a22
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
92 changes: 92 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Transforms/RegionUtils.h"
Expand Down Expand Up @@ -123,9 +124,100 @@ class HostClausesInsertionGuard {
mlir::OpBuilder::InsertPoint ip;
mlir::omp::TargetOp targetOp;

// Finds the list of op operands that escape the target op's region; that is:
// the operands that are used outside the target op but defined inside it.
void
findEscapingOpOperands(llvm::DenseSet<mlir::OpOperand *> &escapingOperands) {
if (!targetOp)
return;

mlir::Region *targetParentRegion = targetOp->getParentRegion();
assert(targetParentRegion != nullptr &&
"Expected omp.target op to be nested in a parent region.");

// Walk the parent region in pre-order to make sure we visit `targetOp`
// before its nested ops.
targetParentRegion->walk<mlir::WalkOrder::PreOrder>(
[&](mlir::Operation *op) {
// Once we come across `targetOp`, we interrupt the walk since we
// already visited all the ops that come before it in the region.
if (op == targetOp)
return mlir::WalkResult::interrupt();

for (mlir::OpOperand &operand : op->getOpOperands()) {
mlir::Operation *operandDefiningOp = operand.get().getDefiningOp();

if (operandDefiningOp == nullptr)
continue;

auto parentTargetOp =
operandDefiningOp->getParentOfType<mlir::omp::TargetOp>();

if (parentTargetOp != targetOp)
continue;

escapingOperands.insert(&operand);
}

return mlir::WalkResult::advance();
});
}

// For an escaping operand, clone its use-def chain (i.e. its backward slice)
// outside the target region.
//
// \return the last op in the chain (this is the op that defines the escaping
// operand).
mlir::Operation *
cloneOperandSliceOutsideTargetOp(mlir::OpOperand *escapingOperand) {
mlir::Operation *operandDefiningOp = escapingOperand->get().getDefiningOp();
llvm::SetVector<mlir::Operation *> backwardSlice;
mlir::BackwardSliceOptions sliceOptions;
sliceOptions.inclusive = true;
mlir::getBackwardSlice(operandDefiningOp, &backwardSlice, sliceOptions);

auto ip = builder.saveInsertionPoint();

mlir::IRMapping mapper;
builder.setInsertionPoint(escapingOperand->getOwner());
mlir::Operation *lastSliceOp;

for (auto *op : backwardSlice)
lastSliceOp = builder.clone(*op, mapper);

builder.restoreInsertionPoint(ip);
return lastSliceOp;
}

/// Fixup any uses of target region block arguments that we have just created
/// outside of the target region, and replace them by their host values.
void fixupExtractedHostOps() {
llvm::DenseSet<mlir::OpOperand *> escapingOperands;
findEscapingOpOperands(escapingOperands);

for (mlir::OpOperand *operand : escapingOperands) {
mlir::Operation *operandDefiningOp = operand->get().getDefiningOp();
assert(operandDefiningOp != nullptr &&
"Expected escaping operand to have a defining op (i.e. not to be "
"a block argument)");
mlir::Operation *lastSliceOp = cloneOperandSliceOutsideTargetOp(operand);

// Find the index of the operand in the list of results produced by its
// defining op.
unsigned operandResultIdx = 0;
for (auto [idx, res] : llvm::enumerate(operandDefiningOp->getResults())) {
if (res == operand->get()) {
operandResultIdx = idx;
break;
}
}

// Replace the escaping operand with the corresponding value from the
// op that we cloned outside the target op.
operand->getOwner()->setOperand(operand->getOperandNumber(),
lastSliceOp->getResult(operandResultIdx));
}

auto useOutsideTargetRegion = [](mlir::OpOperand &operand) {
if (mlir::Operation *owner = operand.getOwner())
return !owner->getParentOfType<mlir::omp::TargetOp>();
Expand Down
35 changes: 35 additions & 0 deletions flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
! Verifies that if expressions are used to compute a target parallel loop, that
! no values escape the target region when flang emits the ops corresponding to
! these expressions (for example the compute the trip count for the target region).

! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

subroutine foo(upper_bound)
implicit none
integer :: upper_bound
integer :: nodes(1 : upper_bound)
integer :: i

!$omp target teams distribute parallel do
do i = 1, ubound(nodes,1)
nodes(i) = i
end do
!$omp end target teams distribute parallel do
end subroutine

! CHECK: func.func @_QPfoo(%[[FUNC_ARG:.*]]: !fir.ref<i32> {fir.bindc_name = "upper_bound"}) {
! CHECK: %[[UB_ALLOC:.*]] = fir.alloca i32
! CHECK: fir.dummy_scope : !fir.dscope
! CHECK: %[[UB_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] {{.*}} {uniq_name = "_QFfooEupper_bound"}

! CHECK: omp.map.info
! CHECK: omp.map.info
! CHECK: omp.map.info

! Verify that we load from the original/host allocation of the `upper_bound`
! variable rather than the corresponding target region arg.

! CHECK: fir.load %[[UB_ALLOC]] : !fir.ref<i32>
! CHECK: omp.target

! CHECK: }

0 comments on commit 6bd0a22

Please sign in to comment.