Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

riscv64: Implement SIMD saturating arithmetic and min/max #6430

Merged
merged 2 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_i16x8_extadd_pairwise_i8x16",
"simd_i16x8_extmul_i8x16",
"simd_i16x8_q15mulr_sat_s",
"simd_i16x8_sat_arith",
"simd_i32x4_arith2",
"simd_i32x4_cmp",
"simd_i32x4_dot_i16x8",
Expand All @@ -240,7 +239,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"simd_i64x2_extmul_i32x4",
"simd_i8x16_arith2",
"simd_i8x16_cmp",
"simd_i8x16_sat_arith",
"simd_int_to_int_extend",
"simd_lane",
"simd_load",
Expand Down
39 changes: 35 additions & 4 deletions cranelift/codegen/src/isa/riscv64/inst/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,32 +278,57 @@ impl VecAluOpRRR {
VecAluOpRRR::VandVV | VecAluOpRRR::VandVX => 0b001001,
VecAluOpRRR::VorVV | VecAluOpRRR::VorVX => 0b001010,
VecAluOpRRR::VxorVV | VecAluOpRRR::VxorVX => 0b001011,
VecAluOpRRR::VminuVV | VecAluOpRRR::VminuVX => 0b000100,
VecAluOpRRR::VminVV | VecAluOpRRR::VminVX => 0b000101,
VecAluOpRRR::VmaxuVV | VecAluOpRRR::VmaxuVX => 0b000110,
VecAluOpRRR::VmaxVV | VecAluOpRRR::VmaxVX => 0b000111,
VecAluOpRRR::VslidedownVX => 0b001111,
VecAluOpRRR::VfrsubVF => 0b100111,
VecAluOpRRR::VmergeVVM | VecAluOpRRR::VmergeVXM | VecAluOpRRR::VfmergeVFM => 0b010111,
VecAluOpRRR::VfdivVV | VecAluOpRRR::VfdivVF => 0b100000,
VecAluOpRRR::VfrdivVF => 0b100001,
VecAluOpRRR::VfdivVV
| VecAluOpRRR::VfdivVF
| VecAluOpRRR::VsadduVV
| VecAluOpRRR::VsadduVX => 0b100000,
VecAluOpRRR::VfrdivVF | VecAluOpRRR::VsaddVV | VecAluOpRRR::VsaddVX => 0b100001,
VecAluOpRRR::VssubuVV | VecAluOpRRR::VssubuVX => 0b100010,
VecAluOpRRR::VssubVV | VecAluOpRRR::VssubVX => 0b100011,
VecAluOpRRR::VfsgnjnVV => 0b001001,
}
}

pub fn category(&self) -> VecOpCategory {
match self {
VecAluOpRRR::VaddVV
| VecAluOpRRR::VsaddVV
| VecAluOpRRR::VsadduVV
| VecAluOpRRR::VsubVV
| VecAluOpRRR::VssubVV
| VecAluOpRRR::VssubuVV
| VecAluOpRRR::VandVV
| VecAluOpRRR::VorVV
| VecAluOpRRR::VxorVV
| VecAluOpRRR::VminuVV
| VecAluOpRRR::VminVV
| VecAluOpRRR::VmaxuVV
| VecAluOpRRR::VmaxVV
| VecAluOpRRR::VmergeVVM => VecOpCategory::OPIVV,
VecAluOpRRR::VmulVV | VecAluOpRRR::VmulhVV | VecAluOpRRR::VmulhuVV => {
VecOpCategory::OPMVV
}
VecAluOpRRR::VaddVX
| VecAluOpRRR::VsaddVX
| VecAluOpRRR::VsadduVX
| VecAluOpRRR::VsubVX
| VecAluOpRRR::VssubVX
| VecAluOpRRR::VssubuVX
| VecAluOpRRR::VrsubVX
| VecAluOpRRR::VandVX
| VecAluOpRRR::VorVX
| VecAluOpRRR::VxorVX
| VecAluOpRRR::VminuVX
| VecAluOpRRR::VminVX
| VecAluOpRRR::VmaxuVX
| VecAluOpRRR::VmaxVX
| VecAluOpRRR::VslidedownVX
| VecAluOpRRR::VmergeVXM => VecOpCategory::OPIVX,
VecAluOpRRR::VfaddVV
Expand Down Expand Up @@ -365,6 +390,8 @@ impl VecAluOpRRImm5 {
VecAluOpRRImm5::VxorVI => 0b001011,
VecAluOpRRImm5::VslidedownVI => 0b001111,
VecAluOpRRImm5::VmergeVIM => 0b010111,
VecAluOpRRImm5::VsadduVI => 0b100000,
VecAluOpRRImm5::VsaddVI => 0b100001,
}
}

Expand All @@ -376,7 +403,9 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VorVI
| VecAluOpRRImm5::VxorVI
| VecAluOpRRImm5::VslidedownVI
| VecAluOpRRImm5::VmergeVIM => VecOpCategory::OPIVI,
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => VecOpCategory::OPIVI,
}
}

Expand All @@ -388,7 +417,9 @@ impl VecAluOpRRImm5 {
| VecAluOpRRImm5::VandVI
| VecAluOpRRImm5::VorVI
| VecAluOpRRImm5::VxorVI
| VecAluOpRRImm5::VmergeVIM => false,
| VecAluOpRRImm5::VmergeVIM
| VecAluOpRRImm5::VsadduVI
| VecAluOpRRImm5::VsaddVI => false,
}
}
}
Expand Down
108 changes: 108 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst_vector.isle
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,21 @@
(type VecAluOpRRR (enum
;; Vector-Vector Opcodes
(VaddVV)
(VsaddVV)
(VsadduVV)
(VsubVV)
(VssubVV)
(VssubuVV)
(VmulVV)
(VmulhVV)
(VmulhuVV)
(VandVV)
(VorVV)
(VxorVV)
(VmaxVV)
(VmaxuVV)
(VminVV)
(VminuVV)
(VfaddVV)
(VfsubVV)
(VfmulVV)
Expand All @@ -107,11 +115,19 @@

;; Vector-Scalar Opcodes
(VaddVX)
(VsaddVX)
(VsadduVX)
(VsubVX)
(VrsubVX)
(VssubVX)
(VssubuVX)
(VandVX)
(VorVX)
(VxorVX)
(VmaxVX)
(VmaxuVX)
(VminVX)
(VminuVX)
(VslidedownVX)
(VfaddVF)
(VfsubVF)
Expand All @@ -127,6 +143,8 @@
(type VecAluOpRRImm5 (enum
;; Regular VI Opcodes
(VaddVI)
(VsaddVI)
(VsadduVI)
(VrsubVI)
(VandVI)
(VorVI)
Expand Down Expand Up @@ -280,6 +298,36 @@
(rule (rv_vadd_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VaddVI) vs2 imm mask vstate))

;; Helper for emitting the `vsadd.vv` instruction.
(decl rv_vsadd_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsadd_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsaddVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vsadd.vx` instruction.
(decl rv_vsadd_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsadd_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsaddVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vsadd.vi` instruction.
(decl rv_vsadd_vi (Reg Imm5 VecOpMasking VState) Reg)
(rule (rv_vsadd_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VsaddVI) vs2 imm mask vstate))

;; Helper for emitting the `vsaddu.vv` instruction.
(decl rv_vsaddu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsaddu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsadduVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vsaddu.vx` instruction.
(decl rv_vsaddu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsaddu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VsadduVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vsaddu.vi` instruction.
(decl rv_vsaddu_vi (Reg Imm5 VecOpMasking VState) Reg)
(rule (rv_vsaddu_vi vs2 imm mask vstate)
(vec_alu_rr_imm5 (VecAluOpRRImm5.VsadduVI) vs2 imm mask vstate))

;; Helper for emitting the `vsub.vv` instruction.
(decl rv_vsub_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vsub_vv vs2 vs1 mask vstate)
Expand All @@ -295,6 +343,26 @@
(rule (rv_vrsub_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VrsubVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vssub.vv` instruction.
(decl rv_vssub_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssub_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vssub.vx` instruction.
(decl rv_vssub_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssub_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vssubu.vv` instruction.
(decl rv_vssubu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssubu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vssubu.vx` instruction.
(decl rv_vssubu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vssubu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VssubuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vneg.v` pseudo-instruction.
(decl rv_vneg_v (Reg VecOpMasking VState) Reg)
(rule (rv_vneg_v vs2 mask vstate)
Expand Down Expand Up @@ -372,6 +440,46 @@
(if-let neg1 (imm5_from_i8 -1))
(rv_vxor_vi vs2 neg1 mask vstate))

;; Helper for emitting the `vmax.vv` instruction.
(decl rv_vmax_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmax_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmax.vx` instruction.
(decl rv_vmax_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmax_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmin.vv` instruction.
(decl rv_vmin_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmin_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmin.vx` instruction.
(decl rv_vmin_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmin_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vmaxu.vv` instruction.
(decl rv_vmaxu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmaxu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vmaxu.vx` instruction.
(decl rv_vmaxu_vx (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vmaxu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VmaxuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vminu.vv` instruction.
(decl rv_vminu_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vminu_vv vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminuVV) vs2 vs1 mask vstate))

;; Helper for emitting the `vminu.vx` instruction.
(decl rv_vminu_vx (Reg Reg VecOpMasking VState) Reg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stray thought, but the x64 backend has different types for Gpr and Xmm which represent a type-level distinction for different classes of registers. I've found that it's actually worked quite well for x64 and it's saved me a few times from mistakes. I realized I should be carefully reading the rules that use *_vx instructions to ensure the variable bound by splat comes second instead of first by accident (since I think that would compile past ISLE but wouldn't get past debug asserts perhaps in the backend).

Having this sort of distinction though is a pretty large refactoring so definitely not necessary on this PR, but perhaps something to consider if you also find yourself trying to carefully ensure each register goes to the right spot.

Copy link
Contributor Author

@afonso360 afonso360 May 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds like a good idea! And I've definitely had that happen to me. Not with .vx, but passing an F register into a .vx opcode. We do hit the regclass asserts in the backend, but it would be better to have it as a type system feature.

It should be fairly easy to start at least with the vector rules and helpers. I'm going to give that a try.

(rule (rv_vminu_vx vs2 vs1 mask vstate)
(vec_alu_rrr (VecAluOpRRR.VminuVX) vs2 vs1 mask vstate))

;; Helper for emitting the `vfadd.vv` instruction.
(decl rv_vfadd_vv (Reg Reg VecOpMasking VState) Reg)
(rule (rv_vfadd_vv vs2 vs1 mask vstate)
Expand Down
105 changes: 97 additions & 8 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -884,25 +884,64 @@
(t2 Reg y))
(value_regs t1 t2)))


;;;;; Rules for `smax`;;;;;;;;;
(rule
(lower (has_type ty (smax x y)))

(rule 0 (lower (has_type (ty_int ty) (smax x y)))
(gen_int_select ty (IntSelectOP.Smax) (ext_int_if_need $true x ty) (ext_int_if_need $true y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (smax x y)))
(rv_vmax_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (smax x (splat y))))
(rv_vmax_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (smax (splat x) y)))
(rv_vmax_vx y x (unmasked) ty))

;;;;; Rules for `smin`;;;;;;;;;
(rule
(lower (has_type ty (smin x y)))

(rule 0 (lower (has_type (ty_int ty) (smin x y)))
(gen_int_select ty (IntSelectOP.Smin) (ext_int_if_need $true x ty) (ext_int_if_need $true y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (smin x y)))
(rv_vmin_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (smin x (splat y))))
(rv_vmin_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (smin (splat x) y)))
(rv_vmin_vx y x (unmasked) ty))

;;;;; Rules for `umax`;;;;;;;;;
(rule
(lower (has_type ty (umax x y)))

(rule 0 (lower (has_type (ty_int ty) (umax x y)))
(gen_int_select ty (IntSelectOP.Umax) (ext_int_if_need $false x ty) (ext_int_if_need $false y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (umax x y)))
(rv_vmaxu_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (umax x (splat y))))
(rv_vmaxu_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (umax (splat x) y)))
(rv_vmaxu_vx y x (unmasked) ty))

;;;;; Rules for `umin`;;;;;;;;;
(rule
(lower (has_type ty (umin x y)))

(rule 0 (lower (has_type (ty_int ty) (umin x y)))
(gen_int_select ty (IntSelectOP.Umin) (ext_int_if_need $false x ty) (ext_int_if_need $false y ty)))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (umin x y)))
(rv_vminu_vv x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (umin x (splat y))))
(rv_vminu_vx x y (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (umin (splat x) y)))
(rv_vminu_vx y x (unmasked) ty))


;;;;; Rules for `debugtrap`;;;;;;;;;
(rule
(lower (debugtrap))
Expand Down Expand Up @@ -1178,3 +1217,53 @@
;; similar in its splat rules.
;; TODO: Look through bitcasts when splatting out registers. We can use
;; `vmv.v.x` in a `(splat.f32x4 (bitcast.f32 val))`. And vice versa for integers.

;;;; Rules for `uadd_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat x y)))
(rv_vsaddu_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat x (splat y))))
(rv_vsaddu_vx x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat (splat x) y)))
(rv_vsaddu_vx y x (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat x (replicated_imm5 y))))
(rv_vsaddu_vi x y (unmasked) ty))

(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (uadd_sat (replicated_imm5 x) y)))
(rv_vsaddu_vi y x (unmasked) ty))

;;;; Rules for `sadd_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat x y)))
(rv_vsadd_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat x (splat y))))
(rv_vsadd_vx x y (unmasked) ty))

(rule 2 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat (splat x) y)))
(rv_vsadd_vx y x (unmasked) ty))

(rule 3 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat x (replicated_imm5 y))))
(rv_vsadd_vi x y (unmasked) ty))

(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (sadd_sat (replicated_imm5 x) y)))
(rv_vsadd_vi y x (unmasked) ty))

;;;; Rules for `usub_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (usub_sat x y)))
(rv_vssubu_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (usub_sat x (splat y))))
(rv_vssubu_vx x y (unmasked) ty))

;;;; Rules for `ssub_sat` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule 0 (lower (has_type (ty_vec_fits_in_register ty) (ssub_sat x y)))
(rv_vssub_vv x y (unmasked) ty))

(rule 1 (lower (has_type (ty_vec_fits_in_register ty) (ssub_sat x (splat y))))
(rv_vssub_vx x y (unmasked) ty))
Loading