Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][OpenMP] Handle expressions in target ... do loop control #107

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/tools.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Transforms/RegionUtils.h"
Expand Down Expand Up @@ -123,9 +124,100 @@ class HostClausesInsertionGuard {
mlir::OpBuilder::InsertPoint ip;
mlir::omp::TargetOp targetOp;

// Finds the list of op operands that escape the target op's region; that is:
// the operands that are used outside the target op but defined inside it.
void
findEscapingOpOperands(llvm::DenseSet<mlir::OpOperand *> &escapingOperands) {
if (!targetOp)
return;

mlir::Region *targetParentRegion = targetOp->getParentRegion();
assert(targetParentRegion != nullptr &&
"Expected omp.target op to be nested in a parent region.");

// Walk the parent region in pre-order to make sure we visit `targetOp`
// before its nested ops.
targetParentRegion->walk<mlir::WalkOrder::PreOrder>(
[&](mlir::Operation *op) {
// Once we come across `targetOp`, we interrupt the walk since we
// already visited all the ops that come before it in the region.
if (op == targetOp)
return mlir::WalkResult::interrupt();

for (mlir::OpOperand &operand : op->getOpOperands()) {
mlir::Operation *operandDefiningOp = operand.get().getDefiningOp();

if (operandDefiningOp == nullptr)
continue;

auto parentTargetOp =
operandDefiningOp->getParentOfType<mlir::omp::TargetOp>();

if (parentTargetOp != targetOp)
continue;

escapingOperands.insert(&operand);
}

return mlir::WalkResult::advance();
});
}

// For an escaping operand, clone its use-def chain (i.e. its backward slice)
// outside the target region.
//
// \return the last op in the chain (this is the op that defines the escaping
// operand).
mlir::Operation *
cloneOperandSliceOutsideTargetOp(mlir::OpOperand *escapingOperand) {
mlir::Operation *operandDefiningOp = escapingOperand->get().getDefiningOp();
llvm::SetVector<mlir::Operation *> backwardSlice;
mlir::BackwardSliceOptions sliceOptions;
sliceOptions.inclusive = true;
mlir::getBackwardSlice(operandDefiningOp, &backwardSlice, sliceOptions);

auto ip = builder.saveInsertionPoint();

mlir::IRMapping mapper;
builder.setInsertionPoint(escapingOperand->getOwner());
mlir::Operation *lastSliceOp;

for (auto *op : backwardSlice)
lastSliceOp = builder.clone(*op, mapper);

builder.restoreInsertionPoint(ip);
return lastSliceOp;
}

/// Fixup any uses of target region block arguments that we have just created
/// outside of the target region, and replace them by their host values.
void fixupExtractedHostOps() {
llvm::DenseSet<mlir::OpOperand *> escapingOperands;
findEscapingOpOperands(escapingOperands);

for (mlir::OpOperand *operand : escapingOperands) {
mlir::Operation *operandDefiningOp = operand->get().getDefiningOp();
ergawy marked this conversation as resolved.
Show resolved Hide resolved
assert(operandDefiningOp != nullptr &&
"Expected escaping operand to have a defining op (i.e. not to be "
"a block argument)");
mlir::Operation *lastSliceOp = cloneOperandSliceOutsideTargetOp(operand);

// Find the index of the operand in the list of results produced by its
// defining op.
unsigned operandResultIdx = 0;
for (auto [idx, res] : llvm::enumerate(operandDefiningOp->getResults())) {
if (res == operand->get()) {
operandResultIdx = idx;
break;
}
}

// Replace the escaping operand with the corresponding value from the
// op that we cloned outside the target op.
operand->getOwner()->setOperand(operand->getOperandNumber(),
lastSliceOp->getResult(operandResultIdx));
}

auto useOutsideTargetRegion = [](mlir::OpOperand &operand) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I'm wondering: Is everything from here to the end of the function needed still? It seems like your additions may potentially be already handling all cases. You can try commenting it out and running check-flang and smoke-fort tests and see what happens.

Copy link
Author

@ergawy ergawy Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything from this point below is still needed. What I added perpares for the following logic to work properly in case a target region argument is used indirectly outside the region. What I added above does not remap the target region arguments.

This is verifiable even using the lit test I added below. If you comment out the logic you previously added, you will find that now we clone the slice of operations needed by the trip count calculation but at the root of this slice, a use of the target region argument is still there.

if (mlir::Operation *owner = operand.getOwner())
return !owner->getParentOfType<mlir::omp::TargetOp>();
Expand Down
35 changes: 35 additions & 0 deletions flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
! Verifies that if expressions are used to compute a target parallel loop, that
! no values escape the target region when flang emits the ops corresponding to
! these expressions (for example the compute the trip count for the target region).

! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

subroutine foo(upper_bound)
implicit none
integer :: upper_bound
integer :: nodes(1 : upper_bound)
integer :: i

!$omp target teams distribute parallel do
do i = 1, ubound(nodes,1)
nodes(i) = i
end do
!$omp end target teams distribute parallel do
end subroutine

! CHECK: func.func @_QPfoo(%[[FUNC_ARG:.*]]: !fir.ref<i32> {fir.bindc_name = "upper_bound"}) {
! CHECK: %[[UB_ALLOC:.*]] = fir.alloca i32
! CHECK: fir.dummy_scope : !fir.dscope
! CHECK: %[[UB_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] {{.*}} {uniq_name = "_QFfooEupper_bound"}

! CHECK: omp.map.info
! CHECK: omp.map.info
! CHECK: omp.map.info

! Verify that we load from the original/host allocation of the `upper_bound`
! variable rather than the corresponding target region arg.

! CHECK: fir.load %[[UB_ALLOC]] : !fir.ref<i32>
! CHECK: omp.target

! CHECK: }