diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index 4d226eaa754c1..61f18008633d5 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -429,13 +429,14 @@ func.func @_QPopenmp_target_data_region() { func.func @_QPomp_target_data_empty() { %0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"} - omp.target_data use_device_addr(%0 : !fir.ref>) { + omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref>) { + omp.terminator } return } // CHECK-LABEL: llvm.func @_QPomp_target_data_empty -// CHECK: omp.target_data use_device_addr(%1 : !llvm.ptr) { +// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) { // CHECK: } // ----- diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 index dedce58143649..ab33b6b380831 100644 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -506,9 +506,8 @@ subroutine omp_target_device_ptr type(c_ptr) :: a integer, target :: b !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"} - !CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}}) + !CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}} -> %[[VAL_1:.*]] : !fir.ref>) !$omp target data map(tofrom: a) use_device_ptr(a) - !CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref>): !CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref>, !fir.field) -> !fir.ref a = c_loc(b) !CHECK: omp.terminator @@ -529,9 +528,8 @@ subroutine omp_target_device_addr !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr> {name = ""} !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref>>, !fir.box>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = "a"} - !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]], %[[DEV_ADDR]] : {{.*}}) { + !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr>, !fir.ref>>) { !$omp target data map(tofrom: a) use_device_addr(a) - !CHECK: ^bb0(%[[ARG_0:.*]]: !fir.llvm_ptr>, %[[ARG_1:.*]]: !fir.ref>>): !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref>>) -> (!fir.ref>>, !fir.ref>>) !CHECK: %[[C10:.*]] = arith.constant 10 : i32 !CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref>> diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 index 085f5419fa7f8..cb26246a6e80f 100644 --- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 +++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 @@ -6,8 +6,7 @@ ! use_device_ptr to use_device_addr works, without breaking any functionality. !CHECK: func.func @{{.*}}only_use_device_ptr() -!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.llvm_ptr>>, !fir.ref>>>) use_device_ptr(%{{.*}} : !fir.ref>) { -!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>, %{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>>>, %{{.*}}: !fir.ref>>>): +!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.llvm_ptr>>, !fir.ref>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref>) { subroutine only_use_device_ptr use iso_c_binding integer, pointer, dimension(:) :: array @@ -19,8 +18,7 @@ subroutine only_use_device_ptr end subroutine !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr() -!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.llvm_ptr>>, !fir.ref>>>) use_device_ptr({{.*}} : !fir.ref>) { -!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>>>, %{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>, %{{.*}}: !fir.ref>>>): +!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.llvm_ptr>>, !fir.ref>>>) use_device_ptr({{.*}} : !fir.ref>) { subroutine mix_use_device_ptr_and_addr use iso_c_binding integer, pointer, dimension(:) :: array @@ -32,8 +30,7 @@ subroutine mix_use_device_ptr_and_addr end subroutine !CHECK: func.func @{{.*}}only_use_device_addr() - !CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.ref>, !fir.llvm_ptr>>, !fir.ref>>>) { - !CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>>>, %{{.*}}: !fir.ref>, %{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>>>): + !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.ref>, !fir.llvm_ptr>>, !fir.ref>>>) { subroutine only_use_device_addr use iso_c_binding integer, pointer, dimension(:) :: array @@ -45,8 +42,7 @@ subroutine only_use_device_addr end subroutine !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map() - !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref, !fir.ref) use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.llvm_ptr>>, !fir.ref>>>) use_device_ptr(%{{.*}} : !fir.ref>) { - !CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>>>, %{{.*}}: !fir.llvm_ptr>>, %{{.*}}: !fir.ref>, %{{.*}}: !fir.ref>>>): + !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref, !fir.ref) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr>>, !fir.ref>>>, !fir.llvm_ptr>>, !fir.ref>>>) use_device_ptr(%{{.*}} : !fir.ref>) { subroutine mix_use_device_ptr_and_addr_and_map use iso_c_binding integer :: i, j diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 97e8b36805072..886554f66afff 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -1209,18 +1209,28 @@ class OpenMP_UseDeviceAddrClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { + let traits = [ + BlockArgOpenMPOpInterface + ]; + let arguments = (ins Variadic:$use_device_addr_vars ); - let optAssemblyFormat = [{ - `use_device_addr` `(` $use_device_addr_vars `:` type($use_device_addr_vars) `)` + let extraClassDeclaration = [{ + unsigned numUseDeviceAddrBlockArgs() { + return getUseDeviceAddrVars().size(); + } }]; let description = [{ The optional `use_device_addr_vars` specifies the address of the objects in the device data environment. }]; + + // Assembly format not defined because this clause must be processed together + // with the first region of the operation, as it defines entry block + // arguments. } def OpenMP_UseDeviceAddrClause : OpenMP_UseDeviceAddrClauseSkip<>; @@ -1234,18 +1244,28 @@ class OpenMP_UseDevicePtrClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { + let traits = [ + BlockArgOpenMPOpInterface + ]; + let arguments = (ins Variadic:$use_device_ptr_vars ); - let optAssemblyFormat = [{ - `use_device_ptr` `(` $use_device_ptr_vars `:` type($use_device_ptr_vars) `)` + let extraClassDeclaration = [{ + unsigned numUseDevicePtrBlockArgs() { + return getUseDevicePtrVars().size(); + } }]; let description = [{ The optional `use_device_ptr_vars` specifies the device pointers to the corresponding list items in the device data environment. }]; + + // Assembly format not defined because this clause must be processed together + // with the first region of the operation, as it defines entry block + // arguments. } def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index e58ccc4e93021..d2a2b44c042fb 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -987,6 +987,12 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [ OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)> ]; + let assemblyFormat = clausesAssemblyFormat # [{ + custom( + $region, $use_device_addr_vars, type($use_device_addr_vars), + $use_device_ptr_vars, type($use_device_ptr_vars)) attr-dict + }]; + let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 2602384744f23..22521b08637cf 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -45,6 +45,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{ return 0; }]>, + InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.", + "unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{ + return 0; + }]>, + InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.", + "unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{ + return 0; + }]>, // Unified access methods for clause-associated entry block arguments. InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.", @@ -72,6 +80,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { auto iface = ::llvm::cast(*$_op); return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs(); }]>, + InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.", + "unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs(); + }]>, + InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.", + "unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs(); + }]>, InterfaceMethod<"Get block arguments defined by `in_reduction`.", "::llvm::MutableArrayRef<::mlir::BlockArgument>", @@ -109,13 +127,30 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { iface.getTaskReductionBlockArgsStart(), $_op.numTaskReductionBlockArgs()); }]>, + InterfaceMethod<"Get block arguments defined by `use_device_addr`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getUseDeviceAddrBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getUseDeviceAddrBlockArgsStart(), + $_op.numUseDeviceAddrBlockArgs()); + }]>, + InterfaceMethod<"Get block arguments defined by `use_device_ptr`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getUseDevicePtrBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getUseDevicePtrBlockArgsStart(), + $_op.numUseDevicePtrBlockArgs()); + }]>, ]; let verify = [{ auto iface = ::llvm::cast($_op); unsigned expectedArgs = iface.numInReductionBlockArgs() + iface.numMapBlockArgs() + iface.numPrivateBlockArgs() + - iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs(); + iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() + + iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs(); if ($_op->getRegion(0).getNumArguments() < expectedArgs) return $_op->emitOpError() << "expected at least " << expectedArgs << " entry block argument(s)"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 12b2ade0d9fcb..bb88632323826 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -504,6 +504,8 @@ struct AllRegionParseArgs { std::optional privateArgs; std::optional reductionArgs; std::optional taskReductionArgs; + std::optional useDeviceAddrArgs; + std::optional useDevicePtrArgs; }; } // namespace @@ -648,6 +650,16 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, return parser.emitError(parser.getCurrentLocation()) << "invalid `task_reduction` format"; + if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr", + args.useDeviceAddrArgs))) + return parser.emitError(parser.getCurrentLocation()) + << "invalid `use_device_addr` format"; + + if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr", + args.useDevicePtrArgs))) + return parser.emitError(parser.getCurrentLocation()) + << "invalid `use_device_addr` format"; + return parser.parseRegion(region, entryBlockArgs); } @@ -735,6 +747,18 @@ static ParseResult parseTaskReductionRegion( return parseBlockArgRegion(parser, region, args); } +static ParseResult parseUseDeviceAddrUseDevicePtrRegion( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl &useDeviceAddrVars, + SmallVectorImpl &useDeviceAddrTypes, + SmallVectorImpl &useDevicePtrVars, + SmallVectorImpl &useDevicePtrTypes) { + AllRegionParseArgs args; + args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes); + args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes); + return parseBlockArgRegion(parser, region, args); +} + //===----------------------------------------------------------------------===// // Printers for operations including clauses that define entry block arguments. //===----------------------------------------------------------------------===// @@ -767,6 +791,8 @@ struct AllRegionPrintArgs { std::optional privateArgs; std::optional reductionArgs; std::optional taskReductionArgs; + std::optional useDeviceAddrArgs; + std::optional useDevicePtrArgs; }; } // namespace @@ -849,6 +875,11 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, printBlockArgClause(p, ctx, "task_reduction", iface.getTaskReductionBlockArgs(), args.taskReductionArgs); + printBlockArgClause(p, ctx, "use_device_addr", + iface.getUseDeviceAddrBlockArgs(), + args.useDeviceAddrArgs); + printBlockArgClause(p, ctx, "use_device_ptr", + iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs); p.printRegion(region, /*printEntryBlockArgs=*/false); } @@ -925,6 +956,18 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, printBlockArgRegion(p, op, region, args); } +static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, + Region ®ion, + ValueRange useDeviceAddrVars, + TypeRange useDeviceAddrTypes, + ValueRange useDevicePtrVars, + TypeRange useDevicePtrTypes) { + AllRegionPrintArgs args; + args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes); + args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes); + printBlockArgRegion(p, op, region, args); +} + /// Verifies Reduction Clause static LogicalResult verifyReductionVarList(Operation *op, std::optional reductionSyms, diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 7c89d3bd6ec5a..9e5f800dca60b 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2462,8 +2462,8 @@ static void collectMapDataFromMapOperands( } }; - addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer); addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address); + addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer); } static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) { @@ -3069,6 +3069,31 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return combinedInfo; }; + // Define a lambda to apply mappings between use_device_addr and + // use_device_ptr base pointers, and their associated block arguments. + auto mapUseDevice = + [&moduleTranslation]( + llvm::OpenMPIRBuilder::DeviceInfoTy type, + llvm::ArrayRef blockArgs, + llvm::OpenMPIRBuilder::MapValuesArrayTy &basePointers, + llvm::OpenMPIRBuilder::MapDeviceInfoArrayTy &devicePointers, + llvm::function_ref mapper = nullptr) { + // Get a range to iterate over `basePointers` after filtering based on + // `devicePointers` and the given device info type. + auto basePtrRange = llvm::map_range( + llvm::make_filter_range( + llvm::zip_equal(basePointers, devicePointers), + [type](auto x) { return std::get<1>(x) == type; }), + [](auto x) { return std::get<0>(x); }); + + // Map block arguments to the corresponding processed base pointer. If + // a mapper is not specified, map the block argument to the base pointer + // directly. + for (auto [arg, basePointer] : llvm::zip_equal(blockArgs, basePtrRange)) + moduleTranslation.mapValue(arg, mapper ? mapper(basePointer) + : basePointer); + }; + llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true, /*SeparateBeginEndCalls=*/true); @@ -3077,29 +3102,28 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) { assert(isa(op) && "BodyGen requested for non TargetDataOp"); + auto blockArgIface = cast(op); Region ®ion = cast(op).getRegion(); switch (bodyGenType) { case BodyGenTy::Priv: // Check if any device ptr/addr info is available if (!info.DevicePtrInfoMap.empty()) { builder.restoreIP(codeGenIP); - unsigned argIndex = 0; - for (auto [basePointer, devicePointer] : llvm::zip_equal( - combinedInfo.BasePointers, combinedInfo.DevicePointers)) { - if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer) { - const auto &arg = region.front().getArgument(argIndex); - moduleTranslation.mapValue( - arg, info.DevicePtrInfoMap[basePointer].second); - argIndex++; - } else if (devicePointer == - llvm::OpenMPIRBuilder::DeviceInfoTy::Address) { - const auto &arg = region.front().getArgument(argIndex); - auto *loadInst = builder.CreateLoad( - builder.getPtrTy(), info.DevicePtrInfoMap[basePointer].second); - moduleTranslation.mapValue(arg, loadInst); - argIndex++; - } - } + + mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, + blockArgIface.getUseDeviceAddrBlockArgs(), + combinedInfo.BasePointers, combinedInfo.DevicePointers, + [&](llvm::Value *basePointer) -> llvm::Value * { + return builder.CreateLoad( + builder.getPtrTy(), + info.DevicePtrInfoMap[basePointer].second); + }); + mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer, + blockArgIface.getUseDevicePtrBlockArgs(), + combinedInfo.BasePointers, combinedInfo.DevicePointers, + [&](llvm::Value *basePointer) { + return info.DevicePtrInfoMap[basePointer].second; + }); bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region", builder, moduleTranslation); @@ -3114,17 +3138,14 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, // For device pass, if use_device_ptr(addr) mappings were present, // we need to link them here before codegen. if (ompBuilder->Config.IsTargetDevice.value_or(false)) { - unsigned argIndex = 0; - for (auto [basePointer, devicePointer] : - llvm::zip_equal(mapData.BasePointers, mapData.DevicePointers)) { - if (devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer || - devicePointer == llvm::OpenMPIRBuilder::DeviceInfoTy::Address) { - const auto &arg = region.front().getArgument(argIndex); - moduleTranslation.mapValue(arg, basePointer); - argIndex++; - } - } + mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, + blockArgIface.getUseDeviceAddrBlockArgs(), + mapData.BasePointers, mapData.DevicePointers); + mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer, + blockArgIface.getUseDevicePtrBlockArgs(), + mapData.BasePointers, mapData.DevicePointers); } + bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region", builder, moduleTranslation); } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 4b1468a6761e6..ce3351ba1149f 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -864,9 +864,11 @@ func.func @omp_target_data (%if_cond : i1, %device : si32, %device_ptr: memref){} // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} - // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_addr(%[[VAL_4:.*]] : memref) use_device_ptr(%[[VAL_3:.*]] : memref) + // CHECK: omp.target_data map_entries(%[[MAP_A]] : memref) use_device_addr(%[[VAL_3:.*]] -> %{{.*}} : memref) use_device_ptr(%[[VAL_4:.*]] -> %{{.*}} : memref) %mapv2 = omp.map.info var_ptr(%map1 : memref, tensor) map_clauses(close, present, to) capture(ByRef) -> memref {name = ""} - omp.target_data use_device_ptr(%device_ptr : memref) use_device_addr(%device_addr : memref) map_entries(%mapv2 : memref) {} + omp.target_data map_entries(%mapv2 : memref) use_device_addr(%device_addr -> %arg0 : memref) use_device_ptr(%device_ptr -> %arg1 : memref) { + omp.terminator + } // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref, tensor) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} // CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%[[VAL_2:.*]] : memref, tensor) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref {name = ""} diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index 458d2f28a78f8..654763c577d1a 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -210,8 +210,7 @@ llvm.func @_QPopenmp_target_use_dev_ptr() { %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr %map1 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} - omp.target_data map_entries(%map1 : !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) { - ^bb0(%arg0: !llvm.ptr): + omp.target_data map_entries(%map1 : !llvm.ptr) use_device_ptr(%map2 -> %arg0 : !llvm.ptr) { %1 = llvm.mlir.constant(10 : i32) : i32 %2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr llvm.store %1, %2 : i32, !llvm.ptr @@ -255,8 +254,7 @@ llvm.func @_QPopenmp_target_use_dev_addr() { %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr %map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} - omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) { - ^bb0(%arg0: !llvm.ptr): + omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 -> %arg0 : !llvm.ptr) { %1 = llvm.mlir.constant(10 : i32) : i32 %2 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr llvm.store %1, %2 : i32, !llvm.ptr @@ -298,8 +296,7 @@ llvm.func @_QPopenmp_target_use_dev_addr_no_ptr() { %a = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr %map = omp.map.info var_ptr(%a : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} %map2 = omp.map.info var_ptr(%a : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} - omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) { - ^bb0(%arg0: !llvm.ptr): + omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 -> %arg0 : !llvm.ptr) { %1 = llvm.mlir.constant(10 : i32) : i32 llvm.store %1, %arg0 : i32, !llvm.ptr omp.terminator @@ -341,8 +338,7 @@ llvm.func @_QPopenmp_target_use_dev_addr_nomap() { %b = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr %map = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""} %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} - omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 : !llvm.ptr) { - ^bb0(%arg0: !llvm.ptr): + omp.target_data map_entries(%map : !llvm.ptr) use_device_addr(%map2 -> %arg0 : !llvm.ptr) { %2 = llvm.mlir.constant(10 : i32) : i32 %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr llvm.store %2, %3 : i32, !llvm.ptr @@ -400,13 +396,12 @@ llvm.func @_QPopenmp_target_use_dev_both() { %map1 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} %map2 = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} %map3 = omp.map.info var_ptr(%b : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} - omp.target_data map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_ptr(%map2 : !llvm.ptr) use_device_addr(%map3 : !llvm.ptr) { - ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr): + omp.target_data map_entries(%map, %map1 : !llvm.ptr, !llvm.ptr) use_device_addr(%map3 -> %arg0 : !llvm.ptr) use_device_ptr(%map2 -> %arg1 : !llvm.ptr) { %2 = llvm.mlir.constant(10 : i32) : i32 - %3 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr + %3 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr llvm.store %2, %3 : i32, !llvm.ptr %4 = llvm.mlir.constant(20 : i32) : i32 - %5 = llvm.load %arg1 : !llvm.ptr -> !llvm.ptr + %5 = llvm.load %arg0 : !llvm.ptr -> !llvm.ptr llvm.store %4, %5 : i32, !llvm.ptr omp.terminator } diff --git a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir index a4f8098879a9f..3a71778e7d0a7 100644 --- a/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir +++ b/mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir @@ -22,8 +22,7 @@ module attributes {omp.is_target_device = true } { %0 = llvm.mlir.constant(1 : i64) : i64 %a = llvm.alloca %0 x !llvm.ptr : (i64) -> !llvm.ptr %map = omp.map.info var_ptr(%a : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} - omp.target_data use_device_ptr(%map : !llvm.ptr) { - ^bb0(%arg0: !llvm.ptr): + omp.target_data use_device_ptr(%map -> %arg0 : !llvm.ptr) { %map1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""} omp.target map_entries(%map1 -> %arg1 : !llvm.ptr){ %1 = llvm.mlir.constant(999 : i32) : i32