diff --git a/examples/test_example_functions_fail.py b/examples/test_example_functions_fail.py index 36755d2..91287b3 100644 --- a/examples/test_example_functions_fail.py +++ b/examples/test_example_functions_fail.py @@ -18,6 +18,10 @@ def test_is(): check.is_(x, y) +def test_is_nan(): + check.is_nan(42) + + def test_is_not(): x = ["foo"] y = x diff --git a/examples/test_example_functions_pass.py b/examples/test_example_functions_pass.py index e216ebe..ec93468 100644 --- a/examples/test_example_functions_pass.py +++ b/examples/test_example_functions_pass.py @@ -2,7 +2,7 @@ Passing versions of all of the check helper functions. """ from pytest_check import check - +import math def test_equal(): check.equal(1, 1) @@ -17,6 +17,8 @@ def test_is(): y = x check.is_(x, y) +def test_is_nan(): + check.is_nan(math.nan) def test_is_not(): x = ["foo"] diff --git a/src/pytest_check/check_functions.py b/src/pytest_check/check_functions.py index c601a31..9e23c61 100644 --- a/src/pytest_check/check_functions.py +++ b/src/pytest_check/check_functions.py @@ -1,7 +1,7 @@ import functools import pytest - +import math from .check_log import log_failure __all__ = [ @@ -15,6 +15,7 @@ "is_none", "is_not_none", "is_in", + "is_nan", "is_not_in", "is_instance", "is_not_instance", @@ -75,6 +76,14 @@ def is_(a, b, msg=""): log_failure(f"check {a} is {b}", msg) return False +def is_nan(a, msg=""): + __tracebackhide__ = True + if math.isnan(a): + return True + else: + log_failure(f"check {a} is NaN", msg) + return False + def is_not(a, b, msg=""): __tracebackhide__ = True diff --git a/tests/test_functions.py b/tests/test_functions.py index 8f220b1..48f9de8 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1,10 +1,10 @@ def test_passing_check_functions(pytester): pytester.copy_example("examples/test_example_functions_pass.py") result = pytester.runpytest() - result.assert_outcomes(failed=0, passed=23) + result.assert_outcomes(failed=0, passed=24) def test_failing_check_functions(pytester): pytester.copy_example("examples/test_example_functions_fail.py") result = pytester.runpytest() - result.assert_outcomes(failed=23, passed=0) + result.assert_outcomes(failed=24, passed=0)