diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 9da60c9410d571..ab768edca01623 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -33,6 +33,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/openmp-directive-sets.h" #include "flang/Semantics/tools.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Transforms/RegionUtils.h" @@ -123,9 +124,100 @@ class HostClausesInsertionGuard { mlir::OpBuilder::InsertPoint ip; mlir::omp::TargetOp targetOp; + // Finds the list of op operands that escape the target op's region; that is: + // the operands that are used outside the target op but defined inside it. + void + findEscapingOpOperands(llvm::DenseSet &escapingOperands) { + if (!targetOp) + return; + + mlir::Region *targetParentRegion = targetOp->getParentRegion(); + assert(targetParentRegion != nullptr && + "Expected omp.target op to be nested in a parent region."); + + // Walk the parent region in pre-order to make sure we visit `targetOp` + // before its nested ops. + targetParentRegion->walk( + [&](mlir::Operation *op) { + // Once we come across `targetOp`, we interrupt the walk since we + // already visited all the ops that come before it in the region. + if (op == targetOp) + return mlir::WalkResult::interrupt(); + + for (mlir::OpOperand &operand : op->getOpOperands()) { + mlir::Operation *operandDefiningOp = operand.get().getDefiningOp(); + + if (operandDefiningOp == nullptr) + continue; + + auto parentTargetOp = + operandDefiningOp->getParentOfType(); + + if (parentTargetOp != targetOp) + continue; + + escapingOperands.insert(&operand); + } + + return mlir::WalkResult::advance(); + }); + } + + // For an escaping operand, clone its use-def chain (i.e. its backward slice) + // outside the target region. + // + // \return the last op in the chain (this is the op that defines the escaping + // operand). + mlir::Operation * + cloneOperandSliceOutsideTargetOp(mlir::OpOperand *escapingOperand) { + mlir::Operation *operandDefiningOp = escapingOperand->get().getDefiningOp(); + llvm::SetVector backwardSlice; + mlir::BackwardSliceOptions sliceOptions; + sliceOptions.inclusive = true; + mlir::getBackwardSlice(operandDefiningOp, &backwardSlice, sliceOptions); + + auto ip = builder.saveInsertionPoint(); + + mlir::IRMapping mapper; + builder.setInsertionPoint(escapingOperand->getOwner()); + mlir::Operation *lastSliceOp; + + for (auto *op : backwardSlice) + lastSliceOp = builder.clone(*op, mapper); + + builder.restoreInsertionPoint(ip); + return lastSliceOp; + } + /// Fixup any uses of target region block arguments that we have just created /// outside of the target region, and replace them by their host values. void fixupExtractedHostOps() { + llvm::DenseSet escapingOperands; + findEscapingOpOperands(escapingOperands); + + for (mlir::OpOperand *operand : escapingOperands) { + mlir::Operation *operandDefiningOp = operand->get().getDefiningOp(); + assert(operandDefiningOp != nullptr && + "Expected escaping operand to have a defining op (i.e. not to be " + "a block argument)"); + mlir::Operation *lastSliceOp = cloneOperandSliceOutsideTargetOp(operand); + + // Find the index of the operand in the list of results produced by its + // defining op. + unsigned operandResultIdx = 0; + for (auto [idx, res] : llvm::enumerate(operandDefiningOp->getResults())) { + if (res == operand->get()) { + operandResultIdx = idx; + break; + } + } + + // Replace the escaping operand with the corresponding value from the + // op that we cloned outside the target op. + operand->getOwner()->setOperand(operand->getOperandNumber(), + lastSliceOp->getResult(operandResultIdx)); + } + auto useOutsideTargetRegion = [](mlir::OpOperand &operand) { if (mlir::Operation *owner = operand.getOwner()) return !owner->getParentOfType(); diff --git a/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 b/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 new file mode 100644 index 00000000000000..027251801ff5df --- /dev/null +++ b/flang/test/Lower/OpenMP/target-do-loop-control-exprs.f90 @@ -0,0 +1,35 @@ +! Verifies that if expressions are used to compute a target parallel loop, that +! no values escape the target region when flang emits the ops corresponding to +! these expressions (for example the compute the trip count for the target region). + +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +subroutine foo(upper_bound) + implicit none + integer :: upper_bound + integer :: nodes(1 : upper_bound) + integer :: i + + !$omp target teams distribute parallel do + do i = 1, ubound(nodes,1) + nodes(i) = i + end do + !$omp end target teams distribute parallel do +end subroutine + +! CHECK: func.func @_QPfoo(%[[FUNC_ARG:.*]]: !fir.ref {fir.bindc_name = "upper_bound"}) { +! CHECK: %[[UB_ALLOC:.*]] = fir.alloca i32 +! CHECK: fir.dummy_scope : !fir.dscope +! CHECK: %[[UB_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] {{.*}} {uniq_name = "_QFfooEupper_bound"} + +! CHECK: omp.map.info +! CHECK: omp.map.info +! CHECK: omp.map.info + +! Verify that we load from the original/host allocation of the `upper_bound` +! variable rather than the corresponding target region arg. + +! CHECK: fir.load %[[UB_ALLOC]] : !fir.ref +! CHECK: omp.target + +! CHECK: }