From a7d8ebef21f675b8e20c425f6ed7b2286122ec82 Mon Sep 17 00:00:00 2001 From: knickish Date: Tue, 15 Oct 2024 04:19:22 -0500 Subject: [PATCH] dialects: (riscv) add fastmath flag to RdRsRs Float Float Int operations (#3276) `riscv`-only part of https://github.com/xdslproject/xdsl/pull/3272, with a new dialect test added. Thanks for the suggestion of separate tests, I didn't realize those existed and caught an error in the LiOp implementation while adding the test cases. --- tests/filecheck/dialects/riscv/riscv_ops.mlir | 15 ++++- xdsl/dialects/riscv.py | 65 +++++++++++++++++-- 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/tests/filecheck/dialects/riscv/riscv_ops.mlir b/tests/filecheck/dialects/riscv/riscv_ops.mlir index 09b4b5a8fe..8823bec412 100644 --- a/tests/filecheck/dialects/riscv/riscv_ops.mlir +++ b/tests/filecheck/dialects/riscv/riscv_ops.mlir @@ -273,6 +273,12 @@ // CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg %fle_s = riscv.fle.s %f0, %f1 : (!riscv.freg, !riscv.freg) -> !riscv.reg // CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg + %feq_s_fm = riscv.feq.s %f0, %f1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.feq.s %{{.*}}, %{{.*}} fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %flt_s_fm = riscv.flt.s %f0, %f1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %fle_s_fm = riscv.fle.s %f0, %f1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg %fclass_s = riscv.fclass.s %f0 : (!riscv.freg) -> !riscv.reg // CHECK-NEXT: %{{.*}} = riscv.fclass.s %{{.*}} : (!riscv.freg) -> !riscv.reg %fcvt_s_w = riscv.fcvt.s.w %0 : (!riscv.reg) -> !riscv.freg @@ -457,9 +463,12 @@ // CHECK-GENERIC-NEXT: %fcvt_w_s = "riscv.fcvt.w.s"(%f0) : (!riscv.freg) -> !riscv.reg // CHECK-GENERIC-NEXT: %fcvt_wu_s = "riscv.fcvt.wu.s"(%f0) : (!riscv.freg) -> !riscv.reg // CHECK-GENERIC-NEXT: %fmv_x_w = "riscv.fmv.x.w"(%f0) : (!riscv.freg) -> !riscv.reg -// CHECK-GENERIC-NEXT: %feq_s = "riscv.feq.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg -// CHECK-GENERIC-NEXT: %flt_s = "riscv.flt.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg -// CHECK-GENERIC-NEXT: %fle_s = "riscv.fle.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg +// CHECK-GENERIC-NEXT: %feq_s = "riscv.feq.s"(%f0, %f1) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.reg +// CHECK-GENERIC-NEXT: %flt_s = "riscv.flt.s"(%f0, %f1) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.reg +// CHECK-GENERIC-NEXT: %fle_s = "riscv.fle.s"(%f0, %f1) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.reg +// CHECK-GENERIC-NEXT: %feq_s_fm = "riscv.feq.s"(%f0, %f1) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.reg +// CHECK-GENERIC-NEXT: %flt_s_fm = "riscv.flt.s"(%f0, %f1) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.reg +// CHECK-GENERIC-NEXT: %fle_s_fm = "riscv.fle.s"(%f0, %f1) {"fastmath" = #riscv.fastmath} : (!riscv.freg, !riscv.freg) -> !riscv.reg // CHECK-GENERIC-NEXT: %fclass_s = "riscv.fclass.s"(%f0) : (!riscv.freg) -> !riscv.reg // CHECK-GENERIC-NEXT: %fcvt_s_w = "riscv.fcvt.s.w"(%0) : (!riscv.reg) -> !riscv.freg // CHECK-GENERIC-NEXT: %fcvt_s_wu = "riscv.fcvt.s.wu"(%0) : (!riscv.reg) -> !riscv.freg diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index 5da0c3cbe3..49bb52b0e6 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -2532,7 +2532,7 @@ def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]: def custom_print_attributes(self, printer: Printer) -> Set[str]: printer.print(" ") print_immediate_value(printer, self.immediate) - return {"immediate"} + return {"immediate", "fastmath"} @classmethod def parse_op_type( @@ -3016,6 +3016,63 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]: return self.rd, self.rs1, self.rs2 +class RdRsRsFloatFloatIntegerOperationWithFastMath(RISCVInstruction, ABC): + """ + A base class for RISC-V operations that have two source floating-point + registers with an integer destination register, and can be annotated with fastmath flags. + + This is called R-Type in the RISC-V specification. + """ + + rd = result_def(IntRegisterType) + rs1 = operand_def(FloatRegisterType) + rs2 = operand_def(FloatRegisterType) + fastmath = attr_def(FastMathFlagsAttr) + + def __init__( + self, + rs1: Operation | SSAValue, + rs2: Operation | SSAValue, + *, + rd: IntRegisterType | str | None = None, + fastmath: FastMathFlagsAttr = FastMathFlagsAttr("none"), + comment: str | StringAttr | None = None, + ): + if rd is None: + rd = IntRegisterType.unallocated() + elif isinstance(rd, str): + rd = IntRegisterType(rd) + if isinstance(comment, str): + comment = StringAttr(comment) + + super().__init__( + operands=[rs1, rs2], + attributes={ + "comment": comment, + "fastmath": fastmath, + }, + result_types=[rd], + ) + + def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]: + return self.rd, self.rs1, self.rs2 + + @classmethod + def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]: + attributes = dict[str, Attribute]() + fast = FastMathFlagsAttr("none") + if parser.parse_optional_keyword("fastmath") is not None: + fast = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser)) + attributes["fastmath"] = fast + return attributes + + def custom_print_attributes(self, printer: Printer) -> Set[str]: + if self.fastmath is not None and self.fastmath != FastMathFlagsAttr("none"): + printer.print(" fastmath") + self.fastmath.print_parameter(printer) + return {"fastmath"} + + class RsRsImmFloatOperation(RISCVInstruction, ABC): """ A base class for RV32F operations that have two source registers @@ -3352,7 +3409,7 @@ class FMvXWOp(RdRsOperation[IntRegisterType, FloatRegisterType]): @irdl_op_definition -class FeqSOP(RdRsRsFloatFloatIntegerOperation): +class FeqSOP(RdRsRsFloatFloatIntegerOperationWithFastMath): """ Performs a quiet equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd. Only signaling NaN inputs cause an Invalid Operation exception. @@ -3367,7 +3424,7 @@ class FeqSOP(RdRsRsFloatFloatIntegerOperation): @irdl_op_definition -class FltSOP(RdRsRsFloatFloatIntegerOperation): +class FltSOP(RdRsRsFloatFloatIntegerOperationWithFastMath): """ Performs a quiet less comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd. Only signaling NaN inputs cause an Invalid Operation exception. @@ -3382,7 +3439,7 @@ class FltSOP(RdRsRsFloatFloatIntegerOperation): @irdl_op_definition -class FleSOP(RdRsRsFloatFloatIntegerOperation): +class FleSOP(RdRsRsFloatFloatIntegerOperationWithFastMath): """ Performs a quiet less or equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd. Only signaling NaN inputs cause an Invalid Operation exception.