From a185594bb3a3fe8c330e6663f24dc9435292904a Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 7 Jun 2024 15:44:33 +0100 Subject: [PATCH 1/2] Implement memref.atomic_rmw --- .../filecheck/dialects/memref/memref_ops.mlir | 8 ++++++- xdsl/dialects/memref.py | 21 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/filecheck/dialects/memref/memref_ops.mlir b/tests/filecheck/dialects/memref/memref_ops.mlir index 903cd927b2..15e9bee4e9 100644 --- a/tests/filecheck/dialects/memref/memref_ops.mlir +++ b/tests/filecheck/dialects/memref/memref_ops.mlir @@ -39,12 +39,15 @@ builtin.module { memref.dealloc %8 : memref<1xindex> memref.dealloc %10 : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> memref.dealloc %11 : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> + %fmemref = "test.op"() : () -> memref<32x32xf32> + %e = "test.op"() : () -> f32 + %207 = "memref.atomic_rmw"(%e, %fmemref, %1, %1) <{kind = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32 func.return } } -// CHECK-NEXT: builtin.module { +// CHECK-NEXT: builtin.module { // CHECK-NEXT: func.func @memref_alloca_scope() { // CHECK-NEXT: "memref.alloca_scope"() ({ // CHECK-NEXT: "memref.alloca_scope.return"() : () -> () @@ -83,6 +86,9 @@ builtin.module { // CHECK-NEXT: memref.dealloc %{{.*}} : memref<1xindex> // CHECK-NEXT: memref.dealloc %{{.*}} : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> // CHECK-NEXT: memref.dealloc %{{.*}} : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32> +// CHECK-NEXT: %{{.*}} = "test.op"() : () -> memref<32x32xf32> +// CHECK-NEXT: %{{.*}} = "test.op"() : () -> f32 +// CHECK-NEXT: %{{.*}} = "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{"kind" = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32 // CHECK-NEXT: func.return // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 9f0af468ec..3b4d18b347 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -6,7 +6,9 @@ from typing_extensions import Self from xdsl.dialects.builtin import ( + AnyFloat, AnyIntegerAttr, + AnySignlessIntegerType, ArrayAttr, BoolAttr, DenseArrayBase, @@ -341,6 +343,24 @@ def verify_(self) -> None: ) +@irdl_op_definition +class AtomicRMWOp(IRDLOperation): + name = "memref.atomic_rmw" + + T = Annotated[ + AnyFloat | AnySignlessIntegerType, + ConstraintVar("T"), + ] + + value = operand_def(T) + memref = operand_def(MemRefType[T]) + indices = var_operand_def(IndexType) + + kind = prop_def(IntegerAttr[i64]) + + result = result_def(T) + + @irdl_op_definition class Dealloc(IRDLOperation): name = "memref.dealloc" @@ -784,6 +804,7 @@ def verify_(self) -> None: Alloca, AllocaScopeOp, AllocaScopeReturnOp, + AtomicRMWOp, CopyOp, CollapseShapeOp, ExpandShapeOp, From b549fe93ef9762859da0de7474614e4be3b2da65 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 10 Jun 2024 10:13:42 +0100 Subject: [PATCH 2/2] Interop and tweaks. --- .../dialects/memref/memref_ops_mlir_conversion.mlir | 6 ++++++ xdsl/dialects/builtin.py | 3 +++ xdsl/dialects/memref.py | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/memref/memref_ops_mlir_conversion.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/memref/memref_ops_mlir_conversion.mlir index bdface6a51..b5c7554611 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/memref/memref_ops_mlir_conversion.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/memref/memref_ops_mlir_conversion.mlir @@ -33,6 +33,9 @@ memref<100xi32>, index, index ) -> () + %fmemref = memref.alloc() : memref<32x32xf32> + %e = arith.constant 1.0 : f32 + %207 = "memref.atomic_rmw"(%e, %fmemref, %1, %1) <{kind = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32 "func.return"() : () -> () }) {"sym_name" = "memref_test", "function_type" = () -> (), "sym_visibility" = "private"} : () -> () }) : () -> () @@ -61,6 +64,9 @@ // CHECK: "memref.dma_start"(%9, %1, %10, %1, %3, %11, %1) {"operandSegmentSizes" = array} : (memref<100xi32, 10 : i64>, index, memref<100xi32, 9 : i64>, index, index, memref<100xi32>, index) -> () // CHECK-NEXT: "memref.dma_wait"(%11, %1, %3) {"operandSegmentSizes" = array} : (memref<100xi32>, index, index) -> () +// CHECK-NEXT: %12 = "memref.alloc"() <{"operandSegmentSizes" = array}> : () -> memref<32x32xf32> +// CHECK-NEXT: %13 = "arith.constant"() <{"value" = 1.000000e+00 : f32}> : () -> f32 +// CHECK-NEXT: %14 = "memref.atomic_rmw"(%13, %12, %1, %1) <{"kind" = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32 // CHECK-NEXT: "func.return"() : () -> () // CHECK-NEXT: }) : () -> () // CHECK-NEXT: }) : () -> () diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 7e47c41b5e..d5d5525ac0 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -359,6 +359,9 @@ def value_range(self) -> tuple[int, int]: i64 = IntegerType(64) i32 = IntegerType(32) i1 = IntegerType(1) +I64 = Annotated[IntegerType, i64] +I32 = Annotated[IntegerType, i32] +I1 = Annotated[IntegerType, i1] SignlessIntegerConstraint = ParamAttrConstraint( diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 3b4d18b347..4dd12285ec 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -356,7 +356,7 @@ class AtomicRMWOp(IRDLOperation): memref = operand_def(MemRefType[T]) indices = var_operand_def(IndexType) - kind = prop_def(IntegerAttr[i64]) + kind = prop_def(IntegerAttr[Annotated[IntegerType, i64]]) result = result_def(T)