Skip to content

Commit

Permalink
[MLIR][OpenMP] Use map format to represent use_device_{addr,ptr} (llv…
Browse files Browse the repository at this point in the history
…m#109810)

This patch updates the `omp.target_data` operation to use the same
formatting as `map` clauses on `omp.target` for `use_device_addr` and
`use_device_ptr`. This is done so the mapping that is being enforced
between op arguments and associated entry block arguments is explicit.

The way it is achieved is by marking these clauses as entry block
argument-defining and adjusting printer/parsers accordingly.

As a result of this change, block arguments for `use_device_addr` come
before those for `use_device_ptr`, which is the opposite of the previous
undocumented situation. Some unit tests are updated based on this
change, in addition to those updated because of the format change.
  • Loading branch information
skatrak authored Oct 1, 2024
1 parent c66dee4 commit 5894d4e
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 63 deletions.
5 changes: 3 additions & 2 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -429,13 +429,14 @@ func.func @_QPopenmp_target_data_region() {

func.func @_QPomp_target_data_empty() {
%0 = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_data_emptyEa"}
omp.target_data use_device_addr(%0 : !fir.ref<!fir.array<1024xi32>>) {
omp.target_data use_device_addr(%0 -> %arg0 : !fir.ref<!fir.array<1024xi32>>) {
omp.terminator
}
return
}

// CHECK-LABEL: llvm.func @_QPomp_target_data_empty
// CHECK: omp.target_data use_device_addr(%1 : !llvm.ptr) {
// CHECK: omp.target_data use_device_addr(%1 -> %{{.*}} : !llvm.ptr) {
// CHECK: }

// -----
Expand Down
6 changes: 2 additions & 4 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,8 @@ subroutine omp_target_device_ptr
type(c_ptr) :: a
integer, target :: b
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}}) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"}
!CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}})
!CHECK: omp.target_data map_entries(%[[MAP]]{{.*}}) use_device_ptr({{.*}} -> %[[VAL_1:.*]] : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>)
!$omp target data map(tofrom: a) use_device_ptr(a)
!CHECK: ^bb0(%[[VAL_1:.*]]: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>):
!CHECK: {{.*}} = fir.coordinate_of %[[VAL_1:.*]], {{.*}} : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
a = c_loc(b)
!CHECK: omp.terminator
Expand All @@ -529,9 +528,8 @@ subroutine omp_target_device_addr
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]], %[[DEV_ADDR]] : {{.*}}) {
!CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
!CHECK: ^bb0(%[[ARG_0:.*]]: !fir.llvm_ptr<!fir.ref<i32>>, %[[ARG_1:.*]]: !fir.ref<!fir.box<!fir.ptr<i32>>>):
!CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
!CHECK: %[[C10:.*]] = arith.constant 10 : i32
!CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
Expand Down
12 changes: 4 additions & 8 deletions flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
! use_device_ptr to use_device_addr works, without breaking any functionality.

!CHECK: func.func @{{.*}}only_use_device_ptr()
!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine only_use_device_ptr
use iso_c_binding
integer, pointer, dimension(:) :: array
Expand All @@ -19,8 +18,7 @@ subroutine only_use_device_ptr
end subroutine

!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
Expand All @@ -32,8 +30,7 @@ subroutine mix_use_device_ptr_and_addr
end subroutine

!CHECK: func.func @{{.*}}only_use_device_addr()
!CHECK: omp.target_data use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
subroutine only_use_device_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
Expand All @@ -45,8 +42,7 @@ subroutine only_use_device_addr
end subroutine

!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
!CHECK: ^bb0(%{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, %{{.*}}: !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, %{{.*}}: !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, %{{.*}}: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>):
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr_and_map
use iso_c_binding
integer :: i, j
Expand Down
28 changes: 24 additions & 4 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -1209,18 +1209,28 @@ class OpenMP_UseDeviceAddrClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface
];

let arguments = (ins
Variadic<OpenMP_PointerLikeType>:$use_device_addr_vars
);

let optAssemblyFormat = [{
`use_device_addr` `(` $use_device_addr_vars `:` type($use_device_addr_vars) `)`
let extraClassDeclaration = [{
unsigned numUseDeviceAddrBlockArgs() {
return getUseDeviceAddrVars().size();
}
}];

let description = [{
The optional `use_device_addr_vars` specifies the address of the objects in
the device data environment.
}];

// Assembly format not defined because this clause must be processed together
// with the first region of the operation, as it defines entry block
// arguments.
}

def OpenMP_UseDeviceAddrClause : OpenMP_UseDeviceAddrClauseSkip<>;
Expand All @@ -1234,18 +1244,28 @@ class OpenMP_UseDevicePtrClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface
];

let arguments = (ins
Variadic<OpenMP_PointerLikeType>:$use_device_ptr_vars
);

let optAssemblyFormat = [{
`use_device_ptr` `(` $use_device_ptr_vars `:` type($use_device_ptr_vars) `)`
let extraClassDeclaration = [{
unsigned numUseDevicePtrBlockArgs() {
return getUseDevicePtrVars().size();
}
}];

let description = [{
The optional `use_device_ptr_vars` specifies the device pointers to the
corresponding list items in the device data environment.
}];

// Assembly format not defined because this clause must be processed together
// with the first region of the operation, as it defines entry block
// arguments.
}

def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,12 @@ def TargetDataOp: OpenMP_Op<"target_data", traits = [
OpBuilder<(ins CArg<"const TargetDataOperands &">:$clauses)>
];

let assemblyFormat = clausesAssemblyFormat # [{
custom<UseDeviceAddrUseDevicePtrRegion>(
$region, $use_device_addr_vars, type($use_device_addr_vars),
$use_device_ptr_vars, type($use_device_ptr_vars)) attr-dict
}];

let hasVerifier = 1;
}

Expand Down
37 changes: 36 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
"unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
"unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
"unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
return 0;
}]>,

// Unified access methods for clause-associated entry block arguments.
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
Expand Down Expand Up @@ -72,6 +80,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.",
"unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.",
"unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
}]>,

InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
Expand Down Expand Up @@ -109,13 +127,30 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
iface.getTaskReductionBlockArgsStart(),
$_op.numTaskReductionBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getUseDeviceAddrBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getUseDeviceAddrBlockArgsStart(),
$_op.numUseDeviceAddrBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getUseDevicePtrBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getUseDevicePtrBlockArgsStart(),
$_op.numUseDevicePtrBlockArgs());
}]>,
];

let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
unsigned expectedArgs = iface.numInReductionBlockArgs() +
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
Expand Down
43 changes: 43 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ struct AllRegionParseArgs {
std::optional<PrivateParseArgs> privateArgs;
std::optional<ReductionParseArgs> reductionArgs;
std::optional<ReductionParseArgs> taskReductionArgs;
std::optional<MapParseArgs> useDeviceAddrArgs;
std::optional<MapParseArgs> useDevicePtrArgs;
};
} // namespace

Expand Down Expand Up @@ -648,6 +650,16 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
return parser.emitError(parser.getCurrentLocation())
<< "invalid `task_reduction` format";

if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
args.useDeviceAddrArgs)))
return parser.emitError(parser.getCurrentLocation())
<< "invalid `use_device_addr` format";

if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
args.useDevicePtrArgs)))
return parser.emitError(parser.getCurrentLocation())
<< "invalid `use_device_addr` format";

return parser.parseRegion(region, entryBlockArgs);
}

Expand Down Expand Up @@ -735,6 +747,18 @@ static ParseResult parseTaskReductionRegion(
return parseBlockArgRegion(parser, region, args);
}

static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
SmallVectorImpl<Type> &useDeviceAddrTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
SmallVectorImpl<Type> &useDevicePtrTypes) {
AllRegionParseArgs args;
args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
return parseBlockArgRegion(parser, region, args);
}

//===----------------------------------------------------------------------===//
// Printers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -767,6 +791,8 @@ struct AllRegionPrintArgs {
std::optional<PrivatePrintArgs> privateArgs;
std::optional<ReductionPrintArgs> reductionArgs;
std::optional<ReductionPrintArgs> taskReductionArgs;
std::optional<MapPrintArgs> useDeviceAddrArgs;
std::optional<MapPrintArgs> useDevicePtrArgs;
};
} // namespace

Expand Down Expand Up @@ -849,6 +875,11 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
printBlockArgClause(p, ctx, "task_reduction",
iface.getTaskReductionBlockArgs(),
args.taskReductionArgs);
printBlockArgClause(p, ctx, "use_device_addr",
iface.getUseDeviceAddrBlockArgs(),
args.useDeviceAddrArgs);
printBlockArgClause(p, ctx, "use_device_ptr",
iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);

p.printRegion(region, /*printEntryBlockArgs=*/false);
}
Expand Down Expand Up @@ -925,6 +956,18 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
printBlockArgRegion(p, op, region, args);
}

static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
Region &region,
ValueRange useDeviceAddrVars,
TypeRange useDeviceAddrTypes,
ValueRange useDevicePtrVars,
TypeRange useDevicePtrTypes) {
AllRegionPrintArgs args;
args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
printBlockArgRegion(p, op, region, args);
}

/// Verifies Reduction Clause
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
Expand Down
Loading

0 comments on commit 5894d4e

Please sign in to comment.