Skip to content

Commit

Permalink
[flang][OpenMP] Handle non-const bounds in do concurrent mapping
Browse files Browse the repository at this point in the history
Lifts a restriction we had so far for `do concurrent` -> OpenMP mapping
by supporting non-const bounds in loop headers.
  • Loading branch information
ergawy committed Jul 9, 2024
1 parent d5a3c4d commit 8e4fe93
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 46 deletions.
83 changes: 37 additions & 46 deletions flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -249,34 +250,48 @@ extractIndVarUpdateOps(fir::DoLoopOp doLoop) {
return std::move(indVarUpdateOps);
}

/// Starting with a value and the end of a defintion/conversion chain, walk the
/// chain backwards and collect all the visited ops along the way. For example,
/// given this IR:
/// Starting with a value at the end of a defintion/conversion chain, walk the
/// chain backwards and collect all the visited ops along the way. This is the
/// same as the "backward slice" of the use-def chain of \p link.
///
/// If the root of the chain/slice is a constant op, then populate \p opChain
/// with the extracted chain/slice. If not, then \p opChain will contains a
/// single value: \p link.
///
/// The value of this function is that we pull in the chain of
/// constant+conversion ops inside the parallel region if possible; which
/// prevents creating an unnecessary shared/mapped value that crosses the OpenMP
/// region.
///
/// For example, given this IR:
/// ```
/// %c10 = arith.constant 10 : i32
/// %10 = fir.convert %c10 : (i32) -> index
/// ```
/// and giving `%10` as the starting input: `link`, `defChain` would contain
/// both of the above ops.
mlir::LogicalResult
collectIndirectOpChain(mlir::Operation *link,
llvm::SmallVectorImpl<mlir::Operation *> &opChain) {
while (!mlir::isa_and_present<mlir::arith::ConstantOp>(link)) {
if (auto convertOp = mlir::dyn_cast_if_present<fir::ConvertOp>(link)) {
opChain.push_back(link);
link = convertOp.getValue().getDefiningOp();
continue;
}
void collectIndirectConstOpChain(mlir::Operation *link,
llvm::SetVector<mlir::Operation *> &opChain) {
mlir::BackwardSliceOptions options;
options.inclusive = true;
mlir::getBackwardSlice(link, &opChain, options);

std::string opStr;
llvm::raw_string_ostream opOs(opStr);
opOs << "Unexpected operation: " << *link;
return mlir::emitError(link->getLoc(), opOs.str());
}
assert(!opChain.empty());

opChain.push_back(link);
std::reverse(opChain.begin(), opChain.end());
return mlir::success();
bool isConstantChain = [&]() {
if (!mlir::isa_and_present<mlir::arith::ConstantOp>(opChain.front()))
return false;

return llvm::all_of(llvm::drop_begin(opChain), [](mlir::Operation *op) {
return mlir::isa_and_present<fir::ConvertOp>(op);
});
}();

if (isConstantChain)
return;

opChain.clear();
opChain.insert(link);
}

/// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This
Expand Down Expand Up @@ -492,25 +507,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
"defining operation.");
}

std::function<bool(mlir::Operation *)> isOpUltimatelyConstant =
[&](mlir::Operation *operation) {
if (mlir::isa_and_present<mlir::arith::ConstantOp>(operation))
return true;

if (auto convertOp =
mlir::dyn_cast_if_present<fir::ConvertOp>(operation))
return isOpUltimatelyConstant(convertOp.getValue().getDefiningOp());

return false;
};

if (!isOpUltimatelyConstant(lbOp) || !isOpUltimatelyConstant(ubOp) ||
!isOpUltimatelyConstant(stepOp)) {
return rewriter.notifyMatchFailure(
doLoop, "`do concurrent` conversion is currently only supported for "
"constant LB, UB, and step values.");
}

llvm::SmallVector<mlir::Value> outermostLoopLives;
looputils::collectLoopLiveIns(doLoop, outermostLoopLives);
assert(!outermostLoopLives.empty());
Expand Down Expand Up @@ -763,13 +759,8 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
// isolated from above.
auto cloneBoundOrStepOpChain =
[&](mlir::Operation *operation) -> mlir::Operation * {
llvm::SmallVector<mlir::Operation *> opChain;
mlir::LogicalResult extractResult =
looputils::collectIndirectOpChain(operation, opChain);

if (failed(extractResult)) {
return nullptr;
}
llvm::SetVector<mlir::Operation *> opChain;
looputils::collectIndirectConstOpChain(operation, opChain);

mlir::Operation *result;
for (mlir::Operation *link : opChain)
Expand Down
45 changes: 45 additions & 0 deletions flang/test/Transforms/DoConcurrent/non_const_bounds.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %s -o - \
! RUN: | FileCheck %s

program main
implicit none

call foo(10)

contains
subroutine foo(n)
implicit none
integer :: n
integer :: i
integer, dimension(n) :: a

do concurrent(i=1:n)
a(i) = i
end do
end subroutine

end program main

! CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %{{.*}} dummy_scope %{{.*}} {uniq_name = "_QFFfooEn"}
! CHECK: fir.load
! CHECK: %[[N_VAL:.*]] = fir.load %[[N_DECL]]#0 : !fir.ref<i32>

! CHECK: omp.parallel {

! Verify the constant chain of ops for the lower bound are cloned in the region.
! CHECK: %[[C1:.*]] = arith.constant 1 : i32
! CHECK: %[[LB:.*]] = fir.convert %[[C1]] : (i32) -> index

! Verify that we restort to using the outside value for the upper bound since it
! is not originally a constant.
! CHECK: %[[UB:.*]] = fir.convert %[[N_VAL]] : (i32) -> index

! CHECK: omp.wsloop {
! CHECK: omp.loop_nest (%{{.*}}) : index = (%[[LB]]) to (%[[UB]]) inclusive step (%{{.*}}) {
! CHECK: omp.yield
! CHECK: }
! CHECK: omp.terminator
! CHECK: }
! CHECK: omp.terminator
! CHECK: }

0 comments on commit 8e4fe93

Please sign in to comment.