Skip to content

Commit

Permalink
Arith to riscv lowering fastmath cmpf (#3277)
Browse files Browse the repository at this point in the history
`backend`(lowering)-only part of
#3272. Depends on #3275 and
#3276.

This should close #2725 once merged

Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
knickish and superlopuh authored Oct 23, 2024
1 parent 163fde2 commit d67b37d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 18 deletions.
45 changes: 45 additions & 0 deletions tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,51 @@ builtin.module {
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_2, 1 : (!riscv.reg) -> !riscv.reg
%cmpf15 = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 15 : i32} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.li 1 : !riscv.reg

// tests with fastmath flags when set to "fast"
%cmpf1_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 1 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf2_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 2 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf3_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 3 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf4_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 4 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf5_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 5 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
%cmpf6_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 6 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.or %cmpf6_fm_1, %cmpf6_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
%cmpf7_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 7 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.and %cmpf7_fm_1, %cmpf7_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
%cmpf8_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 8 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.or %cmpf8_fm_1, %cmpf8_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf8_fm_2, 1 : (!riscv.reg) -> !riscv.reg
%cmpf9_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 9 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf9_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf10_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 10 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf10_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf11_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 11 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf11_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf12_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 12 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf12_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf13_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 13 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf13_fm, 1 : (!riscv.reg) -> !riscv.reg
%cmpf14_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 14 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.and %cmpf14_fm_1, %cmpf14_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_fm_2, 1 : (!riscv.reg) -> !riscv.reg
%index_cast = "arith.index_cast"(%lhsindex) : (index) -> i32
// CHECK-NEXT: }
}
38 changes: 20 additions & 18 deletions xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,29 +352,31 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
lhs, rhs = cast_operands_to_regs(rewriter)
cast_matched_op_results(rewriter)

fastmath = riscv.FastMathFlagsAttr(op.fastmath.data)

match op.predicate.value.data:
# false
case 0:
rewriter.replace_matched_op([riscv.LiOp(0)])
# oeq
case 1:
rewriter.replace_matched_op([riscv.FeqSOP(lhs, rhs)])
rewriter.replace_matched_op([riscv.FeqSOP(lhs, rhs, fastmath=fastmath)])
# ogt
case 2:
rewriter.replace_matched_op([riscv.FltSOP(rhs, lhs)])
rewriter.replace_matched_op([riscv.FltSOP(rhs, lhs, fastmath=fastmath)])
# oge
case 3:
rewriter.replace_matched_op([riscv.FleSOP(rhs, lhs)])
rewriter.replace_matched_op([riscv.FleSOP(rhs, lhs, fastmath=fastmath)])
# olt
case 4:
rewriter.replace_matched_op([riscv.FltSOP(lhs, rhs)])
rewriter.replace_matched_op([riscv.FltSOP(lhs, rhs, fastmath=fastmath)])
# ole
case 5:
rewriter.replace_matched_op([riscv.FleSOP(lhs, rhs)])
rewriter.replace_matched_op([riscv.FleSOP(lhs, rhs, fastmath=fastmath)])
# one
case 6:
flt1 = riscv.FltSOP(lhs, rhs)
flt2 = riscv.FltSOP(rhs, lhs)
flt1 = riscv.FltSOP(lhs, rhs, fastmath=fastmath)
flt2 = riscv.FltSOP(rhs, lhs, fastmath=fastmath)
rewriter.replace_matched_op(
[
flt1,
Expand All @@ -384,8 +386,8 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
)
# ord
case 7:
feq1 = riscv.FeqSOP(lhs, lhs)
feq2 = riscv.FeqSOP(rhs, rhs)
feq1 = riscv.FeqSOP(lhs, lhs, fastmath=fastmath)
feq2 = riscv.FeqSOP(rhs, rhs, fastmath=fastmath)
rewriter.replace_matched_op(
[
feq1,
Expand All @@ -395,34 +397,34 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
)
# ueq
case 8:
flt1 = riscv.FltSOP(lhs, rhs)
flt2 = riscv.FltSOP(rhs, lhs)
flt1 = riscv.FltSOP(lhs, rhs, fastmath=fastmath)
flt2 = riscv.FltSOP(rhs, lhs, fastmath=fastmath)
or_ = riscv.OrOp(flt2, flt1, rd=riscv.IntRegisterType.unallocated())
rewriter.replace_matched_op([flt1, flt2, or_, riscv.XoriOp(or_, 1)])
# ugt
case 9:
fle = riscv.FleSOP(lhs, rhs)
fle = riscv.FleSOP(lhs, rhs, fastmath=fastmath)
rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)])
# uge
case 10:
fle = riscv.FltSOP(lhs, rhs)
fle = riscv.FltSOP(lhs, rhs, fastmath=fastmath)
rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)])
# ult
case 11:
fle = riscv.FleSOP(rhs, lhs)
fle = riscv.FleSOP(rhs, lhs, fastmath=fastmath)
rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)])
# ule
case 12:
flt = riscv.FltSOP(rhs, lhs)
flt = riscv.FltSOP(rhs, lhs, fastmath=fastmath)
rewriter.replace_matched_op([flt, riscv.XoriOp(flt, 1)])
# une
case 13:
feq = riscv.FeqSOP(lhs, rhs)
feq = riscv.FeqSOP(lhs, rhs, fastmath=fastmath)
rewriter.replace_matched_op([feq, riscv.XoriOp(feq, 1)])
# uno
case 14:
feq1 = riscv.FeqSOP(lhs, lhs)
feq2 = riscv.FeqSOP(rhs, rhs)
feq1 = riscv.FeqSOP(lhs, lhs, fastmath=fastmath)
feq2 = riscv.FeqSOP(rhs, rhs, fastmath=fastmath)
and_ = riscv.AndOp(feq2, feq1, rd=riscv.IntRegisterType.unallocated())
rewriter.replace_matched_op([feq1, feq2, and_, riscv.XoriOp(and_, 1)])
# true
Expand Down

0 comments on commit d67b37d

Please sign in to comment.