From d57125632e78474e2623697fd77b5762df5b5454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Fri, 8 Dec 2023 22:39:11 +0100 Subject: [PATCH] Fix x86 SSE4.1 ptestnzc `(op & mask) == 0` and `(op & mask) == mask` need each to be calculated for the whole vector. For example, given * `op = [0b100, 0b010]` * `mask = [0b100, 0b110]` The correct result would be: * `op & mask = [0b100, 0b010]` Comparisons are done on the vector as a whole: * `all_zero = (op & mask) == [0, 0] = false` * `masked_set = (op & mask) == mask = false` * `!all_zero && !masked_set = true` The previous method: `op & mask = [0b100, 0b010]` Comparisons are done element-wise: * `all_zero = (op & mask) == [0, 0] = [true, true]` * `masked_set = (op & mask) == mask = [true, false]` * `!all_zero && !masked_set = [true, false]` After folding with AND, the final result would be `false`, which is incorrect. --- src/tools/miri/src/shims/x86/mod.rs | 49 ++++++++++--------- src/tools/miri/src/shims/x86/sse41.rs | 23 ++++----- .../miri/tests/pass/intrinsics-x86-sse41.rs | 5 ++ 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs index d8c3b4826a96a..1aaf820f460ec 100644 --- a/src/tools/miri/src/shims/x86/mod.rs +++ b/src/tools/miri/src/shims/x86/mod.rs @@ -666,30 +666,33 @@ fn conditional_dot_product<'tcx>( Ok(()) } -/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`. -fn bin_op_folded<'tcx, T>( +/// Calculates two booleans. +/// +/// The first is true when all the bits of `op & mask` are zero. +/// The second is true when `(op & mask) == mask` +fn test_bits_masked<'tcx>( this: &crate::MiriInterpCx<'_, 'tcx>, - lhs: &OpTy<'tcx, Provenance>, - rhs: &OpTy<'tcx, Provenance>, - init: T, - mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>, -) -> InterpResult<'tcx, T> { - assert_eq!(lhs.layout, rhs.layout); - - let (lhs, lhs_len) = this.operand_to_simd(lhs)?; - let (rhs, rhs_len) = this.operand_to_simd(rhs)?; - - assert_eq!(lhs_len, rhs_len); - - let mut acc = init; - for i in 0..lhs_len { - let lhs = this.project_index(&lhs, i)?; - let rhs = this.project_index(&rhs, i)?; - - let lhs = this.read_immediate(&lhs)?; - let rhs = this.read_immediate(&rhs)?; - acc = f(acc, lhs, rhs)?; + op: &OpTy<'tcx, Provenance>, + mask: &OpTy<'tcx, Provenance>, +) -> InterpResult<'tcx, (bool, bool)> { + assert_eq!(op.layout, mask.layout); + + let (op, op_len) = this.operand_to_simd(op)?; + let (mask, mask_len) = this.operand_to_simd(mask)?; + + assert_eq!(op_len, mask_len); + + let mut all_zero = true; + let mut masked_set = true; + for i in 0..op_len { + let op = this.project_index(&op, i)?; + let mask = this.project_index(&mask, i)?; + + let op = this.read_scalar(&op)?.to_uint(op.layout.size)?; + let mask = this.read_scalar(&mask)?.to_uint(mask.layout.size)?; + all_zero &= (op & mask) == 0; + masked_set &= (op & mask) == mask; } - Ok(acc) + Ok((all_zero, masked_set)) } diff --git a/src/tools/miri/src/shims/x86/sse41.rs b/src/tools/miri/src/shims/x86/sse41.rs index 08e3404a2242b..67bb63f0a3d93 100644 --- a/src/tools/miri/src/shims/x86/sse41.rs +++ b/src/tools/miri/src/shims/x86/sse41.rs @@ -1,7 +1,7 @@ use rustc_span::Symbol; use rustc_target::spec::abi::Abi; -use super::{bin_op_folded, conditional_dot_product, round_all, round_first}; +use super::{conditional_dot_product, round_all, round_first, test_bits_masked}; use crate::*; use shims::foreign_items::EmulateForeignItemResult; @@ -217,21 +217,18 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: } // Used to implement the _mm_testz_si128, _mm_testc_si128 // and _mm_testnzc_si128 functions. - // Tests `op & mask == 0`, `op & mask == mask` or - // `op & mask != 0 && op & mask != mask` + // Tests `(op & mask) == 0`, `(op & mask) == mask` or + // `(op & mask) != 0 && (op & mask) != mask` "ptestz" | "ptestc" | "ptestnzc" => { let [op, mask] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let res = bin_op_folded(this, op, mask, true, |acc, op, mask| { - let op = op.to_scalar().to_uint(op.layout.size)?; - let mask = mask.to_scalar().to_uint(mask.layout.size)?; - Ok(match unprefixed_name { - "ptestz" => acc && (op & mask) == 0, - "ptestc" => acc && (op & mask) == mask, - "ptestnzc" => acc && (op & mask) != 0 && (op & mask) != mask, - _ => unreachable!(), - }) - })?; + let (all_zero, masked_set) = test_bits_masked(this, op, mask)?; + let res = match unprefixed_name { + "ptestz" => all_zero, + "ptestc" => masked_set, + "ptestnzc" => !all_zero && !masked_set, + _ => unreachable!(), + }; this.write_scalar(Scalar::from_i32(res.into()), dest)?; } diff --git a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs index 13856d29d3f39..06607f3fd59e1 100644 --- a/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs +++ b/src/tools/miri/tests/pass/intrinsics-x86-sse41.rs @@ -515,6 +515,11 @@ unsafe fn test_sse41() { let mask = _mm_set1_epi8(0b101); let r = _mm_testnzc_si128(a, mask); assert_eq!(r, 0); + + let a = _mm_setr_epi32(0b100, 0, 0, 0b010); + let mask = _mm_setr_epi32(0b100, 0, 0, 0b110); + let r = _mm_testnzc_si128(a, mask); + assert_eq!(r, 1); } test_mm_testnzc_si128(); }