Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Translate memref to dsd #3092

Merged
merged 19 commits into from
Aug 28, 2024
3 changes: 3 additions & 0 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@

%4 = "csl.constants"(%inline_const, %inline_const) <{is_const}> : (i32, i32) -> memref<?xi32>

%5 = "csl.zeros"(%const27) <{is_const}> : (i16) -> memref<?xi16>

csl.return
}

Expand Down Expand Up @@ -408,6 +410,7 @@ csl.func @builtins() {
// CHECK-NEXT: const v1 : [const27]i16 = @constants([const27]i16, const27);
// CHECK-NEXT: const v2 : [const27]i32 = @constants([const27]i32, 100);
// CHECK-NEXT: const v3 : [100]i32 = @constants([100]i32, 100);
// CHECK-NEXT: const v4 : [const27]i16 = @zeros([const27]i16);
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-NEXT: {{ *}}
Expand Down
97 changes: 97 additions & 0 deletions tests/filecheck/transforms/memref-to-dsd.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// RUN: xdsl-opt %s -p memref-to-dsd | filecheck %s

builtin.module {
"csl.module"() <{"kind" = #csl<module_kind program>}> ({
// CHECK-NEXT: builtin.module {
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind program>}> ({

%0 = "test.op"() : () -> (index)
%a = memref.alloc() {"alignment" = 64 : i64} : memref<512xf32>
%b = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
%c = memref.alloc() {"alignment" = 64 : i64} : memref<1024xf32>
%d = memref.subview %a[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>>
%e = memref.subview %a[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>>
"csl.fadds"(%b, %d, %e) : (memref<510xf32>, memref<510xf32, strided<[1]>>, memref<510xf32, strided<[1], offset: 2>>) -> ()
%f = memref.subview %c[1] [510] [2] : memref<1024xf32> to memref<510xf32, strided<[2], offset: 1>>
"csl.fadds"(%b, %b, %f) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[2], offset: 1>>) -> ()

%1 = "csl.addressof"(%a) : (memref<512xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
%2 = "csl.addressof"(%b) : (memref<510xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
%3 = "csl.addressof"(%c) : (memref<1024xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
"csl.export"(%1) <{"var_name" = "a", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
"csl.export"(%2) <{"var_name" = "b", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
"csl.export"(%2) <{"var_name" = "c", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()

// CHECK-NEXT: %0 = "test.op"() : () -> index
// CHECK-NEXT: %a = "csl.zeros"() : () -> memref<512xf32>
// CHECK-NEXT: %a_1 = arith.constant 512 : i16
// CHECK-NEXT: %a_2 = "csl.get_mem_dsd"(%a, %a_1) : (memref<512xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %b = "csl.zeros"() : () -> memref<510xf32>
// CHECK-NEXT: %b_1 = arith.constant 510 : i16
// CHECK-NEXT: %b_2 = "csl.get_mem_dsd"(%b, %b_1) : (memref<510xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %c = "csl.zeros"() : () -> memref<1024xf32>
// CHECK-NEXT: %c_1 = arith.constant 1024 : i16
// CHECK-NEXT: %c_2 = "csl.get_mem_dsd"(%c, %c_1) : (memref<1024xf32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %d = arith.constant 510 : ui16
// CHECK-NEXT: %d_1 = "csl.set_dsd_length"(%a_2, %d) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %e = arith.constant 510 : ui16
// CHECK-NEXT: %e_1 = "csl.set_dsd_length"(%a_2, %e) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %e_2 = arith.constant 2 : si16
// CHECK-NEXT: %e_3 = "csl.increment_dsd_offset"(%e_1, %e_2) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.fadds"(%b_2, %d_1, %e_3) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()
// CHECK-NEXT: %f = arith.constant 510 : ui16
// CHECK-NEXT: %f_1 = "csl.set_dsd_length"(%c_2, %f) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %f_2 = arith.constant 2 : si8
// CHECK-NEXT: %f_3 = "csl.set_dsd_stride"(%f_1, %f_2) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %f_4 = arith.constant 1 : si16
// CHECK-NEXT: %f_5 = "csl.increment_dsd_offset"(%f_3, %f_4) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.fadds"(%b_2, %b_2, %f_5) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()
// CHECK-NEXT: %1 = "csl.addressof"(%a) : (memref<512xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %2 = "csl.addressof"(%b) : (memref<510xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: %3 = "csl.addressof"(%c) : (memref<1024xf32>) -> !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>
// CHECK-NEXT: "csl.export"(%1) <{"var_name" = "a", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"(%2) <{"var_name" = "b", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"(%2) <{"var_name" = "c", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()


%23 = memref.alloc() {"alignment" = 64 : i64} : memref<10xi32>
%24 = memref.alloc() {"alignment" = 64 : i64} : memref<10xi32>
"memref.copy"(%23, %24) : (memref<10xi32>, memref<10xi32>) -> ()

// CHECK: %4 = "csl.zeros"() : () -> memref<10xi32>
// CHECK-NEXT: %5 = arith.constant 10 : i16
// CHECK-NEXT: %6 = "csl.get_mem_dsd"(%4, %5) : (memref<10xi32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %7 = "csl.zeros"() : () -> memref<10xi32>
// CHECK-NEXT: %8 = arith.constant 10 : i16
// CHECK-NEXT: %9 = "csl.get_mem_dsd"(%7, %8) : (memref<10xi32>, i16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "csl.mov32"(%9, %6) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()


%25 = arith.constant 0 : index
%26 = arith.constant 510 : index
%27 = arith.constant 1 : index
%28 = arith.constant 2 : index
%29 = memref.subview %b[%25] [%26] [%27] : memref<510xf32> to memref<510xf32, strided<[1]>>
%30 = memref.subview %c[%27] [%26] [%28] : memref<1024xf32> to memref<510xf32, strided<[2], offset: 1>>

// CHECK-NEXT: %10 = arith.constant 0 : index
// CHECK-NEXT: %11 = arith.constant 510 : index
// CHECK-NEXT: %12 = arith.constant 1 : index
// CHECK-NEXT: %13 = arith.constant 2 : index
// CHECK-NEXT: %14 = arith.index_cast %11 : index to ui16
// CHECK-NEXT: %15 = "csl.set_dsd_length"(%b_2, %14) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %16 = arith.index_cast %12 : index to si8
// CHECK-NEXT: %17 = "csl.set_dsd_stride"(%15, %16) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %18 = arith.index_cast %10 : index to si16
// CHECK-NEXT: %19 = "csl.increment_dsd_offset"(%17, %18) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %20 = arith.index_cast %11 : index to ui16
// CHECK-NEXT: %21 = "csl.set_dsd_length"(%c_2, %20) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %22 = arith.index_cast %13 : index to si8
// CHECK-NEXT: %23 = "csl.set_dsd_stride"(%21, %22) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %24 = arith.index_cast %12 : index to si16
// CHECK-NEXT: %25 = "csl.increment_dsd_offset"(%23, %24) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>

}) {sym_name = "program"} : () -> ()
}
// CHECK-NEXT: }) {"sym_name" = "program"} : () -> ()
// CHECK-NEXT: }
5 changes: 5 additions & 0 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ def print_block(self, body: Block):
self._print_or_promote_to_inline_expr(
res, f"@concat_structs({a_var}, {b_var})"
)
case csl.ZerosOp(result=res, is_const=constness):
type = self._memref_type_to_string(res)
res_name = self._get_variable_name_for(res)
kind = "const" if constness else "var"
self.print(f"{kind} {res_name} : {type} = @zeros({type});")
case csl.ConstantsOp(value=val, result=res, is_const=constness):
type = self._memref_type_to_string(res)
res_name = self._get_variable_name_for(res)
Expand Down
30 changes: 30 additions & 0 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,35 @@ def verify_(self) -> None:
)


@irdl_op_definition
class ZerosOp(IRDLOperation):
"""
Represents the @zeros operation in CSL.
"""

name = "csl.zeros"

T = Annotated[IntegerType | Float32Type | Float16Type, ConstraintVar("T")]

size = opt_operand_def(T)

result = result_def(MemRefType[T])

is_const = opt_prop_def(builtin.UnitAttr)

def __init__(
self,
memref: MemRefType[T],
dynamic_size: SSAValue | Operation | None = None,
is_const: builtin.UnitAttr | None = None,
):
super().__init__(
operands=[dynamic_size] if dynamic_size else [[]],
result_types=[memref],
properties={"is_const": is_const} if is_const else {},
)


@irdl_op_definition
class ConstantsOp(IRDLOperation):
"""
Expand Down Expand Up @@ -1749,6 +1778,7 @@ def __init__(self, struct_a: Operation, struct_b: Operation):
Xor16Op,
Xp162fhOp,
Xp162fsOp,
ZerosOp,
],
[
ColorType,
Expand Down
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def get_memref_stream_legalize():

return memref_stream_legalize.MemrefStreamLegalizePass

def get_memref_to_dsd():
from xdsl.transforms import memref_to_dsd

return memref_to_dsd.MemrefToDsdPass

def get_mlir_opt():
from xdsl.transforms import mlir_opt

Expand Down Expand Up @@ -429,6 +434,7 @@ def get_stencil_bufferize():
"memref-stream-interleave": get_memref_stream_interleave,
"memref-stream-tile-outer-loops": get_memref_stream_tile_outer_loops,
"memref-stream-legalize": get_memref_stream_legalize,
"memref-to-dsd": get_memref_to_dsd,
"mlir-opt": get_mlir_opt,
"printf-to-llvm": get_printf_to_llvm,
"printf-to-putchar": get_printf_to_putchar,
Expand Down
Loading
Loading