Skip to content

Commit

Permalink
Don't accept NaN in float and decimal constraints (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Oct 25, 2023
1 parent acf15bf commit 23d1065
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
19 changes: 15 additions & 4 deletions src/validators/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,19 @@ impl Validator for DecimalValidator {
}
}

// Decimal raises DecimalOperation when comparing NaN, so if it's necessary to compare
// the value to a number, we need to check for NaN first. We cache the result on the first
// time we check it.
let mut is_nan: Option<bool> = None;
let mut is_nan = || -> PyResult<bool> {
match is_nan {
Some(is_nan) => Ok(is_nan),
None => Ok(*is_nan.insert(decimal.call_method0(intern!(py, "is_nan"))?.extract()?)),
}
};

if let Some(le) = &self.le {
if !decimal.le(le)? {
if is_nan()? || !decimal.le(le)? {
return Err(ValError::new(
ErrorType::LessThanEqual {
le: Number::String(le.to_string()),
Expand All @@ -194,7 +205,7 @@ impl Validator for DecimalValidator {
}
}
if let Some(lt) = &self.lt {
if !decimal.lt(lt)? {
if is_nan()? || !decimal.lt(lt)? {
return Err(ValError::new(
ErrorType::LessThan {
lt: Number::String(lt.to_string()),
Expand All @@ -205,7 +216,7 @@ impl Validator for DecimalValidator {
}
}
if let Some(ge) = &self.ge {
if !decimal.ge(ge)? {
if is_nan()? || !decimal.ge(ge)? {
return Err(ValError::new(
ErrorType::GreaterThanEqual {
ge: Number::String(ge.to_string()),
Expand All @@ -216,7 +227,7 @@ impl Validator for DecimalValidator {
}
}
if let Some(gt) = &self.gt {
if !decimal.gt(gt)? {
if is_nan()? || !decimal.gt(gt)? {
return Err(ValError::new(
ErrorType::GreaterThan {
gt: Number::String(gt.to_string()),
Expand Down
10 changes: 6 additions & 4 deletions src/validators/float.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cmp::Ordering;

use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand Down Expand Up @@ -129,7 +131,7 @@ impl Validator for ConstrainedFloatValidator {
}
}
if let Some(le) = self.le {
if float > le {
if !matches!(float.partial_cmp(&le), Some(Ordering::Less | Ordering::Equal)) {
return Err(ValError::new(
ErrorType::LessThanEqual {
le: le.into(),
Expand All @@ -140,7 +142,7 @@ impl Validator for ConstrainedFloatValidator {
}
}
if let Some(lt) = self.lt {
if float >= lt {
if !matches!(float.partial_cmp(&lt), Some(Ordering::Less)) {
return Err(ValError::new(
ErrorType::LessThan {
lt: lt.into(),
Expand All @@ -151,7 +153,7 @@ impl Validator for ConstrainedFloatValidator {
}
}
if let Some(ge) = self.ge {
if float < ge {
if !matches!(float.partial_cmp(&ge), Some(Ordering::Greater | Ordering::Equal)) {
return Err(ValError::new(
ErrorType::GreaterThanEqual {
ge: ge.into(),
Expand All @@ -162,7 +164,7 @@ impl Validator for ConstrainedFloatValidator {
}
}
if let Some(gt) = self.gt {
if float <= gt {
if !matches!(float.partial_cmp(&gt), Some(Ordering::Greater)) {
return Err(ValError::new(
ErrorType::GreaterThan {
gt: gt.into(),
Expand Down
4 changes: 4 additions & 0 deletions tests/validators/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,16 @@ def test_decimal_strict_json(input_value, expected):
({'le': 0}, 0, Decimal(0)),
({'le': 0}, -1, Decimal(-1)),
({'le': 0}, 0.1, Err('Input should be less than or equal to 0')),
({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')),
({'gt': 0, 'allow_inf_nan': True}, float('inf'), Decimal('inf')),
({'lt': 0}, 0, Err('Input should be less than 0')),
({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')),
],
)
def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected):
v = py_and_json({'type': 'decimal', **kwargs})
if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value):
expected = Err('Invalid JSON')
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
Expand Down
4 changes: 4 additions & 0 deletions tests/validators/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ def test_float_strict(py_and_json: PyAndJson, input_value, expected):
({'le': 0}, 0.1, Err('Input should be less than or equal to 0')),
({'lt': 0}, 0, Err('Input should be less than 0')),
({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')),
({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')),
({'gt': 0, 'allow_inf_nan': True}, float('inf'), float('inf')),
],
)
def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected):
v = py_and_json({'type': 'float', **kwargs})
if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value):
expected = Err('Invalid JSON')
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
v.validate_test(input_value)
Expand Down

0 comments on commit 23d1065

Please sign in to comment.