Skip to content

Commit

Permalink
[mlir][Linalg] Avoid generating illegal operations during elementwise…
Browse files Browse the repository at this point in the history
… fusion.

In some cases, fusion can produce illegal operations if after fusion
the range of some of the loops cannot be computed from shapes of its
operands. Check for this case and abort the fusion if this happens.

Differential Revision: https://reviews.llvm.org/D117602
  • Loading branch information
MaheshRavishankar committed Jan 21, 2022
1 parent e6de53b commit a99e06a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr);
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
// So cleanup and abort.
rewriter.eraseOp(fusedOp);
return llvm::None;
}

// Construct an AffineMap from consumer loops to producer loops.
// consumer loop -> tensor index
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -945,3 +945,33 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> ten
} -> tensor<?xf32>
return %8 : tensor<?xf32>
}

// -----

func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> {
%c1_i32 = arith.constant 1 : i32
%0 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
outs(%arg0 : tensor<5000xi64>) {
^bb0(%arg3: i64): // no predecessors
%22 = linalg.index 0 : index
%23 = arith.index_cast %22 : index to i64
linalg.yield %23 : i64
} -> tensor<5000xi64>
%1 = linalg.init_tensor [5000] : tensor<5000xi32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0 : tensor<5000xi64>) outs(%1 : tensor<5000xi32>) {
^bb0(%arg3: i64, %arg5: i32): // no predecessors
%22 = arith.index_cast %arg3 : i64 to index
%23 = tensor.extract %arg1[%22] : tensor<5000xi32>
linalg.yield %23 : i32
} -> tensor<5000xi32>
return %2 : tensor<5000xi32>
}
// CHECK-LABEL: func @illegal_fusion(
// CHECK: %[[PRODUCER:.+]] = linalg.generic
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[PRODUCER]]

0 comments on commit a99e06a

Please sign in to comment.