diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index ec3aaefa1d4431..6a223f1b105219 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -506,6 +506,71 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> { let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// 2.9.1 Canonical Loop +//===----------------------------------------------------------------------===// + +def CanonicalLoopOp : OpenMP_Op<"canonical_loop", [SameVariadicOperandSize, + AllTypesMatch<["lowerBound", "upperBound", "step"]>, + ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskLoopOp", + "WsLoopOp"]>, + RecursiveMemoryEffects]> { + let summary = "canonical loop"; + let description = [{ + All loops that conform to OpenMP's definition of a canonical loop can be + simplified to a CanonicalLoopOp. In particular, there are no loop-carried + variables and the number of iterations it will execute is known before the + operation. This allows e.g. to determine the number of threads and chunks + the iterations space is split into before executing any iteration. More + restrictions may apply in cases such as (collapsed) loop nests, doacross + loops, etc. + + The lower and upper bounds specify a half-open range: the range includes the + lower bound but does not include the upper bound. If the `inclusive` + attribute is specified then the upper bound is also included. + + The body region can contain any number of blocks. The region is terminated + by "omp.yield" instruction without operands. The induction variables, + represented as entry block arguments to the canonical loop operation's + single region, match the types of the `lowerBound`, `upperBound` and `step` + arguments. + + ```mlir + omp.canonical_loop (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) { + %a = load %arrA[%i1, %i2] : memref + %b = load %arrB[%i1, %i2] : memref + %sum = arith.addf %a, %b : f32 + store %sum, %arrC[%i1, %i2] : memref + omp.yield + } + ``` + + This is a temporary simplified definition of canonical loop based on + existing OpenMP loop operations intended to serve as a stopgap solution + until discussion over its long-term representation reaches a conclusion. + Specifically, this approach is not intended to help with the addition of + support for loop transformations. + }]; + + let arguments = (ins Variadic:$lowerBound, + Variadic:$upperBound, + Variadic:$step, + UnitAttr:$inclusive); + + let regions = (region AnyRegion:$region); + + let extraClassDeclaration = [{ + /// Returns the number of loops in the canonical loop nest. + unsigned getNumLoops() { return getLowerBound().size(); } + + /// Returns the induction variables of the canonical loop nest. + ArrayRef getIVs() { return getRegion().getArguments(); } + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // 2.9.2 Workshare Loop Construct //===----------------------------------------------------------------------===// @@ -714,7 +779,7 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments, def YieldOp : OpenMP_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["WsLoopOp", "ReductionDeclareOp", + ParentOneOf<["CanonicalLoopOp", "WsLoopOp", "ReductionDeclareOp", "AtomicUpdateOp", "SimdLoopOp", "PrivateClauseOp"]>]> { let summary = "loop yield and termination operation"; let description = [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 9c2862c568ed2f..a3443f70eda105 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1487,6 +1487,79 @@ LogicalResult SingleOp::verify() { getCopyprivateFuncs()); } +//===----------------------------------------------------------------------===// +// CanonicalLoopOp +//===----------------------------------------------------------------------===// + +ParseResult CanonicalLoopOp::parse(OpAsmParser &parser, + OperationState &result) { + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector ivs; + SmallVector lbs, ubs; + Type loopVarType; + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || + parser.parseColonType(loopVarType) || + // Parse loop bounds. + parser.parseEqual() || + parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) || + parser.parseKeyword("to") || + parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren)) + return failure(); + + for (auto &iv : ivs) + iv.type = loopVarType; + + // Parse "inclusive" flag. + if (succeeded(parser.parseOptionalKeyword("inclusive"))) + result.addAttribute("inclusive", + UnitAttr::get(parser.getBuilder().getContext())); + + // Parse step values. + SmallVector steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) + return failure(); + + // Parse the body. + Region *region = result.addRegion(); + if (parser.parseRegion(*region, ivs)) + return failure(); + + // Resolve operands. + if (parser.resolveOperands(lbs, loopVarType, result.operands) || + parser.resolveOperands(ubs, loopVarType, result.operands) || + parser.resolveOperands(steps, loopVarType, result.operands)) + return failure(); + + // Parse the optional attribute list. + return parser.parseOptionalAttrDict(result.attributes); +} + +void CanonicalLoopOp::print(OpAsmPrinter &p) { + Region ®ion = getRegion(); + auto args = region.getArguments(); + p << " (" << args << ") : " << args[0].getType() << " = (" << getLowerBound() + << ") to (" << getUpperBound() << ") "; + if (getInclusive()) + p << "inclusive "; + p << "step (" << getStep() << ") "; + p.printRegion(region, /*printEntryBlockArgs=*/false); +} + +LogicalResult CanonicalLoopOp::verify() { + if (getLowerBound().size() != getRegion().getNumArguments()) + return emitOpError() << "number of range arguments and IVs do not match"; + + for (auto [iv, lb] : + llvm::zip_equal(getLowerBound(), getRegion().getArguments())) { + if (lb.getType() != iv.getType()) + return emitOpError() + << "range argument type does not match corresponding IV type"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // WsLoopOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 9a3964a844a2ff..835c61246f5ac5 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -87,6 +87,41 @@ func.func @proc_bind_once() { // ----- +func.func @invalid_parent(%lb : index, %ub : index, %step : index) { + // expected-error@+1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}} + omp.canonical_loop (%iv) : index = (%lb) to (%ub) step (%step) { + omp.yield + } +} + +// ----- + +func.func @type_mismatch(%lb : index, %ub : index, %step : index) { + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // expected-error@+1 {{range argument type does not match corresponding IV type}} + "omp.canonical_loop" (%lb, %ub, %step) ({ + ^bb0(%iv2: i32): + omp.yield + }) : (index, index, index) -> () + omp.yield + } +} + +// ----- + +func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) { + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // expected-error@+1 {{number of range arguments and IVs do not match}} + "omp.canonical_loop" (%lb, %ub, %step) ({ + ^bb0(%iv1 : index, %iv2 : index): + omp.yield + }) : (index, index, index) -> () + omp.yield + } +} + +// ----- + func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) { // expected-error @below {{expected 'for'}} omp.wsloop nowait inclusive diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index c79659a4159f01..109b5cb62e01c0 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -133,6 +133,81 @@ func.func @omp_parallel_pretty(%data_var : memref, %if_cond : i1, %num_thre return } +// CHECK-LABEL: omp_canonical_loop +func.func @omp_canonical_loop(%lb : index, %ub : index, %step : index) -> () { + + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.canonical_loop + // CHECK-SAME: (%{{.*}}) : index = + // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}}) + "omp.canonical_loop" (%lb, %ub, %step) ({ + ^bb0(%iv2: index): + omp.yield + }) : (index, index, index) -> () + omp.yield + } + + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.canonical_loop + // CHECK-SAME: (%{{.*}}) : index = + // CHECK-SAME: (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) + "omp.canonical_loop" (%lb, %ub, %step) ({ + ^bb0(%iv2: index): + omp.yield + }) {inclusive} : (index, index, index) -> () + omp.yield + } + + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.canonical_loop + // CHECK-SAME: (%{{.*}}, %{{.*}}) : index = + // CHECK-SAME: (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) + "omp.canonical_loop" (%lb, %lb, %ub, %ub, %step, %step) ({ + ^bb0(%iv2: index, %iv3: index): + omp.yield + }) : (index, index, index, index, index, index) -> () + omp.yield + } + + return +} + +// CHECK-LABEL: omp_canonical_loop_pretty +func.func @omp_canonical_loop_pretty(%lb : index, %ub : index, %step : index) -> () { + + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.canonical_loop + // CHECK-SAME: (%{{.*}}) : index = + // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}}) + omp.canonical_loop (%iv2) : index = (%lb) to (%ub) step (%step) { + omp.yield + } + omp.yield + } + + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.canonical_loop + // CHECK-SAME: (%{{.*}}) : index = + // CHECK-SAME: (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}}) + omp.canonical_loop (%iv2) : index = (%lb) to (%ub) inclusive step (%step) { + omp.yield + } + omp.yield + } + + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + // CHECK: omp.canonical_loop + // CHECK-SAME: (%{{.*}}) : index = + // CHECK-SAME: (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}) + omp.canonical_loop (%iv2, %iv3) : index = (%lb, %lb) to (%ub, %ub) step (%step, %step) { + omp.yield + } + omp.yield + } + + return +} + // CHECK-LABEL: omp_wsloop func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref, %linear_var : i32, %chunk_var : i32) -> () {