From 8e4fe93753ba4793bc0c20958d80fb8b4e73dd51 Mon Sep 17 00:00:00 2001 From: ergawy Date: Tue, 9 Jul 2024 03:25:46 -0500 Subject: [PATCH] [flang][OpenMP] Handle non-const bounds in `do concurrent` mapping Lifts a restriction we had so far for `do concurrent` -> OpenMP mapping by supporting non-const bounds in loop headers. --- .../Transforms/DoConcurrentConversion.cpp | 83 +++++++++---------- .../DoConcurrent/non_const_bounds.f90 | 45 ++++++++++ 2 files changed, 82 insertions(+), 46 deletions(-) create mode 100644 flang/test/Transforms/DoConcurrent/non_const_bounds.f90 diff --git a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp index b30379da272ea6c..caf6bc3616dd058 100644 --- a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp @@ -15,6 +15,7 @@ #include "flang/Optimizer/HLFIR/HLFIRDialect.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Diagnostics.h" @@ -249,34 +250,48 @@ extractIndVarUpdateOps(fir::DoLoopOp doLoop) { return std::move(indVarUpdateOps); } -/// Starting with a value and the end of a defintion/conversion chain, walk the -/// chain backwards and collect all the visited ops along the way. For example, -/// given this IR: +/// Starting with a value at the end of a defintion/conversion chain, walk the +/// chain backwards and collect all the visited ops along the way. This is the +/// same as the "backward slice" of the use-def chain of \p link. +/// +/// If the root of the chain/slice is a constant op, then populate \p opChain +/// with the extracted chain/slice. If not, then \p opChain will contains a +/// single value: \p link. +/// +/// The value of this function is that we pull in the chain of +/// constant+conversion ops inside the parallel region if possible; which +/// prevents creating an unnecessary shared/mapped value that crosses the OpenMP +/// region. +/// +/// For example, given this IR: /// ``` /// %c10 = arith.constant 10 : i32 /// %10 = fir.convert %c10 : (i32) -> index /// ``` /// and giving `%10` as the starting input: `link`, `defChain` would contain /// both of the above ops. -mlir::LogicalResult -collectIndirectOpChain(mlir::Operation *link, - llvm::SmallVectorImpl &opChain) { - while (!mlir::isa_and_present(link)) { - if (auto convertOp = mlir::dyn_cast_if_present(link)) { - opChain.push_back(link); - link = convertOp.getValue().getDefiningOp(); - continue; - } +void collectIndirectConstOpChain(mlir::Operation *link, + llvm::SetVector &opChain) { + mlir::BackwardSliceOptions options; + options.inclusive = true; + mlir::getBackwardSlice(link, &opChain, options); - std::string opStr; - llvm::raw_string_ostream opOs(opStr); - opOs << "Unexpected operation: " << *link; - return mlir::emitError(link->getLoc(), opOs.str()); - } + assert(!opChain.empty()); - opChain.push_back(link); - std::reverse(opChain.begin(), opChain.end()); - return mlir::success(); + bool isConstantChain = [&]() { + if (!mlir::isa_and_present(opChain.front())) + return false; + + return llvm::all_of(llvm::drop_begin(opChain), [](mlir::Operation *op) { + return mlir::isa_and_present(op); + }); + }(); + + if (isConstantChain) + return; + + opChain.clear(); + opChain.insert(link); } /// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This @@ -492,25 +507,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "defining operation."); } - std::function isOpUltimatelyConstant = - [&](mlir::Operation *operation) { - if (mlir::isa_and_present(operation)) - return true; - - if (auto convertOp = - mlir::dyn_cast_if_present(operation)) - return isOpUltimatelyConstant(convertOp.getValue().getDefiningOp()); - - return false; - }; - - if (!isOpUltimatelyConstant(lbOp) || !isOpUltimatelyConstant(ubOp) || - !isOpUltimatelyConstant(stepOp)) { - return rewriter.notifyMatchFailure( - doLoop, "`do concurrent` conversion is currently only supported for " - "constant LB, UB, and step values."); - } - llvm::SmallVector outermostLoopLives; looputils::collectLoopLiveIns(doLoop, outermostLoopLives); assert(!outermostLoopLives.empty()); @@ -763,13 +759,8 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { // isolated from above. auto cloneBoundOrStepOpChain = [&](mlir::Operation *operation) -> mlir::Operation * { - llvm::SmallVector opChain; - mlir::LogicalResult extractResult = - looputils::collectIndirectOpChain(operation, opChain); - - if (failed(extractResult)) { - return nullptr; - } + llvm::SetVector opChain; + looputils::collectIndirectConstOpChain(operation, opChain); mlir::Operation *result; for (mlir::Operation *link : opChain) diff --git a/flang/test/Transforms/DoConcurrent/non_const_bounds.f90 b/flang/test/Transforms/DoConcurrent/non_const_bounds.f90 new file mode 100644 index 000000000000000..4a72d3fbe28761e --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/non_const_bounds.f90 @@ -0,0 +1,45 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %s -o - \ +! RUN: | FileCheck %s + +program main + implicit none + + call foo(10) + + contains + subroutine foo(n) + implicit none + integer :: n + integer :: i + integer, dimension(n) :: a + + do concurrent(i=1:n) + a(i) = i + end do + end subroutine + +end program main + +! CHECK: %[[N_DECL:.*]]:2 = hlfir.declare %{{.*}} dummy_scope %{{.*}} {uniq_name = "_QFFfooEn"} +! CHECK: fir.load +! CHECK: %[[N_VAL:.*]] = fir.load %[[N_DECL]]#0 : !fir.ref + +! CHECK: omp.parallel { + +! Verify the constant chain of ops for the lower bound are cloned in the region. +! CHECK: %[[C1:.*]] = arith.constant 1 : i32 +! CHECK: %[[LB:.*]] = fir.convert %[[C1]] : (i32) -> index + +! Verify that we restort to using the outside value for the upper bound since it +! is not originally a constant. +! CHECK: %[[UB:.*]] = fir.convert %[[N_VAL]] : (i32) -> index + +! CHECK: omp.wsloop { +! CHECK: omp.loop_nest (%{{.*}}) : index = (%[[LB]]) to (%[[UB]]) inclusive step (%{{.*}}) { +! CHECK: omp.yield +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +! CHECK: omp.terminator +! CHECK: } +