-
Notifications
You must be signed in to change notification settings - Fork 12.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][Transforms] Dialect conversion: Fix missing source materializa…
…tion (#97903) This commit fixes a bug in the dialect conversion. During a 1:N signature conversion, the dialect conversion did not insert a cast back to the original block argument type, producing invalid IR. See `test-block-legalization.mlir`: Without this commit, the operand type of the op changes because an `unrealized_conversion_cast` is missing: ``` "test.consumer_of_complex"(%v) : (!llvm.struct<(f64, f64)>) -> () ``` To implement this fix, it was necessary to change the meaning of argument materializations. An argument materialization now maps from the new block argument types to the original block argument type. (It now behaves almost like a source materialization.) This also addresses a `FIXME` in the code base: ``` // FIXME: The current argument materialization hook expects the original // output type, even though it doesn't use that as the actual output type // of the generated IR. The output type is just used as an indicator of // the type of materialization to do. This behavior is really awkward in // that it diverges from the behavior of the other hooks, and can be // easily misunderstood. We should clean up the argument hooks to better // represent the desired invariants we actually care about. ``` It is no longer necessary to distinguish between the "output type" and the "original output type". Most type converter are already written according to the new API. (Most implementations use the same conversion functions as for source materializations.) One exception is the MemRef-to-LLVM type converter, which materialized an `!llvm.struct` based on the elements of a memref descriptor. It still does that, but casts the `!llvm.struct` back to the original memref type. The dialect conversion inserts a target materialization (to `!llvm.struct`) which cancels out with the other cast. This commit also fixes a bug in `computeNecessaryMaterializations`. The implementation did not account for the possibility that a value was replaced multiple times. E.g., replace `a` by `b`, then `b` by `c`. This commit also adds a transform dialect op to populate SCF-to-CF patterns. This transform op was needed to write a test case. The bug described here appears only during a complex interplay of 1:N signature conversions and op replacements. (I was not able to trigger it with ops and patterns from the `test` dialect without duplicating the `scf.if` pattern.) Note for LLVM integration: Make sure that all `addArgument/Source/TargetMaterialization` functions produce an SSA of the specified type. Depends on #98743.
- Loading branch information
1 parent
dd7d81e
commit acc159a
Showing
9 changed files
with
141 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s | ||
|
||
// CHECK-LABEL: func @complex_block_signature_conversion( | ||
// CHECK: %[[cst:.*]] = complex.constant | ||
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)> | ||
// Note: Some blocks are omitted. | ||
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]] | ||
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>): | ||
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64> | ||
// CHECK: llvm.br ^[[block2:.*]] | ||
// CHECK: ^[[block2]]: | ||
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> () | ||
func.func @complex_block_signature_conversion() { | ||
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64> | ||
%true = arith.constant true | ||
%0 = scf.if %true -> complex<f64> { | ||
scf.yield %cst : complex<f64> | ||
} else { | ||
scf.yield %cst : complex<f64> | ||
} | ||
|
||
// Regression test to ensure that the a source materialization is inserted. | ||
// The operand of "test.consumer_of_complex" must not change. | ||
"test.consumer_of_complex"(%0) : (complex<f64>) -> () | ||
return | ||
} | ||
|
||
module attributes {transform.with_named_sequence} { | ||
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) { | ||
%func = transform.structured.match ops{["func.func"]} in %toplevel_module | ||
: (!transform.any_op) -> !transform.any_op | ||
transform.apply_conversion_patterns to %func { | ||
transform.apply_conversion_patterns.dialect_to_llvm "cf" | ||
transform.apply_conversion_patterns.func.func_to_llvm | ||
transform.apply_conversion_patterns.scf.scf_to_control_flow | ||
} with type_converter { | ||
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter | ||
} { | ||
legal_dialects = ["llvm"], | ||
partial_conversion | ||
} : !transform.any_op | ||
transform.yield | ||
} | ||
} |