Skip to content

Commit

Permalink
Auto merge of #2028 - RalfJung:simd-round, r=RalfJung
Browse files Browse the repository at this point in the history
implement SIMD float rounding functions

Cc #1912
  • Loading branch information
bors committed Mar 16, 2022
2 parents 39c72db + 1f237b3 commit 49729b5
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 10 deletions.
66 changes: 60 additions & 6 deletions src/shims/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,20 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
}

// Floating-point operations
"fabsf32" => {
let &[ref f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f32()?;
// Can be implemented in soft-floats.
this.write_scalar(Scalar::from_f32(f.abs()), dest)?;
}
"fabsf64" => {
let &[ref f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f64()?;
// Can be implemented in soft-floats.
this.write_scalar(Scalar::from_f64(f.abs()), dest)?;
}
#[rustfmt::skip]
| "sinf32"
| "fabsf32"
| "cosf32"
| "sqrtf32"
| "expf32"
Expand All @@ -110,7 +121,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
let f = f32::from_bits(this.read_scalar(f)?.to_u32()?);
let f = match intrinsic_name {
"sinf32" => f.sin(),
"fabsf32" => f.abs(),
"cosf32" => f.cos(),
"sqrtf32" => f.sqrt(),
"expf32" => f.exp(),
Expand All @@ -129,7 +139,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx

#[rustfmt::skip]
| "sinf64"
| "fabsf64"
| "cosf64"
| "sqrtf64"
| "expf64"
Expand All @@ -147,7 +156,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
let f = f64::from_bits(this.read_scalar(f)?.to_u64()?);
let f = match intrinsic_name {
"sinf64" => f.sin(),
"fabsf64" => f.abs(),
"cosf64" => f.cos(),
"sqrtf64" => f.sqrt(),
"expf64" => f.exp(),
Expand Down Expand Up @@ -317,20 +325,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
// SIMD operations
#[rustfmt::skip]
| "simd_neg"
| "simd_fabs" => {
| "simd_fabs"
| "simd_ceil"
| "simd_floor"
| "simd_round"
| "simd_trunc" => {
let &[ref op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

#[derive(Copy, Clone)]
enum HostFloatOp {
Ceil,
Floor,
Round,
Trunc,
}
#[derive(Copy, Clone)]
enum Op {
MirOp(mir::UnOp),
Abs,
HostOp(HostFloatOp),
}
let which = match intrinsic_name {
"simd_neg" => Op::MirOp(mir::UnOp::Neg),
"simd_fabs" => Op::Abs,
"simd_ceil" => Op::HostOp(HostFloatOp::Ceil),
"simd_floor" => Op::HostOp(HostFloatOp::Floor),
"simd_round" => Op::HostOp(HostFloatOp::Round),
"simd_trunc" => Op::HostOp(HostFloatOp::Trunc),
_ => unreachable!(),
};

Expand All @@ -342,14 +367,43 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
Op::Abs => {
// Works for f32 and f64.
let ty::Float(float_ty) = op.layout.ty.kind() else {
bug!("simd_fabs operand is not a float")
bug!("{} operand is not a float", intrinsic_name)
};
let op = op.to_scalar()?;
match float_ty {
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
}
}
Op::HostOp(host_op) => {
let ty::Float(float_ty) = op.layout.ty.kind() else {
bug!("{} operand is not a float", intrinsic_name)
};
// FIXME using host floats
match float_ty {
FloatTy::F32 => {
let f = f32::from_bits(op.to_scalar()?.to_u32()?);
let res = match host_op {
HostFloatOp::Ceil => f.ceil(),
HostFloatOp::Floor => f.floor(),
HostFloatOp::Round => f.round(),
HostFloatOp::Trunc => f.trunc(),
};
Scalar::from_u32(res.to_bits())
}
FloatTy::F64 => {
let f = f64::from_bits(op.to_scalar()?.to_u64()?);
let res = match host_op {
HostFloatOp::Ceil => f.ceil(),
HostFloatOp::Floor => f.floor(),
HostFloatOp::Round => f.round(),
HostFloatOp::Trunc => f.trunc(),
};
Scalar::from_u64(res.to_bits())
}
}

}
};
this.write_scalar(val, &dest.into())?;
}
Expand Down
65 changes: 61 additions & 4 deletions tests/run-pass/portable-simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,39 @@ fn simd_ops_i32() {
assert_eq!(a.min(b * i32x4::splat(4)), i32x4::from_array([4, 8, 10, -16]));

assert_eq!(
i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([1, i8::MIN, i8::MAX, 28])),
i8x4::from_array([i8::MAX, -23, 23, i8::MIN]).saturating_add(i8x4::from_array([
1,
i8::MIN,
i8::MAX,
28
])),
i8x4::from_array([i8::MAX, i8::MIN, i8::MAX, -100])
);
assert_eq!(
i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([1, i8::MAX, i8::MAX, -80])),
i8x4::from_array([i8::MAX, -28, 27, 42]).saturating_sub(i8x4::from_array([
1,
i8::MAX,
i8::MAX,
-80
])),
i8x4::from_array([126, i8::MIN, -100, 122])
);
assert_eq!(
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([1, 1, u8::MAX, 200])),
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_add(u8x4::from_array([
1,
1,
u8::MAX,
200
])),
u8x4::from_array([u8::MAX, 1, u8::MAX, 242])
);
assert_eq!(
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([1, 1, u8::MAX, 200])),
u8x4::from_array([u8::MAX, 0, 23, 42]).saturating_sub(u8x4::from_array([
1,
1,
u8::MAX,
200
])),
u8x4::from_array([254, 0, 0, 0])
);

Expand Down Expand Up @@ -259,6 +279,42 @@ fn simd_gather_scatter() {
assert_eq!(vec, vec![124, 11, 12, 82, 14, 15, 16, 17, 18]);
}

fn simd_round() {
assert_eq!(
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
f32x4::from_array([1.0, 2.0, 2.0, -4.0])
);
assert_eq!(
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
f32x4::from_array([0.0, 1.0, 2.0, -5.0])
);
assert_eq!(
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
f32x4::from_array([1.0, 1.0, 2.0, -5.0])
);
assert_eq!(
f32x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
f32x4::from_array([0.0, 1.0, 2.0, -4.0])
);

assert_eq!(
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).ceil(),
f64x4::from_array([1.0, 2.0, 2.0, -4.0])
);
assert_eq!(
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).floor(),
f64x4::from_array([0.0, 1.0, 2.0, -5.0])
);
assert_eq!(
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).round(),
f64x4::from_array([1.0, 1.0, 2.0, -5.0])
);
assert_eq!(
f64x4::from_array([0.9, 1.001, 2.0, -4.5]).trunc(),
f64x4::from_array([0.0, 1.0, 2.0, -4.0])
);
}

fn simd_intrinsics() {
extern "platform-intrinsic" {
fn simd_eq<T, U>(x: T, y: T) -> U;
Expand Down Expand Up @@ -299,5 +355,6 @@ fn main() {
simd_cast();
simd_swizzle();
simd_gather_scatter();
simd_round();
simd_intrinsics();
}

0 comments on commit 49729b5

Please sign in to comment.