Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

riscv64: Improve f{min,max} codegen #7181

Merged
merged 1 commit into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 10 additions & 23 deletions cranelift/codegen/src/isa/riscv64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,6 @@
(f_tmp WritableReg)
(rs Reg)
(ty Type))
;;;; FMax
(FloatSelect
(op FloatSelectOP)
(rd WritableReg)
;; a integer register
(tmp WritableReg)
(rs1 Reg)
(rs2 Reg)
(ty Type))

;; popcnt if target doesn't support extension B
;; use iteration to implement.
Expand Down Expand Up @@ -391,11 +382,6 @@
))


(type FloatSelectOP (enum
(Max)
(Min)
))

(type FloatRoundOP (enum
(Nearest)
(Ceil)
Expand Down Expand Up @@ -1098,15 +1084,6 @@
(_ Unit (emit (MInst.FloatRound op rd tmp tmp2 rs ty))))
(writable_reg_to_reg rd)))

(decl gen_float_select (FloatSelectOP Reg Reg Type) Reg)
(rule
(gen_float_select op x y ty)
(let
((rd WritableReg (temp_writable_reg ty))
(tmp WritableXReg (temp_writable_xreg))
(_ Unit (emit (MInst.FloatSelect op rd tmp x y ty))))
(writable_reg_to_reg rd)))


;;;; Instruction Helpers ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand Down Expand Up @@ -1527,6 +1504,16 @@
(decl rv_fge (Type FReg FReg) XReg)
(rule (rv_fge ty rs1 rs2) (rv_fle ty rs2 rs1))

;; Helper for emitting the `fmin` instruction.
(decl rv_fmin (Type FReg FReg) FReg)
(rule (rv_fmin $F32 rs1 rs2) (fpu_rrr (FpuOPRRR.FminS) $F32 rs1 rs2))
(rule (rv_fmin $F64 rs1 rs2) (fpu_rrr (FpuOPRRR.FminD) $F64 rs1 rs2))

;; Helper for emitting the `fmax` instruction.
(decl rv_fmax (Type FReg FReg) FReg)
(rule (rv_fmax $F32 rs1 rs2) (fpu_rrr (FpuOPRRR.FmaxS) $F32 rs1 rs2))
(rule (rv_fmax $F64 rs1 rs2) (fpu_rrr (FpuOPRRR.FmaxD) $F64 rs1 rs2))


;; `Zba` Extension Instructions

Expand Down
47 changes: 0 additions & 47 deletions cranelift/codegen/src/isa/riscv64/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1695,53 +1695,6 @@ impl FloatRoundOP {
}
}

impl FloatSelectOP {
pub(crate) fn op_name(self) -> &'static str {
match self {
FloatSelectOP::Max => "max",
FloatSelectOP::Min => "min",
}
}

pub(crate) fn to_fpuoprrr(self, ty: Type) -> FpuOPRRR {
match self {
FloatSelectOP::Max => {
if ty == F32 {
FpuOPRRR::FmaxS
} else {
FpuOPRRR::FmaxD
}
}
FloatSelectOP::Min => {
if ty == F32 {
FpuOPRRR::FminS
} else {
FpuOPRRR::FminD
}
}
}
}
// move qnan bits into int register.
pub(crate) fn snan_bits(self, rd: Writable<Reg>, ty: Type) -> SmallInstVec<Inst> {
let mut insts = SmallInstVec::new();
insts.push(Inst::load_imm12(rd, Imm12::from_i16(-1)));
let x = if ty == F32 { 22 } else { 51 };
insts.push(Inst::AluRRImm12 {
alu_op: AluOPRRI::Srli,
rd: rd,
rs: rd.to_reg(),
imm12: Imm12::from_i16(x),
});
insts.push(Inst::AluRRImm12 {
alu_op: AluOPRRI::Slli,
rd: rd,
rs: rd.to_reg(),
imm12: Imm12::from_i16(x),
});
insts
}
}

pub(crate) fn f32_bits(f: f32) -> u32 {
u32::from_le_bytes(f.to_le_bytes())
}
Expand Down
137 changes: 0 additions & 137 deletions cranelift/codegen/src/isa/riscv64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ impl Inst {
| Inst::Unwind { .. }
| Inst::DummyUse { .. }
| Inst::FloatRound { .. }
| Inst::FloatSelect { .. }
| Inst::Popcnt { .. }
| Inst::Rev8 { .. }
| Inst::Cltz { .. }
Expand Down Expand Up @@ -2587,126 +2586,6 @@ impl Inst {
sink.bind_label(label_jump_over, &mut state.ctrl_plane);
}

&Inst::FloatSelect {
op,
rd,
tmp,
rs1,
rs2,
ty,
} => {
let label_nan = sink.get_label();
let label_jump_over = sink.get_label();
// check if rs1 is nan.
Inst::emit_not_nan(tmp, rs1, ty).emit(&[], sink, emit_info, state);
Inst::CondBr {
taken: CondBrTarget::Label(label_nan),
not_taken: CondBrTarget::Fallthrough,
kind: IntegerCompare {
kind: IntCC::Equal,
rs1: tmp.to_reg(),
rs2: zero_reg(),
},
}
.emit(&[], sink, emit_info, state);
// check if rs2 is nan.
Inst::emit_not_nan(tmp, rs2, ty).emit(&[], sink, emit_info, state);
Inst::CondBr {
taken: CondBrTarget::Label(label_nan),
not_taken: CondBrTarget::Fallthrough,
kind: IntegerCompare {
kind: IntCC::Equal,
rs1: tmp.to_reg(),
rs2: zero_reg(),
},
}
.emit(&[], sink, emit_info, state);
// here rs1 and rs2 is not nan.
Inst::FpuRRR {
alu_op: op.to_fpuoprrr(ty),
frm: None,
rd: rd,
rs1: rs1,
rs2: rs2,
}
.emit(&[], sink, emit_info, state);
// special handle for +0 or -0.
{
// check is rs1 and rs2 all equal to zero.
let label_done = sink.get_label();
{
// if rs1 == 0
let mut insts = Inst::emit_if_float_not_zero(
tmp,
rs1,
ty,
CondBrTarget::Label(label_done),
CondBrTarget::Fallthrough,
);
insts.extend(Inst::emit_if_float_not_zero(
tmp,
rs2,
ty,
CondBrTarget::Label(label_done),
CondBrTarget::Fallthrough,
));
insts
.iter()
.for_each(|i| i.emit(&[], sink, emit_info, state));
}
Inst::FpuRR {
alu_op: FpuOPRR::move_f_to_x_op(ty),
frm: None,
rd: tmp,
rs: rs1,
}
.emit(&[], sink, emit_info, state);
Inst::FpuRR {
alu_op: FpuOPRR::move_f_to_x_op(ty),
frm: None,
rd: writable_spilltmp_reg(),
rs: rs2,
}
.emit(&[], sink, emit_info, state);
Inst::AluRRR {
alu_op: if op == FloatSelectOP::Max {
AluOPRRR::And
} else {
AluOPRRR::Or
},
rd: tmp,
rs1: tmp.to_reg(),
rs2: spilltmp_reg(),
}
.emit(&[], sink, emit_info, state);
// move back to rd.
Inst::FpuRR {
alu_op: FpuOPRR::move_x_to_f_op(ty),
frm: None,
rd,
rs: tmp.to_reg(),
}
.emit(&[], sink, emit_info, state);
//
sink.bind_label(label_done, &mut state.ctrl_plane);
}
// we have the reuslt,jump over.
Inst::gen_jump(label_jump_over).emit(&[], sink, emit_info, state);
// here is nan.
sink.bind_label(label_nan, &mut state.ctrl_plane);
op.snan_bits(tmp, ty)
.into_iter()
.for_each(|i| i.emit(&[], sink, emit_info, state));
// move to rd.
Inst::FpuRR {
alu_op: FpuOPRR::move_x_to_f_op(ty),
frm: None,
rd,
rs: tmp.to_reg(),
}
.emit(&[], sink, emit_info, state);
sink.bind_label(label_jump_over, &mut state.ctrl_plane);
}
&Inst::Popcnt {
sum,
tmp,
Expand Down Expand Up @@ -3708,22 +3587,6 @@ impl Inst {
rd: allocs.next_writable(rd),
},

Inst::FloatSelect {
op,
rd,
tmp,
rs1,
rs2,
ty,
} => Inst::FloatSelect {
op,
ty,
rs1: allocs.next(rs1),
rs2: allocs.next(rs2),
tmp: allocs.next_writable(tmp),
rd: allocs.next_writable(rd),
},

Inst::Popcnt {
sum,
tmp,
Expand Down
9 changes: 0 additions & 9 deletions cranelift/codegen/src/isa/riscv64/inst/emit_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2287,15 +2287,6 @@ fn riscv64_worst_case_instruction_size() {
ty: F64,
});

candidates.push(Inst::FloatSelect {
op: FloatSelectOP::Max,
rd: writable_fa0(),
tmp: writable_a0(),
rs1: fa0(),
rs2: fa0(),
ty: F64,
});

let mut max: (u32, MInst) = (0, Inst::Nop0);
for i in candidates {
let mut buffer = MachBuffer::new();
Expand Down
32 changes: 1 addition & 31 deletions cranelift/codegen/src/isa/riscv64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub(crate) type VecWritableReg = Vec<Writable<Reg>>;

pub use crate::isa::riscv64::lower::isle::generated_code::{
AluOPRRI, AluOPRRR, AtomicOP, CsrImmOP, CsrRegOP, FClassResult, FFlagsException, FloatRoundOP,
FloatSelectOP, FpuOPRR, FpuOPRRR, FpuOPRRRR, LoadOP, MInst as Inst, StoreOP, CSR, FRM,
FpuOPRR, FpuOPRRR, FpuOPRRRR, LoadOP, MInst as Inst, StoreOP, CSR, FRM,
};
use crate::isa::riscv64::lower::isle::generated_code::{CjOp, MInst, VecAluOpRRImm5, VecAluOpRRR};

Expand Down Expand Up @@ -609,13 +609,6 @@ fn riscv64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut Operan
collector.reg_early_def(f_tmp);
collector.reg_early_def(rd);
}
&Inst::FloatSelect {
rd, tmp, rs1, rs2, ..
} => {
collector.reg_uses(&[rs1, rs2]);
collector.reg_early_def(tmp);
collector.reg_early_def(rd);
}
&Inst::Popcnt {
sum, step, rs, tmp, ..
} => {
Expand Down Expand Up @@ -1109,29 +1102,6 @@ impl Inst {
ty
)
}
&Inst::FloatSelect {
op,
rd,
tmp,
rs1,
rs2,
ty,
} => {
let rs1 = format_reg(rs1, allocs);
let rs2 = format_reg(rs2, allocs);
let tmp = format_reg(tmp.to_reg(), allocs);
let rd = format_reg(rd.to_reg(), allocs);
format!(
"f{}.{} {},{},{}##tmp={} ty={}",
op.op_name(),
if ty == F32 { "s" } else { "d" },
rd,
rs1,
rs2,
tmp,
ty
)
}
&Inst::AtomicStore { src, ty, p } => {
let src = format_reg(src, allocs);
let p = format_reg(p, allocs);
Expand Down
21 changes: 19 additions & 2 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1478,8 +1478,16 @@

;;;; Rules for `fmin` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; RISC-V's `fmin` instruction returns the number input if one of inputs is a
;; NaN. We handle this by manually checking if one of the inputs is a NaN
;; and selecting based on that result.
(rule 0 (lower (has_type (ty_scalar_float ty) (fmin x y)))
(gen_float_select (FloatSelectOP.Min) x y ty))
(let (;; Check if both inputs are not nan.
(is_ordered FCmp (emit_fcmp (FloatCC.Ordered) ty x y))
;; `fadd` returns a nan if any of the inputs is a NaN.
(nan FReg (rv_fadd ty x y))
(min FReg (rv_fmin ty x y)))
(gen_select_freg is_ordered min nan)))

;; vfmin does almost the right thing, but it does not handle NaN's correctly.
;; We should return a NaN if any of the inputs is a NaN, but vfmin returns the
Expand All @@ -1496,8 +1504,17 @@

;;;; Rules for `fmax` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; RISC-V's `fmax` instruction returns the number input if one of inputs is a
;; NaN. We handle this by manually checking if one of the inputs is a NaN
;; and selecting based on that result.
(rule 0 (lower (has_type (ty_scalar_float ty) (fmax x y)))
(gen_float_select (FloatSelectOP.Max) x y ty))
(let (;; Check if both inputs are not nan.
(is_ordered FCmp (emit_fcmp (FloatCC.Ordered) ty x y))
;; `fadd` returns a NaN if any of the inputs is a NaN.
(nan FReg (rv_fadd ty x y))
(max FReg (rv_fmax ty x y)))
(gen_select_freg is_ordered max nan)))


;; vfmax does almost the right thing, but it does not handle NaN's correctly.
;; We should return a NaN if any of the inputs is a NaN, but vfmax returns the
Expand Down
Loading