Skip to content

Commit

Permalink
Improve the floating point parser in dec2flt.
Browse files Browse the repository at this point in the history
* Remove all remaining traces of unsafe.
* Put `parse_8digits` inside a loop.
* Rework parsing of inf/NaN values.
  • Loading branch information
TDecking committed Apr 9, 2023
1 parent 39bf777 commit 0f96c71
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 288 deletions.
178 changes: 30 additions & 148 deletions library/core/src/num/dec2flt/common.rs
Original file line number Diff line number Diff line change
@@ -1,165 +1,60 @@
//! Common utilities, for internal use only.

use crate::ptr;

/// Helper methods to process immutable bytes.
pub(crate) trait ByteSlice: AsRef<[u8]> {
unsafe fn first_unchecked(&self) -> u8 {
debug_assert!(!self.is_empty());
// SAFETY: safe as long as self is not empty
unsafe { *self.as_ref().get_unchecked(0) }
}

/// Get if the slice contains no elements.
fn is_empty(&self) -> bool {
self.as_ref().is_empty()
}

/// Check if the slice at least `n` length.
fn check_len(&self, n: usize) -> bool {
n <= self.as_ref().len()
}

/// Check if the first character in the slice is equal to c.
fn first_is(&self, c: u8) -> bool {
self.as_ref().first() == Some(&c)
}

/// Check if the first character in the slice is equal to c1 or c2.
fn first_is2(&self, c1: u8, c2: u8) -> bool {
if let Some(&c) = self.as_ref().first() { c == c1 || c == c2 } else { false }
}

/// Bounds-checked test if the first character in the slice is a digit.
fn first_isdigit(&self) -> bool {
if let Some(&c) = self.as_ref().first() { c.is_ascii_digit() } else { false }
}

/// Check if self starts with u with a case-insensitive comparison.
fn starts_with_ignore_case(&self, u: &[u8]) -> bool {
debug_assert!(self.as_ref().len() >= u.len());
let iter = self.as_ref().iter().zip(u.iter());
let d = iter.fold(0, |i, (&x, &y)| i | (x ^ y));
d == 0 || d == 32
}

/// Get the remaining slice after the first N elements.
fn advance(&self, n: usize) -> &[u8] {
&self.as_ref()[n..]
}

/// Get the slice after skipping all leading characters equal c.
fn skip_chars(&self, c: u8) -> &[u8] {
let mut s = self.as_ref();
while s.first_is(c) {
s = s.advance(1);
}
s
}

/// Get the slice after skipping all leading characters equal c1 or c2.
fn skip_chars2(&self, c1: u8, c2: u8) -> &[u8] {
let mut s = self.as_ref();
while s.first_is2(c1, c2) {
s = s.advance(1);
}
s
}

pub(crate) trait ByteSlice {
/// Read 8 bytes as a 64-bit integer in little-endian order.
unsafe fn read_u64_unchecked(&self) -> u64 {
debug_assert!(self.check_len(8));
let src = self.as_ref().as_ptr() as *const u64;
// SAFETY: safe as long as self is at least 8 bytes
u64::from_le(unsafe { ptr::read_unaligned(src) })
}
fn read_u64(&self) -> u64;

/// Try to read the next 8 bytes from the slice.
fn read_u64(&self) -> Option<u64> {
if self.check_len(8) {
// SAFETY: self must be at least 8 bytes.
Some(unsafe { self.read_u64_unchecked() })
} else {
None
}
}

/// Calculate the offset of slice from another.
fn offset_from(&self, other: &Self) -> isize {
other.as_ref().len() as isize - self.as_ref().len() as isize
}
}

impl ByteSlice for [u8] {}

/// Helper methods to process mutable bytes.
pub(crate) trait ByteSliceMut: AsMut<[u8]> {
/// Write a 64-bit integer as 8 bytes in little-endian order.
unsafe fn write_u64_unchecked(&mut self, value: u64) {
debug_assert!(self.as_mut().len() >= 8);
let dst = self.as_mut().as_mut_ptr() as *mut u64;
// NOTE: we must use `write_unaligned`, since dst is not
// guaranteed to be properly aligned. Miri will warn us
// if we use `write` instead of `write_unaligned`, as expected.
// SAFETY: safe as long as self is at least 8 bytes
unsafe {
ptr::write_unaligned(dst, u64::to_le(value));
}
}
}
fn write_u64(&mut self, value: u64);

impl ByteSliceMut for [u8] {}
/// Calculate the offset of a slice from another.
fn offset_from(&self, other: &Self) -> isize;

/// Bytes wrapper with specialized methods for ASCII characters.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct AsciiStr<'a> {
slc: &'a [u8],
/// Iteratively parse and consume digits from bytes.
/// Returns the same bytes with consumed digits being
/// elided.
fn parse_digits(&self, func: impl FnMut(u8)) -> &Self;
}

impl<'a> AsciiStr<'a> {
pub fn new(slc: &'a [u8]) -> Self {
Self { slc }
impl ByteSlice for [u8] {
#[inline(always)] // inlining this is crucial to remove bound checks
fn read_u64(&self) -> u64 {
let mut tmp = [0; 8];
tmp.copy_from_slice(&self[..8]);
u64::from_le_bytes(tmp)
}

/// Advance the view by n, advancing it in-place to (n..).
pub unsafe fn step_by(&mut self, n: usize) -> &mut Self {
// SAFETY: safe as long n is less than the buffer length
self.slc = unsafe { self.slc.get_unchecked(n..) };
self
#[inline(always)] // inlining this is crucial to remove bound checks
fn write_u64(&mut self, value: u64) {
self[..8].copy_from_slice(&value.to_le_bytes())
}

/// Advance the view by n, advancing it in-place to (1..).
pub unsafe fn step(&mut self) -> &mut Self {
// SAFETY: safe as long as self is not empty
unsafe { self.step_by(1) }
#[inline]
fn offset_from(&self, other: &Self) -> isize {
other.len() as isize - self.len() as isize
}

/// Iteratively parse and consume digits from bytes.
pub fn parse_digits(&mut self, mut func: impl FnMut(u8)) {
while let Some(&c) = self.as_ref().first() {
#[inline]
fn parse_digits(&self, mut func: impl FnMut(u8)) -> &Self {
let mut s = self;

// FIXME: Can't use s.split_first() here yet,
// see https://github.com/rust-lang/rust/issues/109328
while let [c, s_next @ ..] = s {
let c = c.wrapping_sub(b'0');
if c < 10 {
func(c);
// SAFETY: self cannot be empty
unsafe {
self.step();
}
s = s_next;
} else {
break;
}
}
}
}

impl<'a> AsRef<[u8]> for AsciiStr<'a> {
#[inline]
fn as_ref(&self) -> &[u8] {
self.slc
s
}
}

impl<'a> ByteSlice for AsciiStr<'a> {}

/// Determine if 8 bytes are all decimal digits.
/// This does not care about the order in which the bytes were loaded.
pub(crate) fn is_8digits(v: u64) -> bool {
Expand All @@ -168,19 +63,6 @@ pub(crate) fn is_8digits(v: u64) -> bool {
(a | b) & 0x8080_8080_8080_8080 == 0
}

/// Iteratively parse and consume digits from bytes.
pub(crate) fn parse_digits(s: &mut &[u8], mut f: impl FnMut(u8)) {
while let Some(&c) = s.get(0) {
let c = c.wrapping_sub(b'0');
if c < 10 {
f(c);
*s = s.advance(1);
} else {
break;
}
}
}

/// A custom 64-bit floating point type, representing `f * 2^e`.
/// e is biased, so it be directly shifted into the exponent bits.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
Expand Down
65 changes: 36 additions & 29 deletions library/core/src/num/dec2flt/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! algorithm can be found in "ParseNumberF64 by Simple Decimal Conversion",
//! available online: <https://nigeltao.github.io/blog/2020/parse-number-f64-simple.html>.

use crate::num::dec2flt::common::{is_8digits, parse_digits, ByteSlice, ByteSliceMut};
use crate::num::dec2flt::common::{is_8digits, ByteSlice};

#[derive(Clone)]
pub struct Decimal {
Expand Down Expand Up @@ -205,29 +205,32 @@ impl Decimal {
pub fn parse_decimal(mut s: &[u8]) -> Decimal {
let mut d = Decimal::default();
let start = s;
s = s.skip_chars(b'0');
parse_digits(&mut s, |digit| d.try_add_digit(digit));
if s.first_is(b'.') {
s = s.advance(1);

while let Some((&b'0', s_next)) = s.split_first() {
s = s_next;
}

s = s.parse_digits(|digit| d.try_add_digit(digit));

if let Some((b'.', s_next)) = s.split_first() {
s = s_next;
let first = s;
// Skip leading zeros.
if d.num_digits == 0 {
s = s.skip_chars(b'0');
while let Some((&b'0', s_next)) = s.split_first() {
s = s_next;
}
}
while s.len() >= 8 && d.num_digits + 8 < Decimal::MAX_DIGITS {
// SAFETY: s is at least 8 bytes.
let v = unsafe { s.read_u64_unchecked() };
let v = s.read_u64();
if !is_8digits(v) {
break;
}
// SAFETY: d.num_digits + 8 is less than d.digits.len()
unsafe {
d.digits[d.num_digits..].write_u64_unchecked(v - 0x3030_3030_3030_3030);
}
d.digits[d.num_digits..].write_u64(v - 0x3030_3030_3030_3030);
d.num_digits += 8;
s = s.advance(8);
s = &s[8..];
}
parse_digits(&mut s, |digit| d.try_add_digit(digit));
s = s.parse_digits(|digit| d.try_add_digit(digit));
d.decimal_point = s.len() as i32 - first.len() as i32;
}
if d.num_digits != 0 {
Expand All @@ -248,22 +251,26 @@ pub fn parse_decimal(mut s: &[u8]) -> Decimal {
d.num_digits = Decimal::MAX_DIGITS;
}
}
if s.first_is2(b'e', b'E') {
s = s.advance(1);
let mut neg_exp = false;
if s.first_is(b'-') {
neg_exp = true;
s = s.advance(1);
} else if s.first_is(b'+') {
s = s.advance(1);
}
let mut exp_num = 0_i32;
parse_digits(&mut s, |digit| {
if exp_num < 0x10000 {
exp_num = 10 * exp_num + digit as i32;
if let Some((&ch, s_next)) = s.split_first() {
if ch == b'e' || ch == b'E' {
s = s_next;
let mut neg_exp = false;
if let Some((&ch, s_next)) = s.split_first() {
neg_exp = ch == b'-';
if ch == b'-' || ch == b'+' {
s = s_next;
}
}
});
d.decimal_point += if neg_exp { -exp_num } else { exp_num };
let mut exp_num = 0_i32;

s.parse_digits(|digit| {
if exp_num < 0x10000 {
exp_num = 10 * exp_num + digit as i32;
}
});

d.decimal_point += if neg_exp { -exp_num } else { exp_num };
}
}
for i in d.num_digits..Decimal::MAX_DIGITS_WITHOUT_OVERFLOW {
d.digits[i] = 0;
Expand Down
7 changes: 4 additions & 3 deletions library/core/src/num/dec2flt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ use crate::error::Error;
use crate::fmt;
use crate::str::FromStr;

use self::common::{BiasedFp, ByteSlice};
use self::common::BiasedFp;
use self::float::RawFloat;
use self::lemire::compute_float;
use self::parse::{parse_inf_nan, parse_number};
Expand Down Expand Up @@ -238,17 +238,18 @@ pub fn dec2flt<F: RawFloat>(s: &str) -> Result<F, ParseFloatError> {
};
let negative = c == b'-';
if c == b'-' || c == b'+' {
s = s.advance(1);
s = &s[1..];
}
if s.is_empty() {
return Err(pfe_invalid());
}

let num = match parse_number(s, negative) {
let mut num = match parse_number(s) {
Some(r) => r,
None if let Some(value) = parse_inf_nan(s, negative) => return Ok(value),
None => return Err(pfe_invalid()),
};
num.negative = negative;
if let Some(value) = num.try_fast_path::<F>() {
return Ok(value);
}
Expand Down
Loading

0 comments on commit 0f96c71

Please sign in to comment.