From 967543eb43f82a372d881117c97de54d5c82b182 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 5 Apr 2023 12:22:55 -0500 Subject: [PATCH] aarch64: Add more lowerings for the CLIF `fma` (#6150) This commit adds new lowerings to the AArch64 backend of the element-based `fmla` and `fmls` instructions. These instructions have one of the multiplicands as an implicit broadcast of a single lane of another register and can help remove `shuffle` or `dup` instructions that would otherwise be used to implement them. --- cranelift/codegen/src/isa/aarch64/inst.isle | 20 ++- .../codegen/src/isa/aarch64/inst/emit.rs | 39 +++++ cranelift/codegen/src/isa/aarch64/inst/mod.rs | 22 ++- cranelift/codegen/src/isa/aarch64/lower.isle | 67 ++++++-- cranelift/codegen/src/isa/riscv64/inst.isle | 8 +- cranelift/codegen/src/prelude.isle | 1 + .../filetests/filetests/isa/aarch64/fma.clif | 149 ++++++++++++++++++ .../filetests/runtests/simd-fma.clif | 36 +++++ 8 files changed, 324 insertions(+), 18 deletions(-) diff --git a/cranelift/codegen/src/isa/aarch64/inst.isle b/cranelift/codegen/src/isa/aarch64/inst.isle index 389fac99b169..1e09a5ed96d3 100644 --- a/cranelift/codegen/src/isa/aarch64/inst.isle +++ b/cranelift/codegen/src/isa/aarch64/inst.isle @@ -651,6 +651,16 @@ (rm Reg) (size VectorSize)) + ;; A vector ALU op modifying a source register. + (VecFmlaElem + (alu_op VecALUModOp) + (rd WritableReg) + (ri Reg) + (rn Reg) + (rm Reg) + (size VectorSize) + (idx u8)) + ;; Vector two register miscellaneous instruction. (VecMisc (op VecMisc2) @@ -1850,7 +1860,7 @@ (_ Unit (emit (MInst.FpuRR op size dst src)))) dst)) -;; Helper for emitting `MInst.VecRRR` instructions which use three registers, +;; Helper for emitting `MInst.VecRRRMod` instructions which use three registers, ;; one of which is both source and output. (decl vec_rrr_mod (VecALUModOp Reg Reg Reg VectorSize) Reg) (rule (vec_rrr_mod op src1 src2 src3 size) @@ -1858,6 +1868,14 @@ (_1 Unit (emit (MInst.VecRRRMod op dst src1 src2 src3 size)))) dst)) +;; Helper for emitting `MInst.VecFmlaElem` instructions which use three registers, +;; one of which is both source and output. +(decl vec_fmla_elem (VecALUModOp Reg Reg Reg VectorSize u8) Reg) +(rule (vec_fmla_elem op src1 src2 src3 size idx) + (let ((dst WritableReg (temp_writable_reg $I8X16)) + (_1 Unit (emit (MInst.VecFmlaElem op dst src1 src2 src3 size idx)))) + dst)) + (decl fpu_rri (FPUOpRI Reg) Reg) (rule (fpu_rri op src) (let ((dst WritableReg (temp_writable_reg $F64)) diff --git a/cranelift/codegen/src/isa/aarch64/inst/emit.rs b/cranelift/codegen/src/isa/aarch64/inst/emit.rs index 2e576fc895c9..808cae255ffc 100644 --- a/cranelift/codegen/src/isa/aarch64/inst/emit.rs +++ b/cranelift/codegen/src/isa/aarch64/inst/emit.rs @@ -2914,6 +2914,45 @@ impl MachInstEmit for Inst { }; sink.put4(enc_vec_rrr(top11 | q << 9, rm, bit15_10, rn, rd)); } + &Inst::VecFmlaElem { + rd, + ri, + rn, + rm, + alu_op, + size, + idx, + } => { + let rd = allocs.next_writable(rd); + let ri = allocs.next(ri); + debug_assert_eq!(rd.to_reg(), ri); + let rn = allocs.next(rn); + let rm = allocs.next(rm); + let idx = u32::from(idx); + + let (q, _size) = size.enc_size(); + let o2 = match alu_op { + VecALUModOp::Fmla => 0b0, + VecALUModOp::Fmls => 0b1, + _ => unreachable!(), + }; + + let (h, l) = match size { + VectorSize::Size32x4 => { + assert!(idx < 4); + (idx >> 1, idx & 1) + } + VectorSize::Size64x2 => { + assert!(idx < 2); + (idx, 0) + } + _ => unreachable!(), + }; + + let top11 = 0b000_011111_00 | (q << 9) | (size.enc_float_size() << 1) | l; + let bit15_10 = 0b000100 | (o2 << 4) | (h << 1); + sink.put4(enc_vec_rrr(top11, rm, bit15_10, rn, rd)); + } &Inst::VecLoadReplicate { rd, rn, diff --git a/cranelift/codegen/src/isa/aarch64/inst/mod.rs b/cranelift/codegen/src/isa/aarch64/inst/mod.rs index b7debcda6620..bc0403f65aba 100644 --- a/cranelift/codegen/src/isa/aarch64/inst/mod.rs +++ b/cranelift/codegen/src/isa/aarch64/inst/mod.rs @@ -812,7 +812,7 @@ fn aarch64_get_operands VReg>(inst: &Inst, collector: &mut Operan collector.reg_use(rn); collector.reg_use(rm); } - &Inst::VecRRRMod { rd, ri, rn, rm, .. } => { + &Inst::VecRRRMod { rd, ri, rn, rm, .. } | &Inst::VecFmlaElem { rd, ri, rn, rm, .. } => { collector.reg_reuse_def(rd, 1); // `rd` == `ri`. collector.reg_use(ri); collector.reg_use(rn); @@ -2171,6 +2171,26 @@ impl Inst { let rm = pretty_print_vreg_vector(rm, size, allocs); format!("{} {}, {}, {}, {}", op, rd, ri, rn, rm) } + &Inst::VecFmlaElem { + rd, + ri, + rn, + rm, + alu_op, + size, + idx, + } => { + let (op, size) = match alu_op { + VecALUModOp::Fmla => ("fmla", size), + VecALUModOp::Fmls => ("fmls", size), + _ => unreachable!(), + }; + let rd = pretty_print_vreg_vector(rd.to_reg(), size, allocs); + let ri = pretty_print_vreg_vector(ri, size, allocs); + let rn = pretty_print_vreg_vector(rn, size, allocs); + let rm = pretty_print_vreg_element(rm, idx.into(), size.lane_size(), allocs); + format!("{} {}, {}, {}, {}", op, rd, ri, rn, rm) + } &Inst::VecRRRLong { rd, rn, diff --git a/cranelift/codegen/src/isa/aarch64/lower.isle b/cranelift/codegen/src/isa/aarch64/lower.isle index a08dcd55343a..8f36c79c2f0a 100644 --- a/cranelift/codegen/src/isa/aarch64/lower.isle +++ b/cranelift/codegen/src/isa/aarch64/lower.isle @@ -513,17 +513,62 @@ ;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(rule (lower (has_type ty @ (multi_lane _ _) (fma x y z))) - (vec_rrr_mod (VecALUModOp.Fmla) z x y (vector_size ty))) - -(rule 1 (lower (has_type ty @ (multi_lane _ _) (fma (fneg x) y z))) - (vec_rrr_mod (VecALUModOp.Fmls) z x y (vector_size ty))) - -(rule 2 (lower (has_type ty @ (multi_lane _ _) (fma x (fneg y) z))) - (vec_rrr_mod (VecALUModOp.Fmls) z x y (vector_size ty))) - -(rule 3 (lower (has_type (ty_scalar_float ty) (fma x y z))) - (fpu_rrrr (FPUOp3.MAdd) (scalar_size ty) x y z)) +(rule (lower (has_type (ty_scalar_float ty) (fma x y z))) + (fpu_rrrr (FPUOp3.MAdd) (scalar_size ty) x y z)) + +;; Delegate vector-based lowerings to helpers below +(rule 1 (lower (has_type ty @ (multi_lane _ _) (fma x y z))) + (lower_fmla (VecALUModOp.Fmla) x y z (vector_size ty))) + +;; Lowers a fused-multiply-add operation handling various forms of the +;; instruction to get maximal coverage of what's available on AArch64. +(decl lower_fmla (VecALUModOp Value Value Value VectorSize) Reg) + +;; Base case, emit the op requested. +(rule (lower_fmla op x y z size) + (vec_rrr_mod op z x y size)) + +;; Special case: if one of the multiplicands are a splat then the element-based +;; fma can be used instead with 0 as the element index. +(rule 1 (lower_fmla op (splat x) y z size) + (vec_fmla_elem op z y x size 0)) +(rule 2 (lower_fmla op x (splat y) z size) + (vec_fmla_elem op z x y size 0)) + +;; Special case: if one of the multiplicands is a shuffle to broadcast a +;; single element of a vector then the element-based fma can be used like splat +;; above. +;; +;; Note that in Cranelift shuffle always has i8x16 inputs and outputs so +;; a `bitcast` is matched here explicitly since that's the main way a shuffle +;; output will be fed into this instruction. +(rule 3 (lower_fmla op (bitcast _ (shuffle x x (shuffle32_from_imm n n n n))) y z size @ (VectorSize.Size32x4)) + (if-let $true (u64_lt n 4)) + (vec_fmla_elem op z y x size n)) +(rule 4 (lower_fmla op x (bitcast _ (shuffle y y (shuffle32_from_imm n n n n))) z size @ (VectorSize.Size32x4)) + (if-let $true (u64_lt n 4)) + (vec_fmla_elem op z x y size n)) +(rule 3 (lower_fmla op (bitcast _ (shuffle x x (shuffle64_from_imm n n))) y z size @ (VectorSize.Size64x2)) + (if-let $true (u64_lt n 2)) + (vec_fmla_elem op z y x size n)) +(rule 4 (lower_fmla op x (bitcast _ (shuffle y y (shuffle64_from_imm n n))) z size @ (VectorSize.Size64x2)) + (if-let $true (u64_lt n 2)) + (vec_fmla_elem op z x y size n)) + +;; Special case: if one of the multiplicands is `fneg` then peel that away, +;; reverse the operation being performed, and then recurse on `lower_fmla` +;; again to generate the actual instruction. +;; +;; Note that these are the highest priority cases for `lower_fmla` to peel +;; away as many `fneg` operations as possible. +(rule 5 (lower_fmla op (fneg x) y z size) + (lower_fmla (neg_fmla op) x y z size)) +(rule 6 (lower_fmla op x (fneg y) z size) + (lower_fmla (neg_fmla op) x y z size)) + +(decl neg_fmla (VecALUModOp) VecALUModOp) +(rule (neg_fmla (VecALUModOp.Fmla)) (VecALUModOp.Fmls)) +(rule (neg_fmla (VecALUModOp.Fmls)) (VecALUModOp.Fmla)) ;;;; Rules for `fcopysign` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index ef6195fe7292..abb15f9cc99e 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -708,8 +708,6 @@ (decl u8_as_i32 (u8) i32) (extern constructor u8_as_i32 u8_as_i32) -(convert u8 u64 u8_as_u64) - (decl convert_valueregs_reg (ValueRegs) Reg) (rule (convert_valueregs_reg x) (value_regs_get x 0)) @@ -1283,7 +1281,7 @@ (rule (load_imm12 x) (rv_addi (zero_reg) (imm12_const x))) - + ;; for load immediate (decl imm_from_bits (u64) Imm12) (extern constructor imm_from_bits imm_from_bits) @@ -1509,7 +1507,7 @@ (_ Unit (emit (MInst.Cltz leading sum step tmp rs ty)))) sum)) - + ;; Extends an integer if it is smaller than 64 bits. (decl ext_int_if_need (bool ValueRegs Type) ValueRegs) ;;; For values smaller than 64 bits, we need to extend them to 64 bits @@ -2117,7 +2115,7 @@ (reuslt VecWritableReg (vec_writable_clone dst)) (_ Unit (emit (MInst.Select dst ty c x y)))) (vec_writable_to_regs reuslt))) - + ;; Parameters are "intcc compare_a compare_b rs1 rs2". (decl gen_select_reg (IntCC Reg Reg Reg Reg) Reg) (extern constructor gen_select_reg gen_select_reg) diff --git a/cranelift/codegen/src/prelude.isle b/cranelift/codegen/src/prelude.isle index 012d5e29af6d..5a3018318f64 100644 --- a/cranelift/codegen/src/prelude.isle +++ b/cranelift/codegen/src/prelude.isle @@ -82,6 +82,7 @@ (decl pure u8_as_u64 (u8) u64) (extern constructor u8_as_u64 u8_as_u64) +(convert u8 u64 u8_as_u64) (decl pure u16_as_u64 (u16) u64) (extern constructor u16_as_u64 u16_as_u64) diff --git a/cranelift/filetests/filetests/isa/aarch64/fma.clif b/cranelift/filetests/filetests/isa/aarch64/fma.clif index e2f4a172c43e..9a59e645cb81 100644 --- a/cranelift/filetests/filetests/isa/aarch64/fma.clif +++ b/cranelift/filetests/filetests/isa/aarch64/fma.clif @@ -157,3 +157,152 @@ block0(v0: f64x2, v1: f64x2, v2: f64x2): ; fmls v0.2d, v5.2d, v1.2d ; ret +function %f32x4_splat0(f32, f32x4, f32x4) -> f32x4 { +block0(v0: f32, v1: f32x4, v2: f32x4): + v3 = splat.f32x4 v0 + v4 = fma v3, v1, v2 + return v4 +} + +; VCode: +; block0: +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmla v0.4s, v0.4s, v1.4s, v5.s[0] +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmla v0.4s, v1.4s, v5.s[0] +; ret + +function %f32x4_splat1(f32x4, f32, f32x4) -> f32x4 { +block0(v0: f32x4, v1: f32, v2: f32x4): + v3 = splat.f32x4 v1 + v4 = fneg v0 + v5 = fma v4, v3, v2 + return v5 +} + +; VCode: +; block0: +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmls v0.4s, v0.4s, v5.4s, v1.s[0] +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmls v0.4s, v5.4s, v1.s[0] +; ret + +function %f32x4_splat2(f32x4, f32x4, f32x4) -> f32x4 { +block0(v0: f32x4, v1: f32x4, v2: f32x4): + v3 = bitcast.i8x16 little v0 + v4 = shuffle v3, v3, 0x07060504_07060504_07060504_07060504 + v5 = bitcast.f32x4 little v4 + v6 = fma v5, v1, v2 + return v6 +} + +; VCode: +; block0: +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmla v0.4s, v0.4s, v1.4s, v5.s[1] +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmla v0.4s, v1.4s, v5.s[1] +; ret + +function %f32x4_splat3(f32x4, f32x4, f32x4) -> f32x4 { +block0(v0: f32x4, v1: f32x4, v2: f32x4): + v3 = bitcast.i8x16 little v1 + v4 = shuffle v3, v3, 0x0f0e0d0c_0f0e0d0c_0f0e0d0c_0f0e0d0c + v5 = bitcast.f32x4 little v4 + v6 = fneg v5 + v7 = fma v0, v6, v2 + return v7 +} + +; VCode: +; block0: +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmls v0.4s, v0.4s, v5.4s, v1.s[3] +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmls v0.4s, v5.4s, v1.s[3] +; ret + +function %f32x4_splat4(f32x4, f32x4, f32x4) -> f32x4 { +block0(v0: f32x4, v1: f32x4, v2: f32x4): + v3 = bitcast.i8x16 little v1 + v4 = shuffle v3, v3, 0x1f1e1d1c_1f1e1d1c_1f1e1d1c_1f1e1d1c + v5 = bitcast.f32x4 little v4 + v6 = fma v0, v5, v2 + return v6 +} + +; VCode: +; block0: +; mov v31.16b, v1.16b +; movz w6, #7452 +; movk w6, w6, #7966, LSL #16 +; dup v17.4s, w6 +; mov v30.16b, v31.16b +; tbl v19.16b, { v30.16b, v31.16b }, v17.16b +; mov v23.16b, v0.16b +; mov v0.16b, v2.16b +; fmla v0.4s, v0.4s, v23.4s, v19.4s +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v31.16b, v1.16b +; mov w6, #0x1d1c +; movk w6, #0x1f1e, lsl #16 +; dup v17.4s, w6 +; mov v30.16b, v31.16b +; tbl v19.16b, {v30.16b, v31.16b}, v17.16b +; mov v23.16b, v0.16b +; mov v0.16b, v2.16b +; fmla v0.4s, v23.4s, v19.4s +; ret + +function %f64x2_splat0(f64x2, f64x2, f64x2) -> f64x2 { +block0(v0: f64x2, v1: f64x2, v2: f64x2): + v3 = bitcast.i8x16 little v1 + v4 = shuffle v3, v3, 0x0f0e0d0c0b0a0908_0f0e0d0c0b0a0908 + v5 = bitcast.f64x2 little v4 + v6 = fneg v5 + v7 = fma v0, v6, v2 + return v7 +} + +; VCode: +; block0: +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmls v0.2d, v0.2d, v5.2d, v1.d[1] +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; mov v5.16b, v0.16b +; mov v0.16b, v2.16b +; fmls v0.2d, v5.2d, v1.d[1] +; ret + diff --git a/cranelift/filetests/filetests/runtests/simd-fma.clif b/cranelift/filetests/filetests/runtests/simd-fma.clif index 8f81ed6b4976..894243cec3f7 100644 --- a/cranelift/filetests/filetests/runtests/simd-fma.clif +++ b/cranelift/filetests/filetests/runtests/simd-fma.clif @@ -87,3 +87,39 @@ block0(v0: f64x2, v1: f64x2, v2: f64x2): ; run: %fma_is_nan_f64x2([0x0.0 0x0.0], [+NaN 0x0.0], [0x0.0 +NaN]) == 1 ; run: %fma_is_nan_f64x2([-NaN 0x0.0], [0x0.0 -NaN], [0x0.0 0x0.0]) == 1 ; run: %fma_is_nan_f64x2([0x0.0 NaN], [0x0.0 NaN], [-NaN NaN]) == 1 + +function %fma_f32x4_splat1(f32x4, f32, f32x4) -> f32x4 { +block0(v0: f32x4, v1: f32, v2: f32x4): + v3 = splat.f32x4 v1 + v4 = fma v0, v3, v2 + return v4 +} +; run: %fma_f32x4_splat1([0x9.0 0x9.0 0x9.0 0x9.0], 0x9.0, [0x9.0 0x9.0 0x9.0 0x9.0]) == [0x1.680000p6 0x1.680000p6 0x1.680000p6 0x1.680000p6] +; run: %fma_f32x4_splat1([0x1.0 0x2.0 0x3.0 0x4.0], 0x0.0, [0x5.0 0x6.0 0x7.0 0x8.0]) == [0x5.0 0x6.0 0x7.0 0x8.0] + +function %fma_f32x4_splat2(f32, f32x4, f32x4) -> f32x4 { +block0(v0: f32, v1: f32x4, v2: f32x4): + v3 = splat.f32x4 v0 + v4 = fma v3, v1, v2 + return v4 +} +; run: %fma_f32x4_splat2(0x9.0, [0x9.0 0x9.0 0x9.0 0x9.0], [0x9.0 0x9.0 0x9.0 0x9.0]) == [0x1.680000p6 0x1.680000p6 0x1.680000p6 0x1.680000p6] +; run: %fma_f32x4_splat2(0x0.0, [0x1.0 0x2.0 0x3.0 0x4.0], [0x5.0 0x6.0 0x7.0 0x8.0]) == [0x5.0 0x6.0 0x7.0 0x8.0] + +function %fma_f64x2_splat1(f64x2, f64, f64x2) -> f64x2 { +block0(v0: f64x2, v1: f64, v2: f64x2): + v3 = splat.f64x2 v1 + v4 = fma v0, v3, v2 + return v4 +} +; run: %fma_f64x2_splat1([0x9.0 0x9.0], 0x9.0, [0x9.0 0x9.0]) == [0x1.680000p6 0x1.680000p6] +; run: %fma_f64x2_splat1([0x1.0 0x2.0], 0x0.0, [0x5.0 0x6.0]) == [0x5.0 0x6.0] + +function %fma_f64x2_splat2(f64, f64x2, f64x2) -> f64x2 { +block0(v0: f64, v1: f64x2, v2: f64x2): + v3 = splat.f64x2 v0 + v4 = fma v3, v1, v2 + return v4 +} +; run: %fma_f64x2_splat2(0x9.0, [0x9.0 0x9.0], [0x9.0 0x9.0]) == [0x1.680000p6 0x1.680000p6] +; run: %fma_f64x2_splat2(0x0.0, [0x1.0 0x2.0], [0x5.0 0x6.0]) == [0x5.0 0x6.0]