Skip to content

Commit

Permalink
Fix x86 SSE4.1 ptestnzc
Browse files Browse the repository at this point in the history
`(op & mask) == 0` and `(op & mask) == mask` need each to be calculated for the whole vector.

For example, given
* `op = [0b100, 0b010]`
* `mask = [0b100, 0b110]`

The correct result would be:
* `op & mask = [0b100, 0b010]`
Comparisons are done on the vector as a whole:
* `all_zero = (op & mask) == [0, 0] = false`
* `masked_set = (op & mask) == mask = false`
* `!all_zero && !masked_set = true`

The previous method:
`op & mask = [0b100, 0b010]`
Comparisons are done element-wise:
* `all_zero = (op & mask) == [0, 0] = [true, true]`
* `masked_set = (op & mask) == mask = [true, false]`
* `!all_zero && !masked_set = [true, false]`
After folding with AND, the final result would be `false`, which is incorrect.
  • Loading branch information
eduardosm committed Dec 8, 2023
1 parent a5b9f54 commit 092eb11
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 36 deletions.
49 changes: 26 additions & 23 deletions src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,30 +666,33 @@ fn conditional_dot_product<'tcx>(
Ok(())
}

/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
fn bin_op_folded<'tcx, T>(
/// Calculates two booleans.
///
/// The first is true when all the bits of `op & mask` are zero.
/// The second is true when `(op & mask) == mask`
fn test_bits_masked<'tcx>(
this: &crate::MiriInterpCx<'_, 'tcx>,
lhs: &OpTy<'tcx, Provenance>,
rhs: &OpTy<'tcx, Provenance>,
init: T,
mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>,
) -> InterpResult<'tcx, T> {
assert_eq!(lhs.layout, rhs.layout);

let (lhs, lhs_len) = this.operand_to_simd(lhs)?;
let (rhs, rhs_len) = this.operand_to_simd(rhs)?;

assert_eq!(lhs_len, rhs_len);

let mut acc = init;
for i in 0..lhs_len {
let lhs = this.project_index(&lhs, i)?;
let rhs = this.project_index(&rhs, i)?;

let lhs = this.read_immediate(&lhs)?;
let rhs = this.read_immediate(&rhs)?;
acc = f(acc, lhs, rhs)?;
op: &OpTy<'tcx, Provenance>,
mask: &OpTy<'tcx, Provenance>,
) -> InterpResult<'tcx, (bool, bool)> {
assert_eq!(op.layout, mask.layout);

let (op, op_len) = this.operand_to_simd(op)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;

assert_eq!(op_len, mask_len);

let mut all_zero = true;
let mut masked_set = true;
for i in 0..op_len {
let op = this.project_index(&op, i)?;
let mask = this.project_index(&mask, i)?;

let op = this.read_scalar(&op)?.to_uint(op.layout.size)?;
let mask = this.read_scalar(&mask)?.to_uint(mask.layout.size)?;
all_zero &= (op & mask) == 0;
masked_set &= (op & mask) == mask;
}

Ok(acc)
Ok((all_zero, masked_set))
}
23 changes: 10 additions & 13 deletions src/shims/x86/sse41.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;

use super::{bin_op_folded, conditional_dot_product, round_all, round_first};
use super::{conditional_dot_product, round_all, round_first, test_bits_masked};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -217,21 +217,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
}
// Used to implement the _mm_testz_si128, _mm_testc_si128
// and _mm_testnzc_si128 functions.
// Tests `op & mask == 0`, `op & mask == mask` or
// `op & mask != 0 && op & mask != mask`
// Tests `(op & mask) == 0`, `(op & mask) == mask` or
// `(op & mask) != 0 && (op & mask) != mask`
"ptestz" | "ptestc" | "ptestnzc" => {
let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let res = bin_op_folded(this, op, mask, true, |acc, op, mask| {
let op = op.to_scalar().to_uint(op.layout.size)?;
let mask = mask.to_scalar().to_uint(mask.layout.size)?;
Ok(match unprefixed_name {
"ptestz" => acc && (op & mask) == 0,
"ptestc" => acc && (op & mask) == mask,
"ptestnzc" => acc && (op & mask) != 0 && (op & mask) != mask,
_ => unreachable!(),
})
})?;
let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
let res = match unprefixed_name {
"ptestz" => all_zero,
"ptestc" => masked_set,
"ptestnzc" => !all_zero && !masked_set,
_ => unreachable!(),
};

this.write_scalar(Scalar::from_i32(res.into()), dest)?;
}
Expand Down
5 changes: 5 additions & 0 deletions tests/pass/intrinsics-x86-sse41.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,11 @@ unsafe fn test_sse41() {
let mask = _mm_set1_epi8(0b101);
let r = _mm_testnzc_si128(a, mask);
assert_eq!(r, 0);

let a = _mm_setr_epi32(0b100, 0, 0, 0b010);
let mask = _mm_setr_epi32(0b100, 0, 0, 0b110);
let r = _mm_testnzc_si128(a, mask);
assert_eq!(r, 1);
}
test_mm_testnzc_si128();
}
Expand Down

0 comments on commit 092eb11

Please sign in to comment.