Skip to content

Commit

Permalink
cranelift: Add constant propagation for f16 and f128
Browse files Browse the repository at this point in the history
  • Loading branch information
beetrees committed Jul 2, 2024
1 parent 72cc361 commit e424cd8
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 23 deletions.
78 changes: 58 additions & 20 deletions cranelift/codegen/src/ir/immediates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,9 @@ fn parse_float(s: &str, w: u8, t: u8) -> Result<u128, &'static str> {
impl Ieee16 {
const SIGNIFICAND_BITS: u8 = 10;
const EXPONENT_BITS: u8 = 5;
const SIGN_MASK: u16 = 1 << (Self::EXPONENT_BITS + Self::SIGNIFICAND_BITS);
const SIGNIFICAND_MASK: u16 = u16::MAX >> (Self::EXPONENT_BITS + 1);
const EXPONENT_MASK: u16 = !Self::SIGN_MASK & !Self::SIGNIFICAND_MASK;

/// Create a new `Ieee16` containing the bits of `x`.
pub fn with_bits(x: u16) -> Self {
Expand All @@ -779,6 +782,16 @@ impl Ieee16 {
self.0
}

/// Computes the absolute value of self.
pub fn abs(self) -> Self {
Self::with_bits(self.bits() & !Self::SIGN_MASK)
}

/// Returns a number composed of the magnitude of self and the sign of sign.
pub fn copysign(self, sign: Self) -> Self {
Self::with_bits((self.bits() & !Self::SIGN_MASK) | (sign.bits() & Self::SIGN_MASK))
}

/// Returns true if self is positive or negative zero
pub fn is_zero(&self) -> bool {
self.partial_cmp(&Self::with_bits(0)) == Some(Ordering::Equal)
Expand All @@ -788,14 +801,12 @@ impl Ieee16 {
impl PartialOrd for Ieee16 {
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
// FIXME(#8312): Use Rust `f16` comparisons once `f16` support is stabalised.
let significand_mask = u16::MAX >> (Self::EXPONENT_BITS + 1);
let sign_mask = 1 << (Self::EXPONENT_BITS + Self::SIGNIFICAND_BITS);
let exponent_mask = !sign_mask & !significand_mask;

let lhs_abs = self.bits() & !sign_mask;
let rhs_abs = rhs.bits() & !sign_mask;
if (lhs_abs & exponent_mask == exponent_mask && lhs_abs & significand_mask != 0)
&& (rhs_abs & exponent_mask == exponent_mask && rhs_abs & significand_mask != 0)
let lhs_abs = self.bits() & !Self::SIGN_MASK;
let rhs_abs = rhs.bits() & !Self::SIGN_MASK;
if (lhs_abs & Self::EXPONENT_MASK == Self::EXPONENT_MASK
&& lhs_abs & Self::SIGNIFICAND_MASK != 0)
&& (rhs_abs & Self::EXPONENT_MASK == Self::EXPONENT_MASK
&& rhs_abs & Self::SIGNIFICAND_MASK != 0)
{
// One of the floats is a NaN.
return None;
Expand All @@ -804,8 +815,8 @@ impl PartialOrd for Ieee16 {
// Zeros are always equal regardless of sign.
return Some(Ordering::Equal);
}
let lhs_positive = self.bits() & sign_mask == 0;
let rhs_positive = rhs.bits() & sign_mask == 0;
let lhs_positive = self.bits() & Self::SIGN_MASK == 0;
let rhs_positive = rhs.bits() & Self::SIGN_MASK == 0;
if lhs_positive != rhs_positive {
// Different signs: negative < positive
return lhs_positive.partial_cmp(&rhs_positive);
Expand Down Expand Up @@ -849,6 +860,14 @@ impl IntoBytes for Ieee16 {
}
}

impl Neg for Ieee16 {
type Output = Self;

fn neg(self) -> Self {
Self::with_bits(self.bits() ^ Self::SIGN_MASK)
}
}

impl Ieee32 {
/// Create a new `Ieee32` containing the bits of `x`.
pub fn with_bits(x: u32) -> Self {
Expand Down Expand Up @@ -1293,6 +1312,9 @@ impl Not for Ieee64 {
impl Ieee128 {
const SIGNIFICAND_BITS: u8 = 112;
const EXPONENT_BITS: u8 = 15;
const SIGN_MASK: u128 = 1 << (Self::EXPONENT_BITS + Self::SIGNIFICAND_BITS);
const SIGNIFICAND_MASK: u128 = u128::MAX >> (Self::EXPONENT_BITS + 1);
const EXPONENT_MASK: u128 = !Self::SIGN_MASK & !Self::SIGNIFICAND_MASK;

/// Create a new `Ieee128` containing the bits of `x`.
pub fn with_bits(x: u128) -> Self {
Expand All @@ -1304,6 +1326,16 @@ impl Ieee128 {
self.0
}

/// Computes the absolute value of self.
pub fn abs(self) -> Self {
Self::with_bits(self.bits() & !Self::SIGN_MASK)
}

/// Returns a number composed of the magnitude of self and the sign of sign.
pub fn copysign(self, sign: Self) -> Self {
Self::with_bits((self.bits() & !Self::SIGN_MASK) | (sign.bits() & Self::SIGN_MASK))
}

/// Returns true if self is positive or negative zero
pub fn is_zero(&self) -> bool {
self.partial_cmp(&Self::with_bits(0)) == Some(Ordering::Equal)
Expand All @@ -1313,14 +1345,12 @@ impl Ieee128 {
impl PartialOrd for Ieee128 {
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
// FIXME(#8312): Use Rust `f128` comparisons once `f128` support is stabalised.
let significand_mask = u128::MAX >> (Self::EXPONENT_BITS + 1);
let sign_mask = 1 << (Self::EXPONENT_BITS + Self::SIGNIFICAND_BITS);
let exponent_mask = !sign_mask & !significand_mask;

let lhs_abs = self.bits() & !sign_mask;
let rhs_abs = rhs.bits() & !sign_mask;
if (lhs_abs & exponent_mask == exponent_mask && lhs_abs & significand_mask != 0)
&& (rhs_abs & exponent_mask == exponent_mask && rhs_abs & significand_mask != 0)
let lhs_abs = self.bits() & !Self::SIGN_MASK;
let rhs_abs = rhs.bits() & !Self::SIGN_MASK;
if (lhs_abs & Self::EXPONENT_MASK == Self::EXPONENT_MASK
&& lhs_abs & Self::SIGNIFICAND_MASK != 0)
&& (rhs_abs & Self::EXPONENT_MASK == Self::EXPONENT_MASK
&& rhs_abs & Self::SIGNIFICAND_MASK != 0)
{
// One of the floats is a NaN.
return None;
Expand All @@ -1329,8 +1359,8 @@ impl PartialOrd for Ieee128 {
// Zeros are always equal regardless of sign.
return Some(Ordering::Equal);
}
let lhs_positive = self.bits() & sign_mask == 0;
let rhs_positive = rhs.bits() & sign_mask == 0;
let lhs_positive = self.bits() & Self::SIGN_MASK == 0;
let rhs_positive = rhs.bits() & Self::SIGN_MASK == 0;
if lhs_positive != rhs_positive {
// Different signs: negative < positive
return lhs_positive.partial_cmp(&rhs_positive);
Expand Down Expand Up @@ -1369,6 +1399,14 @@ impl IntoBytes for Ieee128 {
}
}

impl Neg for Ieee128 {
type Output = Self;

fn neg(self) -> Self {
Self::with_bits(self.bits() ^ Self::SIGN_MASK)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
24 changes: 24 additions & 0 deletions cranelift/codegen/src/isle_prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,18 @@ macro_rules! isle_common_prelude_methods {
}
}

fn f16_neg(&mut self, n: Ieee16) -> Ieee16 {
-n
}

fn f16_abs(&mut self, n: Ieee16) -> Ieee16 {
n.abs()
}

fn f16_copysign(&mut self, a: Ieee16, b: Ieee16) -> Ieee16 {
a.copysign(b)
}

fn f32_neg(&mut self, n: Ieee32) -> Ieee32 {
n.neg()
}
Expand All @@ -961,5 +973,17 @@ macro_rules! isle_common_prelude_methods {
fn f64_copysign(&mut self, a: Ieee64, b: Ieee64) -> Ieee64 {
a.copysign(b)
}

fn f128_neg(&mut self, n: Ieee128) -> Ieee128 {
-n
}

fn f128_abs(&mut self, n: Ieee128) -> Ieee128 {
n.abs()
}

fn f128_copysign(&mut self, a: Ieee128, b: Ieee128) -> Ieee128 {
a.copysign(b)
}
};
}
10 changes: 9 additions & 1 deletion cranelift/codegen/src/opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::egraph::{NewOrExistingInst, OptimizeCtx};
pub use crate::ir::condcodes::{FloatCC, IntCC};
use crate::ir::dfg::ValueDef;
pub use crate::ir::immediates::{Ieee16, Ieee32, Ieee64, Imm64, Offset32, Uimm8, V128Imm};
pub use crate::ir::immediates::{Ieee128, Ieee16, Ieee32, Ieee64, Imm64, Offset32, Uimm8, V128Imm};
use crate::ir::instructions::InstructionFormat;
pub use crate::ir::types::*;
pub use crate::ir::{
Expand Down Expand Up @@ -292,4 +292,12 @@ impl<'a, 'b, 'c> generated_code::Context for IsleContext<'a, 'b, 'c> {
fn u64_bswap64(&mut self, n: u64) -> u64 {
n.swap_bytes()
}

fn ieee128_constant_extractor(&mut self, n: Constant) -> Option<Ieee128> {
self.ctx.func.dfg.constants.get(n).try_into().ok()
}

fn ieee128_constant(&mut self, n: Ieee128) -> Constant {
self.ctx.func.dfg.constants.insert(n.into())
}
}
16 changes: 16 additions & 0 deletions cranelift/codegen/src/opts/cprop.isle
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,33 @@
(extern constructor u64_bswap64 u64_bswap64)

;; Constant fold bitwise float operations (fneg/fabs/fcopysign)
(rule (simplify (fneg $F16 (f16const $F16 n)))
(subsume (f16const $F16 (f16_neg n))))
(rule (simplify (fneg $F32 (f32const $F32 n)))
(subsume (f32const $F32 (f32_neg n))))
(rule (simplify (fneg $F64 (f64const $F64 n)))
(subsume (f64const $F64 (f64_neg n))))
(rule (simplify (fneg $F128 (f128const $F128 (ieee128_constant n))))
(subsume (f128const $F128 (ieee128_constant (f128_neg n)))))

(rule (simplify (fabs $F16 (f16const $F16 n)))
(subsume (f16const $F16 (f16_abs n))))
(rule (simplify (fabs $F32 (f32const $F32 n)))
(subsume (f32const $F32 (f32_abs n))))
(rule (simplify (fabs $F64 (f64const $F64 n)))
(subsume (f64const $F64 (f64_abs n))))
(rule (simplify (fabs $F128 (f128const $F128 (ieee128_constant n))))
(subsume (f128const $F128 (ieee128_constant (f128_abs n)))))

(rule (simplify (fcopysign $F16 (f16const $F16 n) (f16const $F16 m)))
(subsume (f16const $F16 (f16_copysign n m))))
(rule (simplify (fcopysign $F32 (f32const $F32 n) (f32const $F32 m)))
(subsume (f32const $F32 (f32_copysign n m))))
(rule (simplify (fcopysign $F64 (f64const $F64 n) (f64const $F64 m)))
(subsume (f64const $F64 (f64_copysign n m))))
(rule (simplify (fcopysign $F128 (f128const $F128 (ieee128_constant n)) (f128const $F128 (ieee128_constant m))))
(subsume (f128const $F128 (ieee128_constant (f128_copysign n m)))))

(decl ieee128_constant (Ieee128) Constant)
(extern constructor ieee128_constant ieee128_constant)
(extern extractor ieee128_constant ieee128_constant_extractor)
15 changes: 15 additions & 0 deletions cranelift/codegen/src/prelude.isle
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@

;; Floating point operations

(decl pure f16_neg (Ieee16) Ieee16)
(extern constructor f16_neg f16_neg)
(decl pure f16_abs (Ieee16) Ieee16)
(extern constructor f16_abs f16_abs)
(decl pure f16_copysign (Ieee16 Ieee16) Ieee16)
(extern constructor f16_copysign f16_copysign)
(decl pure f32_neg (Ieee32) Ieee32)
(extern constructor f32_neg f32_neg)
(decl pure f32_abs (Ieee32) Ieee32)
Expand All @@ -251,6 +257,13 @@
(extern constructor f64_abs f64_abs)
(decl pure f64_copysign (Ieee64 Ieee64) Ieee64)
(extern constructor f64_copysign f64_copysign)
(decl pure f128_neg (Ieee128) Ieee128)
(extern constructor f128_neg f128_neg)
(decl pure f128_abs (Ieee128) Ieee128)
(extern constructor f128_abs f128_abs)
(decl pure f128_copysign (Ieee128 Ieee128) Ieee128)
(extern constructor f128_copysign f128_copysign)
(type Ieee128 (primitive Ieee128))

;;;; `cranelift_codegen::ir::Type` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand All @@ -263,8 +276,10 @@
(extern const $R32 Type)
(extern const $R64 Type)

(extern const $F16 Type)
(extern const $F32 Type)
(extern const $F64 Type)
(extern const $F128 Type)

(extern const $I8X8 Type)
(extern const $I8X16 Type)
Expand Down
69 changes: 67 additions & 2 deletions cranelift/filetests/filetests/egraph/cprop.clif
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,37 @@ block0:
; check: v2 = iconst.i64 0xf0de_bc9a_7856_3412
; nextln: return v2

function %f16_fneg() -> f16 {
block0:
v1 = f16const 0.0
v2 = fneg v1
return v2
}

; check: v3 = f16const -0.0
; check: return v3 ; v3 = -0.0

function %f16_fabs() -> f16 {
block0:
v1 = f16const -0.0
v2 = fabs v1
return v2
}

; check: v3 = f16const 0.0
; check: return v3 ; v3 = 0.0

function %f16_fcopysign() -> f16 {
block0:
v1 = f16const -0.0
v2 = f16const NaN
v3 = fcopysign v2, v1
return v3
}

; check: v4 = f16const -NaN
; check: return v4 ; v4 = -NaN

function %f32_fneg() -> f32 {
block0:
v1 = f32const 0.0
Expand All @@ -333,7 +364,7 @@ block0:
; check: v3 = f32const 0.0
; check: return v3 ; v3 = 0.0

function %f32_fabs() -> f32 {
function %f32_fcopysign() -> f32 {
block0:
v1 = f32const -0.0
v2 = f32const NaN
Expand Down Expand Up @@ -364,7 +395,7 @@ block0:
; check: v3 = f64const 0.0
; check: return v3 ; v3 = 0.0

function %f64_fabs() -> f64 {
function %f64_fcopysign() -> f64 {
block0:
v1 = f64const -0.0
v2 = f64const NaN
Expand All @@ -374,3 +405,37 @@ block0:

; check: v4 = f64const -NaN
; check: return v4 ; v4 = -NaN

function %f128_fneg() -> f128 {
block0:
v1 = f128const 0.0
v2 = fneg v1
return v2
}

; check: const1 = 0x80000000000000000000000000000000
; check: v3 = f128const const1
; check: return v3 ; v3 = const1

function %f128_fabs() -> f128 {
block0:
v1 = f128const -0.0
v2 = fabs v1
return v2
}

; check: const1 = 0x00000000000000000000000000000000
; check: v3 = f128const const1
; check: return v3 ; v3 = const1

function %f128_fcopysign() -> f128 {
block0:
v1 = f128const -0.0
v2 = f128const NaN
v3 = fcopysign v2, v1
return v3
}

; check: const2 = 0xffff8000000000000000000000000000
; check: v4 = f128const const2
; check: return v4 ; v4 = const2

0 comments on commit e424cd8

Please sign in to comment.