From abebdd435bf0d4e89bf25682445acbec46b871ac Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 10 Oct 2024 12:27:06 +0200 Subject: [PATCH 1/2] transformations: (memref-to-dsd) Support pre-existing GetMemDsd ops --- tests/filecheck/transforms/memref-to-dsd.mlir | 8 +++++ xdsl/transforms/memref_to_dsd.py | 31 ++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/filecheck/transforms/memref-to-dsd.mlir b/tests/filecheck/transforms/memref-to-dsd.mlir index 9907c976c7..313d9c7a0a 100644 --- a/tests/filecheck/transforms/memref-to-dsd.mlir +++ b/tests/filecheck/transforms/memref-to-dsd.mlir @@ -106,6 +106,14 @@ builtin.module { // CHECK-NEXT: %28 = "csl.load_var"(%27) : (!csl.var>) -> !csl // CHECK-NEXT: "csl.store_var"(%27, %28) : (!csl.var>, !csl) -> () +// ensure that pre-existing get_mem_dsd ops access the underlying buffer, not the get_mem_dsd created on top of it + +%36 = arith.constant 510 : i16 +%37 = "csl.get_mem_dsd"(%b, %36) : (memref<510xf32>, i16) -> !csl + +// CHECK-NEXT: %29 = arith.constant 510 : i16 +// CHECK-NEXT: %30 = "csl.get_mem_dsd"(%b, %29) : (memref<510xf32>, i16) -> !csl + }) {sym_name = "program"} : () -> () } // CHECK-NEXT: }) {"sym_name" = "program"} : () -> () diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index 1d64c3b4f5..7fa73c333f 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -17,7 +17,7 @@ StridedLayoutAttr, UnrealizedConversionCastOp, ) -from xdsl.ir import Attribute, Operation, SSAValue +from xdsl.ir import Attribute, Operation, OpResult, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -64,6 +64,31 @@ def match_and_rewrite(self, op: memref.Alloc, rewriter: PatternRewriter, /): rewriter.replace_matched_op([zeros_op, *shape, dsd_op]) +class FixGetDsdOnGetDsd(RewritePattern): + """ + This rewrite pattern resolves @get_dsd being called on @get_dsd instead of the underlying buffer, + a side effect created by `LowerAllocOpPass` in case of pre-existing @get_dsd ops being present in + the program that were created outside of this pass. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter, /): + if isinstance(op.base_addr.type, csl.DsdType): + if isinstance(op.base_addr, OpResult) and isinstance( + op.base_addr.op, csl.GetMemDsdOp + ): + rewriter.replace_matched_op( + csl.GetMemDsdOp.build( + operands=[op.base_addr.op.base_addr, op.sizes], + properties=op.properties, + attributes=op.attributes, + result_types=op.result_types, + ) + ) + else: + raise ValueError("Failed to resolve GetMemDsdOp called on dsd type") + + class LowerSubviewOpPass(RewritePattern): """Lowers memref.subview to dsd ops""" @@ -335,3 +360,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: apply_recursively=False, ) forward_pass.rewrite_module(op) + cleanup_pass = PatternRewriteWalker( + FixGetDsdOnGetDsd(), + ) + cleanup_pass.rewrite_module(op) From 5a795bea47e561fff4d0fdd1e16f86f19c3fe74a Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 10 Oct 2024 12:30:41 +0200 Subject: [PATCH 2/2] Updating docstring --- xdsl/transforms/memref_to_dsd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xdsl/transforms/memref_to_dsd.py b/xdsl/transforms/memref_to_dsd.py index 7fa73c333f..1c2d0fcc52 100644 --- a/xdsl/transforms/memref_to_dsd.py +++ b/xdsl/transforms/memref_to_dsd.py @@ -66,8 +66,8 @@ def match_and_rewrite(self, op: memref.Alloc, rewriter: PatternRewriter, /): class FixGetDsdOnGetDsd(RewritePattern): """ - This rewrite pattern resolves @get_dsd being called on @get_dsd instead of the underlying buffer, - a side effect created by `LowerAllocOpPass` in case of pre-existing @get_dsd ops being present in + This rewrite pattern resolves GetMemDsdOp being called on GetMemDsdOp instead of the underlying buffer, + a side effect created by `LowerAllocOpPass` in case of pre-existing GetMemDsdOp ops being present in the program that were created outside of this pass. """