Skip to content

Commit

Permalink
dialects: Define memref.atomic rmw (#2702)
Browse files Browse the repository at this point in the history
Also add minor helpers to builtin: `I1`,`I32`,`I64`, to use in type
expressions.
  • Loading branch information
PapyChacal authored Jun 10, 2024
1 parent 37ebf90 commit 9cede3c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/filecheck/dialects/memref/memref_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ builtin.module {
memref.dealloc %8 : memref<1xindex>
memref.dealloc %10 : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
memref.dealloc %11 : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
%fmemref = "test.op"() : () -> memref<32x32xf32>
%e = "test.op"() : () -> f32
%207 = "memref.atomic_rmw"(%e, %fmemref, %1, %1) <{kind = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32

func.return
}
}

// CHECK-NEXT: builtin.module {
// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func @memref_alloca_scope() {
// CHECK-NEXT: "memref.alloca_scope"() ({
// CHECK-NEXT: "memref.alloca_scope.return"() : () -> ()
Expand Down Expand Up @@ -83,6 +86,9 @@ builtin.module {
// CHECK-NEXT: memref.dealloc %{{.*}} : memref<1xindex>
// CHECK-NEXT: memref.dealloc %{{.*}} : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
// CHECK-NEXT: memref.dealloc %{{.*}} : memref<64x64xindex, strided<[2, 4], offset: 6>, 2 : i32>
// CHECK-NEXT: %{{.*}} = "test.op"() : () -> memref<32x32xf32>
// CHECK-NEXT: %{{.*}} = "test.op"() : () -> f32
// CHECK-NEXT: %{{.*}} = "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{"kind" = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
memref<100xi32>, index,
index
) -> ()
%fmemref = memref.alloc() : memref<32x32xf32>
%e = arith.constant 1.0 : f32
%207 = "memref.atomic_rmw"(%e, %fmemref, %1, %1) <{kind = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32
"func.return"() : () -> ()
}) {"sym_name" = "memref_test", "function_type" = () -> (), "sym_visibility" = "private"} : () -> ()
}) : () -> ()
Expand Down Expand Up @@ -61,6 +64,9 @@

// CHECK: "memref.dma_start"(%9, %1, %10, %1, %3, %11, %1) {"operandSegmentSizes" = array<i32: 1, 1, 1, 1, 1, 1, 1>} : (memref<100xi32, 10 : i64>, index, memref<100xi32, 9 : i64>, index, index, memref<100xi32>, index) -> ()
// CHECK-NEXT: "memref.dma_wait"(%11, %1, %3) {"operandSegmentSizes" = array<i32: 1, 1, 1>} : (memref<100xi32>, index, index) -> ()
// CHECK-NEXT: %12 = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<32x32xf32>
// CHECK-NEXT: %13 = "arith.constant"() <{"value" = 1.000000e+00 : f32}> : () -> f32
// CHECK-NEXT: %14 = "memref.atomic_rmw"(%13, %12, %1, %1) <{"kind" = 0 : i64}> : (f32, memref<32x32xf32>, index, index) -> f32
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }) : () -> ()
// CHECK-NEXT: }) : () -> ()
3 changes: 3 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ def value_range(self) -> tuple[int, int]:
i64 = IntegerType(64)
i32 = IntegerType(32)
i1 = IntegerType(1)
I64 = Annotated[IntegerType, i64]
I32 = Annotated[IntegerType, i32]
I1 = Annotated[IntegerType, i1]


SignlessIntegerConstraint = ParamAttrConstraint(
Expand Down
21 changes: 21 additions & 0 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing_extensions import Self

from xdsl.dialects.builtin import (
AnyFloat,
AnyIntegerAttr,
AnySignlessIntegerType,
ArrayAttr,
BoolAttr,
DenseArrayBase,
Expand Down Expand Up @@ -341,6 +343,24 @@ def verify_(self) -> None:
)


@irdl_op_definition
class AtomicRMWOp(IRDLOperation):
name = "memref.atomic_rmw"

T = Annotated[
AnyFloat | AnySignlessIntegerType,
ConstraintVar("T"),
]

value = operand_def(T)
memref = operand_def(MemRefType[T])
indices = var_operand_def(IndexType)

kind = prop_def(IntegerAttr[Annotated[IntegerType, i64]])

result = result_def(T)


@irdl_op_definition
class Dealloc(IRDLOperation):
name = "memref.dealloc"
Expand Down Expand Up @@ -784,6 +804,7 @@ def verify_(self) -> None:
Alloca,
AllocaScopeOp,
AllocaScopeReturnOp,
AtomicRMWOp,
CopyOp,
CollapseShapeOp,
ExpandShapeOp,
Expand Down

0 comments on commit 9cede3c

Please sign in to comment.