Skip to content

Commit

Permalink
dialects: (arith) split generic binary op definition into specific on…
Browse files Browse the repository at this point in the history
…es (#3274)

This makes the next step for Pyright updating easier, as we can't
actually properly represent generics in the IRDL constraint system.
Currently, the arith base class helpers work because the generics are
all specified by the time the ops are "defined" with the annotation. But
my proposed solution of replacing the Annotated with VarConstraints
means that we need to define the constraint on a specific type already.
My understanding is that this is not a functional change, only requiring
some clients of the base classes to update the names that they refer to.

Part of #3264

Note stacked PR.

---------

Co-authored-by: Alex Rice <alexrice999@hotmail.co.uk>
  • Loading branch information
superlopuh and alexarice authored Oct 10, 2024
1 parent ff389e4 commit 964834a
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 66 deletions.
10 changes: 6 additions & 4 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Addi,
AddUIExtended,
AndI,
BinaryOperation,
CeilDivSI,
CeilDivUI,
Cmpf,
Expand All @@ -20,7 +19,7 @@
ExtSIOp,
ExtUIOp,
FastMathFlagsAttr,
FloatingPointLikeBinaryOp,
FloatingPointLikeBinaryOperation,
FloorDivSI,
FPToSIOp,
IndexCastOp,
Expand All @@ -41,6 +40,7 @@
ShLI,
ShRSI,
ShRUI,
SignlessIntegerBinaryOperation,
SIToFPOp,
Subf,
Subi,
Expand Down Expand Up @@ -103,7 +103,7 @@ class Test_integer_arith_construction:
@pytest.mark.parametrize("return_type", [None, operand_type])
def test_arith_ops_init(
self,
OpClass: type[BinaryOperation[_BinOpArgT]],
OpClass: type[SignlessIntegerBinaryOperation],
return_type: Attribute,
):
op = OpClass(self.a, self.b)
Expand Down Expand Up @@ -210,7 +210,9 @@ class Test_float_arith_construction:
"flags", [FastMathFlagsAttr("none"), FastMathFlagsAttr("fast"), None]
)
def test_arith_ops(
self, func: type[FloatingPointLikeBinaryOp], flags: FastMathFlagsAttr | None
self,
func: type[FloatingPointLikeBinaryOperation],
flags: FastMathFlagsAttr | None,
):
op = func(self.a, self.b, flags)
assert op.operands[0].owner is self.a
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from xdsl.dialects.arith import BinaryOperation, Constant
from xdsl.dialects.arith import Constant, FloatingPointLikeBinaryOperation
from xdsl.dialects.builtin import (
DenseIntOrFPElementsAttr,
FloatAttr,
Expand Down Expand Up @@ -84,7 +84,7 @@ class Test_float_math_binary_construction:
@pytest.mark.parametrize("return_type", [None, operand_type])
def test_float_binary_ops_constant_math_init(
self,
OpClass: type[BinaryOperation[_BinOpArgT]],
OpClass: type[FloatingPointLikeBinaryOperation],
return_type: Attribute,
):
op = OpClass(self.a, self.b)
Expand All @@ -104,7 +104,7 @@ def test_float_binary_ops_constant_math_init(
)
@pytest.mark.parametrize("return_type", [None, f32_vector_type])
def test_flaot_binary_vector_ops_init(
self, OpClass: type[BinaryOperation[_BinOpArgT]], return_type: Attribute
self, OpClass: type[FloatingPointLikeBinaryOperation], return_type: Attribute
):
op = OpClass(self.lhs_vector, self.rhs_vector)
assert isinstance(op, OpClass)
Expand All @@ -123,7 +123,7 @@ def test_flaot_binary_vector_ops_init(
)
@pytest.mark.parametrize("return_type", [None, f32_tensor_type])
def test_float_binary_ops_tensor_math_init(
self, OpClass: type[BinaryOperation[_BinOpArgT]], return_type: Attribute
self, OpClass: type[FloatingPointLikeBinaryOperation], return_type: Attribute
):
op = OpClass(self.lhs_tensor, self.rhs_tensor)
assert isinstance(op, OpClass)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def match_and_rewrite(

@dataclass
class LowerBinaryIntegerOp(RewritePattern):
arith_op_cls: type[arith.SignlessIntegerBinaryOp]
arith_op_cls: type[arith.SignlessIntegerBinaryOperation]
riscv_op_cls: type[RdRsRsIntegerOperation]

def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
Expand All @@ -169,7 +169,7 @@ def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:

@dataclass
class LowerBinaryFloatOp(RewritePattern):
arith_op_cls: type[arith.FloatingPointLikeBinaryOp]
arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
riscv_f_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@dataclass
class LowerBinaryFloatVectorOp(RewritePattern):
arith_op_cls: type[arith.FloatingPointLikeBinaryOp]
arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_snitch_v_f32_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
riscv_snitch_v_f16_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
Expand Down
Loading

0 comments on commit 964834a

Please sign in to comment.