Skip to content

Commit

Permalink
[MLIR] [Transforms] Let transform.structured.convert_to_loops retur…
Browse files Browse the repository at this point in the history
…n handles to loops (#83984)

This lets `transform.structured.convert_to_loops` return handles to the
generated loops, making this transformation more useful to use for
(transformation-)nesting purposes. This is modelled after SCFs
`transform.loop.forall_to_for` which returns handles to loops.

Introduced in commit aa2a96a, with a
note that they might move out of the `Linalg`-Dialect, but no reason
given for the non-return of handles. As far as I can see, this transform
always returns loops.
  • Loading branch information
lhunloh authored Mar 6, 2024
1 parent f0eb0c5 commit 47bc565
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1274,33 +1274,29 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
}];
}

//===----------------------------------------------------------------------===//
// ConvertToLoopsOp
//===----------------------------------------------------------------------===//

def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
For operations that implement the `TilingInterface`, and implement
the `generateScalarImplementation` method, lowers the operation to
loops. This operation does not return any handles.
loops. The return handle points to all generated loops.
Fails if the payload ops cannot be lowered to loops.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
$target attr-dict `:` type($target)
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::TilingInterface target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
$target attr-dict `:` functional-type(operands, results)
}];
}


//===----------------------------------------------------------------------===//
// DecomposeInterfaceOp
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 25 additions & 10 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2112,16 +2112,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
// ConvertToLoopsOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
transform::TransformRewriter &rewriter, TilingInterface target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
FailureOr<SmallVector<scf::ForOp>> loops =
scf::lowerToLoopsUsingSCFForOp(rewriter, target);
if (failed(loops))
return emitDefaultDefiniteFailure(target);
rewriter.eraseOp(target);
DiagnosedSilenceableFailure
transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> loops;
for (Operation *target : state.getPayloadOps(getTarget())) {
auto tilingOp = dyn_cast<TilingInterface>(*target);
if (!target) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "expected the payload to implement TilingInterface";
diag.attachNote(target->getLoc()) << "payload op";
return diag;
}
rewriter.setInsertionPoint(target);
FailureOr<SmallVector<scf::ForOp>> generatedLoops =
scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
if (failed(generatedLoops))
return emitDefaultDefiniteFailure(target);
for (scf::ForOp &loop : *generatedLoops) {
loops.push_back(loop.getOperation());
}
rewriter.eraseOp(target);
}
results.set(cast<OpResult>(getResult()), loops);
return DiagnosedSilenceableFailure::success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %matmul : !transform.any_op
%0 = transform.structured.convert_to_loops %matmul
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand All @@ -37,6 +38,57 @@ module attributes {transform.with_named_sequence} {

// -----

func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
%arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
outs(%arg2 : memref<?x?xf32>)
linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
outs(%arg4 : memref<?xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
: (!transform.any_op) -> !transform.any_op
%0 = transform.structured.convert_to_loops %linalg_ops
: (!transform.any_op) -> (!transform.any_op)
%1:5 = transform.split_handle %0
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK-LABEL: func @gemm
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
// CHECK: memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]

// -----

func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
%arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
linalg.generic {
Expand Down Expand Up @@ -66,7 +118,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %generic : !transform.any_op
%0 = transform.structured.convert_to_loops %generic
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -111,7 +164,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %conv : !transform.any_op
%0 = transform.structured.convert_to_loops %conv
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -165,7 +219,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %pool : !transform.any_op
%0 = transform.structured.convert_to_loops %pool
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -216,7 +271,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%map = transform.structured.match ops{["linalg.map"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %map : !transform.any_op
%0 = transform.structured.convert_to_loops %map
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -248,7 +304,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %transpose : !transform.any_op
%0 = transform.structured.convert_to_loops %transpose
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -285,7 +342,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %reduce : !transform.any_op
%0 = transform.structured.convert_to_loops %reduce
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -322,7 +380,8 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
: (!transform.any_op) -> !transform.any_op
transform.structured.convert_to_loops %broadcast : !transform.any_op
%0 = transform.structured.convert_to_loops %broadcast
: (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}
Expand Down

0 comments on commit 47bc565

Please sign in to comment.