diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 33286258543e53..be34ef8bbd6258 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -318,6 +318,13 @@ fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, consumer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr); + if (!fusedOp.getShapesToLoopsMap()) { + // Fused op has invalid indexing maps. Typically this means something is off + // in the input, but going ahead here would result in verification errors. + // So cleanup and abort. + rewriter.eraseOp(fusedOp); + return llvm::None; + } // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 6ae9e15543e1ca..3f68820b18cc73 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -945,3 +945,33 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor, %arg1: index) -> ten } -> tensor return %8 : tensor } + +// ----- + +func @illegal_fusion(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi32>) -> tensor<5000xi32> { + %c1_i32 = arith.constant 1 : i32 + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + outs(%arg0 : tensor<5000xi64>) { + ^bb0(%arg3: i64): // no predecessors + %22 = linalg.index 0 : index + %23 = arith.index_cast %22 : index to i64 + linalg.yield %23 : i64 + } -> tensor<5000xi64> + %1 = linalg.init_tensor [5000] : tensor<5000xi32> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor<5000xi64>) outs(%1 : tensor<5000xi32>) { + ^bb0(%arg3: i64, %arg5: i32): // no predecessors + %22 = arith.index_cast %arg3 : i64 to index + %23 = tensor.extract %arg1[%22] : tensor<5000xi32> + linalg.yield %23 : i32 + } -> tensor<5000xi32> + return %2 : tensor<5000xi32> +} +// CHECK-LABEL: func @illegal_fusion( +// CHECK: %[[PRODUCER:.+]] = linalg.generic +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[PRODUCER]]