From d67b37dbbe1b36c07b98a33ea71a7f9676b2ff52 Mon Sep 17 00:00:00 2001 From: knickish Date: Wed, 23 Oct 2024 13:04:37 -0500 Subject: [PATCH] Arith to riscv lowering fastmath cmpf (#3277) `backend`(lowering)-only part of https://github.com/xdslproject/xdsl/pull/3272. Depends on #3275 and #3276. This should close #2725 once merged Co-authored-by: Sasha Lopoukhine --- .../backend/riscv/convert_arith_to_riscv.mlir | 45 +++++++++++++++++++ .../riscv/lowering/convert_arith_to_riscv.py | 38 ++++++++-------- 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir b/tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir index 2b9712c223..527e4404c1 100644 --- a/tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir +++ b/tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir @@ -219,6 +219,51 @@ builtin.module { // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_2, 1 : (!riscv.reg) -> !riscv.reg %cmpf15 = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 15 : i32} : (f32, f32) -> i1 // CHECK-NEXT: %{{.*}} = riscv.li 1 : !riscv.reg + + // tests with fastmath flags when set to "fast" + %cmpf1_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 1 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %cmpf2_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 2 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %cmpf3_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 3 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %cmpf4_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 4 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %cmpf5_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 5 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + %cmpf6_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 6 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.or %cmpf6_fm_1, %cmpf6_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg + %cmpf7_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 7 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.and %cmpf7_fm_1, %cmpf7_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg + %cmpf8_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 8 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.or %cmpf8_fm_1, %cmpf8_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf8_fm_2, 1 : (!riscv.reg) -> !riscv.reg + %cmpf9_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 9 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf9_fm, 1 : (!riscv.reg) -> !riscv.reg + %cmpf10_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 10 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf10_fm, 1 : (!riscv.reg) -> !riscv.reg + %cmpf11_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 11 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf11_fm, 1 : (!riscv.reg) -> !riscv.reg + %cmpf12_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 12 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf12_fm, 1 : (!riscv.reg) -> !riscv.reg + %cmpf13_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 13 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf13_fm, 1 : (!riscv.reg) -> !riscv.reg + %cmpf14_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 14 : i32, "fastmath" = #arith.fastmath} : (f32, f32) -> i1 + // CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath : (!riscv.freg, !riscv.freg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.and %cmpf14_fm_1, %cmpf14_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg + // CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_fm_2, 1 : (!riscv.reg) -> !riscv.reg %index_cast = "arith.index_cast"(%lhsindex) : (index) -> i32 // CHECK-NEXT: } } diff --git a/xdsl/backend/riscv/lowering/convert_arith_to_riscv.py b/xdsl/backend/riscv/lowering/convert_arith_to_riscv.py index 5a6795db81..f474aa94ed 100644 --- a/xdsl/backend/riscv/lowering/convert_arith_to_riscv.py +++ b/xdsl/backend/riscv/lowering/convert_arith_to_riscv.py @@ -352,29 +352,31 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None: lhs, rhs = cast_operands_to_regs(rewriter) cast_matched_op_results(rewriter) + fastmath = riscv.FastMathFlagsAttr(op.fastmath.data) + match op.predicate.value.data: # false case 0: rewriter.replace_matched_op([riscv.LiOp(0)]) # oeq case 1: - rewriter.replace_matched_op([riscv.FeqSOP(lhs, rhs)]) + rewriter.replace_matched_op([riscv.FeqSOP(lhs, rhs, fastmath=fastmath)]) # ogt case 2: - rewriter.replace_matched_op([riscv.FltSOP(rhs, lhs)]) + rewriter.replace_matched_op([riscv.FltSOP(rhs, lhs, fastmath=fastmath)]) # oge case 3: - rewriter.replace_matched_op([riscv.FleSOP(rhs, lhs)]) + rewriter.replace_matched_op([riscv.FleSOP(rhs, lhs, fastmath=fastmath)]) # olt case 4: - rewriter.replace_matched_op([riscv.FltSOP(lhs, rhs)]) + rewriter.replace_matched_op([riscv.FltSOP(lhs, rhs, fastmath=fastmath)]) # ole case 5: - rewriter.replace_matched_op([riscv.FleSOP(lhs, rhs)]) + rewriter.replace_matched_op([riscv.FleSOP(lhs, rhs, fastmath=fastmath)]) # one case 6: - flt1 = riscv.FltSOP(lhs, rhs) - flt2 = riscv.FltSOP(rhs, lhs) + flt1 = riscv.FltSOP(lhs, rhs, fastmath=fastmath) + flt2 = riscv.FltSOP(rhs, lhs, fastmath=fastmath) rewriter.replace_matched_op( [ flt1, @@ -384,8 +386,8 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None: ) # ord case 7: - feq1 = riscv.FeqSOP(lhs, lhs) - feq2 = riscv.FeqSOP(rhs, rhs) + feq1 = riscv.FeqSOP(lhs, lhs, fastmath=fastmath) + feq2 = riscv.FeqSOP(rhs, rhs, fastmath=fastmath) rewriter.replace_matched_op( [ feq1, @@ -395,34 +397,34 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None: ) # ueq case 8: - flt1 = riscv.FltSOP(lhs, rhs) - flt2 = riscv.FltSOP(rhs, lhs) + flt1 = riscv.FltSOP(lhs, rhs, fastmath=fastmath) + flt2 = riscv.FltSOP(rhs, lhs, fastmath=fastmath) or_ = riscv.OrOp(flt2, flt1, rd=riscv.IntRegisterType.unallocated()) rewriter.replace_matched_op([flt1, flt2, or_, riscv.XoriOp(or_, 1)]) # ugt case 9: - fle = riscv.FleSOP(lhs, rhs) + fle = riscv.FleSOP(lhs, rhs, fastmath=fastmath) rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)]) # uge case 10: - fle = riscv.FltSOP(lhs, rhs) + fle = riscv.FltSOP(lhs, rhs, fastmath=fastmath) rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)]) # ult case 11: - fle = riscv.FleSOP(rhs, lhs) + fle = riscv.FleSOP(rhs, lhs, fastmath=fastmath) rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)]) # ule case 12: - flt = riscv.FltSOP(rhs, lhs) + flt = riscv.FltSOP(rhs, lhs, fastmath=fastmath) rewriter.replace_matched_op([flt, riscv.XoriOp(flt, 1)]) # une case 13: - feq = riscv.FeqSOP(lhs, rhs) + feq = riscv.FeqSOP(lhs, rhs, fastmath=fastmath) rewriter.replace_matched_op([feq, riscv.XoriOp(feq, 1)]) # uno case 14: - feq1 = riscv.FeqSOP(lhs, lhs) - feq2 = riscv.FeqSOP(rhs, rhs) + feq1 = riscv.FeqSOP(lhs, lhs, fastmath=fastmath) + feq2 = riscv.FeqSOP(rhs, rhs, fastmath=fastmath) and_ = riscv.AndOp(feq2, feq1, rd=riscv.IntRegisterType.unallocated()) rewriter.replace_matched_op([feq1, feq2, and_, riscv.XoriOp(and_, 1)]) # true