Skip to content

Commit

Permalink
Move implementation of SSE4.1 dpps/dppd to helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardosm committed Dec 8, 2023
1 parent b1fcba4 commit 44bf5fc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 37 deletions.
50 changes: 50 additions & 0 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,56 @@ fn horizontal_bin_op<'tcx>(
Ok(())
}

/// Conditionally multiplies the packed floating-point elements in
/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
/// products (up to 4), and conditionally stores the sum in `dest` using
/// the low 4 bits of `imm`.
fn conditional_dot_product<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
imm: &OpTy<'tcx, Provenance>,
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert!(dest_len <= 4);

let imm = this.read_scalar(imm)?.to_u8()?;

let element_layout = left.layout.field(this, 0);

// Calculate dot product
// Elements are floating point numbers, but we can use `from_int`
// because the representation of 0.0 is all zero bits.
let mut sum = ImmTy::from_int(0u8, element_layout);
for i in 0..left_len {
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;

let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
}
}

// Write to destination (conditioned to imm)
for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;

if imm & (1 << i) != 0 {
this.write_immediate(*sum, &dest)?;
} else {
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
}
}

Ok(())
}

/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
fn bin_op_folded<'tcx, T>(
this: &crate::MiriInterpCx<'_, 'tcx>,
Expand Down
39 changes: 2 additions & 37 deletions src/shims/x86/sse41.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use rustc_middle::mir;
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;

use super::{bin_op_folded, round_all, round_first};
use super::{bin_op_folded, conditional_dot_product, round_all, round_first};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -104,41 +103,7 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
let [left, right, imm] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert!(dest_len <= 4);

let imm = this.read_scalar(imm)?.to_u8()?;

let element_layout = left.layout.field(this, 0);

// Calculate dot product
// Elements are floating point numbers, but we can use `from_int`
// because the representation of 0.0 is all zero bits.
let mut sum = ImmTy::from_int(0u8, element_layout);
for i in 0..left_len {
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
let left = this.read_immediate(&this.project_index(&left, i)?)?;
let right = this.read_immediate(&this.project_index(&right, i)?)?;

let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
}
}

// Write to destination (conditioned to imm)
for i in 0..dest_len {
let dest = this.project_index(&dest, i)?;

if imm & (1 << i) != 0 {
this.write_immediate(*sum, &dest)?;
} else {
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
}
}
conditional_dot_product(this, left, right, imm, dest)?;
}
// Used to implement the _mm_floor_ss, _mm_ceil_ss and _mm_round_ss
// functions. Rounds the first element of `right` according to `rounding`
Expand Down

0 comments on commit 44bf5fc

Please sign in to comment.