From 23d106551d14214f24a0bcd577b74a672ccfd0c0 Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 25 Oct 2023 09:52:26 +0100 Subject: [PATCH] Don't accept NaN in float and decimal constraints (#1037) --- src/validators/decimal.rs | 19 +++++++++++++++---- src/validators/float.rs | 10 ++++++---- tests/validators/test_decimal.py | 4 ++++ tests/validators/test_float.py | 4 ++++ 4 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/validators/decimal.rs b/src/validators/decimal.rs index be19d1eda..730eeac69 100644 --- a/src/validators/decimal.rs +++ b/src/validators/decimal.rs @@ -182,8 +182,19 @@ impl Validator for DecimalValidator { } } + // Decimal raises DecimalOperation when comparing NaN, so if it's necessary to compare + // the value to a number, we need to check for NaN first. We cache the result on the first + // time we check it. + let mut is_nan: Option = None; + let mut is_nan = || -> PyResult { + match is_nan { + Some(is_nan) => Ok(is_nan), + None => Ok(*is_nan.insert(decimal.call_method0(intern!(py, "is_nan"))?.extract()?)), + } + }; + if let Some(le) = &self.le { - if !decimal.le(le)? { + if is_nan()? || !decimal.le(le)? { return Err(ValError::new( ErrorType::LessThanEqual { le: Number::String(le.to_string()), @@ -194,7 +205,7 @@ impl Validator for DecimalValidator { } } if let Some(lt) = &self.lt { - if !decimal.lt(lt)? { + if is_nan()? || !decimal.lt(lt)? { return Err(ValError::new( ErrorType::LessThan { lt: Number::String(lt.to_string()), @@ -205,7 +216,7 @@ impl Validator for DecimalValidator { } } if let Some(ge) = &self.ge { - if !decimal.ge(ge)? { + if is_nan()? || !decimal.ge(ge)? { return Err(ValError::new( ErrorType::GreaterThanEqual { ge: Number::String(ge.to_string()), @@ -216,7 +227,7 @@ impl Validator for DecimalValidator { } } if let Some(gt) = &self.gt { - if !decimal.gt(gt)? { + if is_nan()? || !decimal.gt(gt)? { return Err(ValError::new( ErrorType::GreaterThan { gt: Number::String(gt.to_string()), diff --git a/src/validators/float.rs b/src/validators/float.rs index 1d62d2006..646d8f4d8 100644 --- a/src/validators/float.rs +++ b/src/validators/float.rs @@ -1,3 +1,5 @@ +use std::cmp::Ordering; + use pyo3::intern; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -129,7 +131,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(le) = self.le { - if float > le { + if !matches!(float.partial_cmp(&le), Some(Ordering::Less | Ordering::Equal)) { return Err(ValError::new( ErrorType::LessThanEqual { le: le.into(), @@ -140,7 +142,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(lt) = self.lt { - if float >= lt { + if !matches!(float.partial_cmp(<), Some(Ordering::Less)) { return Err(ValError::new( ErrorType::LessThan { lt: lt.into(), @@ -151,7 +153,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(ge) = self.ge { - if float < ge { + if !matches!(float.partial_cmp(&ge), Some(Ordering::Greater | Ordering::Equal)) { return Err(ValError::new( ErrorType::GreaterThanEqual { ge: ge.into(), @@ -162,7 +164,7 @@ impl Validator for ConstrainedFloatValidator { } } if let Some(gt) = self.gt { - if float <= gt { + if !matches!(float.partial_cmp(>), Some(Ordering::Greater)) { return Err(ValError::new( ErrorType::GreaterThan { gt: gt.into(), diff --git a/tests/validators/test_decimal.py b/tests/validators/test_decimal.py index 376a9816a..cd54c89ae 100644 --- a/tests/validators/test_decimal.py +++ b/tests/validators/test_decimal.py @@ -148,12 +148,16 @@ def test_decimal_strict_json(input_value, expected): ({'le': 0}, 0, Decimal(0)), ({'le': 0}, -1, Decimal(-1)), ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), + ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), + ({'gt': 0, 'allow_inf_nan': True}, float('inf'), Decimal('inf')), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), ], ) def test_decimal_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'decimal', **kwargs}) + if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): + expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value) diff --git a/tests/validators/test_float.py b/tests/validators/test_float.py index 74f0024ca..b18181fbb 100644 --- a/tests/validators/test_float.py +++ b/tests/validators/test_float.py @@ -86,10 +86,14 @@ def test_float_strict(py_and_json: PyAndJson, input_value, expected): ({'le': 0}, 0.1, Err('Input should be less than or equal to 0')), ({'lt': 0}, 0, Err('Input should be less than 0')), ({'lt': 0.123456}, 1, Err('Input should be less than 0.123456')), + ({'lt': 0, 'allow_inf_nan': True}, float('nan'), Err('Input should be less than 0')), + ({'gt': 0, 'allow_inf_nan': True}, float('inf'), float('inf')), ], ) def test_float_kwargs(py_and_json: PyAndJson, kwargs: Dict[str, Any], input_value, expected): v = py_and_json({'type': 'float', **kwargs}) + if v.validator_type == 'json' and isinstance(input_value, float) and not math.isfinite(input_value): + expected = Err('Invalid JSON') if isinstance(expected, Err): with pytest.raises(ValidationError, match=re.escape(expected.message)): v.validate_test(input_value)