Skip to content

Commit

Permalink
riscv64: Refactor and improve some rotate-related codegen (#7251)
Browse files Browse the repository at this point in the history
* riscv64: Refactor `rotl` rules

Move from `inst.isle` to `lower.isle` since it's the only caller,
reorganize the rules to be a bit cleaner, add immediate shifting
specializations.

* riscv64: Refactor `rotr` lowerings

Same as the prior `rotl` lowerings, move the rules to `lower.isle` and
additionally add constant rules.

* Fix shift-by-128

* Remove empty comments
  • Loading branch information
alexcrichton authored Oct 16, 2023
1 parent f24abd2 commit fe7f060
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 157 deletions.
160 changes: 10 additions & 150 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,16 @@
(rule (rv_rorw rs1 rs2)
(alu_rrr (AluOPRRR.Rorw) rs1 rs2))

;; Helper for emitting the `rori` ("Rotate Right") instruction.
(decl rv_rori (XReg Imm12) XReg)
(rule (rv_rori rs1 rs2)
(alu_rr_imm12 (AluOPRRI.Rori) rs1 rs2))

;; Helper for emitting the `roriw` ("Rotate Right Word") instruction.
(decl rv_roriw (XReg Imm12) XReg)
(rule (rv_roriw rs1 rs2)
(alu_rr_imm12 (AluOPRRI.Roriw) rs1 rs2))

;; Helper for emitting the `rev8` ("Byte Reverse") instruction.
(decl rv_rev8 (XReg) XReg)
(rule (rv_rev8 rs1)
Expand Down Expand Up @@ -2186,101 +2196,12 @@
((tmp XReg (rv_mul rs1 rs2)))
(rv_srli tmp (imm12_const (ty_bits ty)))))


(decl lower_rotl (Type XReg XReg) XReg)

(rule 1
(lower_rotl $I64 rs amount)
(if-let $true (has_zbb))
(rv_rol rs amount))

(rule
(lower_rotl $I64 rs amount)
(if-let $false (has_zbb))
(lower_rotl_shift $I64 rs amount))

(rule 1
(lower_rotl $I32 rs amount)
(if-let $true (has_zbb))
(rv_rolw rs amount))

(rule
(lower_rotl $I32 rs amount)
(if-let $false (has_zbb))
(lower_rotl_shift $I32 rs amount))

(rule -1
(lower_rotl ty rs amount)
(lower_rotl_shift ty rs amount))

;;; using shift to implement rotl.
(decl lower_rotl_shift (Type XReg XReg) XReg)

;;; for I8 and I16 ...
(rule
(lower_rotl_shift ty rs amount)
(let
((x ValueRegs (gen_shamt ty amount))
(shamt Reg (value_regs_get x 0))
(len_sub_shamt Reg (value_regs_get x 1))
;;
(part1 Reg (rv_sll rs shamt))
;;
(part2 Reg (rv_srl rs len_sub_shamt))
(part3 Reg (gen_select_xreg (cmp_eqz shamt) (zero_reg) part2)))
(rv_or part1 part3)))


;;;; construct shift amount.rotl on i128 will use shift to implement. So can call this function.
;;;; this will return shift amount and (ty_bits - "shift amount")
;;;; if ty_bits is greater than 64 like i128, then shmat will fallback to 64.because We are 64 bit platform.
(decl gen_shamt (Type XReg) ValueRegs)
(extern constructor gen_shamt gen_shamt)

(decl lower_rotr (Type XReg XReg) XReg)

(rule 1
(lower_rotr $I64 rs amount)
(if-let $true (has_zbb))
(rv_ror rs amount))
(rule
(lower_rotr $I64 rs amount)
(if-let $false (has_zbb))
(lower_rotr_shift $I64 rs amount))

(rule 1
(lower_rotr $I32 rs amount)
(if-let $true (has_zbb))
(rv_rorw rs amount))

(rule
(lower_rotr $I32 rs amount)
(if-let $false (has_zbb))
(lower_rotr_shift $I32 rs amount))

(rule -1
(lower_rotr ty rs amount)
(lower_rotr_shift ty rs amount))

(decl lower_rotr_shift (Type XReg XReg) XReg)

;;;
(rule
(lower_rotr_shift ty rs amount)
(let
((x ValueRegs (gen_shamt ty amount))
(shamt XReg (value_regs_get x 0))
(len_sub_shamt XReg (value_regs_get x 1))
;;
(part1 XReg (rv_srl rs shamt))
;;
(part2 XReg (rv_sll rs len_sub_shamt))
;;
(part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) part2)))
(rv_or part1 part3)))



;; bseti: Set a single bit in a register, indexed by a constant.
(decl gen_bseti (Reg u64) Reg)
(rule (gen_bseti val bit)
Expand Down Expand Up @@ -2308,67 +2229,6 @@
(_ Unit (emit (MInst.Popcnt sum step tmp rs $I64))))
(writable_reg_to_reg sum)))


(decl lower_i128_rotl (ValueRegs ValueRegs) ValueRegs)
(rule
(lower_i128_rotl x y)
(let
((tmp ValueRegs (gen_shamt $I128 (value_regs_get y 0)))
(shamt XReg (value_regs_get tmp 0))
(len_sub_shamt XReg (value_regs_get tmp 1))
;;
(low_part1 XReg (rv_sll (value_regs_get x 0) shamt))
(low_part2 XReg (rv_srl (value_regs_get x 1) len_sub_shamt))
;;; if shamt == 0 low_part2 will overflow we should zero instead.
(low_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) low_part2))
(low XReg (rv_or low_part1 low_part3))
;;
(high_part1 XReg (rv_sll (value_regs_get x 1) shamt))
(high_part2 XReg (rv_srl (value_regs_get x 0) len_sub_shamt))
(high_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) high_part2))
(high XReg (rv_or high_part1 high_part3))
;;
(const64 XReg (imm $I64 64))
(shamt_128 XReg (rv_andi (value_regs_get y 0) (imm12_const 127))))
;; right now we only rotate less than 64 bits.
;; if shamt is greater than or equal 64 , we should switch low and high.
(gen_select_regs
(cmp_geu shamt_128 const64)
(value_regs high low)
(value_regs low high)
)))


(decl lower_i128_rotr (ValueRegs ValueRegs) ValueRegs)
(rule
(lower_i128_rotr x y)
(let
((tmp ValueRegs (gen_shamt $I128 (value_regs_get y 0)))
(shamt XReg (value_regs_get tmp 0))
(len_sub_shamt XReg (value_regs_get tmp 1))
;;
(low_part1 XReg (rv_srl (value_regs_get x 0) shamt))
(low_part2 XReg (rv_sll (value_regs_get x 1) len_sub_shamt))
;;; if shamt == 0 low_part2 will overflow we should zero instead.
(low_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) low_part2))
(low XReg (rv_or low_part1 low_part3))
;;
(high_part1 XReg (rv_srl (value_regs_get x 1) shamt))
(high_part2 XReg (rv_sll (value_regs_get x 0) len_sub_shamt))
(high_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) high_part2))
(high XReg (rv_or high_part1 high_part3))

;;
(const64 XReg (imm $I64 64))
(shamt_128 XReg (rv_andi (value_regs_get y 0) (imm12_const 127))))
;; right now we only rotate less than 64 bits.
;; if shamt is greater than or equal 64 , we should switch low and high.
(gen_select_regs
(cmp_geu shamt_128 const64)
(value_regs high low)
(value_regs low high)
)))

;; Generates a AMode that points to a register plus an offset.
(decl gen_reg_offset_amode (Reg i64 Type) AMode)
(extern constructor gen_reg_offset_amode gen_reg_offset_amode)
Expand Down
109 changes: 102 additions & 7 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1219,19 +1219,114 @@


;;;; Rules for `rotl` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type (fits_in_64 ty) (rotl x y)))
(lower_rotl ty (zext x) (value_regs_get y 0)))

(rule 0 (lower (has_type (fits_in_64 ty) (rotl rs amount)))
(let
((rs XReg (zext rs))
(amount XReg (value_regs_get amount 0))
(x ValueRegs (gen_shamt ty amount))
(shamt XReg (value_regs_get x 0))
(len_sub_shamt Reg (value_regs_get x 1))
(part1 Reg (rv_sll rs shamt))
(part2 Reg (rv_srl rs len_sub_shamt))
(part3 Reg (gen_select_xreg (cmp_eqz shamt) (zero_reg) part2)))
(rv_or part1 part3)))

(rule 1 (lower (has_type $I32 (rotl rs amount)))
(if-let $true (has_zbb))
(rv_rolw rs (value_regs_get amount 0)))

(rule 2 (lower (has_type $I32 (rotl rs (u64_from_iconst n))))
(if-let $true (has_zbb))
(if-let (imm12_from_u64 imm) (u64_sub 32 (u64_and n 31)))
(rv_roriw rs imm))

(rule 1 (lower (has_type $I64 (rotl rs amount)))
(if-let $true (has_zbb))
(rv_rol rs (value_regs_get amount 0)))

(rule 2 (lower (has_type $I64 (rotl rs (u64_from_iconst n))))
(if-let $true (has_zbb))
(if-let (imm12_from_u64 imm) (u64_sub 64 (u64_and n 63)))
(rv_rori rs imm))

(rule 1 (lower (has_type $I128 (rotl x y)))
(lower_i128_rotl x y))
(let
((tmp ValueRegs (gen_shamt $I128 (value_regs_get y 0)))
(shamt XReg (value_regs_get tmp 0))
(len_sub_shamt XReg (value_regs_get tmp 1))
(low_part1 XReg (rv_sll (value_regs_get x 0) shamt))
(low_part2 XReg (rv_srl (value_regs_get x 1) len_sub_shamt))
;;; if shamt == 0 low_part2 will overflow we should zero instead.
(low_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) low_part2))
(low XReg (rv_or low_part1 low_part3))
(high_part1 XReg (rv_sll (value_regs_get x 1) shamt))
(high_part2 XReg (rv_srl (value_regs_get x 0) len_sub_shamt))
(high_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) high_part2))
(high XReg (rv_or high_part1 high_part3))
(const64 XReg (imm $I64 64))
(shamt_128 XReg (rv_andi (value_regs_get y 0) (imm12_const 127))))
;; right now we only rotate less than 64 bits.
;; if shamt is greater than or equal 64 , we should switch low and high.
(gen_select_regs
(cmp_geu shamt_128 const64)
(value_regs high low)
(value_regs low high)
)))

;;;; Rules for `rotr` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type (fits_in_64 ty) (rotr x y)))
(lower_rotr ty (zext x) (value_regs_get y 0)))

(rule 1 (lower (has_type $I128 (rotr x y)))
(lower_i128_rotr x y))
(rule (lower (has_type (fits_in_64 ty) (rotr rs amount)))
(let
((rs XReg (zext rs))
(amount XReg (value_regs_get amount 0))
(x ValueRegs (gen_shamt ty amount))
(shamt XReg (value_regs_get x 0))
(len_sub_shamt XReg (value_regs_get x 1))
(part1 XReg (rv_srl rs shamt))
(part2 XReg (rv_sll rs len_sub_shamt))
(part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) part2)))
(rv_or part1 part3)))

(rule 1 (lower (has_type $I32 (rotr rs amount)))
(if-let $true (has_zbb))
(rv_rorw rs (value_regs_get amount 0)))

(rule 2 (lower (has_type $I32 (rotr rs (imm12_from_value n))))
(if-let $true (has_zbb))
(rv_roriw rs n))

(rule 1 (lower (has_type $I64 (rotr rs amount)))
(if-let $true (has_zbb))
(rv_ror rs (value_regs_get amount 0)))

(rule 2 (lower (has_type $I64 (rotr rs (imm12_from_value n))))
(if-let $true (has_zbb))
(rv_rori rs n))

(rule 1 (lower (has_type $I128 (rotr x y)))
(let
((tmp ValueRegs (gen_shamt $I128 (value_regs_get y 0)))
(shamt XReg (value_regs_get tmp 0))
(len_sub_shamt XReg (value_regs_get tmp 1))
(low_part1 XReg (rv_srl (value_regs_get x 0) shamt))
(low_part2 XReg (rv_sll (value_regs_get x 1) len_sub_shamt))
;;; if shamt == 0 low_part2 will overflow we should zero instead.
(low_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) low_part2))
(low XReg (rv_or low_part1 low_part3))
(high_part1 XReg (rv_srl (value_regs_get x 1) shamt))
(high_part2 XReg (rv_sll (value_regs_get x 0) len_sub_shamt))
(high_part3 XReg (gen_select_xreg (cmp_eqz shamt) (zero_reg) high_part2))
(high XReg (rv_or high_part1 high_part3))
(const64 XReg (imm $I64 64))
(shamt_128 XReg (rv_andi (value_regs_get y 0) (imm12_const 127))))
;; right now we only rotate less than 64 bits.
;; if shamt is greater than or equal 64 , we should switch low and high.
(gen_select_regs
(cmp_geu shamt_128 const64)
(value_regs high low)
(value_regs low high)
)))

;;;; Rules for `fabs` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule 0 (lower (has_type (ty_scalar_float ty) (fabs x)))
Expand Down
79 changes: 79 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/wasm/zbb.wat
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
;;! target = "riscv64"
;;! compile = true
;;! settings = ["has_zbb", "opt_level=speed"]

(module
(func (export "rolw") (param i32 i32) (result i32)
(i32.rotl (local.get 0) (local.get 1)))
(func (export "rol") (param i64 i64) (result i64)
(i64.rotl (local.get 0) (local.get 1)))
(func (export "rolwi") (param i32 ) (result i32)
(i32.rotl (local.get 0) (i32.const 100)))
(func (export "roli") (param i64) (result i64)
(i64.rotl (local.get 0) (i64.const 40)))

(func (export "rorw") (param i32 i32) (result i32)
(i32.rotr (local.get 0) (local.get 1)))
(func (export "ror") (param i64 i64) (result i64)
(i64.rotr (local.get 0) (local.get 1)))
(func (export "rorwi") (param i32 ) (result i32)
(i32.rotr (local.get 0) (i32.const 100)))
(func (export "rori") (param i64) (result i64)
(i64.rotr (local.get 0) (i64.const 40)))
)

;; function u0:0:
;; block0:
;; j label1
;; block1:
;; rolw a0,a0,a1
;; ret
;;
;; function u0:1:
;; block0:
;; j label1
;; block1:
;; rol a0,a0,a1
;; ret
;;
;; function u0:2:
;; block0:
;; j label1
;; block1:
;; roriw a0,a0,28
;; ret
;;
;; function u0:3:
;; block0:
;; j label1
;; block1:
;; rori a0,a0,24
;; ret
;;
;; function u0:4:
;; block0:
;; j label1
;; block1:
;; rorw a0,a0,a1
;; ret
;;
;; function u0:5:
;; block0:
;; j label1
;; block1:
;; ror a0,a0,a1
;; ret
;;
;; function u0:6:
;; block0:
;; j label1
;; block1:
;; roriw a0,a0,100
;; ret
;;
;; function u0:7:
;; block0:
;; j label1
;; block1:
;; rori a0,a0,40
;; ret

0 comments on commit fe7f060

Please sign in to comment.