Skip to content

Commit

Permalink
dialects: (riscv) add fastmath flag to RdRsRs Float Float Int operati…
Browse files Browse the repository at this point in the history
…ons (#3276)

`riscv`-only part of #3272, with
a new dialect test added. Thanks for the suggestion of separate tests, I
didn't realize those existed and caught an error in the LiOp
implementation while adding the test cases.
  • Loading branch information
knickish authored Oct 15, 2024
1 parent 322c1b1 commit a7d8ebe
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
15 changes: 12 additions & 3 deletions tests/filecheck/dialects/riscv/riscv_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,12 @@
// CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg
%fle_s = riscv.fle.s %f0, %f1 : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg
%feq_s_fm = riscv.feq.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.feq.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%flt_s_fm = riscv.flt.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%fle_s_fm = riscv.fle.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%fclass_s = riscv.fclass.s %f0 : (!riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.fclass.s %{{.*}} : (!riscv.freg) -> !riscv.reg
%fcvt_s_w = riscv.fcvt.s.w %0 : (!riscv.reg) -> !riscv.freg
Expand Down Expand Up @@ -457,9 +463,12 @@
// CHECK-GENERIC-NEXT: %fcvt_w_s = "riscv.fcvt.w.s"(%f0) : (!riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fcvt_wu_s = "riscv.fcvt.wu.s"(%f0) : (!riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fmv_x_w = "riscv.fmv.x.w"(%f0) : (!riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %feq_s = "riscv.feq.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %flt_s = "riscv.flt.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fle_s = "riscv.fle.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %feq_s = "riscv.feq.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<none>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %flt_s = "riscv.flt.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<none>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fle_s = "riscv.fle.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<none>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %feq_s_fm = "riscv.feq.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %flt_s_fm = "riscv.flt.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fle_s_fm = "riscv.fle.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fclass_s = "riscv.fclass.s"(%f0) : (!riscv.freg) -> !riscv.reg
// CHECK-GENERIC-NEXT: %fcvt_s_w = "riscv.fcvt.s.w"(%0) : (!riscv.reg) -> !riscv.freg
// CHECK-GENERIC-NEXT: %fcvt_s_wu = "riscv.fcvt.s.wu"(%0) : (!riscv.reg) -> !riscv.freg
Expand Down
65 changes: 61 additions & 4 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2532,7 +2532,7 @@ def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
def custom_print_attributes(self, printer: Printer) -> Set[str]:
printer.print(" ")
print_immediate_value(printer, self.immediate)
return {"immediate"}
return {"immediate", "fastmath"}

@classmethod
def parse_op_type(
Expand Down Expand Up @@ -3016,6 +3016,63 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
return self.rd, self.rs1, self.rs2


class RdRsRsFloatFloatIntegerOperationWithFastMath(RISCVInstruction, ABC):
"""
A base class for RISC-V operations that have two source floating-point
registers with an integer destination register, and can be annotated with fastmath flags.
This is called R-Type in the RISC-V specification.
"""

rd = result_def(IntRegisterType)
rs1 = operand_def(FloatRegisterType)
rs2 = operand_def(FloatRegisterType)
fastmath = attr_def(FastMathFlagsAttr)

def __init__(
self,
rs1: Operation | SSAValue,
rs2: Operation | SSAValue,
*,
rd: IntRegisterType | str | None = None,
fastmath: FastMathFlagsAttr = FastMathFlagsAttr("none"),
comment: str | StringAttr | None = None,
):
if rd is None:
rd = IntRegisterType.unallocated()
elif isinstance(rd, str):
rd = IntRegisterType(rd)
if isinstance(comment, str):
comment = StringAttr(comment)

super().__init__(
operands=[rs1, rs2],
attributes={
"comment": comment,
"fastmath": fastmath,
},
result_types=[rd],
)

def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
return self.rd, self.rs1, self.rs2

@classmethod
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
attributes = dict[str, Attribute]()
fast = FastMathFlagsAttr("none")
if parser.parse_optional_keyword("fastmath") is not None:
fast = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser))
attributes["fastmath"] = fast
return attributes

def custom_print_attributes(self, printer: Printer) -> Set[str]:
if self.fastmath is not None and self.fastmath != FastMathFlagsAttr("none"):
printer.print(" fastmath")
self.fastmath.print_parameter(printer)
return {"fastmath"}


class RsRsImmFloatOperation(RISCVInstruction, ABC):
"""
A base class for RV32F operations that have two source registers
Expand Down Expand Up @@ -3352,7 +3409,7 @@ class FMvXWOp(RdRsOperation[IntRegisterType, FloatRegisterType]):


@irdl_op_definition
class FeqSOP(RdRsRsFloatFloatIntegerOperation):
class FeqSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
"""
Performs a quiet equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
Only signaling NaN inputs cause an Invalid Operation exception.
Expand All @@ -3367,7 +3424,7 @@ class FeqSOP(RdRsRsFloatFloatIntegerOperation):


@irdl_op_definition
class FltSOP(RdRsRsFloatFloatIntegerOperation):
class FltSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
"""
Performs a quiet less comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
Only signaling NaN inputs cause an Invalid Operation exception.
Expand All @@ -3382,7 +3439,7 @@ class FltSOP(RdRsRsFloatFloatIntegerOperation):


@irdl_op_definition
class FleSOP(RdRsRsFloatFloatIntegerOperation):
class FleSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
"""
Performs a quiet less or equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
Only signaling NaN inputs cause an Invalid Operation exception.
Expand Down

0 comments on commit a7d8ebe

Please sign in to comment.