Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

Commit

Permalink
Implement FromStr for serde_yaml::Number
Browse files Browse the repository at this point in the history
  • Loading branch information
dtolnay committed Jul 17, 2023
1 parent 6b212e0 commit 610d7b2
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 29 deletions.
6 changes: 3 additions & 3 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ fn parse_negative_int<T>(
from_str_radix(scalar, 10).ok()
}

fn parse_f64(scalar: &str) -> Option<f64> {
pub(crate) fn parse_f64(scalar: &str) -> Option<f64> {
let unpositive = if let Some(unpositive) = scalar.strip_prefix('+') {
if unpositive.starts_with(['+', '-']) {
return None;
Expand All @@ -1089,14 +1089,14 @@ fn parse_f64(scalar: &str) -> Option<f64> {
None
}

fn digits_but_not_number(scalar: &str) -> bool {
pub(crate) fn digits_but_not_number(scalar: &str) -> bool {
// Leading zero(s) followed by numeric characters is a string according to
// the YAML 1.2 spec. https://yaml.org/spec/1.2/spec.html#id2761292
let scalar = scalar.strip_prefix(['-', '+']).unwrap_or(scalar);
scalar.len() > 1 && scalar.starts_with('0') && scalar[1..].bytes().all(|b| b.is_ascii_digit())
}

fn visit_int<'de, V>(visitor: V, v: &str) -> Result<Result<V::Value>, V>
pub(crate) fn visit_int<'de, V>(visitor: V, v: &str) -> Result<Result<V::Value>, V>
where
V: Visitor<'de>,
{
Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub(crate) enum ErrorImpl {
ScalarInMergeElement,
SequenceInMergeElement,
EmptyTag,
FailedToParseNumber,

Shared(Arc<ErrorImpl>),
}
Expand Down Expand Up @@ -239,6 +240,7 @@ impl ErrorImpl {
f.write_str("expected a mapping for merging, but found sequence")
}
ErrorImpl::EmptyTag => f.write_str("empty YAML tag is not allowed"),
ErrorImpl::FailedToParseNumber => f.write_str("failed to parse YAML number"),
ErrorImpl::Shared(_) => unreachable!(),
}
}
Expand Down
70 changes: 44 additions & 26 deletions src/number.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::Error;
use crate::de;
use crate::error::{self, Error, ErrorImpl};
use serde::de::{Unexpected, Visitor};
use serde::{forward_to_deserialize_any, Deserialize, Deserializer, Serialize, Serializer};
use std::cmp::Ordering;
use std::fmt::{self, Display};
use std::hash::{Hash, Hasher};
use std::str::FromStr;

/// Represents a YAML number, whether integer or floating point.
#[derive(Clone, PartialEq, PartialOrd)]
Expand Down Expand Up @@ -308,6 +310,22 @@ impl Display for Number {
}
}

impl FromStr for Number {
type Err = Error;

fn from_str(repr: &str) -> Result<Self, Self::Err> {
if let Ok(result) = de::visit_int(NumberVisitor, repr) {
return result;
}
if !de::digits_but_not_number(repr) {
if let Some(float) = de::parse_f64(repr) {
return Ok(float.into());
}
}
Err(error::new(ErrorImpl::FailedToParseNumber))
}
}

impl PartialEq for N {
fn eq(&self, other: &N) -> bool {
match (*self, *other) {
Expand Down Expand Up @@ -389,37 +407,37 @@ impl Serialize for Number {
}
}

impl<'de> Deserialize<'de> for Number {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Number, D::Error>
where
D: Deserializer<'de>,
{
struct NumberVisitor;
struct NumberVisitor;

impl<'de> Visitor<'de> for NumberVisitor {
type Value = Number;
impl<'de> Visitor<'de> for NumberVisitor {
type Value = Number;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a number")
}
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a number")
}

#[inline]
fn visit_i64<E>(self, value: i64) -> Result<Number, E> {
Ok(value.into())
}
#[inline]
fn visit_i64<E>(self, value: i64) -> Result<Number, E> {
Ok(value.into())
}

#[inline]
fn visit_u64<E>(self, value: u64) -> Result<Number, E> {
Ok(value.into())
}
#[inline]
fn visit_u64<E>(self, value: u64) -> Result<Number, E> {
Ok(value.into())
}

#[inline]
fn visit_f64<E>(self, value: f64) -> Result<Number, E> {
Ok(value.into())
}
}
#[inline]
fn visit_f64<E>(self, value: f64) -> Result<Number, E> {
Ok(value.into())
}
}

impl<'de> Deserialize<'de> for Number {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Number, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(NumberVisitor)
}
}
Expand Down

0 comments on commit 610d7b2

Please sign in to comment.