diff --git a/tests/filecheck/transforms/memref-to-dsd.mlir b/tests/filecheck/transforms/memref-to-dsd.mlir index 9907c976c7..313d9c7a0a 100644 --- a/tests/filecheck/transforms/memref-to-dsd.mlir +++ b/tests/filecheck/transforms/memref-to-dsd.mlir @@ -106,6 +106,14 @@ builtin.module { // CHECK-NEXT: %28 = "csl.load_var"(%27) : (!csl.var>) -> !csl // CHECK-NEXT: "csl.store_var"(%27, %28) : (!csl.var>, !csl) -> () +// ensure that pre-existing get_mem_dsd ops access the underlying buffer, not the get_mem_dsd created on top of it + +%36 = arith.constant 510 : i16 +%37 = "csl.get_mem_dsd"(%b, %36) : (memref<510xf32>, i16) -> !csl + +// CHECK-NEXT: %29 = arith.constant 510 : i16 +// CHECK-NEXT: %30 = "csl.get_mem_dsd"(%b, %29) : (memref<510xf32>, i16) -> !csl + }) {sym_name = "program"} : () -> () } // CHECK-NEXT: }) {"sym_name" = "program"} : () -> () diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index 1d64c3b4f5..1c2d0fcc52 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -17,7 +17,7 @@ StridedLayoutAttr, UnrealizedConversionCastOp, ) -from xdsl.ir import Attribute, Operation, SSAValue +from xdsl.ir import Attribute, Operation, OpResult, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -64,6 +64,31 @@ def match_and_rewrite(self, op: memref.Alloc, rewriter: PatternRewriter, /): rewriter.replace_matched_op([zeros_op, *shape, dsd_op]) +class FixGetDsdOnGetDsd(RewritePattern): + """ + This rewrite pattern resolves GetMemDsdOp being called on GetMemDsdOp instead of the underlying buffer, + a side effect created by `LowerAllocOpPass` in case of pre-existing GetMemDsdOp ops being present in + the program that were created outside of this pass. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /): + if isinstance(op.base_addr.type, csl.DsdType): + if isinstance(op.base_addr, OpResult) and isinstance( + op.base_addr.op, csl.GetMemDsdOp + ): + rewriter.replace_matched_op( + csl.GetMemDsdOp.build( + operands=[op.base_addr.op.base_addr, op.sizes], + properties=op.properties, + attributes=op.attributes, + result_types=op.result_types, + ) + ) + else: + raise ValueError("Failed to resolve GetMemDsdOp called on dsd type") + + class LowerSubviewOpPass(RewritePattern): """Lowers memref.subview to dsd ops""" @@ -335,3 +360,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: apply_recursively=False, ) forward_pass.rewrite_module(op) + cleanup_pass = PatternRewriteWalker( + FixGetDsdOnGetDsd(), + ) + cleanup_pass.rewrite_module(op)