diff --git a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp index 912d33e0e38e9d..75e5e6f1168220 100644 --- a/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp @@ -565,9 +565,9 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "defining operation."); } - llvm::SmallVector outermostLoopLives; - looputils::collectLoopLiveIns(doLoop, outermostLoopLives); - assert(!outermostLoopLives.empty()); + llvm::SmallVector outermostLoopLiveIns; + looputils::collectLoopLiveIns(doLoop, outermostLoopLiveIns); + assert(!outermostLoopLiveIns.empty()); looputils::LoopNestToIndVarMap loopNest; bool hasRemainingNestedLoops = @@ -577,28 +577,35 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "Some `do concurent` loops are not perfectly-nested. " "These will be serialzied."); - mlir::IRMapping mapper; - llvm::SetVector locals; looputils::collectLoopLocalValues(loopNest.back().first, locals); + // We do not want to map "loop-local" values to the device through + // `omp.map.info` ops. Therefore, we remove them from the list of live-ins. + outermostLoopLiveIns.erase(llvm::remove_if(outermostLoopLiveIns, + [&](mlir::Value liveIn) { + return locals.contains(liveIn); + }), + outermostLoopLiveIns.end()); looputils::sinkLoopIVArgs(rewriter, loopNest); mlir::omp::TargetOp targetOp; mlir::omp::LoopNestOperands loopNestClauseOps; + mlir::IRMapping mapper; + if (mapToDevice) { mlir::omp::TargetOperands targetClauseOps; // The outermost loop will contain all the live-in values in all nested // loops since live-in values are collected recursively for all nested // ops. - for (mlir::Value liveIn : outermostLoopLives) + for (mlir::Value liveIn : outermostLoopLiveIns) targetClauseOps.mapVars.push_back( genMapInfoOpForLiveIn(rewriter, liveIn)); targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, - outermostLoopLives, targetClauseOps); + outermostLoopLiveIns, targetClauseOps); genTeamsOp(doLoop.getLoc(), rewriter); } diff --git a/flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90 b/flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90 index 8b79d87d12c907..58d95e24d830f1 100644 --- a/flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90 +++ b/flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90 @@ -5,7 +5,10 @@ ! occur due to multiple teams trying to access the same allocation. ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %s -o - \ -! RUN: | FileCheck %s +! RUN: | FileCheck %s --check-prefixes=COMMON + +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %s -o - \ +! RUN: | FileCheck %s --check-prefixes=COMMON,DEVICE module struct_mod type test_struct @@ -49,18 +52,27 @@ program main print *, "total =", total end program main -! CHECK: omp.parallel { -! CHECK: %[[LOCAL_TEMP:.*]] = fir.alloca !fir.type<_QMstruct_modTtest_struct{x_:!fir.box>}> {bindc_name = ".result"} -! CHECK: omp.wsloop { -! CHECK: omp.loop_nest {{.*}} { -! CHECK: %[[TEMP_VAL:.*]] = fir.call @_QMstruct_modPconstruct_from_components -! CHECK: fir.save_result %[[TEMP_VAL]] to %[[LOCAL_TEMP]] -! CHECK: %[[EMBOXED_LOCAL:.*]] = fir.embox %[[LOCAL_TEMP]] -! CHECK: %[[CONVERTED_LOCAL:.*]] = fir.convert %[[EMBOXED_LOCAL]] -! CHECK: fir.call @_FortranADestroy(%[[CONVERTED_LOCAL]]) -! CHECK: omp.yield -! CHECK: } -! CHECK: omp.terminator -! CHECK: } -! CHECK: omp.terminator -! CHECK: } +! DEVICE: omp.target {{.*}} { +! DEVICE: omp.teams { +! COMMON: omp.parallel { +! COMMON: %[[LOCAL_TEMP:.*]] = fir.alloca !fir.type<_QMstruct_modTtest_struct{x_:!fir.box>}> {bindc_name = ".result"} +! DEVICE: omp.distribute { +! COMMON: omp.wsloop { +! COMMON: omp.loop_nest {{.*}} { +! COMMON: %[[TEMP_VAL:.*]] = fir.call @_QMstruct_modPconstruct_from_components +! COMMON: fir.save_result %[[TEMP_VAL]] to %[[LOCAL_TEMP]] +! COMMON: %[[EMBOXED_LOCAL:.*]] = fir.embox %[[LOCAL_TEMP]] +! COMMON: %[[CONVERTED_LOCAL:.*]] = fir.convert %[[EMBOXED_LOCAL]] +! COMMON: fir.call @_FortranADestroy(%[[CONVERTED_LOCAL]]) +! COMMON: omp.yield +! COMMON: } +! COMMON: omp.terminator +! COMMON: } +! DEVICE: omp.terminator +! DEVICE: } +! COMMON: omp.terminator +! COMMON: } +! DEVICE: omp.terminator +! DEVICE: } +! DEVICE: omp.terminator +! DEVICE: }