Skip to content

Commit

Permalink
Auto merge of #115827 - eduardosm:miri-sse-reduce-code-dup, r=RalfJung
Browse files Browse the repository at this point in the history
miri: reduce code duplication in some SSE/SSE2 intrinsics

Reduces code duplication in the Miri implementation of some SSE and SSE2 using generics and rustc_const_eval helper functions.

There are also some other minor changes.

r? `@RalfJung`
  • Loading branch information
bors committed Sep 20, 2023
2 parents d255bf0 + 3bb6853 commit 48eadf7
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 521 deletions.
37 changes: 36 additions & 1 deletion src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use rustc_middle::mir;
use rustc_middle::ty::{
self,
layout::{IntegerExt as _, LayoutOf, TyAndLayout},
Ty, TyCtxt,
IntTy, Ty, TyCtxt, UintTy,
};
use rustc_span::{def_id::CrateNum, sym, Span, Symbol};
use rustc_target::abi::{Align, FieldIdx, FieldsShape, Integer, Size, Variants};
Expand Down Expand Up @@ -1066,6 +1066,24 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
),
}
}

/// Returns an integer type that is twice wide as `ty`
fn get_twice_wide_int_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
let this = self.eval_context_ref();
match ty.kind() {
// Unsigned
ty::Uint(UintTy::U8) => this.tcx.types.u16,
ty::Uint(UintTy::U16) => this.tcx.types.u32,
ty::Uint(UintTy::U32) => this.tcx.types.u64,
ty::Uint(UintTy::U64) => this.tcx.types.u128,
// Signed
ty::Int(IntTy::I8) => this.tcx.types.i16,
ty::Int(IntTy::I16) => this.tcx.types.i32,
ty::Int(IntTy::I32) => this.tcx.types.i64,
ty::Int(IntTy::I64) => this.tcx.types.i128,
_ => span_bug!(this.cur_span(), "unexpected type: {ty:?}"),
}
}
}

impl<'mir, 'tcx> MiriMachine<'mir, 'tcx> {
Expand Down Expand Up @@ -1151,3 +1169,20 @@ pub fn get_local_crates(tcx: TyCtxt<'_>) -> Vec<CrateNum> {
pub fn target_os_is_unix(target_os: &str) -> bool {
matches!(target_os, "linux" | "macos" | "freebsd" | "android")
}

pub(crate) fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
// SIMD uses all-1 as pattern for "true". In two's complement,
// -1 has all its bits set to one and `from_int` will truncate or
// sign-extend it to `size` as required.
let val = if b { -1 } else { 0 };
Scalar::from_int(val, size)
}

pub(crate) fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
let val = elem.to_scalar().to_int(elem.layout.size)?;
Ok(match val {
0 => false,
-1 => true,
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
})
}
19 changes: 2 additions & 17 deletions src/shims/intrinsics/simd.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use rustc_apfloat::{Float, Round};
use rustc_middle::ty::layout::{HasParamEnv, LayoutOf};
use rustc_middle::{mir, ty, ty::FloatTy};
use rustc_target::abi::{Endian, HasDataLayout, Size};
use rustc_target::abi::{Endian, HasDataLayout};

use crate::*;
use helpers::check_arg_count;
use helpers::{bool_to_simd_element, check_arg_count, simd_element_to_bool};

impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {}
pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
Expand Down Expand Up @@ -612,21 +612,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
}

fn bool_to_simd_element(b: bool, size: Size) -> Scalar<Provenance> {
// SIMD uses all-1 as pattern for "true"
let val = if b { -1 } else { 0 };
Scalar::from_int(val, size)
}

fn simd_element_to_bool(elem: ImmTy<'_, Provenance>) -> InterpResult<'_, bool> {
let val = elem.to_scalar().to_int(elem.layout.size)?;
Ok(match val {
0 => false,
-1 => true,
_ => throw_ub_format!("each element of a SIMD mask must be all-0-bits or all-1-bits"),
})
}

fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
assert!(idx < vec_len);
match endianness {
Expand Down
158 changes: 157 additions & 1 deletion src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::InterpResult;
use rustc_middle::mir;
use rustc_target::abi::Size;

use crate::*;
use helpers::bool_to_simd_element;

pub(super) mod sse;
pub(super) mod sse2;
Expand Down Expand Up @@ -43,3 +47,155 @@ impl FloatCmpOp {
}
}
}

#[derive(Copy, Clone)]
enum FloatBinOp {
/// Arithmetic operation
Arith(mir::BinOp),
/// Comparison
Cmp(FloatCmpOp),
/// Minimum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/minss>
/// <https://www.felixcloutier.com/x86/minps>
/// <https://www.felixcloutier.com/x86/minsd>
/// <https://www.felixcloutier.com/x86/minpd>
Min,
/// Maximum value (with SSE semantics)
///
/// <https://www.felixcloutier.com/x86/maxss>
/// <https://www.felixcloutier.com/x86/maxps>
/// <https://www.felixcloutier.com/x86/maxsd>
/// <https://www.felixcloutier.com/x86/maxpd>
Max,
}

/// Performs `which` scalar operation on `left` and `right` and returns
/// the result.
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
this: &crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &ImmTy<'tcx, Provenance>,
right: &ImmTy<'tcx, Provenance>,
) -> InterpResult<'tcx, Scalar<Provenance>> {
match which {
FloatBinOp::Arith(which) => {
let (res, _overflow, _ty) = this.overflowing_binary_op(which, left, right)?;
Ok(res)
}
FloatBinOp::Cmp(which) => {
let left = left.to_scalar().to_float::<F>()?;
let right = right.to_scalar().to_float::<F>()?;
// FIXME: Make sure that these operations match the semantics
// of cmpps/cmpss/cmppd/cmpsd
let res = match which {
FloatCmpOp::Eq => left == right,
FloatCmpOp::Lt => left < right,
FloatCmpOp::Le => left <= right,
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
FloatCmpOp::Neq => left != right,
FloatCmpOp::Nlt => !(left < right),
FloatCmpOp::Nle => !(left <= right),
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
};
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
}
FloatBinOp::Min => {
let left_scalar = left.to_scalar();
let left = left_scalar.to_float::<F>()?;
let right_scalar = right.to_scalar();
let right = right_scalar.to_float::<F>()?;
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
// is true when `x` is either +0 or -0.
if (left == F::ZERO && right == F::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left >= right
{
Ok(right_scalar)
} else {
Ok(left_scalar)
}
}
FloatBinOp::Max => {
let left_scalar = left.to_scalar();
let left = left_scalar.to_float::<F>()?;
let right_scalar = right.to_scalar();
let right = right_scalar.to_float::<F>()?;
// SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
// is true when `x` is either +0 or -0.
if (left == F::ZERO && right == F::ZERO)
|| left.is_nan()
|| right.is_nan()
|| left <= right
{
Ok(right_scalar)
} else {
Ok(left_scalar)
}
}
}
}

/// Performs `which` operation on the first component of `left` and `right`
/// and copies the other components from `left`. The result is stored in `dest`.
fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

let res0 = bin_op_float::<F>(
this,
which,
&this.read_immediate(&this.project_index(&left, 0)?)?,
&this.read_immediate(&this.project_index(&right, 0)?)?,
)?;
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;

for i in 1..dest_len {
this.copy_op(
&this.project_index(&left, i)?,
&this.project_index(&dest, i)?,
/*allow_transmute*/ false,
)?;
}

Ok(())
}

/// Performs `which` operation on each component of `left` and
/// `right`, storing the result is stored in `dest`.
fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
which: FloatBinOp,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

for i in 0..dest_len {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;
let dest = this.project_index(&dest, i)?;

let res = bin_op_float::<F>(this, which, &left, &right)?;
this.write_scalar(res, &dest)?;
}

Ok(())
}
Loading

0 comments on commit 48eadf7

Please sign in to comment.