Skip to content

Commit

Permalink
riscv64: Implement vector floating point rounding instructions (#6920)
Browse files Browse the repository at this point in the history
* riscv64: Add CSR Instructions

* riscv64: Add float to int vector instructions

* cranelift: Split vector rounding mode tests

* riscv64: Implement float rounding ops for vectors

* riscv64: Update tests
  • Loading branch information
afonso360 authored Aug 30, 2023
1 parent 134dddc commit d6b4825
Show file tree
Hide file tree
Showing 22 changed files with 1,028 additions and 100 deletions.
2 changes: 0 additions & 2 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,6 @@ fn ignore(testsuite: &str, testname: &str, strategy: &str) -> bool {
"cvt_from_uint",
"issue_3327_bnot_lowering",
"simd_conversions",
"simd_f32x4_rounding",
"simd_f64x2_rounding",
"simd_i32x4_trunc_sat_f32x4",
"simd_i32x4_trunc_sat_f64x2",
"simd_load",
Expand Down
80 changes: 78 additions & 2 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@
(rs Reg)
(imm12 Imm12))

;; A CSR Reading or Writing instruction with a register source and a register destination.
(CsrReg
(op CsrRegOP)
(rd WritableReg)
(rs Reg)
(csr CSR))

;; A CSR Writing instruction with an immediate source and a register destination.
(CsrImm
(op CsrImmOP)
(rd WritableReg)
(imm UImm5)
(csr CSR))

;; An load
(Load
(rd WritableReg)
Expand Down Expand Up @@ -689,6 +703,30 @@
(Bseti)
))

(type CsrRegOP (enum
;; Atomic Read/Write CSR
(CsrRW)
;; Atomic Read and Set Bits in CSR
(CsrRS)
;; Atomic Read and Clear Bits in CSR
(CsrRC)
))

(type CsrImmOP (enum
;; Atomic Read/Write CSR (Immediate Source)
(CsrRWI)
;; Atomic Read and Set Bits in CSR (Immediate Source)
(CsrRSI)
;; Atomic Read and Clear Bits in CSR (Immediate Source)
(CsrRCI)
))

;; Enum of the known CSR registers
(type CSR (enum
;; Floating-Point Dynamic Rounding Mode
(Frm)
))


(type FRM (enum
;; Round to Nearest, ties to Even
Expand All @@ -706,6 +744,10 @@
(Fcsr)
))

(decl pure frm_bits (FRM) UImm5)
(extern constructor frm_bits frm_bits)
(convert FRM UImm5 frm_bits)

(type FFlagsException (enum
;; Invalid Operation
(NV)
Expand Down Expand Up @@ -1508,6 +1550,30 @@
(alu_rrr (AluOPRRR.Packw) rs1 rs2))


;; `Zicsr` Extension Instructions

;; Helper for emitting the `csrrwi` instruction.
(decl rv_csrrwi (CSR UImm5) XReg)
(rule (rv_csrrwi csr imm)
(csr_imm (CsrImmOP.CsrRWI) csr imm))

;; This is a special case of `csrrwi` when the CSR is the `frm` CSR.
(decl rv_fsrmi (FRM) XReg)
(rule (rv_fsrmi frm) (rv_csrrwi (CSR.Frm) frm))


;; Helper for emitting the `csrw` instruction. This is a special case of
;; `csrrw` where the destination register is always `x0`.
(decl rv_csrw (CSR XReg) Unit)
(rule (rv_csrw csr rs)
(csr_reg_dst_zero (CsrRegOP.CsrRW) csr rs))

;; This is a special case of `csrw` when the CSR is the `frm` CSR.
(decl rv_fsrm (XReg) Unit)
(rule (rv_fsrm rs) (rv_csrw (CSR.Frm) rs))





;; Generate a mask for the bit-width of the given type
Expand Down Expand Up @@ -1686,7 +1752,6 @@
(_ Unit (emit (MInst.FpuRRR op (gen_default_frm) dst src1 src2))))
dst))


;; Helper for emitting `MInst.FpuRRRR` instructions.
(decl fpu_rrrr (FpuOPRRRR Type Reg Reg Reg) Reg)
(rule (fpu_rrrr op ty src1 src2 src3)
Expand All @@ -1710,7 +1775,6 @@
(_ Unit (emit (MInst.AluRRImm12 op dst src (imm12_zero)))))
dst))


;; Helper for emitting the `Lui` instruction.
;; TODO: This should be something like `emit_u_type`. And should share the
;; `MInst` with `auipc` since these instructions share the U-Type format.
Expand All @@ -1720,6 +1784,18 @@
(_ Unit (emit (MInst.Lui dst imm))))
dst))

;; Helper for emitting `MInst.CsrImm` instructions.
(decl csr_imm (CsrImmOP CSR UImm5) XReg)
(rule (csr_imm op csr imm)
(let ((dst WritableXReg (temp_writable_xreg))
(_ Unit (emit (MInst.CsrImm op dst imm csr))))
dst))

;; Helper for emitting a `MInst.CsrReg` instruction that writes the result to x0.
(decl csr_reg_dst_zero (CsrRegOP CSR XReg) Unit)
(rule (csr_reg_dst_zero op csr rs)
(emit (MInst.CsrReg op (writable_zero_reg) rs csr)))



(decl select_addi (Type) AluOPRRI)
Expand Down
76 changes: 76 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1810,3 +1810,79 @@ pub(crate) fn f64_cvt_to_int_bounds(signed: bool, out_bits: u8) -> (f64, f64) {
_ => unreachable!(),
}
}

impl CsrRegOP {
pub(crate) fn funct3(self) -> u32 {
match self {
CsrRegOP::CsrRW => 0b001,
CsrRegOP::CsrRS => 0b010,
CsrRegOP::CsrRC => 0b011,
}
}

pub(crate) fn opcode(self) -> u32 {
0b1110011
}

pub(crate) fn name(self) -> &'static str {
match self {
CsrRegOP::CsrRW => "csrrw",
CsrRegOP::CsrRS => "csrrs",
CsrRegOP::CsrRC => "csrrc",
}
}
}

impl Display for CsrRegOP {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "{}", self.name())
}
}

impl CsrImmOP {
pub(crate) fn funct3(self) -> u32 {
match self {
CsrImmOP::CsrRWI => 0b101,
CsrImmOP::CsrRSI => 0b110,
CsrImmOP::CsrRCI => 0b111,
}
}

pub(crate) fn opcode(self) -> u32 {
0b1110011
}

pub(crate) fn name(self) -> &'static str {
match self {
CsrImmOP::CsrRWI => "csrrwi",
CsrImmOP::CsrRSI => "csrrsi",
CsrImmOP::CsrRCI => "csrrci",
}
}
}

impl Display for CsrImmOP {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "{}", self.name())
}
}

impl CSR {
pub(crate) fn bits(self) -> Imm12 {
Imm12::from_bits(match self {
CSR::Frm => 0x0002,
})
}

pub(crate) fn name(self) -> &'static str {
match self {
CSR::Frm => "frm",
}
}
}

impl Display for CSR {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "{}", self.name())
}
}
13 changes: 13 additions & 0 deletions cranelift/codegen/src/isa/riscv64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ impl Inst {
| Inst::AluRRR { .. }
| Inst::FpuRRR { .. }
| Inst::AluRRImm12 { .. }
| Inst::CsrReg { .. }
| Inst::CsrImm { .. }
| Inst::Load { .. }
| Inst::Store { .. }
| Inst::Args { .. }
Expand Down Expand Up @@ -595,6 +597,17 @@ impl MachInstEmit for Inst {
| alu_op.imm12(imm12) << 20;
sink.put4(x);
}
&Inst::CsrReg { op, rd, rs, csr } => {
let rs = allocs.next(rs);
let rd = allocs.next_writable(rd);

sink.put4(encode_csr_reg(op, rd, rs, csr));
}
&Inst::CsrImm { op, rd, csr, imm } => {
let rd = allocs.next_writable(rd);

sink.put4(encode_csr_imm(op, rd, csr, imm));
}
&Inst::Load {
rd,
op,
Expand Down
43 changes: 35 additions & 8 deletions cranelift/codegen/src/isa/riscv64/inst/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! Some instructions especially in extensions have slight variations from
//! the base RISC-V specification.
use super::{Imm12, Imm5, UImm5, VType};
use super::*;
use crate::isa::riscv64::inst::reg_to_gpr_num;
use crate::isa::riscv64::lower::isle::generated_code::{
VecAluOpRImm5, VecAluOpRR, VecAluOpRRImm5, VecAluOpRRR, VecAluOpRRRImm5, VecAluOpRRRR,
Expand Down Expand Up @@ -53,21 +53,30 @@ pub fn encode_r_type(
)
}

/// Encode an I-type instruction.
///
/// Layout:
/// 0-------6-7-------11-12------14-15------19-20------------------31
/// | Opcode | rd | width | rs1 | Offset[11:0] |
pub fn encode_i_type(opcode: u32, rd: WritableReg, width: u32, rs1: Reg, offset: Imm12) -> u32 {
fn encode_i_type_bits(opcode: u32, rd: u32, funct3: u32, rs1: u32, offset: u32) -> u32 {
let mut bits = 0;
bits |= unsigned_field_width(opcode, 7);
bits |= reg_to_gpr_num(rd.to_reg()) << 7;
bits |= unsigned_field_width(width, 3) << 12;
bits |= reg_to_gpr_num(rs1) << 15;
bits |= unsigned_field_width(offset.as_u32(), 12) << 20;
bits |= unsigned_field_width(rd, 5) << 7;
bits |= unsigned_field_width(funct3, 3) << 12;
bits |= unsigned_field_width(rs1, 5) << 15;
bits |= unsigned_field_width(offset, 12) << 20;
bits
}

/// Encode an I-type instruction.
pub fn encode_i_type(opcode: u32, rd: WritableReg, width: u32, rs1: Reg, offset: Imm12) -> u32 {
encode_i_type_bits(
opcode,
reg_to_gpr_num(rd.to_reg()),
width,
reg_to_gpr_num(rs1),
offset.as_u32(),
)
}

/// Encode an S-type instruction.
///
/// Layout:
Expand Down Expand Up @@ -297,3 +306,21 @@ pub fn encode_vmem_store(
// with different names on the fields.
encode_vmem_load(opcode, vs3, width, rs1, sumop, masking, mop, nf)
}

// The CSR Reg instruction is really just an I type instruction with the CSR in
// the immediate field.
pub fn encode_csr_reg(op: CsrRegOP, rd: WritableReg, rs: Reg, csr: CSR) -> u32 {
encode_i_type(op.opcode(), rd, op.funct3(), rs, csr.bits())
}

// The CSR Imm instruction is an I type instruction with the CSR in
// the immediate field and the value to be set in the `rs1` field.
pub fn encode_csr_imm(op: CsrImmOP, rd: WritableReg, csr: CSR, imm: UImm5) -> u32 {
encode_i_type_bits(
op.opcode(),
reg_to_gpr_num(rd.to_reg()),
op.funct3(),
imm.bits(),
csr.bits().as_u32(),
)
}
37 changes: 35 additions & 2 deletions cranelift/codegen/src/isa/riscv64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ pub(crate) type VecWritableReg = Vec<Writable<Reg>>;
// Instructions (top level): definition

pub use crate::isa::riscv64::lower::isle::generated_code::{
AluOPRRI, AluOPRRR, AtomicOP, FClassResult, FFlagsException, FloatRoundOP, FloatSelectOP,
FpuOPRR, FpuOPRRR, FpuOPRRRR, IntSelectOP, LoadOP, MInst as Inst, StoreOP, FRM,
AluOPRRI, AluOPRRR, AtomicOP, CsrImmOP, CsrRegOP, FClassResult, FFlagsException, FloatRoundOP,
FloatSelectOP, FpuOPRR, FpuOPRRR, FpuOPRRRR, IntSelectOP, LoadOP, MInst as Inst, StoreOP, CSR,
FRM,
};
use crate::isa::riscv64::lower::isle::generated_code::{MInst, VecAluOpRRImm5, VecAluOpRRR};

Expand Down Expand Up @@ -399,6 +400,13 @@ fn riscv64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut Operan
collector.reg_use(rs);
collector.reg_def(rd);
}
&Inst::CsrReg { rd, rs, .. } => {
collector.reg_use(rs);
collector.reg_def(rd);
}
&Inst::CsrImm { rd, .. } => {
collector.reg_def(rd);
}
&Inst::Load { rd, from, .. } => {
if let Some(r) = from.get_allocatable_register() {
collector.reg_use(r);
Expand Down Expand Up @@ -1512,6 +1520,31 @@ impl Inst {
}
}
}
&Inst::CsrReg { op, rd, rs, csr } => {
let rs_s = format_reg(rs, allocs);
let rd_s = format_reg(rd.to_reg(), allocs);

match (op, csr, rd) {
(CsrRegOP::CsrRW, CSR::Frm, rd) if rd.to_reg() == zero_reg() => {
format!("fsrm {rs_s}")
}
_ => {
format!("{op} {rd_s},{csr},{rs_s}")
}
}
}
&Inst::CsrImm { op, rd, csr, imm } => {
let rd_s = format_reg(rd.to_reg(), allocs);

match (op, csr, rd) {
(CsrImmOP::CsrRWI, CSR::Frm, rd) if rd.to_reg() != zero_reg() => {
format!("fsrmi {rd_s},{imm}")
}
_ => {
format!("{op} {rd_s},{csr},{imm}")
}
}
}
&Inst::Load {
rd,
op,
Expand Down
Loading

0 comments on commit d6b4825

Please sign in to comment.