diff --git a/pydra/engine/specs.py b/pydra/engine/specs.py index cccd272a9a..bbfbd57941 100644 --- a/pydra/engine/specs.py +++ b/pydra/engine/specs.py @@ -991,8 +991,8 @@ def get_value( f"named '{node.checksum}' in any of the cache locations.\n" + "\n".join(str(p) for p in set(node.cache_locations)) + f"\n\nThis is likely due to hash changes in '{self.name}' node inputs. " - f"Current values and hashes: {self.inputs}, " - f"{self.inputs._hashes}\n\n" + f"Current values and hashes: {node.inputs}, " + f"{node.inputs._hashes}\n\n" "Set loglevel to 'debug' in order to track hash changes " "throughout the execution of the workflow.\n\n " "These issues may have been caused by `bytes_repr()` methods " diff --git a/pydra/utils/hash.py b/pydra/utils/hash.py index 6f35da0f76..1946b4b364 100644 --- a/pydra/utils/hash.py +++ b/pydra/utils/hash.py @@ -1,9 +1,11 @@ """Generic object hashing dispatch""" +import sys import os import struct from datetime import datetime import typing as ty +import types from pathlib import Path from collections.abc import Mapping from functools import singledispatch @@ -467,6 +469,10 @@ def type_name(tp): yield b")" +if sys.version_info >= (3, 10): + register_serializer(types.UnionType)(bytes_repr_type) + + @register_serializer(FileSet) def bytes_repr_fileset( fileset: FileSet, cache: Cache diff --git a/pydra/utils/tests/test_hash.py b/pydra/utils/tests/test_hash.py index 2c74de6e48..de065a03de 100644 --- a/pydra/utils/tests/test_hash.py +++ b/pydra/utils/tests/test_hash.py @@ -1,5 +1,6 @@ import re import os +import sys from hashlib import blake2b from pathlib import Path import time @@ -200,6 +201,14 @@ def test_bytes_special_form1(): assert obj_repr == b"type:(typing.Union[type:(builtins.int)type:(builtins.float)])" +@pytest.mark.skipif(condition=sys.version_info < (3, 10), reason="requires python3.10") +def test_bytes_special_form1a(): + obj_repr = join_bytes_repr(int | float) + assert ( + obj_repr == b"type:(types.UnionType[type:(builtins.int)type:(builtins.float)])" + ) + + def test_bytes_special_form2(): obj_repr = join_bytes_repr(ty.Any) assert re.match(rb"type:\(typing.Any\)", obj_repr) @@ -212,6 +221,15 @@ def test_bytes_special_form3(): ) +@pytest.mark.skipif(condition=sys.version_info < (3, 10), reason="requires python3.10") +def test_bytes_special_form3a(): + obj_repr = join_bytes_repr(Path | None) + assert ( + obj_repr + == b"type:(types.UnionType[type:(pathlib.Path)type:(builtins.NoneType)])" + ) + + def test_bytes_special_form4(): obj_repr = join_bytes_repr(ty.Type[Path]) assert obj_repr == b"type:(builtins.type[type:(pathlib.Path)])" diff --git a/pydra/utils/tests/test_typing.py b/pydra/utils/tests/test_typing.py index 665d79327d..b41aefd2a8 100644 --- a/pydra/utils/tests/test_typing.py +++ b/pydra/utils/tests/test_typing.py @@ -120,6 +120,11 @@ def test_type_check_basic15(): TypeParser(ty.Union[Path, File, float])(lz(int)) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_check_basic15a(): + TypeParser(Path | File | float)(lz(int)) + + def test_type_check_basic16(): with pytest.raises( TypeError, match="Cannot coerce to any of the union types" @@ -127,6 +132,14 @@ def test_type_check_basic16(): TypeParser(ty.Union[Path, File, bool, int])(lz(float)) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_check_basic16a(): + with pytest.raises( + TypeError, match="Cannot coerce to any of the union types" + ): + TypeParser(Path | File | bool | int)(lz(float)) + + def test_type_check_basic17(): TypeParser(ty.Sequence)(lz(ty.Tuple[int, ...])) @@ -194,6 +207,12 @@ def test_type_check_fail2(): TypeParser(ty.Union[Path, File])(lz(int)) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_check_fail2a(): + with pytest.raises(TypeError, match="to any of the union types"): + TypeParser(Path | File)(lz(int)) + + def test_type_check_fail3(): with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( @@ -312,6 +331,18 @@ def test_type_coercion_basic12(): assert TypeParser(ty.Union[Path, File, int], coercible=[(ty.Any, ty.Any)])(1.0) == 1 +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_coercion_basic12a(): + with pytest.raises(TypeError, match="explicitly excluded"): + TypeParser( + list, + coercible=[(ty.Sequence, ty.Sequence)], + not_coercible=[(str, ty.Sequence)], + )("a-string") + + assert TypeParser(Path | File | int, coercible=[(ty.Any, ty.Any)])(1.0) == 1 + + def test_type_coercion_basic13(): assert ( TypeParser(ty.Union[Path, File, bool, int], coercible=[(ty.Any, ty.Any)])(1.0) @@ -319,6 +350,13 @@ def test_type_coercion_basic13(): ) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_coercion_basic13a(): + assert ( + TypeParser(Path | File | bool | int, coercible=[(ty.Any, ty.Any)])(1.0) is True + ) + + def test_type_coercion_basic14(): assert TypeParser(ty.Sequence, coercible=[(ty.Any, ty.Any)])((1, 2, 3)) == ( 1, @@ -404,6 +442,12 @@ def test_type_coercion_fail2(): TypeParser(ty.Union[Path, File], coercible=[(ty.Any, ty.Any)])(1) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_coercion_fail2a(): + with pytest.raises(TypeError, match="to any of the union types"): + TypeParser(Path | File, coercible=[(ty.Any, ty.Any)])(1) + + def test_type_coercion_fail3(): with pytest.raises(TypeError, match="doesn't match any of the explicit inclusion"): TypeParser(ty.Sequence, coercible=[(ty.Sequence, ty.Sequence)])( @@ -446,7 +490,7 @@ def f(x: ty.List[File], y: ty.Dict[str, ty.List[File]]): TypeParser(ty.List[str])(task.lzout.a) # pylint: disable=no-member with pytest.raises( TypeError, - match="Cannot coerce into ", + match="Cannot coerce into ", ): TypeParser(ty.List[int])(task.lzout.a) # pylint: disable=no-member @@ -469,6 +513,27 @@ def test_matches_type_union(): assert not TypeParser.matches_type(ty.Union[int, bool, str], ty.Union[int, bool]) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_matches_type_union_a(): + assert TypeParser.matches_type(int | bool | str, int | bool | str) + assert TypeParser.matches_type(int | bool, int | bool | str) + assert not TypeParser.matches_type(int | bool | str, int | bool) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_matches_type_union_b(): + assert TypeParser.matches_type(int | bool | str, ty.Union[int, bool, str]) + assert TypeParser.matches_type(int | bool, ty.Union[int, bool, str]) + assert not TypeParser.matches_type(int | bool | str, ty.Union[int, bool]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_matches_type_union_c(): + assert TypeParser.matches_type(ty.Union[int, bool, str], int | bool | str) + assert TypeParser.matches_type(ty.Union[int, bool], int | bool | str) + assert not TypeParser.matches_type(ty.Union[int, bool, str], int | bool) + + def test_matches_type_dict(): COERCIBLE = [(str, Path), (Path, str), (int, float)] @@ -713,18 +778,61 @@ def test_union_is_subclass1(): assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml]) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass1a(): + assert TypeParser.is_subclass(Json | Yaml, Json | Yaml | Xml) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass1b(): + assert TypeParser.is_subclass(Json | Yaml, ty.Union[Json, Yaml, Xml]) + + +## Up to here! + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass1c(): + assert TypeParser.is_subclass(ty.Union[Json, Yaml], Json | Yaml | Xml) + + def test_union_is_subclass2(): assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml]) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass2a(): + assert not TypeParser.is_subclass(Json | Yaml | Xml, Json | Yaml) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass2b(): + assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], Json | Yaml) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass2c(): + assert not TypeParser.is_subclass(Json | Yaml | Xml, ty.Union[Json, Yaml]) + + def test_union_is_subclass3(): assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml]) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass3a(): + assert TypeParser.is_subclass(Json, Json | Yaml) + + def test_union_is_subclass4(): assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_union_is_subclass4a(): + assert not TypeParser.is_subclass(Json | Yaml, Json) + + def test_generic_is_subclass1(): assert TypeParser.is_subclass(ty.List[int], list) @@ -737,6 +845,56 @@ def test_generic_is_subclass3(): assert not TypeParser.is_subclass(ty.List[float], ty.List[int]) +def test_none_is_subclass1(): + assert TypeParser.is_subclass(None, ty.Union[int, None]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_none_is_subclass1a(): + assert TypeParser.is_subclass(None, int | None) + + +def test_none_is_subclass2(): + assert not TypeParser.is_subclass(None, ty.Union[int, float]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_none_is_subclass2a(): + assert not TypeParser.is_subclass(None, int | float) + + +def test_none_is_subclass3(): + assert TypeParser.is_subclass(ty.Tuple[int, None], ty.Tuple[int, None]) + + +def test_none_is_subclass4(): + assert TypeParser.is_subclass(None, None) + + +def test_none_is_subclass5(): + assert not TypeParser.is_subclass(None, int) + + +def test_none_is_subclass6(): + assert not TypeParser.is_subclass(int, None) + + +def test_none_is_subclass7(): + assert TypeParser.is_subclass(None, type(None)) + + +def test_none_is_subclass8(): + assert TypeParser.is_subclass(type(None), None) + + +def test_none_is_subclass9(): + assert TypeParser.is_subclass(type(None), type(None)) + + +def test_none_is_subclass10(): + assert TypeParser.is_subclass(type(None), type(None)) + + @pytest.mark.skipif( sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9" ) @@ -780,3 +938,46 @@ def test_type_is_instance3(): def test_type_is_instance4(): assert TypeParser.is_instance(Json, type) + + +def test_type_is_instance5(): + assert TypeParser.is_instance(None, None) + + +def test_type_is_instance6(): + assert TypeParser.is_instance(None, type(None)) + + +def test_type_is_instance7(): + assert not TypeParser.is_instance(None, int) + + +def test_type_is_instance8(): + assert not TypeParser.is_instance(1, None) + + +def test_type_is_instance9(): + assert TypeParser.is_instance(None, ty.Union[int, None]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_is_instance9a(): + assert TypeParser.is_instance(None, int | None) + + +def test_type_is_instance10(): + assert TypeParser.is_instance(1, ty.Union[int, None]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_is_instance10a(): + assert TypeParser.is_instance(1, int | None) + + +def test_type_is_instance11(): + assert not TypeParser.is_instance(None, ty.Union[int, str]) + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="No UnionType < Py3.10") +def test_type_is_instance11a(): + assert not TypeParser.is_instance(None, int | str) diff --git a/pydra/utils/typing.py b/pydra/utils/typing.py index ee8e733e44..c765b1339c 100644 --- a/pydra/utils/typing.py +++ b/pydra/utils/typing.py @@ -3,6 +3,7 @@ from pathlib import Path import os import sys +import types import typing as ty import logging import attr @@ -20,6 +21,11 @@ # Python < 3.8 from typing_extensions import get_origin, get_args # type: ignore +if sys.version_info >= (3, 10): + UNION_TYPES = (ty.Union, types.UnionType) +else: + UNION_TYPES = (ty.Union,) + logger = logging.getLogger("pydra") NO_GENERIC_ISSUBCLASS = sys.version_info.major == 3 and sys.version_info.minor < 10 @@ -128,7 +134,9 @@ def expand_pattern(t): # If no args were provided, or those arguments were an ellipsis assert isinstance(origin, type) return origin - if origin not in (ty.Union, type) and not issubclass(origin, ty.Iterable): + if origin not in UNION_TYPES + (type,) and not issubclass( + origin, ty.Iterable + ): raise TypeError( f"TypeParser doesn't know how to handle args ({args}) for {origin} " f"types{self.label_str}" @@ -209,7 +217,7 @@ def expand_and_coerce(obj, pattern: ty.Union[type, tuple]): if not isinstance(pattern, tuple): return coerce_basic(obj, pattern) origin, pattern_args = pattern - if origin is ty.Union: + if origin in UNION_TYPES: return coerce_union(obj, pattern_args) if origin is type: return coerce_type(obj, pattern_args) @@ -370,7 +378,7 @@ def expand_and_check(tp, pattern: ty.Union[type, tuple]): if not isinstance(pattern, tuple): return check_basic(tp, pattern) pattern_origin, pattern_args = pattern - if pattern_origin is ty.Union: + if pattern_origin in UNION_TYPES: return check_union(tp, pattern_args) tp_origin = get_origin(tp) if tp_origin is None: @@ -402,7 +410,7 @@ def check_basic(tp, target): self.check_coercible(tp, target) def check_union(tp, pattern_args): - if get_origin(tp) is ty.Union: + if get_origin(tp) in UNION_TYPES: for tp_arg in get_args(tp): reasons = [] for pattern_arg in pattern_args: @@ -603,7 +611,7 @@ def matches_type( def is_instance( cls, obj: object, - candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]]], + candidates: ty.Union[ty.Type[ty.Any], ty.Sequence[ty.Type[ty.Any]], None], ) -> bool: """Checks whether the object is an instance of cls or that cls is typing.Any, extending the built-in isinstance to check nested type args @@ -615,6 +623,8 @@ def is_instance( candidates : type or ty.Iterable[type] the candidate types to check the object against """ + if candidates is None: + candidates = [type(None)] if not isinstance(candidates, ty.Sequence): candidates = [candidates] for candidate in candidates: @@ -656,6 +666,9 @@ def is_subclass( any_ok : bool whether klass=typing.Any should return True or False """ + if klass is None: + # Implicitly convert None to NoneType, like in other typing + klass = type(None) if not isinstance(candidates, ty.Sequence): candidates = [candidates] if ty.Any in candidates: @@ -667,6 +680,8 @@ def is_subclass( args = get_args(klass) for candidate in candidates: + if candidate is None: + candidate = type(None) candidate_origin = get_origin(candidate) candidate_args = get_args(candidate) # Handle ty.Type[*] types in klass and candidates @@ -684,9 +699,11 @@ def is_subclass( ): return True else: - if origin is ty.Union: + if origin in UNION_TYPES: union_args = ( - candidate_args if candidate_origin is ty.Union else (candidate,) + candidate_args + if candidate_origin in UNION_TYPES + else (candidate,) ) matches = all( any(cls.is_subclass(a, c) for c in union_args) for a in args @@ -694,7 +711,7 @@ def is_subclass( if matches: return True else: - if candidate_args and candidate_origin is not ty.Union: + if candidate_args and candidate_origin not in UNION_TYPES: if ( origin and issubclass(origin, candidate_origin) # type: ignore[arg-type] @@ -728,7 +745,7 @@ def contains_type(cls, target: ty.Type[ty.Any], type_: ty.Type[ty.Any]): if not type_args: return False type_origin = get_origin(type_) - if type_origin is ty.Union: + if type_origin in UNION_TYPES: for type_arg in type_args: if cls.contains_type(target, type_arg): return True @@ -851,7 +868,7 @@ def strip_splits(cls, type_: ty.Type[ty.Any]) -> ty.Tuple[ty.Type, int]: while cls.is_subclass(type_, StateArray) and not cls.is_subclass(type_, str): origin = get_origin(type_) # If type is a union, pick the first sequence type in the union - if origin is ty.Union: + if origin in UNION_TYPES: for tp in get_args(type_): if cls.is_subclass(tp, ty.Sequence): type_ = tp