diff --git a/tests/test_node.py b/tests/test_node.py index 085415fbb..ed8c96aa2 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -44,6 +44,10 @@ def fn() -> int: assert node_copy_copy.name == "rename_fn_again" +np_version = np.__version__ +major, minor, _ = map(int, np_version.split(".")) + + @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher") def test_node_handles_annotated(): from typing import Annotated @@ -56,15 +60,24 @@ def annotated_func(first: ArrayN[np.float64], other: float = 2.0) -> ArrayN[np.f node = Node.from_fn(annotated_func) assert node.name == "annotated_func" - expected = { - "first": ( - Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]], - DependencyType.REQUIRED, - ), - "other": (float, DependencyType.OPTIONAL), - } + if major == 2 and minor > 1: # greater that 2.1 + expected = { + "first": ( + Annotated[np.ndarray[tuple[int, ...], np.dtype[np.float64]], Literal["N"]], + DependencyType.REQUIRED, + ), + "other": (float, DependencyType.OPTIONAL), + } + else: + expected = { + "first": ( + Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]], + DependencyType.REQUIRED, + ), + "other": (float, DependencyType.OPTIONAL), + } assert node.input_types == expected - assert node.type == Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]] + assert node.type == Annotated[np.ndarray[tuple[int, ...], np.dtype[np.float64]], Literal["N"]] @pytest.mark.parametrize(