Skip to content

Commit

Permalink
riscv64: Implement SIMD saturating arithmetic and min/max (#6430)
Browse files Browse the repository at this point in the history
* riscv64: Implement SIMD `{u,s}{add,sub}_sat`

* riscv64: Implement SIMD `{u,s}{min,max}`
  • Loading branch information
afonso360 authored May 23, 2023
1 parent a61be19 commit b4c8509
Show file tree
Hide file tree
Showing 23 changed files with 4,271 additions and 15 deletions.
2 changes: 0 additions & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_i16x8_extadd_pairwise_i8x16",
"simd_i16x8_extmul_i8x16",
"simd_i16x8_q15mulr_sat_s",
"simd_i16x8_sat_arith",
"simd_i32x4_arith2",
"simd_i32x4_cmp",
"simd_i32x4_dot_i16x8",
Expand All @@ -240,7 +239,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_i64x2_extmul_i32x4",
"simd_i8x16_arith2",
"simd_i8x16_cmp",
"simd_i8x16_sat_arith",
"simd_int_to_int_extend",
"simd_lane",
"simd_load",
Expand Down
39 changes: 35 additions & 4 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,32 +278,57 @@ impl VecAluOpRRR {
VecAluOpRRR::VandVV | VecAluOpRRR::VandVX => 0b001001,
VecAluOpRRR::VorVV | VecAluOpRRR::VorVX => 0b001010,
VecAluOpRRR::VxorVV | VecAluOpRRR::VxorVX => 0b001011,
VecAluOpRRR::VminuVV | VecAluOpRRR::VminuVX => 0b000100,
VecAluOpRRR::VminVV | VecAluOpRRR::VminVX => 0b000101,
VecAluOpRRR::VmaxuVV | VecAluOpRRR::VmaxuVX => 0b000110,
VecAluOpRRR::VmaxVV | VecAluOpRRR::VmaxVX => 0b000111,
VecAluOpRRR::VslidedownVX => 0b001111,
VecAluOpRRR::VfrsubVF => 0b100111,
VecAluOpRRR::VmergeVVM | VecAluOpRRR::VmergeVXM | VecAluOpRRR::VfmergeVFM => 0b010111,
VecAluOpRRR::VfdivVV | VecAluOpRRR::VfdivVF => 0b100000,
VecAluOpRRR::VfrdivVF => 0b100001,
VecAluOpRRR::VfdivVV
| VecAluOpRRR::VfdivVF
| VecAluOpRRR::VsadduVV
| VecAluOpRRR::VsadduVX => 0b100000,
VecAluOpRRR::VfrdivVF | VecAluOpRRR::VsaddVV | VecAluOpRRR::VsaddVX => 0b100001,
VecAluOpRRR::VssubuVV | VecAluOpRRR::VssubuVX => 0b100010,
VecAluOpRRR::VssubVV | VecAluOpRRR::VssubVX => 0b100011,
VecAluOpRRR::VfsgnjnVV => 0b001001,
}
}

pub fn category(&self) -> VecOpCategory {
match self {
VecAluOpRRR::VaddVV
| VecAluOpRRR::VsaddVV
| VecAluOpRRR::VsadduVV
| VecAluOpRRR::VsubVV
| VecAluOpRRR::VssubVV
| VecAluOpRRR::VssubuVV
| VecAluOpRRR::VandVV
| VecAluOpRRR::VorVV
| VecAluOpRRR::VxorVV
| VecAluOpRRR::VminuVV
| VecAluOpRRR::VminVV
| VecAluOpRRR::VmaxuVV
| VecAluOpRRR::VmaxVV
| VecAluOpRRR::VmergeVVM => VecOpCategory::OPIVV,
VecAluOpRRR::VmulVV | VecAluOpRRR::VmulhVV | VecAluOpRRR::VmulhuVV => {
VecOpCategory::OPMVV
}
VecAluOpRRR::VaddVX
| VecAluOpRRR::VsaddVX
| VecAluOpRRR::VsadduVX
| VecAluOpRRR::VsubVX
| VecAluOpRRR::VssubVX
| VecAluOpRRR::VssubuVX
| VecAluOpRRR::VrsubVX
| VecAluOpRRR::VandVX
| VecAluOpRRR::VorVX
| VecAluOpRRR::VxorVX
| VecAluOpRRR::VminuVX
| VecAluOpRRR::VminVX
| VecAluOpRRR::VmaxuVX
| VecAluOpRRR::VmaxVX
| VecAluOpRRR::VslidedownVX
| VecAluOpRRR::VmergeVXM => VecOpCategory::OPIVX,
VecAluOpRRR::VfaddVV
Expand Down Expand Up @@ -365,6 +390,8 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VxorVI => 0b001011,
VecAluOpRRImm5::VslidedownVI => 0b001111,
VecAluOpRRImm5::VmergeVIM => 0b010111,
VecAluOpRRImm5::VsadduVI => 0b100000,
VecAluOpRRImm5::VsaddVI => 0b100001,
}
}

Expand All @@ -376,7 +403,9 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VorVI
| VecAluOpRRImm5::VxorVI
| VecAluOpRRImm5::VslidedownVI
| VecAluOpRRImm5::VmergeVIM => VecOpCategory::OPIVI,
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => VecOpCategory::OPIVI,
}
}

Expand All @@ -388,7 +417,9 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VandVI
| VecAluOpRRImm5::VorVI
| VecAluOpRRImm5::VxorVI
| VecAluOpRRImm5::VmergeVIM => false,
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => false,
}
}
}
Expand Down
108 changes: 108 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,21 @@
(type VecAluOpRRR (enum
;; Vector-Vector Opcodes
(VaddVV)
(VsaddVV)
(VsadduVV)
(VsubVV)
(VssubVV)
(VssubuVV)
(VmulVV)
(VmulhVV)
(VmulhuVV)
(VandVV)
(VorVV)
(VxorVV)
(VmaxVV)
(VmaxuVV)
(VminVV)
(VminuVV)
(VfaddVV)
(VfsubVV)
(VfmulVV)
Expand All @@ -107,11 +115,19 @@

;; Vector-Scalar Opcodes
(VaddVX)
(VsaddVX)
(VsadduVX)
(VsubVX)
(VrsubVX)
(VssubVX)
(VssubuVX)
(VandVX)
(VorVX)
(VxorVX)
(VmaxVX)
(VmaxuVX)
(VminVX)
(VminuVX)
(VslidedownVX)
(VfaddVF)
(VfsubVF)
Expand All @@ -127,6 +143,8 @@
(type VecAluOpRRImm5 (enum
;; Regular VI Opcodes
(VaddVI)
(VsaddVI)
(VsadduVI)
(VrsubVI)
(VandVI)
(VorVI)
Expand Down Expand Up @@ -280,6 +298,36 @@
(rule (rv_vadd_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VaddVI) vs2 imm mask vstate))

;; Helper for emitting the `vsadd.vv` instruction.
(decl rv_vsadd_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsadd_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsaddVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vsadd.vx` instruction.
(decl rv_vsadd_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsadd_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsaddVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vsadd.vi` instruction.
(decl rv_vsadd_vi (Reg Imm5 VecOpMasking VState) Reg)
(rule (rv_vsadd_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VsaddVI) vs2 imm mask vstate))

;; Helper for emitting the `vsaddu.vv` instruction.
(decl rv_vsaddu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsaddu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsadduVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vsaddu.vx` instruction.
(decl rv_vsaddu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsaddu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsadduVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vsaddu.vi` instruction.
(decl rv_vsaddu_vi (Reg Imm5 VecOpMasking VState) Reg)
(rule (rv_vsaddu_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VsadduVI) vs2 imm mask vstate))

;; Helper for emitting the `vsub.vv` instruction.
(decl rv_vsub_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsub_vv vs2 vs1 mask vstate)
Expand All @@ -295,6 +343,26 @@
(rule (rv_vrsub_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VrsubVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vssub.vv` instruction.
(decl rv_vssub_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssub_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vssub.vx` instruction.
(decl rv_vssub_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssub_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vssubu.vv` instruction.
(decl rv_vssubu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssubu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vssubu.vx` instruction.
(decl rv_vssubu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssubu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vneg.v` pseudo-instruction.
(decl rv_vneg_v (Reg VecOpMasking VState) Reg)
(rule (rv_vneg_v vs2 mask vstate)
Expand Down Expand Up @@ -372,6 +440,46 @@
(if-let neg1 (imm5_from_i8 -1))
(rv_vxor_vi vs2 neg1 mask vstate))

;; Helper for emitting the `vmax.vv` instruction.
(decl rv_vmax_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmax_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmax.vx` instruction.
(decl rv_vmax_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmax_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmin.vv` instruction.
(decl rv_vmin_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmin_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmin.vx` instruction.
(decl rv_vmin_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmin_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmaxu.vv` instruction.
(decl rv_vmaxu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmaxu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmaxu.vx` instruction.
(decl rv_vmaxu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmaxu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vminu.vv` instruction.
(decl rv_vminu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vminu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vminu.vx` instruction.
(decl rv_vminu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vminu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vfadd.vv` instruction.
(decl rv_vfadd_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vfadd_vv vs2 vs1 mask vstate)
Expand Down
105 changes: 97 additions & 8 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -884,25 +884,64 @@
(t2 Reg y))
(value_regs t1 t2)))


;;;;; Rules for `smax`;;;;;;;;;
(rule
(lower (has_type ty (smax x y)))

(rule 0 (lower (has_type (ty_int ty) (smax x y)))
(gen_int_select ty (IntSelectOP.Smax) (ext_int_if_need $true x ty) (ext_int_if_need $true y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (smax x y)))
(rv_vmax_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (smax x (splat y))))
(rv_vmax_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (smax (splat x) y)))
(rv_vmax_vx y x (unmasked) ty))

;;;;; Rules for `smin`;;;;;;;;;
(rule
(lower (has_type ty (smin x y)))

(rule 0 (lower (has_type (ty_int ty) (smin x y)))
(gen_int_select ty (IntSelectOP.Smin) (ext_int_if_need $true x ty) (ext_int_if_need $true y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (smin x y)))
(rv_vmin_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (smin x (splat y))))
(rv_vmin_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (smin (splat x) y)))
(rv_vmin_vx y x (unmasked) ty))

;;;;; Rules for `umax`;;;;;;;;;
(rule
(lower (has_type ty (umax x y)))

(rule 0 (lower (has_type (ty_int ty) (umax x y)))
(gen_int_select ty (IntSelectOP.Umax) (ext_int_if_need $false x ty) (ext_int_if_need $false y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (umax x y)))
(rv_vmaxu_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (umax x (splat y))))
(rv_vmaxu_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (umax (splat x) y)))
(rv_vmaxu_vx y x (unmasked) ty))

;;;;; Rules for `umin`;;;;;;;;;
(rule
(lower (has_type ty (umin x y)))

(rule 0 (lower (has_type (ty_int ty) (umin x y)))
(gen_int_select ty (IntSelectOP.Umin) (ext_int_if_need $false x ty) (ext_int_if_need $false y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (umin x y)))
(rv_vminu_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (umin x (splat y))))
(rv_vminu_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (umin (splat x) y)))
(rv_vminu_vx y x (unmasked) ty))


;;;;; Rules for `debugtrap`;;;;;;;;;
(rule
(lower (debugtrap))
Expand Down Expand Up @@ -1178,3 +1217,53 @@
;; similar in its splat rules.
;; TODO: Look through bitcasts when splatting out registers. We can use
;; `vmv.v.x` in a `(splat.f32x4 (bitcast.f32 val))`. And vice versa for integers.

;;;; Rules for `uadd_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat x y)))
(rv_vsaddu_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat x (splat y))))
(rv_vsaddu_vx x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat (splat x) y)))
(rv_vsaddu_vx y x (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat x (replicated_imm5 y))))
(rv_vsaddu_vi x y (unmasked) ty))

(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat (replicated_imm5 x) y)))
(rv_vsaddu_vi y x (unmasked) ty))

;;;; Rules for `sadd_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat x y)))
(rv_vsadd_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat x (splat y))))
(rv_vsadd_vx x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat (splat x) y)))
(rv_vsadd_vx y x (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat x (replicated_imm5 y))))
(rv_vsadd_vi x y (unmasked) ty))

(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat (replicated_imm5 x) y)))
(rv_vsadd_vi y x (unmasked) ty))

;;;; Rules for `usub_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (usub_sat x y)))
(rv_vssubu_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (usub_sat x (splat y))))
(rv_vssubu_vx x y (unmasked) ty))

;;;; Rules for `ssub_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (ssub_sat x y)))
(rv_vssub_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (ssub_sat x (splat y))))
(rv_vssub_vx x y (unmasked) ty))
Loading

0 comments on commit b4c8509

Please sign in to comment.