Skip to content

Commit

Permalink
Disable jump threading of float equality
Browse files Browse the repository at this point in the history
Jump threading stores values as `u128` (`ScalarInt`) and does its
comparisons for equality as integer comparisons.
This works great for integers. Sadly, not everything is an integer.

Floats famously have wonky equality semantcs, with `NaN!=NaN` and
`0.0 == -0.0`. This does not match our beautiful integer bitpattern
equality and therefore causes things to go horribly wrong.

While jump threading could be extended to support floats by remembering
that they're floats in the value state and handling them properly,
it's signficantly easier to just disable it for now.
  • Loading branch information
Noratrieb committed Jul 27, 2024
1 parent 3942254 commit f305e18
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 0 deletions.
7 changes: 7 additions & 0 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
if value.const_.ty().is_floating_point() {
// Floating point equality does not follow bit-patterns.
// -0.0 and NaN both have special rules for equality,
// and therefore we cannot use integer comparisons for them.
// Avoid handling them, though this could be extended in the future.
return None;
}
let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
let conds = conditions.map(self.arena, |c| Condition {
value,
Expand Down
59 changes: 59 additions & 0 deletions tests/mir-opt/jump_threading.floats.JumpThreading.panic-abort.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
- // MIR for `floats` before JumpThreading
+ // MIR for `floats` after JumpThreading

fn floats() -> u32 {
let mut _0: u32;
let _1: f64;
let mut _2: bool;
let mut _3: bool;
let mut _4: f64;
scope 1 {
debug x => _1;
}

bb0: {
StorageLive(_1);
StorageLive(_2);
_2 = const true;
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ goto -> bb1;
}

bb1: {
_1 = const -0f64;
goto -> bb3;
}

bb2: {
_1 = const 1f64;
goto -> bb3;
}

bb3: {
StorageDead(_2);
StorageLive(_3);
StorageLive(_4);
_4 = _1;
_3 = Eq(move _4, const 0f64);
switchInt(move _3) -> [0: bb5, otherwise: bb4];
}

bb4: {
StorageDead(_4);
_0 = const 0_u32;
goto -> bb6;
}

bb5: {
StorageDead(_4);
_0 = const 1_u32;
goto -> bb6;
}

bb6: {
StorageDead(_3);
StorageDead(_1);
return;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
- // MIR for `floats` before JumpThreading
+ // MIR for `floats` after JumpThreading

fn floats() -> u32 {
let mut _0: u32;
let _1: f64;
let mut _2: bool;
let mut _3: bool;
let mut _4: f64;
scope 1 {
debug x => _1;
}

bb0: {
StorageLive(_1);
StorageLive(_2);
_2 = const true;
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ goto -> bb1;
}

bb1: {
_1 = const -0f64;
goto -> bb3;
}

bb2: {
_1 = const 1f64;
goto -> bb3;
}

bb3: {
StorageDead(_2);
StorageLive(_3);
StorageLive(_4);
_4 = _1;
_3 = Eq(move _4, const 0f64);
switchInt(move _3) -> [0: bb5, otherwise: bb4];
}

bb4: {
StorageDead(_4);
_0 = const 0_u32;
goto -> bb6;
}

bb5: {
StorageDead(_4);
_0 = const 1_u32;
goto -> bb6;
}

bb6: {
StorageDead(_3);
StorageDead(_1);
return;
}
}

12 changes: 12 additions & 0 deletions tests/mir-opt/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,16 @@ fn aggregate_copy() -> u32 {
if c == 2 { b.0 } else { 13 }
}

fn floats() -> u32 {
// CHECK-LABEL: fn floats(
// CHECK: switchInt(

// Test for issue #128243, where float equality was assumed to be bitwise.
// When adding float support, it must be ensured that this continues working properly.
let x = if true { -0.0 } else { 1.0 };
if x == 0.0 { 0 } else { 1 }
}

fn main() {
// CHECK-LABEL: fn main(
too_complex(Ok(0));
Expand All @@ -535,6 +545,7 @@ fn main() {
disappearing_bb(7);
aggregate(7);
assume(7, false);
floats();
}

// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
Expand All @@ -550,3 +561,4 @@ fn main() {
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff
// EMIT_MIR jump_threading.assume.JumpThreading.diff
// EMIT_MIR jump_threading.aggregate_copy.JumpThreading.diff
// EMIT_MIR jump_threading.floats.JumpThreading.diff

0 comments on commit f305e18

Please sign in to comment.