Skip to content

Commit

Permalink
Merge branch 'main' into nicolai/memref-to-dsds
Browse files Browse the repository at this point in the history
  • Loading branch information
n-io committed Aug 26, 2024
2 parents 05be80a + ee743a2 commit 1ee9981
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 38 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ dev = [
"nbval<0.12",
"filecheck==1.0.0",
"lit<19.0.0",
"marimo==0.7.20",
"marimo==0.8.3",
"pre-commit==3.8.0",
"ruff==0.6.0",
"ruff==0.6.1",
"asv<0.7",
"nbconvert>=7.7.2,<8.0.0",
"textual-dev==1.5.1",
"pytest-asyncio==0.23.8",
"pytest-asyncio==0.24.0",
"pyright==1.1.345",
]
gui = ["textual==0.76.0", "pyclip==0.7"]
gui = ["textual==0.77.0", "pyclip==0.7"]
jax = ["jax==0.4.31", "numpy==2.1.0"]
onnx = ["onnx==1.16.2", "numpy==2.1.0"]
riscv = ["riscemu==2.2.7"]
Expand Down
35 changes: 35 additions & 0 deletions tests/dialects/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,38 @@ def test_split_handle():
""" %0, %1 = "transform.split_handle"(%2) <{"pass_through_empty_handle" = true, "fail_on_payload_too_small" = true, "overflow_result" = 1 : i64}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op) """,
None,
)


def test_amount_of_loops():
block = Block(
arg_types=[
transform.AnyValueType(),
transform.OperationType("linalg.matmul"),
]
)

target = block.args[0]
static_sizes = DenseArrayBase.create_dense_int_or_index(IndexType(), [8, 0])

assert_print_op(
transform.TileOp(
target=target,
dynamic_sizes=[],
static_sizes=static_sizes,
),
"""%0, %1 = "transform.structured.tile_using_for"(%2) <{"static_sizes" = array<index: 8, 0>}> : (!transform.any_value) -> (!transform.any_op, !transform.any_op)""",
None,
)


def test_structured_match():
handle = test.TestOp(result_types=[transform.AnyOpType()]).results[0]
assert_print_op(
transform.MatchOp(
target=handle,
ops=[],
op_attrs={},
),
""" %0 = "transform.structured.match"(%1) <{"ops" = [], "op_attrs" = {}}> : (!transform.any_op) -> !transform.any_op """,
None,
)
14 changes: 14 additions & 0 deletions tests/filecheck/dialects/stablehlo/attrs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: XDSL_ROUNDTRIP

"test.op"() {
default = #stablehlo<precision DEFAULT>,
high = #stablehlo<precision HIGH>,
highest = #stablehlo<precision HIGHEST>
} : () -> ()

%token = "test.op"() : () -> (!stablehlo.token)

// CHECK: builtin.module {
// CHECK-NEXT: "test.op"() {"default" = #stablehlo<precision DEFAULT>, "high" = #stablehlo<precision HIGH>, "highest" = #stablehlo<precision HIGHEST>} : () -> ()
// CHECK-NEXT: %token = "test.op"() : () -> !stablehlo.token
// CHECK-NEXT: }
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/stablehlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@

// CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>

// CHECK: "stablehlo.return"(%t0) : (tensor<i32>) -> ()
"stablehlo.return"(%t0) : (tensor<i32>) -> ()
2 changes: 2 additions & 0 deletions tests/filecheck/dialects/transform/transform_types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ builtin.module attributes {"transform.with_named_sequence"} {
%26 = "test.op"() : () -> !transform.any_param
"transform.match.param.cmpi"(%25, %26) <{predicate = 1 : i32}> : (!transform.any_param, !transform.any_param) -> ()
%27:2 = "transform.split_handle"(%24) <{fail_on_payload_too_small = true, pass_through_empty_handle = true}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%28 = "transform.structured.match"(%24) <{"op_attrs" = {"qmatmul_0"}}> : (!transform.any_op) -> !transform.any_op
}


Expand Down Expand Up @@ -84,4 +85,5 @@ builtin.module attributes {"transform.with_named_sequence"} {
//CHECK-NEXT: %22 = "test.op"() : () -> !transform.any_param
//CHECK-NEXT: "transform.match.param.cmpi"(%21, %22) <{"predicate" = 1 : i32}> : (!transform.any_param, !transform.any_param) -> ()
//CHECK-NEXT: %23, %24 = "transform.split_handle"(%20) <{"fail_on_payload_too_small" = true, "pass_through_empty_handle" = true}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
//CHECK-NEXT: %25 = "transform.structured.match"(%20) <{"op_attrs" = {"qmatmul_0"}}> : (!transform.any_op) -> !transform.any_op
//CHECK-NEXT:}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ builtin.module attributes {"transform.with_named_sequence"} {
%26 = "test.op"() : () -> !transform.any_param
"transform.match.param.cmpi"(%25, %26) <{predicate = 1 : i32}> : (!transform.any_param, !transform.any_param) -> ()
%27:2 = "transform.split_handle"(%24) <{fail_on_payload_too_small = true, pass_through_empty_handle = true}> : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%28 = "transform.structured.match"(%24) <{"op_attrs" = {"qmatmul_0"}}> : (!transform.any_op) -> !transform.any_op
}


Expand All @@ -55,15 +56,15 @@ builtin.module attributes {"transform.with_named_sequence"} {
//CHECK-NEXT: %4 = "test.op"() : () -> !transform.param<i64>
//CHECK-NEXT: %5 = "test.op"() : () -> !transform.type
//CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op, %arg1: !transform.op<"linalg.quantized_matmul">, %arg2: !transform.op<"linalg.elemwise_binary">) {
//CHECK-NEXT: %18 = transform.cast %arg1 : !transform.op<"linalg.quantized_matmul"> to !transform.any_op
//CHECK-NEXT: %19 = transform.cast %arg1 : !transform.op<"linalg.quantized_matmul"> to !transform.any_op
//CHECK-NEXT: %tiled_op, %forall_op = transform.structured.tile_using_forall %arg1 tile_sizes [4, 32] : (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op)
//CHECK-NEXT: %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %arg1[8, 8] : (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op)
//CHECK-NEXT: transform.yield
//CHECK-NEXT: }
//CHECK-NEXT: transform.sequence failures(propagate) {
//CHECK-NEXT: ^bb0(%arg0: !transform.any_op):
//CHECK-NEXT: %18 = select "linalg.quantized_matmul" in %arg0 : (!transform.any_op) -> !transform.op<"linalg.quantized_matmul">
//CHECK-NEXT: %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %18[8, 8] : (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op)
//CHECK-NEXT: %19 = select "linalg.quantized_matmul" in %arg0 : (!transform.any_op) -> !transform.op<"linalg.quantized_matmul">
//CHECK-NEXT: %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %19[8, 8] : (!transform.op<"linalg.quantized_matmul">) -> (!transform.any_op, !transform.any_op, !transform.any_op)
//CHECK-NEXT: }
//CHECK-NEXT: %6 = "test.op"() : () -> !transform.any_op
//CHECK-NEXT: %7 = transform.get_producer_of_operand %6[0] : (!transform.any_op) -> !transform.any_op
Expand All @@ -81,4 +82,5 @@ builtin.module attributes {"transform.with_named_sequence"} {
//CHECK-NEXT: %16 = "test.op"() : () -> !transform.any_param
//CHECK-NEXT: transform.match.param.cmpi ne %15, %16 : !transform.any_param
//CHECK-NEXT: %17:2 = transform.split_handle %14 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
//CHECK-NEXT: %18 = transform.structured.match attributes {qmatmul_0} in %14 : (!transform.any_op) -> !transform.any_op
//CHECK-NEXT: }
82 changes: 80 additions & 2 deletions xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,27 @@
from typing import Annotated, TypeAlias, cast

from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, IntegerType, TensorType
from xdsl.ir import Attribute, Dialect, SSAValue
from xdsl.ir import (
Attribute,
Dialect,
EnumAttribute,
ParametrizedAttribute,
SpacedOpaqueSyntaxAttribute,
SSAValue,
StrEnum,
TypeAttribute,
)
from xdsl.irdl import (
ConstraintVar,
IRDLOperation,
attr_def,
irdl_attr_definition,
irdl_op_definition,
operand_def,
result_def,
var_operand_def,
)
from xdsl.traits import IsTerminator
from xdsl.utils.exceptions import VerifyException

# region Abstract Base Classes
Expand All @@ -41,6 +53,48 @@ def __init__(
super().__init__(operands=(lhs, rhs), result_types=(result_type,))


# endregion

# region Attributes


class Precision(StrEnum):
"""
XLA precision for an operand. Has backend specific meaning.
"""

DEFAULT = "DEFAULT"
HIGH = "HIGH"
HIGHEST = "HIGHEST"


@irdl_attr_definition
class PrecisionAttr(EnumAttribute[Precision], SpacedOpaqueSyntaxAttribute):
"""
XLA precision for an operand. Has backend specific meaning.
https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloEnums.td#L46
"""

name = "stablehlo.precision"


@irdl_attr_definition
class TokenType(TypeAttribute, ParametrizedAttribute):
"""
Token types represent tokens, i.e. opaque values produced and consumed by some operations.
Tokens are used for imposing execution order on operations as described in the Execution section.
E.g.,
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
"""

name = "stablehlo.token"


# endregion


Expand Down Expand Up @@ -159,6 +213,26 @@ class SubtractOp(ElementwiseBinaryOperation):
name = "stablehlo.subtract"


@irdl_op_definition
class ReturnOp(IRDLOperation):
"""This op is un-documented.
StableHLO's return is used inside of the bodies of StableHLO ops.
It behaves like func.return but for StableHLO ops.
The func.return op is used inside of func.func op.
https://discord.com/channels/999073994483433573/1259494021269688360/1259992088565645312
"""

name = "stablehlo.return"

input = var_operand_def(AnyTensorType)
traits = frozenset([IsTerminator()])

def __init__(self, input: list[SSAValue]):
super().__init__(operands=(input,))


@irdl_op_definition
class TransposeOp(IRDLOperation):
"""
Expand Down Expand Up @@ -221,8 +295,12 @@ def verify_(self) -> None:
AddOp,
AndOp,
MultiplyOp,
ReturnOp,
SubtractOp,
TransposeOp,
],
[],
[
PrecisionAttr,
TokenType,
],
)
66 changes: 63 additions & 3 deletions xdsl/dialects/transform.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Annotated, TypeAlias

from xdsl.dialects.builtin import (
Expand Down Expand Up @@ -614,8 +614,11 @@ def __init__(
[
AnyOpType()
for _ in range(
len(static_sizes.as_tuple())
if isinstance(static_sizes, DenseArrayBase)
(
len(static_sizes.as_tuple())
- static_sizes.as_tuple().count(0)
)
if static_sizes
else 0
)
],
Expand Down Expand Up @@ -765,6 +768,62 @@ def __init__(self, input: SSAValue):
super().__init__(operands=[input], result_types=[AnyOpType()])


@irdl_op_definition
class MatchOp(IRDLOperation):
"""
https://mlir.llvm.org/docs/Dialects/Transform/#transformstructuredmatch-transformmatchop
"""

name = "transform.structured.match"

ops = opt_prop_def(ArrayAttr[StringAttr])
interface = opt_prop_def(AnyIntegerAttr)
op_attrs = opt_prop_def(DictionaryAttr)
filter_result_types = opt_prop_def(TypeAttribute)
filter_operand_types = opt_prop_def(TypeAttribute)

target = operand_def(TransformOpHandleType)
result = result_def(TransformOpHandleType)

def __init__(
self,
target: SSAValue,
ops: Sequence[str] | ArrayAttr[StringAttr] | None = None,
interface: int | AnyIntegerAttr | str | None = None,
op_attrs: dict[str, Attribute] | DictionaryAttr | None = None,
filter_result_types: TypeAttribute | None = None,
filter_operand_types: TypeAttribute | None = None,
):
if isinstance(ops, Sequence):
ops = ArrayAttr([StringAttr(op) for op in ops])
if isinstance(interface, str):
match interface:
case "LinalgOp":
interface = IntegerAttr(0, IntegerType(32))
case "TilingInterface":
interface = IntegerAttr(1, IntegerType(32))
case "LoopLikeInterface":
interface = IntegerAttr(2, IntegerType(32))
case _:
raise ValueError(f"Unknown interface: {interface}")
if isinstance(interface, int):
interface = IntegerAttr(interface, IntegerType(32))

if isinstance(op_attrs, Mapping):
op_attrs = DictionaryAttr(op_attrs)
super().__init__(
properties={
"ops": ops,
"interface": interface,
"op_attrs": op_attrs,
"filter_result_types": filter_result_types,
"filter_operand_types": filter_operand_types,
},
operands=[target],
result_types=[AnyOpType()],
)


Transform = Dialect(
"transform",
[
Expand All @@ -788,6 +847,7 @@ def __init__(self, input: SSAValue):
SelectOp,
NamedSequenceOp,
CastOp,
MatchOp,
],
[
# Types
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.hints import isa
from xdsl.utils.isa import isattr
from xdsl.utils.isattr import isattr


def tensor_to_memref_type(t: TensorType[Attribute]) -> memref.MemRefType[Attribute]:
Expand Down
25 changes: 0 additions & 25 deletions xdsl/utils/isa.py

This file was deleted.

4 changes: 4 additions & 0 deletions xdsl/utils/isattr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
def isattr(
arg: Any, hint: type[AttributeInvT] | GenericAttrConstraint[AttributeInvT]
) -> TypeGuard[AttributeInvT]:
"""
A helper method to check whether a given attribute has a given type or conforms to a
given constraint.
"""
from xdsl.irdl import ConstraintContext

if isinstance(hint, GenericAttrConstraint):
Expand Down

0 comments on commit 1ee9981

Please sign in to comment.