Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some x86 intrinsics code to helper functions in shims::x86 #3214

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use rand::Rng as _;

use rustc_apfloat::{ieee::Single, Float as _};
use rustc_middle::{mir, ty};
use rustc_span::Symbol;
use rustc_target::abi::Size;
Expand Down Expand Up @@ -331,6 +334,210 @@ fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
Ok(())
}

#[derive(Copy, Clone)]
enum FloatUnaryOp {
/// sqrt(x)
///
/// <https://www.felixcloutier.com/x86/sqrtss>
/// <https://www.felixcloutier.com/x86/sqrtps>
Sqrt,
/// Approximation of 1/x
///
/// <https://www.felixcloutier.com/x86/rcpss>
/// <https://www.felixcloutier.com/x86/rcpps>
Rcp,
/// Approximation of 1/sqrt(x)
///
/// <https://www.felixcloutier.com/x86/rsqrtss>
/// <https://www.felixcloutier.com/x86/rsqrtps>
Rsqrt,
}

/// Performs `which` scalar operation on `op` and returns the result.
#[allow(clippy::arithmetic_side_effects)] // floating point operations without side effects
fn unary_op_f32<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatUnaryOp,
op: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatUnaryOp::Sqrt => {
let op = op.to_scalar();
// FIXME using host floats
Ok(Scalar::from_u32(f32::from_bits(op.to_u32()?).sqrt().to_bits()))
}
FloatUnaryOp::Rcp => {
let op = op.to_scalar().to_f32()?;
let div = (Single::from_u128(1).value / op).value;
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
// inaccuracy of RCP.
let res = apply_random_float_error(this, div, -12);
Ok(Scalar::from_f32(res))
}
FloatUnaryOp::Rsqrt => {
let op = op.to_scalar().to_u32()?;
// FIXME using host floats
let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into());
let rsqrt = (Single::from_u128(1).value / sqrt).value;
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
// inaccuracy of RSQRT.
let res = apply_random_float_error(this, rsqrt, -12);
Ok(Scalar::from_f32(res))
}
}
}

/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
#[allow(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic
fn apply_random_float_error<F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, '_>,
val: F,
err_scale: i32,
) -> F {
let rng = this.machine.rng.get_mut();
// generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale
let err =
F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.checked_sub(64).unwrap());
// give it a random sign
let err = if rng.gen::<bool>() { -err } else { err };
// multiple the value with (1+err)
(val * (F::from_u128(1).value + err).value).value
}

/// Performs `which` operation on the first component of `op` and copies
/// the other components. The result is stored in `dest`.
fn unary_op_ss<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatUnaryOp,
op: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

let res0 = unary_op_f32(this, which, &this.read_immediate(&this.project_index(&op, 0)?)?)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&op, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

/// Performs `which` operation on each component of `op`, storing the
/// result is stored in `dest`.
fn unary_op_ps<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatUnaryOp,
op: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

for i in 0..dest_len {
let op = this.read_immediate(&this.project_index(&op, i)?)?;
let dest = this.project_index(&dest, i)?;

let res = unary_op_f32(this, which, &op)?;
this.write_scalar(res, &dest)?;
}

Ok(())
}

// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
let res = op0.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, 0)?,
)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

// Rounds all elements of `op` according to `rounding`.
fn round_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
rounding: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, op_len);

let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;

for i in 0..dest_len {
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
let res = op.round_to_integral(rounding).value;
this.write_scalar(
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
&this.project_index(&dest, i)?,
)?;
}

Ok(())
}

/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
/// `round.{ss,sd,ps,pd}` intrinsics.
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
// The fourth bit of `rounding` only affects the SSE status
// register, which cannot be accessed from Miri (or from Rust,
// for that matter), so we can ignore it.
match rounding & !0b1000 {
// When the third bit is 0, the rounding mode is determined by the
// first two bits.
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
0b011 => Ok(rustc_apfloat::Round::TowardZero),
// When the third bit is 1, the rounding mode is determined by the
// SSE status register. Since we do not support modifying it from
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
}
}

/// Converts each element of `op` from floating point to signed integer.
///
/// When the input value is NaN or out of range, fall back to minimum value.
Expand Down Expand Up @@ -408,3 +615,81 @@ fn horizontal_bin_op<'tcx>(

Ok(())
}

/// Conditionally multiplies the packed floating-point elements in
/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
/// products (up to 4), and conditionally stores the sum in `dest` using
/// the low 4 bits of `imm`.
fn conditional_dot_product<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
imm: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert!(dest_len <= 4);

let imm = this.read_scalar(imm)?.to_u8()?;

let element_layout = left.layout.field(this, 0);

// Calculate dot product
// Elements are floating point numbers, but we can use `from_int`
// because the representation of 0.0 is all zero bits.
let mut sum = ImmTy::from_int(0u8, element_layout);
for i in 0..left_len {
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;

let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
}
}

// Write to destination (conditioned to imm)
for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;

if imm & (1 << i) != 0 {
this.write_immediate(*sum, &dest)?;
} else {
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
}
}

Ok(())
}

/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
fn bin_op_folded<'tcx, T>(
this: &crate::MiriInterpCx<'_, 'tcx>,
lhs: &OpTy<'tcx, Provenance>,
rhs: &OpTy<'tcx, Provenance>,
init: T,
mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>,
) -> InterpResult<'tcx, T> {
assert_eq!(lhs.layout, rhs.layout);

let (lhs, lhs_len) = this.operand_to_simd(lhs)?;
let (rhs, rhs_len) = this.operand_to_simd(rhs)?;

assert_eq!(lhs_len, rhs_len);

let mut acc = init;
for i in 0..lhs_len {
let lhs = this.project_index(&lhs, i)?;
let rhs = this.project_index(&rhs, i)?;

let lhs = this.read_immediate(&lhs)?;
let rhs = this.read_immediate(&rhs)?;
acc = f(acc, lhs, rhs)?;
}

Ok(acc)
}
Loading
Loading