From 996778272685e253374798f0c2eb1009f3876adf Mon Sep 17 00:00:00 2001 From: Nick Fitzgerald Date: Tue, 15 Nov 2022 09:18:55 -0800 Subject: [PATCH] Cranelift(Aarch64): Optimize lowering of `icmp`s with immediates (#5252) We can encode more constants into 12-bit immediates if we do the following rewrite for comparisons with odd constants: A >= B + 1 ==> A - 1 >= B ==> A > B --- cranelift/codegen/src/isa/aarch64/inst.isle | 100 +++++++++++----- cranelift/codegen/src/isa/aarch64/lower.isle | 60 ++++++---- .../codegen/src/isa/aarch64/lower/isle.rs | 4 + cranelift/codegen/src/isle_prelude.rs | 5 + cranelift/codegen/src/prelude.isle | 3 + .../filetests/isa/aarch64/heap_addr.clif | 4 +- .../filetests/isa/aarch64/icmp-const.clif | 111 ++++++++++++++++++ 7 files changed, 236 insertions(+), 51 deletions(-) create mode 100644 cranelift/filetests/filetests/isa/aarch64/icmp-const.clif diff --git a/cranelift/codegen/src/isa/aarch64/inst.isle b/cranelift/codegen/src/isa/aarch64/inst.isle index 3debb844ac1a..5f08f6838760 100644 --- a/cranelift/codegen/src/isa/aarch64/inst.isle +++ b/cranelift/codegen/src/isa/aarch64/inst.isle @@ -1614,6 +1614,9 @@ (decl imm12_from_u64 (Imm12) u64) (extern extractor imm12_from_u64 imm12_from_u64) +(decl pure make_imm12_from_u64 (u64) Imm12) +(extern constructor make_imm12_from_u64 make_imm12_from_u64) + (decl u8_into_uimm5 (u8) UImm5) (extern constructor u8_into_uimm5 u8_into_uimm5) @@ -3401,12 +3404,34 @@ (_ Unit (emit (MInst.ElfTlsGetAddr name dst)))) dst)) +;; A tuple of `ProducesFlags` and `IntCC`. +(type FlagsAndCC (enum (FlagsAndCC (flags ProducesFlags) + (cc IntCC)))) + +;; Helper constructor for `FlagsAndCC`. +(decl flags_and_cc (ProducesFlags IntCC) FlagsAndCC) +(rule (flags_and_cc flags cc) (FlagsAndCC.FlagsAndCC flags cc)) + +;; Materialize a `FlagsAndCC` into a boolean `ValueRegs`. +(decl flags_and_cc_to_bool (FlagsAndCC) ValueRegs) +(rule (flags_and_cc_to_bool (FlagsAndCC.FlagsAndCC flags cc)) + (with_flags flags (materialize_bool_result (cond_code cc)))) + +;; Get the `ProducesFlags` out of a `FlagsAndCC`. +(decl flags_and_cc_flags (FlagsAndCC) ProducesFlags) +(rule (flags_and_cc_flags (FlagsAndCC.FlagsAndCC flags _cc)) flags) + +;; Get the `IntCC` out of a `FlagsAndCC`. +(decl flags_and_cc_cc (FlagsAndCC) IntCC) +(rule (flags_and_cc_cc (FlagsAndCC.FlagsAndCC _flags cc)) cc) + ;; Helpers for lowering `icmp` sequences. ;; `lower_icmp` contains shared functionality for lowering `icmp` ;; sequences, which `lower_icmp_into_{reg,flags}` extend from. -(decl lower_icmp (IntCC Value Value Type) ProducesFlags) +(decl lower_icmp (IntCC Value Value Type) FlagsAndCC) (decl lower_icmp_into_reg (IntCC Value Value Type Type) ValueRegs) -(decl lower_icmp_into_flags (IntCC Value Value Type) ProducesFlags) +(decl lower_icmp_into_flags (IntCC Value Value Type) FlagsAndCC) +(decl lower_icmp_const (IntCC Value u64 Type) FlagsAndCC) ;; For most cases, `lower_icmp_into_flags` is the same as `lower_icmp`, ;; except for some I128 cases (see below). (rule -1 (lower_icmp_into_flags cond x y ty) (lower_icmp cond x y ty)) @@ -3430,38 +3455,59 @@ (rule -2 (lower_icmp_into_reg cond rn rm in_ty out_ty) (if (ty_int_ref_scalar_64 in_ty)) (let ((cc Cond (cond_code cond))) - (with_flags - (lower_icmp cond rn rm in_ty) - (materialize_bool_result cc)))) + (flags_and_cc_to_bool (lower_icmp cond rn rm in_ty)))) (rule 1 (lower_icmp cond rn rm (fits_in_16 ty)) (if (signed_cond_code cond)) (let ((rn Reg (put_in_reg_sext32 rn))) - (cmp_extend (operand_size ty) rn rm (lower_icmp_extend ty $true)))) + (flags_and_cc (cmp_extend (operand_size ty) rn rm (lower_icmp_extend ty $true)) cond))) (rule -1 (lower_icmp cond rn (imm12_from_value rm) (fits_in_16 ty)) (let ((rn Reg (put_in_reg_zext32 rn))) - (cmp_imm (operand_size ty) rn rm))) + (flags_and_cc (cmp_imm (operand_size ty) rn rm) cond))) (rule -2 (lower_icmp cond rn rm (fits_in_16 ty)) (let ((rn Reg (put_in_reg_zext32 rn))) - (cmp_extend (operand_size ty) rn rm (lower_icmp_extend ty $false)))) -(rule -3 (lower_icmp cond rn (imm12_from_value rm) ty) + (flags_and_cc (cmp_extend (operand_size ty) rn rm (lower_icmp_extend ty $false)) cond))) +(rule -3 (lower_icmp cond rn (u64_from_iconst c) ty) (if (ty_int_ref_scalar_64 ty)) - (cmp_imm (operand_size ty) rn rm)) + (lower_icmp_const cond rn c ty)) (rule -4 (lower_icmp cond rn rm ty) (if (ty_int_ref_scalar_64 ty)) - (cmp (operand_size ty) rn rm)) + (flags_and_cc (cmp (operand_size ty) rn rm) cond)) + +;; We get better encodings when testing against an immediate that's even instead +;; of odd, so rewrite comparisons to use even immediates: +;; +;; A >= B + 1 +;; ==> A - 1 >= B +;; ==> A > B +(rule (lower_icmp_const (IntCC.UnsignedGreaterThanOrEqual) a b ty) + (if (ty_int_ref_scalar_64 ty)) + (if-let $true (u64_is_odd b)) + (if-let imm (make_imm12_from_u64 (u64_sub b 1))) + (flags_and_cc (cmp_imm (operand_size ty) a imm) (IntCC.UnsignedGreaterThan))) +(rule (lower_icmp_const (IntCC.SignedGreaterThanOrEqual) a b ty) + (if (ty_int_ref_scalar_64 ty)) + (if-let $true (u64_is_odd b)) + (if-let imm (make_imm12_from_u64 (u64_sub b 1))) + (flags_and_cc (cmp_imm (operand_size ty) a imm) (IntCC.SignedGreaterThan))) + +(rule -1 (lower_icmp_const cond rn (imm12_from_u64 c) ty) + (if (ty_int_ref_scalar_64 ty)) + (flags_and_cc (cmp_imm (operand_size ty) rn c) cond)) +(rule -2 (lower_icmp_const cond rn c ty) + (if (ty_int_ref_scalar_64 ty)) + (flags_and_cc (cmp (operand_size ty) rn (imm ty (ImmExtend.Zero) c)) cond)) + ;; 128-bit integers. (rule (lower_icmp_into_reg cond @ (IntCC.Equal) rn rm $I128 $I8) (let ((cc Cond (cond_code cond))) - (with_flags - (lower_icmp cond rn rm $I128) - (materialize_bool_result cc)))) + (flags_and_cc_to_bool + (lower_icmp cond rn rm $I128)))) (rule (lower_icmp_into_reg cond @ (IntCC.NotEqual) rn rm $I128 $I8) (let ((cc Cond (cond_code cond))) - (with_flags - (lower_icmp cond rn rm $I128) - (materialize_bool_result cc)))) + (flags_and_cc_to_bool + (lower_icmp cond rn rm $I128)))) ;; cmp lhs_lo, rhs_lo ;; ccmp lhs_hi, rhs_hi, #0, eq @@ -3478,9 +3524,9 @@ (nzcv $false $false $false $false) (Cond.Eq) cmp_inst))) (rule (lower_icmp (IntCC.Equal) lhs rhs $I128) - (lower_icmp_i128_eq_ne lhs rhs)) + (flags_and_cc (lower_icmp_i128_eq_ne lhs rhs) (IntCC.Equal))) (rule (lower_icmp (IntCC.NotEqual) lhs rhs $I128) - (lower_icmp_i128_eq_ne lhs rhs)) + (flags_and_cc (lower_icmp_i128_eq_ne lhs rhs) (IntCC.NotEqual))) ;; cmp lhs_lo, rhs_lo ;; cset tmp1, unsigned_cond @@ -3564,39 +3610,39 @@ (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0)) (tmp Reg (imm $I64 (ImmExtend.Sign) 1))) ;; mov tmp, #1 - (cmp (OperandSize.Size64) dst tmp))) + (flags_and_cc (cmp (OperandSize.Size64) dst tmp) cond))) (rule (lower_icmp_into_flags cond @ (IntCC.UnsignedGreaterThanOrEqual) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0)) (tmp Reg (imm $I64 (ImmExtend.Zero) 1))) - (cmp (OperandSize.Size64) dst tmp))) + (flags_and_cc (cmp (OperandSize.Size64) dst tmp) cond))) (rule (lower_icmp_into_flags cond @ (IntCC.SignedLessThanOrEqual) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0)) (tmp Reg (imm $I64 (ImmExtend.Sign) 1))) - (cmp (OperandSize.Size64) tmp dst))) + (flags_and_cc (cmp (OperandSize.Size64) tmp dst) cond))) (rule (lower_icmp_into_flags cond @ (IntCC.UnsignedLessThanOrEqual) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0)) (tmp Reg (imm $I64 (ImmExtend.Zero) 1))) - (cmp (OperandSize.Size64) tmp dst))) + (flags_and_cc (cmp (OperandSize.Size64) tmp dst) cond))) ;; For strict comparisons, we compare with 0. (rule (lower_icmp_into_flags cond @ (IntCC.SignedGreaterThan) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0))) - (cmp (OperandSize.Size64) dst (zero_reg)))) + (flags_and_cc (cmp (OperandSize.Size64) dst (zero_reg)) cond))) (rule (lower_icmp_into_flags cond @ (IntCC.UnsignedGreaterThan) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0))) - (cmp (OperandSize.Size64) dst (zero_reg)))) + (flags_and_cc (cmp (OperandSize.Size64) dst (zero_reg)) cond))) (rule (lower_icmp_into_flags cond @ (IntCC.SignedLessThan) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0))) - (cmp (OperandSize.Size64) (zero_reg) dst))) + (flags_and_cc (cmp (OperandSize.Size64) (zero_reg) dst) cond))) (rule (lower_icmp_into_flags cond @ (IntCC.UnsignedLessThan) lhs rhs $I128) (let ((dst ValueRegs (lower_icmp_into_reg cond lhs rhs $I128 $I8)) (dst Reg (value_regs_get dst 0))) - (cmp (OperandSize.Size64) (zero_reg) dst))) + (flags_and_cc (cmp (OperandSize.Size64) (zero_reg) dst) cond))) ;; Helpers for generating select instruction sequences. (decl lower_select (ProducesFlags Cond Type Value Value) ValueRegs) diff --git a/cranelift/codegen/src/isa/aarch64/lower.isle b/cranelift/codegen/src/isa/aarch64/lower.isle index ad143613761b..24fabee6a2ca 100644 --- a/cranelift/codegen/src/isa/aarch64/lower.isle +++ b/cranelift/codegen/src/isa/aarch64/lower.isle @@ -1695,14 +1695,22 @@ ;;;; Rules for `select` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (rule (lower (has_type ty - (select _flags @ (icmp cc x @ (value_type in_ty) y) rn rm))) - (let ((cond Cond (cond_code cc))) - (lower_select - (lower_icmp_into_flags cc x y in_ty) - cond ty rn rm))) + (select (icmp cc + x @ (value_type in_ty) + y) + rn + rm))) + (let ((comparison FlagsAndCC (lower_icmp_into_flags cc x y in_ty))) + (lower_select (flags_and_cc_flags comparison) + (cond_code (flags_and_cc_cc comparison)) + ty + rn + rm))) (rule (lower (has_type ty - (select _flags @ (fcmp cc x @ (value_type in_ty) y) rn rm))) + (select (fcmp cc x @ (value_type in_ty) y) + rn + rm))) (let ((cond Cond (fp_cond_code cc))) (lower_select (fpu_cmp (scalar_size in_ty) x y) @@ -1729,12 +1737,16 @@ ;;;; Rules for `select_spectre_guard` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (rule (lower (has_type ty - (select_spectre_guard - (icmp cc x @ (value_type in_ty) y) if_true if_false))) - (let ((cond Cond (cond_code cc)) + (select_spectre_guard (icmp cc x @ (value_type in_ty) y) + if_true + if_false))) + (let ((comparison FlagsAndCC (lower_icmp_into_flags cc x y in_ty)) (dst ValueRegs (lower_select - (lower_icmp_into_flags cc x y in_ty) - cond ty if_true if_false)) + (flags_and_cc_flags comparison) + (cond_code (flags_and_cc_cc comparison)) + ty + if_true + if_false)) (_ InstOutput (side_effect (csdb)))) dst)) @@ -2381,23 +2393,27 @@ ;; `brz` following `icmp` (rule (lower_branch (brz (icmp cc x @ (value_type ty) y) _ _) targets) - (let ((cond Cond (cond_code cc)) - (cond Cond (invert_cond cond)) ;; negate for `brz` + (let ((comparison FlagsAndCC (lower_icmp_into_flags cc x y ty)) + ;; Negate the condition for `brz`. + (cond Cond (invert_cond (cond_code (flags_and_cc_cc comparison)))) (taken BranchTarget (branch_target targets 0)) (not_taken BranchTarget (branch_target targets 1))) - (side_effect - (with_flags_side_effect (lower_icmp_into_flags cc x y ty) - (cond_br taken not_taken - (cond_br_cond cond)))))) + (side_effect + (with_flags_side_effect (flags_and_cc_flags comparison) + (cond_br taken + not_taken + (cond_br_cond cond)))))) ;; `brnz` following `icmp` (rule (lower_branch (brnz (icmp cc x @ (value_type ty) y) _ _) targets) - (let ((cond Cond (cond_code cc)) + (let ((comparison FlagsAndCC (lower_icmp_into_flags cc x y ty)) + (cond Cond (cond_code (flags_and_cc_cc comparison))) (taken BranchTarget (branch_target targets 0)) (not_taken BranchTarget (branch_target targets 1))) - (side_effect - (with_flags_side_effect (lower_icmp_into_flags cc x y ty) - (cond_br taken not_taken - (cond_br_cond cond)))))) + (side_effect + (with_flags_side_effect (flags_and_cc_flags comparison) + (cond_br taken + not_taken + (cond_br_cond cond)))))) ;; `brz` following `fcmp` (rule (lower_branch (brz (fcmp cc x @ (value_type (ty_scalar_float ty)) y) _ _) targets) (let ((cond Cond (fp_cond_code cc)) diff --git a/cranelift/codegen/src/isa/aarch64/lower/isle.rs b/cranelift/codegen/src/isa/aarch64/lower/isle.rs index e47e3fc6b177..5d7ea9b10a1a 100644 --- a/cranelift/codegen/src/isa/aarch64/lower/isle.rs +++ b/cranelift/codegen/src/isa/aarch64/lower/isle.rs @@ -129,6 +129,10 @@ impl Context for IsleContext<'_, '_, MInst, Flags, IsaFlags, 6> { Imm12::maybe_from_u64(n) } + fn make_imm12_from_u64(&mut self, n: u64) -> Option { + Imm12::maybe_from_u64(n) + } + fn imm12_from_negated_u64(&mut self, n: u64) -> Option { Imm12::maybe_from_u64((n as i64).wrapping_neg() as u64) } diff --git a/cranelift/codegen/src/isle_prelude.rs b/cranelift/codegen/src/isle_prelude.rs index f5d6ee1b2907..bdee39cbd86c 100644 --- a/cranelift/codegen/src/isle_prelude.rs +++ b/cranelift/codegen/src/isle_prelude.rs @@ -90,6 +90,11 @@ macro_rules! isle_common_prelude_methods { 0 == value } + #[inline] + fn u64_is_odd(&mut self, x: u64) -> Option { + Some(x & 1 == 1) + } + #[inline] fn u64_sextend_u32(&mut self, x: u64) -> Option { Some(x as u32 as i32 as i64 as u64) diff --git a/cranelift/codegen/src/prelude.isle b/cranelift/codegen/src/prelude.isle index c469f9fbde16..ca9a307d28d8 100644 --- a/cranelift/codegen/src/prelude.isle +++ b/cranelift/codegen/src/prelude.isle @@ -137,6 +137,9 @@ (decl u64_nonzero (u64) u64) (extractor (u64_nonzero x) (and (u64_is_zero $false) x)) +(decl pure u64_is_odd (u64) bool) +(extern constructor u64_is_odd u64_is_odd) + ;;;; `cranelift_codegen::ir::Type` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (extern const $I8 Type) diff --git a/cranelift/filetests/filetests/isa/aarch64/heap_addr.clif b/cranelift/filetests/filetests/isa/aarch64/heap_addr.clif index b3b1b8ce335a..4e735fc8598a 100644 --- a/cranelift/filetests/filetests/isa/aarch64/heap_addr.clif +++ b/cranelift/filetests/filetests/isa/aarch64/heap_addr.clif @@ -83,13 +83,13 @@ block0(v0: i64, v1: i32): ; block0: ; mov w9, w1 -; movz x10, #65512 +; movz w10, #65512 ; subs xzr, x9, x10 ; b.ls label1 ; b label2 ; block1: ; add x11, x0, x1, UXTW ; add x11, x11, #16 -; movz x10, #65512 +; movz w10, #65512 ; movz x12, #0 ; subs xzr, x9, x10 ; csel x0, x12, x11, hi diff --git a/cranelift/filetests/filetests/isa/aarch64/icmp-const.clif b/cranelift/filetests/filetests/isa/aarch64/icmp-const.clif new file mode 100644 index 000000000000..d48e8c5019cf --- /dev/null +++ b/cranelift/filetests/filetests/isa/aarch64/icmp-const.clif @@ -0,0 +1,111 @@ +;; Test our lowerings that do things like `A >= B + 1 ==> A > B` to make better +;; use of immediate encodings. + +test compile precise-output +set unwind_info=false +target aarch64 + +function %a(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111001 + v2 = icmp.i32 uge v0, v1 + return v2 +} + +; block0: +; subs wzr, w0, #1118208 +; cset x0, hi +; ret + +function %b(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111000 + v2 = icmp.i32 uge v0, v1 + return v2 +} + +; block0: +; subs wzr, w0, #1118208 +; cset x0, hs +; ret + +function %c(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111111 + v2 = icmp.i32 uge v0, v1 + return v2 +} + +; block0: +; movz w2, #4369 +; movk w2, w2, #17, LSL #16 +; subs wzr, w0, w2 +; cset x0, hs +; ret + +function %d(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111110 + v2 = icmp.i32 uge v0, v1 + return v2 +} + +; block0: +; movz w2, #4368 +; movk w2, w2, #17, LSL #16 +; subs wzr, w0, w2 +; cset x0, hs +; ret + +function %e(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111001 + v2 = icmp.i32 sge v0, v1 + return v2 +} + +; block0: +; subs wzr, w0, #1118208 +; cset x0, gt +; ret + +function %f(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111000 + v2 = icmp.i32 sge v0, v1 + return v2 +} + +; block0: +; subs wzr, w0, #1118208 +; cset x0, ge +; ret + +function %g(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111111 + v2 = icmp.i32 sge v0, v1 + return v2 +} + +; block0: +; movz w2, #4369 +; movk w2, w2, #17, LSL #16 +; subs wzr, w0, w2 +; cset x0, ge +; ret + +function %h(i32) -> i8 { +block0(v0: i32): + v1 = iconst.i32 0x111110 + v2 = icmp.i32 sge v0, v1 + return v2 +} + +; block0: +; movz w2, #4368 +; movk w2, w2, #17, LSL #16 +; subs wzr, w0, w2 +; cset x0, ge +; ret +