Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (arith) split generic binary op definition into specific ones #3274

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Addi,
AddUIExtended,
AndI,
BinaryOperation,
CeilDivSI,
CeilDivUI,
Cmpf,
Expand All @@ -20,7 +19,7 @@
ExtSIOp,
ExtUIOp,
FastMathFlagsAttr,
FloatingPointLikeBinaryOp,
FloatingPointLikeBinaryOperation,
FloorDivSI,
FPToSIOp,
IndexCastOp,
Expand All @@ -41,6 +40,7 @@
ShLI,
ShRSI,
ShRUI,
SignlessIntegerBinaryOperation,
SIToFPOp,
Subf,
Subi,
Expand Down Expand Up @@ -103,7 +103,7 @@ class Test_integer_arith_construction:
@pytest.mark.parametrize("return_type", [None, operand_type])
def test_arith_ops_init(
self,
OpClass: type[BinaryOperation[_BinOpArgT]],
OpClass: type[SignlessIntegerBinaryOperation],
return_type: Attribute,
):
op = OpClass(self.a, self.b)
Expand Down Expand Up @@ -210,7 +210,9 @@ class Test_float_arith_construction:
"flags", [FastMathFlagsAttr("none"), FastMathFlagsAttr("fast"), None]
)
def test_arith_ops(
self, func: type[FloatingPointLikeBinaryOp], flags: FastMathFlagsAttr | None
self,
func: type[FloatingPointLikeBinaryOperation],
flags: FastMathFlagsAttr | None,
):
op = func(self.a, self.b, flags)
assert op.operands[0].owner is self.a
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from xdsl.dialects.arith import BinaryOperation, Constant
from xdsl.dialects.arith import Constant, FloatingPointLikeBinaryOperation
from xdsl.dialects.builtin import (
DenseIntOrFPElementsAttr,
FloatAttr,
Expand Down Expand Up @@ -84,7 +84,7 @@ class Test_float_math_binary_construction:
@pytest.mark.parametrize("return_type", [None, operand_type])
def test_float_binary_ops_constant_math_init(
self,
OpClass: type[BinaryOperation[_BinOpArgT]],
OpClass: type[FloatingPointLikeBinaryOperation],
return_type: Attribute,
):
op = OpClass(self.a, self.b)
Expand All @@ -104,7 +104,7 @@ def test_float_binary_ops_constant_math_init(
)
@pytest.mark.parametrize("return_type", [None, f32_vector_type])
def test_flaot_binary_vector_ops_init(
self, OpClass: type[BinaryOperation[_BinOpArgT]], return_type: Attribute
self, OpClass: type[FloatingPointLikeBinaryOperation], return_type: Attribute
):
op = OpClass(self.lhs_vector, self.rhs_vector)
assert isinstance(op, OpClass)
Expand All @@ -123,7 +123,7 @@ def test_flaot_binary_vector_ops_init(
)
@pytest.mark.parametrize("return_type", [None, f32_tensor_type])
def test_float_binary_ops_tensor_math_init(
self, OpClass: type[BinaryOperation[_BinOpArgT]], return_type: Attribute
self, OpClass: type[FloatingPointLikeBinaryOperation], return_type: Attribute
):
op = OpClass(self.lhs_tensor, self.rhs_tensor)
assert isinstance(op, OpClass)
Expand Down
27 changes: 26 additions & 1 deletion tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
AttrSizedOperandSegments,
AttrSizedRegionSegments,
AttrSizedResultSegments,
BaseAttr,
ConstraintVar,
EqAttrConstraint,
GenericAttrConstraint,
IRDLOperation,
ParamAttrConstraint,
ParameterDef,
ParsePropInAttrDict,
VarOperand,
Expand Down Expand Up @@ -1559,7 +1562,7 @@ class OptSuccessorOp(IRDLOperation):
# Inference #
################################################################################

_T = TypeVar("_T", bound=Attribute)
_T = TypeVar("_T", bound=Attribute, covariant=True)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1661,6 +1664,18 @@ class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
p: ParameterDef[_T]
q: ParameterDef[Attribute]

@classmethod
def constr(
cls,
*,
n: GenericAttrConstraint[Attribute] | None = None,
p: GenericAttrConstraint[_T] | None = None,
q: GenericAttrConstraint[Attribute] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if n is None and p is None and q is None:
return BaseAttr(cls)
return ParamAttrConstraint(cls, (n, p, q))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are those changes related?

@irdl_op_definition
class TwoOperandsNestedVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
Expand Down Expand Up @@ -1695,6 +1710,16 @@ class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
name = "test.param_one"
p: ParameterDef[_T]

@classmethod
def constr(
cls,
*,
p: GenericAttrConstraint[_T] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if p is None:
return BaseAttr(cls)
return ParamAttrConstraint(cls, (p,))

@irdl_op_definition
class OneOperandOneResultNestedOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]
Expand Down
8 changes: 1 addition & 7 deletions tests/tblgen_to_py/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,7 @@ class Test_AttributesOp(IRDLOperation):
name = "test.attributes"

int_attr = prop_def(
ParamAttrConstraint(
IntegerAttr,
(
AnyAttr(),
EqAttrConstraint(IntegerType(16)),
),
)
IntegerAttr[IntegerType].constr(type=EqAttrConstraint(IntegerType(16)))
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
)

in_ = prop_def(BaseAttr(Test_TestAttr), prop_name="in")
Expand Down
4 changes: 2 additions & 2 deletions xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def match_and_rewrite(

@dataclass
class LowerBinaryIntegerOp(RewritePattern):
arith_op_cls: type[arith.SignlessIntegerBinaryOp]
arith_op_cls: type[arith.SignlessIntegerBinaryOperation]
riscv_op_cls: type[RdRsRsIntegerOperation]

def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
Expand All @@ -169,7 +169,7 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:

@dataclass
class LowerBinaryFloatOp(RewritePattern):
arith_op_cls: type[arith.FloatingPointLikeBinaryOp]
arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@dataclass
class LowerBinaryFloatVectorOp(RewritePattern):
arith_op_cls: type[arith.FloatingPointLikeBinaryOp]
arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_snitch_v_f32_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_snitch_v_f16_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
Expand Down
Loading
Loading