Skip to content

Commit

Permalink
riscv64: Consolidate conditional moves into one instruction (#7145)
Browse files Browse the repository at this point in the history
* riscv64: Add codegen tests for min/max

* riscv64: Remove unused `Type` payload from `Select`

* riscv64: Add codegen test for brif

* riscv64: Consolidate conditional moves into one instruction

This commit removes the `IntSelect` and `SelectReg` pseudo-instructions
from the riscv64 backend and consolidates them into the `Select`
instruction. Additionally the `Select` instruction is updated to subsume
the functionality of these two previous instructions. Namely `Select`
now operates with `ValueRegs` to handle i128 and additionally takes an
`IntegerCompare` as the condition for the conditional branch to use.

This commit touches a fair bit of the backend since conditional
selection of registers was used in quite a few places. The previous
`gen_select_*` functions are replaced with new typed equivalents of
`gen_select_{xreg,vreg,freg,regs}`. Furthermore new `cmp_*` helpers were
added to create `IntegerCompare` instructions which sort-of match
conditional branch instructions, or at least the pnemonics they use.

Finally since this affected the `select` CLIF instruction itself I went
ahead and did some refactoring there too. The `select` instruction
creates an `IntegerCompare` from its argument to use to generate the
appropriate register selection instruction. This is basically the same
thing that `brif` does and now both go through a new helper,
`lower_int_compare`, which takes a `Value` and produces an
`IntegerCompare` representing if that value is either true or false.
This enables folding an `icmp` or an `fcmp`, for example, directly into
a branching instruction.

* Fix a test

* Favor sign-extension in equality comparisons
  • Loading branch information
alexcrichton authored Oct 4, 2023
1 parent 993e26e commit d4e4f61
Show file tree
Hide file tree
Showing 75 changed files with 2,988 additions and 2,322 deletions.
244 changes: 142 additions & 102 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@

;; select x or y base on condition
(Select
(dst VecWritableReg)
(ty Type)
(condition Reg)
(dst WritableValueRegs)
(condition IntegerCompare)
(x ValueRegs)
(y ValueRegs))

Expand All @@ -242,28 +241,13 @@
(addr Reg)
(v Reg)
(ty Type))
;; select x or y base on op_code
(IntSelect
(op IntSelectOP)
(dst VecWritableReg)
(x ValueRegs)
(y ValueRegs)
(ty Type))
;; an integer compare.
(Icmp
(cc IntCC)
(rd WritableReg)
(a ValueRegs)
(b ValueRegs)
(ty Type))
;; select a reg base on condition.
;; very useful because in lowering stage we can not have condition branch.
(SelectReg
(rd WritableReg)
(rs1 Reg)
(rs2 Reg)
(condition IntegerCompare))
;;
(FcvtToInt
(is_sat bool)
(rd WritableReg)
Expand Down Expand Up @@ -419,13 +403,6 @@
(Trunc)
))

(type IntSelectOP (enum
(Smax)
(Umax)
(Smin)
(Umin)
))

(type AtomicOP (enum
(LrW)
(ScW)
Expand Down Expand Up @@ -2233,7 +2210,7 @@
(part1 Reg (rv_sll rs shamt))
;;
(part2 Reg (rv_srl rs len_sub_shamt))
(part3 Reg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) part2)))
(part3 Reg (gen_select_xreg (cmp_eqz shamt) (zero_reg) part2)))
(rv_or part1 part3)))


Expand Down Expand Up @@ -2282,7 +2259,7 @@
;;
(part2 XReg (rv_sll rs len_sub_shamt))
;;
(part3 XReg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) part2)))
(part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) part2)))
(rv_or part1 part3)))


Expand Down Expand Up @@ -2326,21 +2303,22 @@
(low_part1 XReg (rv_sll (value_regs_get x 0) shamt))
(low_part2 XReg (rv_srl (value_regs_get x 1) len_sub_shamt))
;;; if shamt == 0 low_part2 will overflow we should zero instead.
(low_part3 XReg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) low_part2))
(low_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) low_part2))
(low XReg (rv_or low_part1 low_part3))
;;
(high_part1 XReg (rv_sll (value_regs_get x 1) shamt))
(high_part2 XReg (rv_srl (value_regs_get x 0) len_sub_shamt))
(high_part3 XReg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) high_part2))
(high_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) high_part2))
(high XReg (rv_or high_part1 high_part3))
;;
(const64 XReg (imm $I64 64))
(shamt_128 XReg (rv_andi (value_regs_get y 0) (imm12_const 127))))
;; right now we only rotate less than 64 bits.
;; if shamt is greater than or equal 64 , we should switch low and high.
(value_regs
(gen_select_reg (IntCC.UnsignedGreaterThanOrEqual) shamt_128 const64 high low)
(gen_select_reg (IntCC.UnsignedGreaterThanOrEqual) shamt_128 const64 low high)
(gen_select_regs
(cmp_geu shamt_128 const64)
(value_regs high low)
(value_regs low high)
)))


Expand All @@ -2355,22 +2333,23 @@
(low_part1 XReg (rv_srl (value_regs_get x 0) shamt))
(low_part2 XReg (rv_sll (value_regs_get x 1) len_sub_shamt))
;;; if shamt == 0 low_part2 will overflow we should zero instead.
(low_part3 XReg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) low_part2))
(low_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) low_part2))
(low XReg (rv_or low_part1 low_part3))
;;
(high_part1 XReg (rv_srl (value_regs_get x 1) shamt))
(high_part2 XReg (rv_sll (value_regs_get x 0) len_sub_shamt))
(high_part3 XReg (gen_select_reg (IntCC.Equal) shamt (zero_reg) (zero_reg) high_part2))
(high_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) high_part2))
(high XReg (rv_or high_part1 high_part3))

;;
(const64 XReg (imm $I64 64))
(shamt_128 XReg (rv_andi (value_regs_get y 0) (imm12_const 127))))
;; right now we only rotate less than 64 bits.
;; if shamt is greater than or equal 64 , we should switch low and high.
(value_regs
(gen_select_reg (IntCC.UnsignedGreaterThanOrEqual) shamt_128 const64 high low)
(gen_select_reg (IntCC.UnsignedGreaterThanOrEqual) shamt_128 const64 low high)
(gen_select_regs
(cmp_geu shamt_128 const64)
(value_regs high low)
(value_regs low high)
)))

;; Generates a AMode that points to a register plus an offset.
Expand Down Expand Up @@ -2566,41 +2545,31 @@
(decl gen_stack_addr (StackSlot Offset32) Reg)
(extern constructor gen_stack_addr gen_stack_addr)

;;
(decl gen_select (Type Reg ValueRegs ValueRegs) ValueRegs)
(rule
(gen_select ty c x y)
(decl gen_select_xreg (IntegerCompare XReg XReg) XReg)
(rule (gen_select_xreg c x y)
(let
((dst VecWritableReg (alloc_vec_writable ty))
;;
(reuslt VecWritableReg (vec_writable_clone dst))
(_ Unit (emit (MInst.Select dst ty c x y))))
(vec_writable_to_regs reuslt)))

;; Parameters are "intcc compare_a compare_b rs1 rs2".
(decl gen_select_reg (IntCC XReg XReg Reg Reg) Reg)
(extern constructor gen_select_reg gen_select_reg)

;;; clone WritableReg
;;; if not rust compiler will complain about use moved value.
(decl vec_writable_clone (VecWritableReg) VecWritableReg)
(extern constructor vec_writable_clone vec_writable_clone)

(decl vec_writable_to_regs (VecWritableReg) ValueRegs)
(extern constructor vec_writable_to_regs vec_writable_to_regs)

(decl alloc_vec_writable (Type) VecWritableReg)
(extern constructor alloc_vec_writable alloc_vec_writable)

(decl gen_int_select (Type IntSelectOP ValueRegs ValueRegs) ValueRegs)
(rule
(gen_int_select ty op x y)
((dst WritableReg (temp_writable_xreg))
(_ Unit (emit (MInst.Select dst c x y))))
(writable_reg_to_reg dst)))
(decl gen_select_vreg (IntegerCompare VReg VReg) VReg)
(rule (gen_select_vreg c x y)
(let
((dst WritableReg (temp_writable_vreg))
(_ Unit (emit (MInst.Select dst c (vreg_to_reg x) (vreg_to_reg y)))))
(writable_reg_to_reg dst)))
(decl gen_select_freg (IntegerCompare FReg FReg) FReg)
(rule (gen_select_freg c x y)
(let
( ;;;
(dst VecWritableReg (alloc_vec_writable ty))
;;;
(_ Unit (emit (MInst.IntSelect op (vec_writable_clone dst) x y ty))))
(vec_writable_to_regs dst)))
((dst WritableReg (temp_writable_freg))
(_ Unit (emit (MInst.Select dst c (freg_to_reg x) (freg_to_reg y)))))
(writable_reg_to_reg dst)))
(decl gen_select_regs (IntegerCompare ValueRegs ValueRegs) ValueRegs)
(rule (gen_select_regs c x y)
(let
((dst1 WritableReg (temp_writable_xreg))
(dst2 WritableReg (temp_writable_xreg))
(_ Unit (emit (MInst.Select (writable_value_regs dst1 dst2) c x y))))
(value_regs dst1 dst2)))

(decl udf (TrapCode) InstOutput)
(rule
Expand Down Expand Up @@ -2665,20 +2634,6 @@
(decl int_zero_reg (Type) ValueRegs)
(extern constructor int_zero_reg int_zero_reg)

;; Convert a truthy value, possibly of more than one register (an I128), to
;; one register.
;;
;; Zero-extends as necessary to ensure that the returned register only contains
;; nonzero if the input value was logically nonzero.
(decl truthy_to_reg (Value) XReg)
(rule 1 (truthy_to_reg val @ (value_type (fits_in_64 _)))
(zext val))
(rule 0 (truthy_to_reg val @ (value_type $I128))
(let ((regs ValueRegs val)
(lo XReg (value_regs_get regs 0))
(hi XReg (value_regs_get regs 1)))
(rv_or lo hi)))

;; Consume a CmpResult, producing a branch on its result.
(decl cond_br (IntegerCompare CondBrTarget CondBrTarget) SideEffectNoResult)
(rule (cond_br cmp then else)
Expand All @@ -2698,25 +2653,102 @@
(extern constructor label_to_br_target label_to_br_target)
(convert MachLabel CondBrTarget label_to_br_target)

(decl partial lower_branch (Inst MachLabelSlice) Unit)
(rule (lower_branch (jump _) (single_target label))
(emit_side_effect (rv_j label)))
(decl cmp_eqz (XReg) IntegerCompare)
(rule (cmp_eqz r) (int_compare (IntCC.Equal) r (zero_reg)))

;; Default behavior for branching based on an input value.
(rule (lower_branch (brif v @ (value_type (fits_in_64 ty)) _ _) (two_targets then else))
(emit_side_effect (cond_br (int_compare (IntCC.NotEqual) (zext v) (zero_reg)) then else)))
(rule 2 (lower_branch (brif v @ (value_type $I128)_ _) (two_targets then else))
(emit_side_effect (cond_br (int_compare (IntCC.NotEqual) (truthy_to_reg v) (zero_reg)) then else)))
(decl cmp_nez (XReg) IntegerCompare)
(rule (cmp_nez r) (int_compare (IntCC.NotEqual) r (zero_reg)))

;; Branching on the result of an fcmp.
(rule 1 (lower_branch (brif (maybe_uextend (fcmp cc a @ (value_type ty) b)) _ _) (two_targets then else))
(emit_side_effect (cond_br (emit_fcmp cc ty a b) then else)))
(decl cmp_eq (XReg XReg) IntegerCompare)
(rule (cmp_eq rs1 rs2) (int_compare (IntCC.Equal) rs1 rs2))

(decl fcmp_to_compare (FCmp) IntegerCompare)
(rule (fcmp_to_compare (FCmp.One r)) (int_compare (IntCC.NotEqual) r (zero_reg)))
(rule (fcmp_to_compare (FCmp.Zero r)) (int_compare (IntCC.Equal) r (zero_reg)))
(convert FCmp IntegerCompare fcmp_to_compare)
(decl cmp_ne (XReg XReg) IntegerCompare)
(rule (cmp_ne rs1 rs2) (int_compare (IntCC.NotEqual) rs1 rs2))

(decl cmp_lt (XReg XReg) IntegerCompare)
(rule (cmp_lt rs1 rs2) (int_compare (IntCC.SignedLessThan) rs1 rs2))

(decl cmp_ltz (XReg) IntegerCompare)
(rule (cmp_ltz rs) (int_compare (IntCC.SignedLessThan) rs (zero_reg)))

(decl cmp_gt (XReg XReg) IntegerCompare)
(rule (cmp_gt rs1 rs2) (int_compare (IntCC.SignedGreaterThan) rs1 rs2))

(decl cmp_ge (XReg XReg) IntegerCompare)
(rule (cmp_ge rs1 rs2) (int_compare (IntCC.SignedGreaterThanOrEqual) rs1 rs2))

(decl cmp_le (XReg XReg) IntegerCompare)
(rule (cmp_le rs1 rs2) (int_compare (IntCC.SignedLessThanOrEqual) rs1 rs2))

(decl cmp_gtu (XReg XReg) IntegerCompare)
(rule (cmp_gtu rs1 rs2) (int_compare (IntCC.UnsignedGreaterThan) rs1 rs2))

(decl cmp_geu (XReg XReg) IntegerCompare)
(rule (cmp_geu rs1 rs2) (int_compare (IntCC.UnsignedGreaterThanOrEqual) rs1 rs2))

(decl cmp_ltu (XReg XReg) IntegerCompare)
(rule (cmp_ltu rs1 rs2) (int_compare (IntCC.UnsignedLessThan) rs1 rs2))

(decl cmp_leu (XReg XReg) IntegerCompare)
(rule (cmp_leu rs1 rs2) (int_compare (IntCC.UnsignedLessThanOrEqual) rs1 rs2))

;; Helper to generate an `IntegerCompare` which represents the "truthy" value of
;; the input provided.
;;
;; This is used in `Select` and `brif` for example to generate conditional
;; branches. The returned comparison, when taken, represents that `Value` is
;; nonzero. When not taken the input `Value` is zero.
(decl lower_int_compare (Value) IntegerCompare)

;; Base case - convert to a "truthy" value and compare it against zero.
;;
;; Note that non-64-bit types need to be extended since the upper bits from
;; Cranelift's point of view are undefined. Favor a zero extension for 8-bit
;; types because that's a single `andi` instruction, but favor sign-extension
;; for 16 and 32-bit types because many RISC-V which operate on the low 32-bits.
;; Additionally the base 64-bit ISA has a single instruction for sign-extending
;; from 32 to 64-bits which makes that a bit cheaper if used.
;; of registers sign-extend the results.
(rule 0 (lower_int_compare val @ (value_type (fits_in_64 _)))
(cmp_nez (sext val)))
(rule 1 (lower_int_compare val @ (value_type $I8))
(cmp_nez (zext val)))
(rule 1 (lower_int_compare val @ (value_type $I128))
(cmp_nez (rv_or (value_regs_get val 0) (value_regs_get val 1))))

;; If the input value is itself an `icmp` we can avoid generating the result of
;; the `icmp` and instead move the comparison directly into the `IntegerCompare`
;; that's returned. Note that comparisons compare full registers so
;; sign-extension according to the integer comparison performed here is
;; required.
;;
;; Also note that as a small optimization `Equal` and `NotEqual` use
;; sign-extension for 32-bit values since the same result is produced with
;; either zero-or-sign extension and many values are already sign-extended given
;; the RV64 instruction set (e.g. `addw` adds 32-bit values and sign extends),
;; theoretically resulting in more efficient codegen.
(rule 2 (lower_int_compare (maybe_uextend (icmp cc a b @ (value_type (fits_in_64 in_ty)))))
(int_compare cc (zext a) (zext b)))
(rule 3 (lower_int_compare (maybe_uextend (icmp cc a b @ (value_type (fits_in_64 in_ty)))))
(if (signed_cond_code cc))
(int_compare cc (sext a) (sext b)))
(rule 4 (lower_int_compare (maybe_uextend (icmp cc @ (IntCC.Equal) a b @ (value_type $I32))))
(int_compare cc (sext a) (sext b)))
(rule 4 (lower_int_compare (maybe_uextend (icmp cc @ (IntCC.NotEqual) a b @ (value_type $I32))))
(int_compare cc (sext a) (sext b)))

;; If the input is an `fcmp` then the `FCmp` return value is directly
;; convertible to `IntegerCompare` which can shave off an instruction from the
;; fallback lowering above.
(rule 2 (lower_int_compare (maybe_uextend (fcmp cc a @ (value_type ty) b)))
(emit_fcmp cc ty a b))

(decl partial lower_branch (Inst MachLabelSlice) Unit)
(rule (lower_branch (jump _) (single_target label))
(emit_side_effect (rv_j label)))

(rule (lower_branch (brif v _ _) (two_targets then else))
(emit_side_effect (cond_br (lower_int_compare v) then else)))

(decl lower_br_table (Reg MachLabelSlice) Unit)
(extern constructor lower_br_table lower_br_table)
Expand Down Expand Up @@ -2802,7 +2834,7 @@

(rule (max (fits_in_64 (ty_int ty)) x y)
(if-let $false (has_zbb))
(gen_select_reg (IntCC.SignedGreaterThan) x y x y))
(gen_select_xreg (cmp_gt x y) x y))


;; Builds an instruction sequence that traps if the comparision succeeds.
Expand Down Expand Up @@ -2840,8 +2872,11 @@

;; Generates either 0 if `Value` is zero or -1 otherwise.
(decl gen_bmask (Value) XReg)
(rule (gen_bmask val)
(let ((non_zero XReg (rv_snez (truthy_to_reg val))))
(rule 0 (gen_bmask val @ (value_type (fits_in_64 _)))
(let ((non_zero XReg (rv_snez (sext val))))
(rv_neg non_zero)))
(rule 1 (gen_bmask val @ (value_type $I128))
(let ((non_zero XReg (rv_snez (rv_or (value_regs_get val 0) (value_regs_get val 1)))))
(rv_neg non_zero)))

(decl lower_bmask (Value Type) ValueRegs)
Expand Down Expand Up @@ -2898,6 +2933,11 @@
(rule (fcmp_invert (FCmp.One r)) (FCmp.Zero r))
(rule (fcmp_invert (FCmp.Zero r)) (FCmp.One r))

(decl fcmp_to_compare (FCmp) IntegerCompare)
(rule (fcmp_to_compare (FCmp.One r)) (cmp_nez r))
(rule (fcmp_to_compare (FCmp.Zero r)) (cmp_eqz r))
(convert FCmp IntegerCompare fcmp_to_compare)

;; Compare two floating point numbers and return a zero/non-zero result.
(decl emit_fcmp (FloatCC Type FReg FReg) FCmp)

Expand Down
Loading

0 comments on commit d4e4f61

Please sign in to comment.