From 6796c5765d7b52271cd09ab977cb617ee3971a3a Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Sun, 8 Oct 2023 12:03:01 +0200 Subject: [PATCH 1/3] miri: make NaN generation non-deterministic --- .../src/interpret/intrinsics.rs | 6 + .../rustc_const_eval/src/interpret/machine.rs | 8 + .../src/interpret/operator.rs | 15 +- src/tools/miri/src/machine.rs | 5 + src/tools/miri/src/operator.rs | 60 ++-- src/tools/miri/tests/pass/float_nan.rs | 316 ++++++++++++++++++ 6 files changed, 385 insertions(+), 25 deletions(-) create mode 100644 src/tools/miri/tests/pass/float_nan.rs diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 2c6a4de456d06..1891d286a3c1f 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -500,6 +500,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { b: &ImmTy<'tcx, M::Provenance>, dest: &PlaceTy<'tcx, M::Provenance>, ) -> InterpResult<'tcx> { + assert_eq!(a.layout.ty, b.layout.ty); + assert!(matches!(a.layout.ty.kind(), ty::Int(..) | ty::Uint(..))); + // Performs an exact division, resulting in undefined behavior where // `x % y != 0` or `y == 0` or `x == T::MIN && y == -1`. // First, check x % y != 0 (or if that computation overflows). @@ -522,7 +525,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { l: &ImmTy<'tcx, M::Provenance>, r: &ImmTy<'tcx, M::Provenance>, ) -> InterpResult<'tcx, Scalar> { + assert_eq!(l.layout.ty, r.layout.ty); + assert!(matches!(l.layout.ty.kind(), ty::Int(..) | ty::Uint(..))); assert!(matches!(mir_op, BinOp::Add | BinOp::Sub)); + let (val, overflowed) = self.overflowing_binary_op(mir_op, l, r)?; Ok(if overflowed { let size = l.layout.size; diff --git a/compiler/rustc_const_eval/src/interpret/machine.rs b/compiler/rustc_const_eval/src/interpret/machine.rs index aaa674a598f84..b172fd9f51774 100644 --- a/compiler/rustc_const_eval/src/interpret/machine.rs +++ b/compiler/rustc_const_eval/src/interpret/machine.rs @@ -6,6 +6,7 @@ use std::borrow::{Borrow, Cow}; use std::fmt::Debug; use std::hash::Hash; +use rustc_apfloat::Float; use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece}; use rustc_middle::mir; use rustc_middle::ty::layout::TyAndLayout; @@ -240,6 +241,13 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized { right: &ImmTy<'tcx, Self::Provenance>, ) -> InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)>; + /// Generate the NaN returned by a float operation, given the list of inputs. + /// (This is all inputs, not just NaN inputs!) + fn generate_nan(_ecx: &InterpCx<'mir, 'tcx, Self>, _inputs: &[F]) -> F { + // By default we always return the preferred NaN. + F::NAN + } + /// Called before writing the specified `local` of the `frame`. /// Since writing a ZST is not actually accessing memory or locals, this is never invoked /// for ZST reads. diff --git a/compiler/rustc_const_eval/src/interpret/operator.rs b/compiler/rustc_const_eval/src/interpret/operator.rs index b084864f3a730..fe8572d9c6fe6 100644 --- a/compiler/rustc_const_eval/src/interpret/operator.rs +++ b/compiler/rustc_const_eval/src/interpret/operator.rs @@ -113,6 +113,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { ) -> (ImmTy<'tcx, M::Provenance>, bool) { use rustc_middle::mir::BinOp::*; + // Performs appropriate non-deterministic adjustments of NaN results. + let adjust_nan = |f: F| -> F { + if f.is_nan() { M::generate_nan(self, &[l, r]) } else { f } + }; + let val = match bin_op { Eq => ImmTy::from_bool(l == r, *self.tcx), Ne => ImmTy::from_bool(l != r, *self.tcx), @@ -120,11 +125,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { Le => ImmTy::from_bool(l <= r, *self.tcx), Gt => ImmTy::from_bool(l > r, *self.tcx), Ge => ImmTy::from_bool(l >= r, *self.tcx), - Add => ImmTy::from_scalar((l + r).value.into(), layout), - Sub => ImmTy::from_scalar((l - r).value.into(), layout), - Mul => ImmTy::from_scalar((l * r).value.into(), layout), - Div => ImmTy::from_scalar((l / r).value.into(), layout), - Rem => ImmTy::from_scalar((l % r).value.into(), layout), + Add => ImmTy::from_scalar(adjust_nan((l + r).value).into(), layout), + Sub => ImmTy::from_scalar(adjust_nan((l - r).value).into(), layout), + Mul => ImmTy::from_scalar(adjust_nan((l * r).value).into(), layout), + Div => ImmTy::from_scalar(adjust_nan((l / r).value).into(), layout), + Rem => ImmTy::from_scalar(adjust_nan((l % r).value).into(), layout), _ => span_bug!(self.cur_span(), "invalid float op: `{:?}`", bin_op), }; (val, false) diff --git a/src/tools/miri/src/machine.rs b/src/tools/miri/src/machine.rs index 930fa053d2091..d7177a4a1d202 100644 --- a/src/tools/miri/src/machine.rs +++ b/src/tools/miri/src/machine.rs @@ -1001,6 +1001,11 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> { ecx.binary_ptr_op(bin_op, left, right) } + #[inline(always)] + fn generate_nan(ecx: &InterpCx<'mir, 'tcx, Self>, inputs: &[F]) -> F { + ecx.generate_nan(inputs) + } + fn thread_local_static_base_pointer( ecx: &mut MiriInterpCx<'mir, 'tcx>, def_id: DefId, diff --git a/src/tools/miri/src/operator.rs b/src/tools/miri/src/operator.rs index 1faf8f9fc1228..d655d1b0465e1 100644 --- a/src/tools/miri/src/operator.rs +++ b/src/tools/miri/src/operator.rs @@ -1,20 +1,16 @@ +use std::iter; + use log::trace; +use rand::{seq::IteratorRandom, Rng}; +use rustc_apfloat::Float; use rustc_middle::mir; use rustc_target::abi::Size; use crate::*; -pub trait EvalContextExt<'tcx> { - fn binary_ptr_op( - &self, - bin_op: mir::BinOp, - left: &ImmTy<'tcx, Provenance>, - right: &ImmTy<'tcx, Provenance>, - ) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)>; -} - -impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> { +impl<'mir, 'tcx: 'mir> EvalContextExt<'mir, 'tcx> for crate::MiriInterpCx<'mir, 'tcx> {} +pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { fn binary_ptr_op( &self, bin_op: mir::BinOp, @@ -23,12 +19,13 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> { ) -> InterpResult<'tcx, (ImmTy<'tcx, Provenance>, bool)> { use rustc_middle::mir::BinOp::*; + let this = self.eval_context_ref(); trace!("ptr_op: {:?} {:?} {:?}", *left, bin_op, *right); Ok(match bin_op { Eq | Ne | Lt | Le | Gt | Ge => { assert_eq!(left.layout.abi, right.layout.abi); // types an differ, e.g. fn ptrs with different `for` - let size = self.pointer_size(); + let size = this.pointer_size(); // Just compare the bits. ScalarPairs are compared lexicographically. // We thus always compare pairs and simply fill scalars up with 0. let left = match **left { @@ -50,7 +47,7 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> { Ge => left >= right, _ => bug!(), }; - (ImmTy::from_bool(res, *self.tcx), false) + (ImmTy::from_bool(res, *this.tcx), false) } // Some more operations are possible with atomics. @@ -58,26 +55,49 @@ impl<'mir, 'tcx> EvalContextExt<'tcx> for super::MiriInterpCx<'mir, 'tcx> { Add | Sub | BitOr | BitAnd | BitXor => { assert!(left.layout.ty.is_unsafe_ptr()); assert!(right.layout.ty.is_unsafe_ptr()); - let ptr = left.to_scalar().to_pointer(self)?; + let ptr = left.to_scalar().to_pointer(this)?; // We do the actual operation with usize-typed scalars. - let left = ImmTy::from_uint(ptr.addr().bytes(), self.machine.layouts.usize); + let left = ImmTy::from_uint(ptr.addr().bytes(), this.machine.layouts.usize); let right = ImmTy::from_uint( - right.to_scalar().to_target_usize(self)?, - self.machine.layouts.usize, + right.to_scalar().to_target_usize(this)?, + this.machine.layouts.usize, ); - let (result, overflowing) = self.overflowing_binary_op(bin_op, &left, &right)?; + let (result, overflowing) = this.overflowing_binary_op(bin_op, &left, &right)?; // Construct a new pointer with the provenance of `ptr` (the LHS). let result_ptr = Pointer::new( ptr.provenance, - Size::from_bytes(result.to_scalar().to_target_usize(self)?), + Size::from_bytes(result.to_scalar().to_target_usize(this)?), ); ( - ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, self), left.layout), + ImmTy::from_scalar(Scalar::from_maybe_pointer(result_ptr, this), left.layout), overflowing, ) } - _ => span_bug!(self.cur_span(), "Invalid operator on pointers: {:?}", bin_op), + _ => span_bug!(this.cur_span(), "Invalid operator on pointers: {:?}", bin_op), }) } + + fn generate_nan(&self, inputs: &[F]) -> F { + let this = self.eval_context_ref(); + let mut rand = this.machine.rng.borrow_mut(); + // Assemble an iterator of possible NaNs: preferred, unchanged propagation, quieting propagation. + let preferred_nan = F::qnan(Some(0)); + let nans = iter::once(preferred_nan) + .chain(inputs.iter().filter(|f| f.is_nan()).copied()) + .chain(inputs.iter().filter(|f| f.is_signaling()).map(|f| { + // Make it quiet, by setting the bit. We assume that `preferred_nan` + // only has bits set that all quiet NaNs need to have set. + F::from_bits(f.to_bits() | preferred_nan.to_bits()) + })); + // Pick one of the NaNs. + let nan = nans.choose(&mut *rand).unwrap(); + // Non-deterministically flip the sign. + if rand.gen() { + // This will properly flip even for NaN. + -nan + } else { + nan + } + } } diff --git a/src/tools/miri/tests/pass/float_nan.rs b/src/tools/miri/tests/pass/float_nan.rs new file mode 100644 index 0000000000000..8fa567aa1061b --- /dev/null +++ b/src/tools/miri/tests/pass/float_nan.rs @@ -0,0 +1,316 @@ +use std::collections::HashSet; +use std::fmt; +use std::hash::Hash; +use std::hint::black_box; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Sign { + Neg = 1, + Pos = 0, +} +use Sign::*; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum NaNKind { + Quiet = 1, + Signaling = 0, +} +use NaNKind::*; + +#[track_caller] +fn check_all_outcomes(expected: HashSet, generate: impl Fn() -> T) { + let mut seen = HashSet::new(); + // Let's give it 8x as many tries as we are expecting values. + let tries = expected.len() * 8; + for _ in 0..tries { + let val = generate(); + assert!(expected.contains(&val), "got an unexpected value: {val}"); + seen.insert(val); + } + // Let's see if we saw them all. + for val in expected { + if !seen.contains(&val) { + panic!("did not get value that should be possible: {val}"); + } + } +} + +// -- f32 support +#[repr(C)] +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +struct F32(u32); + +impl From for F32 { + fn from(x: f32) -> Self { + F32(x.to_bits()) + } +} + +/// Returns a value that is `ones` many 1-bits. +fn u32_ones(ones: u32) -> u32 { + assert!(ones <= 32); + if ones == 0 { + // `>>` by 32 doesn't actually shift. So inconsistent :( + return 0; + } + u32::MAX >> (32 - ones) +} + +const F32_SIGN_BIT: u32 = 32 - 1; // position of the sign bit +const F32_EXP: u32 = 8; // 8 bits of exponent +const F32_MANTISSA: u32 = F32_SIGN_BIT - F32_EXP; +const F32_NAN_PAYLOAD: u32 = F32_MANTISSA - 1; + +impl fmt::Display for F32 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Alaways show raw bits. + write!(f, "0x{:08x} ", self.0)?; + // Also show nice version. + let val = self.0; + let sign = val >> F32_SIGN_BIT; + let val = val & u32_ones(F32_SIGN_BIT); // mask away sign + let exp = val >> F32_MANTISSA; + let mantissa = val & u32_ones(F32_MANTISSA); + if exp == u32_ones(F32_EXP) { + // A NaN! Special printing. + let sign = if sign != 0 { Neg } else { Pos }; + let quiet = if (mantissa >> F32_NAN_PAYLOAD) != 0 { Quiet } else { Signaling }; + let payload = mantissa & u32_ones(F32_NAN_PAYLOAD); + write!(f, "(NaN: {:?}, {:?}, payload = {:#x})", sign, quiet, payload) + } else { + // Normal float value. + write!(f, "({})", f32::from_bits(self.0)) + } + } +} + +impl F32 { + fn nan(sign: Sign, kind: NaNKind, payload: u32) -> Self { + // Either the quiet bit must be set of the payload must be non-0; + // otherwise this is not a NaN but an infinity. + assert!(kind == Quiet || payload != 0); + // Payload must fit in 22 bits. + assert!(payload < (1 << F32_NAN_PAYLOAD)); + // Concatenate the bits (with a 22bit payload). + // Pattern: [negative] ++ [1]^8 ++ [quiet] ++ [payload] + let val = ((sign as u32) << F32_SIGN_BIT) + | (u32_ones(F32_EXP) << F32_MANTISSA) + | ((kind as u32) << F32_NAN_PAYLOAD) + | payload; + // Sanity check. + assert!(f32::from_bits(val).is_nan()); + // Done! + F32(val) + } + + fn as_f32(self) -> f32 { + black_box(f32::from_bits(self.0)) + } +} + +// -- f64 support +#[repr(C)] +#[derive(Copy, Clone, Eq, PartialEq, Hash)] +struct F64(u64); + +impl From for F64 { + fn from(x: f64) -> Self { + F64(x.to_bits()) + } +} + +/// Returns a value that is `ones` many 1-bits. +fn u64_ones(ones: u32) -> u64 { + assert!(ones <= 64); + if ones == 0 { + // `>>` by 32 doesn't actually shift. So inconsistent :( + return 0; + } + u64::MAX >> (64 - ones) +} + +const F64_SIGN_BIT: u32 = 64 - 1; // position of the sign bit +const F64_EXP: u32 = 11; // 11 bits of exponent +const F64_MANTISSA: u32 = F64_SIGN_BIT - F64_EXP; +const F64_NAN_PAYLOAD: u32 = F64_MANTISSA - 1; + +impl fmt::Display for F64 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Alaways show raw bits. + write!(f, "0x{:08x} ", self.0)?; + // Also show nice version. + let val = self.0; + let sign = val >> F64_SIGN_BIT; + let val = val & u64_ones(F64_SIGN_BIT); // mask away sign + let exp = val >> F64_MANTISSA; + let mantissa = val & u64_ones(F64_MANTISSA); + if exp == u64_ones(F64_EXP) { + // A NaN! Special printing. + let sign = if sign != 0 { Neg } else { Pos }; + let quiet = if (mantissa >> F64_NAN_PAYLOAD) != 0 { Quiet } else { Signaling }; + let payload = mantissa & u64_ones(F64_NAN_PAYLOAD); + write!(f, "(NaN: {:?}, {:?}, payload = {:#x})", sign, quiet, payload) + } else { + // Normal float value. + write!(f, "({})", f64::from_bits(self.0)) + } + } +} + +impl F64 { + fn nan(sign: Sign, kind: NaNKind, payload: u64) -> Self { + // Either the quiet bit must be set of the payload must be non-0; + // otherwise this is not a NaN but an infinity. + assert!(kind == Quiet || payload != 0); + // Payload must fit in 52 bits. + assert!(payload < (1 << F64_NAN_PAYLOAD)); + // Concatenate the bits (with a 52bit payload). + // Pattern: [negative] ++ [1]^11 ++ [quiet] ++ [payload] + let val = ((sign as u64) << F64_SIGN_BIT) + | (u64_ones(F64_EXP) << F64_MANTISSA) + | ((kind as u64) << F64_NAN_PAYLOAD) + | payload; + // Sanity check. + assert!(f64::from_bits(val).is_nan()); + // Done! + F64(val) + } + + fn as_f64(self) -> f64 { + black_box(f64::from_bits(self.0)) + } +} + +// -- actual tests + +fn test_f32() { + // Freshly generated NaNs can have either sign. + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(0.0 / black_box(0.0)), + ); + // When there are NaN inputs, their payload can be propagated, with any sign. + let all1_payload = u32_ones(22); + let all1 = F32::nan(Pos, Quiet, all1_payload).as_f32(); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, all1_payload), + F32::nan(Neg, Quiet, all1_payload), + ]), + || F32::from(0.0 + all1), + ); + // When there are two NaN inputs, the output can be either one, or the preferred NaN. + let just1 = F32::nan(Neg, Quiet, 1).as_f32(); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, 1), + F32::nan(Neg, Quiet, 1), + F32::nan(Pos, Quiet, all1_payload), + F32::nan(Neg, Quiet, all1_payload), + ]), + || F32::from(just1 - all1), + ); + // When there are *signaling* NaN inputs, they might be quieted or not. + let all1_snan = F32::nan(Pos, Signaling, all1_payload).as_f32(); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, all1_payload), + F32::nan(Neg, Quiet, all1_payload), + F32::nan(Pos, Signaling, all1_payload), + F32::nan(Neg, Signaling, all1_payload), + ]), + || F32::from(0.0 * all1_snan), + ); + // Mix signaling and non-signaling NaN. + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, 1), + F32::nan(Neg, Quiet, 1), + F32::nan(Pos, Quiet, all1_payload), + F32::nan(Neg, Quiet, all1_payload), + F32::nan(Pos, Signaling, all1_payload), + F32::nan(Neg, Signaling, all1_payload), + ]), + || F32::from(just1 % all1_snan), + ); +} + +fn test_f64() { + // Freshly generated NaNs can have either sign. + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(0.0 / black_box(0.0)), + ); + // When there are NaN inputs, their payload can be propagated, with any sign. + let all1_payload = u64_ones(51); + let all1 = F64::nan(Pos, Quiet, all1_payload).as_f64(); + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, all1_payload), + F64::nan(Neg, Quiet, all1_payload), + ]), + || F64::from(0.0 + all1), + ); + // When there are two NaN inputs, the output can be either one, or the preferred NaN. + let just1 = F64::nan(Neg, Quiet, 1).as_f64(); + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, 1), + F64::nan(Neg, Quiet, 1), + F64::nan(Pos, Quiet, all1_payload), + F64::nan(Neg, Quiet, all1_payload), + ]), + || F64::from(just1 - all1), + ); + // When there are *signaling* NaN inputs, they might be quieted or not. + let all1_snan = F64::nan(Pos, Signaling, all1_payload).as_f64(); + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, all1_payload), + F64::nan(Neg, Quiet, all1_payload), + F64::nan(Pos, Signaling, all1_payload), + F64::nan(Neg, Signaling, all1_payload), + ]), + || F64::from(0.0 * all1_snan), + ); + // Mix signaling and non-signaling NaN. + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, 1), + F64::nan(Neg, Quiet, 1), + F64::nan(Pos, Quiet, all1_payload), + F64::nan(Neg, Quiet, all1_payload), + F64::nan(Pos, Signaling, all1_payload), + F64::nan(Neg, Signaling, all1_payload), + ]), + || F64::from(just1 % all1_snan), + ); +} + +fn main() { + // Check our constants against std, just to be sure. + // We add 1 since our numbers are the number of bits stored + // to represent the value, and std has the precision of the value, + // which is one more due to the implicit leading 1. + assert_eq!(F32_MANTISSA + 1, f32::MANTISSA_DIGITS); + assert_eq!(F64_MANTISSA + 1, f64::MANTISSA_DIGITS); + + test_f32(); + test_f64(); +} From 615d738abea23715f649e51cf27112451a0f607b Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Sun, 8 Oct 2023 20:36:25 +0200 Subject: [PATCH 2/3] ensure unary minus propagates NaN payloads exactly --- compiler/rustc_const_eval/src/interpret/operator.rs | 1 + src/tools/miri/tests/pass/float_nan.rs | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/compiler/rustc_const_eval/src/interpret/operator.rs b/compiler/rustc_const_eval/src/interpret/operator.rs index fe8572d9c6fe6..a685b499d0e9b 100644 --- a/compiler/rustc_const_eval/src/interpret/operator.rs +++ b/compiler/rustc_const_eval/src/interpret/operator.rs @@ -461,6 +461,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { Ok((ImmTy::from_bool(res, *self.tcx), false)) } ty::Float(fty) => { + // No NaN adjustment here, `-` is a bitwise operation! let res = match (un_op, fty) { (Neg, FloatTy::F32) => Scalar::from_f32(-val.to_f32()?), (Neg, FloatTy::F64) => Scalar::from_f64(-val.to_f64()?), diff --git a/src/tools/miri/tests/pass/float_nan.rs b/src/tools/miri/tests/pass/float_nan.rs index 8fa567aa1061b..cb9cc42c60184 100644 --- a/src/tools/miri/tests/pass/float_nan.rs +++ b/src/tools/miri/tests/pass/float_nan.rs @@ -241,6 +241,14 @@ fn test_f32() { ]), || F32::from(just1 % all1_snan), ); + + // Unary `-` must preserve payloads exactly. + check_all_outcomes(HashSet::from_iter([F32::nan(Neg, Quiet, all1_payload)]), || { + F32::from(-all1) + }); + check_all_outcomes(HashSet::from_iter([F32::nan(Neg, Signaling, all1_payload)]), || { + F32::from(-all1_snan) + }); } fn test_f64() { From 08deb0daed9f4517794e861e1fd3b9621668d560 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Mon, 9 Oct 2023 07:38:00 +0200 Subject: [PATCH 3/3] float-to-float casts also have non-deterministic NaN results --- .../rustc_const_eval/src/interpret/cast.rs | 23 ++++- .../rustc_const_eval/src/interpret/machine.rs | 9 +- .../src/interpret/operator.rs | 4 +- src/tools/miri/src/machine.rs | 5 +- src/tools/miri/src/operator.rs | 36 ++++++-- src/tools/miri/tests/pass/float_nan.rs | 90 +++++++++++++++++++ 6 files changed, 150 insertions(+), 17 deletions(-) diff --git a/compiler/rustc_const_eval/src/interpret/cast.rs b/compiler/rustc_const_eval/src/interpret/cast.rs index b9f88cf635271..b9557eaf6abbf 100644 --- a/compiler/rustc_const_eval/src/interpret/cast.rs +++ b/compiler/rustc_const_eval/src/interpret/cast.rs @@ -311,6 +311,21 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { F: Float + Into> + FloatConvert + FloatConvert, { use rustc_type_ir::sty::TyKind::*; + + fn adjust_nan< + 'mir, + 'tcx: 'mir, + M: Machine<'mir, 'tcx>, + F1: rustc_apfloat::Float + FloatConvert, + F2: rustc_apfloat::Float, + >( + ecx: &InterpCx<'mir, 'tcx, M>, + f1: F1, + f2: F2, + ) -> F2 { + if f2.is_nan() { M::generate_nan(ecx, &[f1]) } else { f2 } + } + match *dest_ty.kind() { // float -> uint Uint(t) => { @@ -330,9 +345,13 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { Scalar::from_int(v, size) } // float -> f32 - Float(FloatTy::F32) => Scalar::from_f32(f.convert(&mut false).value), + Float(FloatTy::F32) => { + Scalar::from_f32(adjust_nan(self, f, f.convert(&mut false).value)) + } // float -> f64 - Float(FloatTy::F64) => Scalar::from_f64(f.convert(&mut false).value), + Float(FloatTy::F64) => { + Scalar::from_f64(adjust_nan(self, f, f.convert(&mut false).value)) + } // That's it. _ => span_bug!(self.cur_span(), "invalid float to {} cast", dest_ty), } diff --git a/compiler/rustc_const_eval/src/interpret/machine.rs b/compiler/rustc_const_eval/src/interpret/machine.rs index b172fd9f51774..b615ced6c7690 100644 --- a/compiler/rustc_const_eval/src/interpret/machine.rs +++ b/compiler/rustc_const_eval/src/interpret/machine.rs @@ -6,7 +6,7 @@ use std::borrow::{Borrow, Cow}; use std::fmt::Debug; use std::hash::Hash; -use rustc_apfloat::Float; +use rustc_apfloat::{Float, FloatConvert}; use rustc_ast::{InlineAsmOptions, InlineAsmTemplatePiece}; use rustc_middle::mir; use rustc_middle::ty::layout::TyAndLayout; @@ -243,9 +243,12 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized { /// Generate the NaN returned by a float operation, given the list of inputs. /// (This is all inputs, not just NaN inputs!) - fn generate_nan(_ecx: &InterpCx<'mir, 'tcx, Self>, _inputs: &[F]) -> F { + fn generate_nan, F2: Float>( + _ecx: &InterpCx<'mir, 'tcx, Self>, + _inputs: &[F1], + ) -> F2 { // By default we always return the preferred NaN. - F::NAN + F2::NAN } /// Called before writing the specified `local` of the `frame`. diff --git a/compiler/rustc_const_eval/src/interpret/operator.rs b/compiler/rustc_const_eval/src/interpret/operator.rs index a685b499d0e9b..53e1756d897aa 100644 --- a/compiler/rustc_const_eval/src/interpret/operator.rs +++ b/compiler/rustc_const_eval/src/interpret/operator.rs @@ -1,4 +1,4 @@ -use rustc_apfloat::Float; +use rustc_apfloat::{Float, FloatConvert}; use rustc_middle::mir; use rustc_middle::mir::interpret::{InterpResult, Scalar}; use rustc_middle::ty::layout::TyAndLayout; @@ -104,7 +104,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { (ImmTy::from_bool(res, *self.tcx), false) } - fn binary_float_op>>( + fn binary_float_op + Into>>( &self, bin_op: mir::BinOp, layout: TyAndLayout<'tcx>, diff --git a/src/tools/miri/src/machine.rs b/src/tools/miri/src/machine.rs index d7177a4a1d202..3de27460860c9 100644 --- a/src/tools/miri/src/machine.rs +++ b/src/tools/miri/src/machine.rs @@ -1002,7 +1002,10 @@ impl<'mir, 'tcx> Machine<'mir, 'tcx> for MiriMachine<'mir, 'tcx> { } #[inline(always)] - fn generate_nan(ecx: &InterpCx<'mir, 'tcx, Self>, inputs: &[F]) -> F { + fn generate_nan, F2: rustc_apfloat::Float>( + ecx: &InterpCx<'mir, 'tcx, Self>, + inputs: &[F1], + ) -> F2 { ecx.generate_nan(inputs) } diff --git a/src/tools/miri/src/operator.rs b/src/tools/miri/src/operator.rs index d655d1b0465e1..e5a437f95f0ea 100644 --- a/src/tools/miri/src/operator.rs +++ b/src/tools/miri/src/operator.rs @@ -3,7 +3,7 @@ use std::iter; use log::trace; use rand::{seq::IteratorRandom, Rng}; -use rustc_apfloat::Float; +use rustc_apfloat::{Float, FloatConvert}; use rustc_middle::mir; use rustc_target::abi::Size; @@ -78,17 +78,35 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { }) } - fn generate_nan(&self, inputs: &[F]) -> F { + fn generate_nan, F2: Float>(&self, inputs: &[F1]) -> F2 { + /// Make the given NaN a signaling NaN. + /// Returns `None` if this would not result in a NaN. + fn make_signaling(f: F) -> Option { + // The quiet/signaling bit is the leftmost bit in the mantissa. + // That's position `PRECISION-1`, since `PRECISION` includes the fixed leading 1 bit, + // and then we subtract 1 more since this is 0-indexed. + let quiet_bit_mask = 1 << (F::PRECISION - 2); + // Unset the bit. Double-check that this wasn't the last bit set in the payload. + // (which would turn the NaN into an infinity). + let f = F::from_bits(f.to_bits() & !quiet_bit_mask); + if f.is_nan() { Some(f) } else { None } + } + let this = self.eval_context_ref(); let mut rand = this.machine.rng.borrow_mut(); - // Assemble an iterator of possible NaNs: preferred, unchanged propagation, quieting propagation. - let preferred_nan = F::qnan(Some(0)); + // Assemble an iterator of possible NaNs: preferred, quieting propagation, unchanged propagation. + // On some targets there are more possibilities; for now we just generate those options that + // are possible everywhere. + let preferred_nan = F2::qnan(Some(0)); let nans = iter::once(preferred_nan) - .chain(inputs.iter().filter(|f| f.is_nan()).copied()) - .chain(inputs.iter().filter(|f| f.is_signaling()).map(|f| { - // Make it quiet, by setting the bit. We assume that `preferred_nan` - // only has bits set that all quiet NaNs need to have set. - F::from_bits(f.to_bits() | preferred_nan.to_bits()) + .chain(inputs.iter().filter(|f| f.is_nan()).map(|&f| { + // Regular apfloat cast is quieting. + f.convert(&mut false).value + })) + .chain(inputs.iter().filter(|f| f.is_signaling()).filter_map(|&f| { + let f: F2 = f.convert(&mut false).value; + // We have to de-quiet this again for unchanged propagation. + make_signaling(f) })); // Pick one of the NaNs. let nan = nans.choose(&mut *rand).unwrap(); diff --git a/src/tools/miri/tests/pass/float_nan.rs b/src/tools/miri/tests/pass/float_nan.rs index cb9cc42c60184..698aa447e266a 100644 --- a/src/tools/miri/tests/pass/float_nan.rs +++ b/src/tools/miri/tests/pass/float_nan.rs @@ -311,6 +311,95 @@ fn test_f64() { ); } +fn test_casts() { + let all1_payload_32 = u32_ones(22); + let all1_payload_64 = u64_ones(51); + let left1_payload_64 = (all1_payload_32 as u64) << (51 - 22); + + // 64-to-32 + check_all_outcomes( + HashSet::from_iter([F32::nan(Pos, Quiet, 0), F32::nan(Neg, Quiet, 0)]), + || F32::from(F64::nan(Pos, Quiet, 0).as_f64() as f32), + ); + // The preferred payload is always a possibility. + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, all1_payload_32), + F32::nan(Neg, Quiet, all1_payload_32), + ]), + || F32::from(F64::nan(Pos, Quiet, all1_payload_64).as_f64() as f32), + ); + // If the input is signaling, then the output *may* also be signaling. + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, all1_payload_32), + F32::nan(Neg, Quiet, all1_payload_32), + F32::nan(Pos, Signaling, all1_payload_32), + F32::nan(Neg, Signaling, all1_payload_32), + ]), + || F32::from(F64::nan(Pos, Signaling, all1_payload_64).as_f64() as f32), + ); + // Check that the low bits are gone (not the high bits). + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + ]), + || F32::from(F64::nan(Pos, Quiet, 1).as_f64() as f32), + ); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + F32::nan(Pos, Quiet, 1), + F32::nan(Neg, Quiet, 1), + ]), + || F32::from(F64::nan(Pos, Quiet, 1 << (51-22)).as_f64() as f32), + ); + check_all_outcomes( + HashSet::from_iter([ + F32::nan(Pos, Quiet, 0), + F32::nan(Neg, Quiet, 0), + // The `1` payload becomes `0`, and the `0` payload cannot be signaling, + // so these are the only options. + ]), + || F32::from(F64::nan(Pos, Signaling, 1).as_f64() as f32), + ); + + // 32-to-64 + check_all_outcomes( + HashSet::from_iter([F64::nan(Pos, Quiet, 0), F64::nan(Neg, Quiet, 0)]), + || F64::from(F32::nan(Pos, Quiet, 0).as_f32() as f64), + ); + // The preferred payload is always a possibility. + // Also checks that 0s are added on the right. + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, left1_payload_64), + F64::nan(Neg, Quiet, left1_payload_64), + ]), + || F64::from(F32::nan(Pos, Quiet, all1_payload_32).as_f32() as f64), + ); + // If the input is signaling, then the output *may* also be signaling. + check_all_outcomes( + HashSet::from_iter([ + F64::nan(Pos, Quiet, 0), + F64::nan(Neg, Quiet, 0), + F64::nan(Pos, Quiet, left1_payload_64), + F64::nan(Neg, Quiet, left1_payload_64), + F64::nan(Pos, Signaling, left1_payload_64), + F64::nan(Neg, Signaling, left1_payload_64), + ]), + || F64::from(F32::nan(Pos, Signaling, all1_payload_32).as_f32() as f64), + ); +} + fn main() { // Check our constants against std, just to be sure. // We add 1 since our numbers are the number of bits stored @@ -321,4 +410,5 @@ fn main() { test_f32(); test_f64(); + test_casts(); }