From 6b65d79fbb4682468333cea42b62f15c2dffd8f3 Mon Sep 17 00:00:00 2001 From: Spenser Bauman Date: Mon, 1 Jan 2024 12:12:40 -0500 Subject: [PATCH] [mlir][linalg] Fix for invalid IR in eliminate_empty_tensors (#73513) The transform.structured.eliminate_empty_tensors can produce mis-typed IR when traversing use-def chains past tensor reshaping operations for sharing candidates. This results in Linalg operations whose output types do not match their 'outs' arguments. This patch filters out candidate tensor.empty operations when their types do not match the candidate input operand. --- .../Transforms/EliminateEmptyTensors.cpp | 5 +- ...ot-bufferize-empty-tensor-elimination.mlir | 86 +++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp index 5a8320bdb2875..f28f8f0d34a4d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp @@ -60,7 +60,10 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( config.alwaysIncludeLeaves = false; SetVector emptyTensors = state.findValueInReverseUseDefChain( in->get(), /*condition=*/ - [&](Value val) { return val.getDefiningOp(); }, + [&](Value val) { + return val.getDefiningOp() && + val.getType() == in->get().getType(); + }, config); if (emptyTensors.empty()) continue; diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir index 0172760576efc..761b75d818373 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir @@ -42,3 +42,89 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0) -> (d0)> + +// This test is intended to check that the produced IR does not contain any +// type errors from sharing empty tensor operations with different types. +// The verifiers are sufficient to lock down the intended behavior. + +// CHECK-LABEL: func.func @collapse_shape_prevents_reuse( +func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<56xf32> +{ + %init0 = tensor.empty() : tensor<56xf32> + %init1 = tensor.empty() : tensor<56x1xf32> + + %filled_tensor = linalg.fill + ins(%fill_value : f32) + outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32> + + // The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be + // pushed into the output of the linalg.generic. + %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0, 1]] + : tensor<56x1xf32> into tensor<56xf32> + + %bias = linalg.generic { + indexing_maps = [#map, #map], + iterator_types = ["parallel"] + } ins(%reshaped_tensor : tensor<56xf32>) + outs(%init0 : tensor<56xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<56xf32> + + return %bias : tensor<56xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.eliminate_empty_tensors %0 : !transform.any_op + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> + +// This test is intended to check that the produced IR does not contain any +// type errors from sharing empty tensor operations with different types. +// The verifiers are sufficient to lock down the intended behavior. + +// CHECK-LABEL: func.func @collapse_cast_prevents_reuse( +func.func @collapse_cast_prevents_reuse(%fill_value: f32) -> tensor<56x?xf32> +{ + %c1 = arith.constant 1 : index + %init0 = tensor.empty(%c1) : tensor<56x?xf32> + %init1 = tensor.empty() : tensor<56x1xf32> + + %filled_tensor = linalg.fill + ins(%fill_value : f32) + outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32> + + // The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be + // pushed into the output of the linalg.generic. + %cast = tensor.cast %filled_tensor : tensor<56x1xf32> to tensor<56x?xf32> + + %bias = linalg.generic { + indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel"] + } ins(%cast : tensor<56x?xf32>) + outs(%init0 : tensor<56x?xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<56x?xf32> + + return %bias : tensor<56x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.eliminate_empty_tensors %0 : !transform.any_op + transform.yield + } +}