diff --git a/Doc/deprecations/pending-removal-in-future.rst b/Doc/deprecations/pending-removal-in-future.rst index 3be0dabfd1f257..f4b471523c1211 100644 --- a/Doc/deprecations/pending-removal-in-future.rst +++ b/Doc/deprecations/pending-removal-in-future.rst @@ -128,6 +128,11 @@ although there is currently no date scheduled for their removal. * :class:`typing.Text` (:gh:`92332`). +* The internal class ``typing._UnionGenericAlias`` is no longer used to implement + :class:`typing.Union`. To preserve compatibility with users using this private + class, a compatibility shim will be provided until at least Python 3.17. (Contributed by + Jelle Zijlstra in :gh:`105499`.) + * :class:`unittest.IsolatedAsyncioTestCase`: it is deprecated to return a value that is not ``None`` from a test case. diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index e26a2226aa947a..d6332bfd1b6783 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -515,7 +515,7 @@ The :mod:`functools` module defines the following functions: ... for i, elem in enumerate(arg): ... print(i, elem) - :data:`types.UnionType` and :data:`typing.Union` can also be used:: + :class:`typing.Union` can also be used:: >>> @fun.register ... def _(arg: int | float, verbose=False): @@ -651,8 +651,8 @@ The :mod:`functools` module defines the following functions: The :func:`register` attribute now supports using type annotations. .. versionchanged:: 3.11 - The :func:`register` attribute now supports :data:`types.UnionType` - and :data:`typing.Union` as type annotations. + The :func:`register` attribute now supports + :class:`typing.Union` as a type annotation. .. class:: singledispatchmethod(func) diff --git a/Doc/library/stdtypes.rst b/Doc/library/stdtypes.rst index a9b7662dcb212b..a62e58bd6ba0cf 100644 --- a/Doc/library/stdtypes.rst +++ b/Doc/library/stdtypes.rst @@ -5211,7 +5211,7 @@ Union Type A union object holds the value of the ``|`` (bitwise or) operation on multiple :ref:`type objects `. These types are intended primarily for :term:`type annotations `. The union type expression -enables cleaner type hinting syntax compared to :data:`typing.Union`. +enables cleaner type hinting syntax compared to subscripting :class:`typing.Union`. .. describe:: X | Y | ... @@ -5247,9 +5247,10 @@ enables cleaner type hinting syntax compared to :data:`typing.Union`. int | str == str | int - * It is compatible with :data:`typing.Union`:: + * It creates instances of :class:`typing.Union`:: int | str == typing.Union[int, str] + type(int | str) is typing.Union * Optional types can be spelled as a union with ``None``:: @@ -5275,16 +5276,15 @@ enables cleaner type hinting syntax compared to :data:`typing.Union`. TypeError: isinstance() argument 2 cannot be a parameterized generic The user-exposed type for the union object can be accessed from -:data:`types.UnionType` and used for :func:`isinstance` checks. An object cannot be -instantiated from the type:: +:class:`typing.Union` and used for :func:`isinstance` checks:: - >>> import types - >>> isinstance(int | str, types.UnionType) + >>> import typing + >>> isinstance(int | str, typing.Union) True - >>> types.UnionType() + >>> typing.Union() Traceback (most recent call last): File "", line 1, in - TypeError: cannot create 'types.UnionType' instances + TypeError: cannot create 'typing.Union' instances .. note:: The :meth:`!__or__` method for type objects was added to support the syntax diff --git a/Doc/library/types.rst b/Doc/library/types.rst index 439e119461f798..2bedd7fdd3c8c8 100644 --- a/Doc/library/types.rst +++ b/Doc/library/types.rst @@ -314,6 +314,10 @@ Standard names are defined for the following types: .. versionadded:: 3.10 + .. versionchanged:: 3.14 + + This is now an alias for :class:`typing.Union`. + .. class:: TracebackType(tb_next, tb_frame, tb_lasti, tb_lineno) The type of traceback objects such as found in ``sys.exception().__traceback__``. diff --git a/Doc/library/typing.rst b/Doc/library/typing.rst index cd8b90854b0e94..e871d3b49dcde1 100644 --- a/Doc/library/typing.rst +++ b/Doc/library/typing.rst @@ -1086,7 +1086,7 @@ Special forms These can be used as types in annotations. They all support subscription using ``[]``, but each has a unique syntax. -.. data:: Union +.. class:: Union Union type; ``Union[X, Y]`` is equivalent to ``X | Y`` and means either X or Y. @@ -1121,6 +1121,14 @@ These can be used as types in annotations. They all support subscription using Unions can now be written as ``X | Y``. See :ref:`union type expressions`. + .. versionchanged:: 3.14 + :class:`types.UnionType` is now an alias for :class:`Union`, and both + ``Union[int, str]`` and ``int | str`` create instances of the same class. + To check whether an object is a ``Union`` at runtime, use + ``isinstance(obj, Union)``. For compatibility with earlier versions of + Python, use + ``get_origin(obj) is typing.Union or get_origin(obj) is types.UnionType``. + .. data:: Optional ``Optional[X]`` is equivalent to ``X | None`` (or ``Union[X, None]``). diff --git a/Doc/whatsnew/3.10.rst b/Doc/whatsnew/3.10.rst index e4699fbf8edaf7..3c815721a92f8c 100644 --- a/Doc/whatsnew/3.10.rst +++ b/Doc/whatsnew/3.10.rst @@ -722,10 +722,10 @@ PEP 604: New Type Union Operator A new type union operator was introduced which enables the syntax ``X | Y``. This provides a cleaner way of expressing 'either type X or type Y' instead of -using :data:`typing.Union`, especially in type hints. +using :class:`typing.Union`, especially in type hints. In previous versions of Python, to apply a type hint for functions accepting -arguments of multiple types, :data:`typing.Union` was used:: +arguments of multiple types, :class:`typing.Union` was used:: def square(number: Union[int, float]) -> Union[int, float]: return number ** 2 diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst index e5c6d7cd308504..080ea4539cbc87 100644 --- a/Doc/whatsnew/3.11.rst +++ b/Doc/whatsnew/3.11.rst @@ -740,8 +740,8 @@ fractions functools --------- -* :func:`functools.singledispatch` now supports :data:`types.UnionType` - and :data:`typing.Union` as annotations to the dispatch argument.:: +* :func:`functools.singledispatch` now supports :class:`types.UnionType` + and :class:`typing.Union` as annotations to the dispatch argument.:: >>> from functools import singledispatch >>> @singledispatch diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst index de4c7fd4c0486b..d9960745f98817 100644 --- a/Doc/whatsnew/3.13.rst +++ b/Doc/whatsnew/3.13.rst @@ -1488,7 +1488,6 @@ Optimizations Removed Modules And APIs ======================== - .. _whatsnew313-pep594: PEP 594: Remove "dead batteries" from the standard library diff --git a/Include/internal/pycore_unionobject.h b/Include/internal/pycore_unionobject.h index 6ece7134cdeca0..4bd36f6504d42c 100644 --- a/Include/internal/pycore_unionobject.h +++ b/Include/internal/pycore_unionobject.h @@ -18,6 +18,7 @@ PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject *, PyObject *); extern PyObject *_Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *); extern PyObject *_Py_make_parameters(PyObject *); extern PyObject *_Py_union_args(PyObject *self); +extern PyObject *_Py_union_from_tuple(PyObject *args); #ifdef __cplusplus } diff --git a/Lib/functools.py b/Lib/functools.py index 27abd622a8cff1..ac1b184e3e32a4 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -928,16 +928,11 @@ def dispatch(cls): dispatch_cache[cls] = impl return impl - def _is_union_type(cls): - from typing import get_origin, Union - return get_origin(cls) in {Union, UnionType} - def _is_valid_dispatch_type(cls): if isinstance(cls, type): return True - from typing import get_args - return (_is_union_type(cls) and - all(isinstance(arg, type) for arg in get_args(cls))) + return (isinstance(cls, UnionType) and + all(isinstance(arg, type) for arg in cls.__args__)) def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -969,7 +964,7 @@ def register(cls, func=None): from annotationlib import Format, ForwardRef argname, cls = next(iter(get_type_hints(func, format=Format.FORWARDREF).items())) if not _is_valid_dispatch_type(cls): - if _is_union_type(cls): + if isinstance(cls, UnionType): raise TypeError( f"Invalid annotation for {argname!r}. " f"{cls!r} not all arguments are classes." @@ -985,10 +980,8 @@ def register(cls, func=None): f"{cls!r} is not a class." ) - if _is_union_type(cls): - from typing import get_args - - for arg in get_args(cls): + if isinstance(cls, UnionType): + for arg in cls.__args__: registry[arg] = func else: registry[cls] = func diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 2e6c49e29ce828..84f251a3c7801e 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -2313,7 +2313,7 @@ def test_docstring_one_field_with_default_none(self): class C: x: Union[int, type(None)] = None - self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") + self.assertDocStrEqual(C.__doc__, "C(x:int|None=None)") def test_docstring_list_field(self): @dataclass diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index d590af090abc6e..22765c580090f1 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -3033,7 +3033,7 @@ def _(arg: typing.Union[int, typing.Iterable[str]]): "Invalid annotation for 'arg'." )) self.assertTrue(str(exc.exception).endswith( - 'typing.Union[int, typing.Iterable[str]] not all arguments are classes.' + 'int | typing.Iterable[str] not all arguments are classes.' )) def test_invalid_positional_argument(self): diff --git a/Lib/test/test_inspect/test_inspect.py b/Lib/test/test_inspect/test_inspect.py index a4430a868676e2..03aa3b81d141f5 100644 --- a/Lib/test/test_inspect/test_inspect.py +++ b/Lib/test/test_inspect/test_inspect.py @@ -1736,8 +1736,8 @@ class C(metaclass=M): class TestFormatAnnotation(unittest.TestCase): def test_typing_replacement(self): from test.typinganndata.ann_module9 import ann, ann1 - self.assertEqual(inspect.formatannotation(ann), 'Union[List[str], int]') - self.assertEqual(inspect.formatannotation(ann1), 'Union[List[testModule.typing.A], int]') + self.assertEqual(inspect.formatannotation(ann), 'List[str] | int') + self.assertEqual(inspect.formatannotation(ann1), 'List[testModule.typing.A] | int') class TestIsMethodDescriptor(unittest.TestCase): diff --git a/Lib/test/test_pydoc/test_pydoc.py b/Lib/test/test_pydoc/test_pydoc.py index 2a4d3ab73db608..a76e89222d20d2 100644 --- a/Lib/test/test_pydoc/test_pydoc.py +++ b/Lib/test/test_pydoc/test_pydoc.py @@ -132,7 +132,7 @@ class C(builtins.object) c_alias = test.test_pydoc.pydoc_mod.C[int] list_alias1 = typing.List[int] list_alias2 = list[int] - type_union1 = typing.Union[int, str] + type_union1 = int | str type_union2 = int | str VERSION @@ -222,7 +222,7 @@ class C(builtins.object) c_alias = test.test_pydoc.pydoc_mod.C[int] list_alias1 = typing.List[int] list_alias2 = list[int] - type_union1 = typing.Union[int, str] + type_union1 = int | str type_union2 = int | str Author @@ -1363,17 +1363,17 @@ def test_generic_alias(self): self.assertIn(list.__doc__.strip().splitlines()[0], doc) def test_union_type(self): - self.assertEqual(pydoc.describe(typing.Union[int, str]), '_UnionGenericAlias') + self.assertEqual(pydoc.describe(typing.Union[int, str]), 'Union') doc = pydoc.render_doc(typing.Union[int, str], renderer=pydoc.plaintext) - self.assertIn('_UnionGenericAlias in module typing', doc) - self.assertIn('Union = typing.Union', doc) + self.assertIn('Union in module typing', doc) + self.assertIn('class Union(builtins.object)', doc) if typing.Union.__doc__: self.assertIn(typing.Union.__doc__.strip().splitlines()[0], doc) - self.assertEqual(pydoc.describe(int | str), 'UnionType') + self.assertEqual(pydoc.describe(int | str), 'Union') doc = pydoc.render_doc(int | str, renderer=pydoc.plaintext) - self.assertIn('UnionType in module types object', doc) - self.assertIn('\nclass UnionType(builtins.object)', doc) + self.assertIn('Union in module typing', doc) + self.assertIn('class Union(builtins.object)', doc) if not MISSING_C_DOCSTRINGS: self.assertIn(types.UnionType.__doc__.strip().splitlines()[0], doc) diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index d1161719d98040..5a65b5dacaf581 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -709,10 +709,6 @@ def test_or_types_operator(self): y = int | bool with self.assertRaises(TypeError): x < y - # Check that we don't crash if typing.Union does not have a tuple in __args__ - y = typing.Union[str, int] - y.__args__ = [str, int] - self.assertEqual(x, y) def test_hash(self): self.assertEqual(hash(int | str), hash(str | int)) @@ -727,17 +723,40 @@ class B(metaclass=UnhashableMeta): ... self.assertEqual((A | B).__args__, (A, B)) union1 = A | B - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union1) union2 = int | B - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union2) union3 = A | int - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union3) + def test_unhashable_becomes_hashable(self): + is_hashable = False + class UnhashableMeta(type): + def __hash__(self): + if is_hashable: + return 1 + else: + raise TypeError("not hashable") + + class A(metaclass=UnhashableMeta): ... + class B(metaclass=UnhashableMeta): ... + + union = A | B + self.assertEqual(union.__args__, (A, B)) + + with self.assertRaisesRegex(TypeError, "not hashable"): + hash(union) + + is_hashable = True + + with self.assertRaisesRegex(TypeError, "union contains 2 unhashable elements"): + hash(union) + def test_instancecheck_and_subclasscheck(self): for x in (int | str, typing.Union[int, str]): with self.subTest(x=x): @@ -921,7 +940,7 @@ def forward_before(x: ForwardBefore[int]) -> None: ... self.assertEqual(typing.get_args(typing.get_type_hints(forward_after)['x']), (int, Forward)) self.assertEqual(typing.get_args(typing.get_type_hints(forward_before)['x']), - (int, Forward)) + (Forward, int)) def test_or_type_operator_with_Protocol(self): class Proto(typing.Protocol): @@ -1015,9 +1034,14 @@ def __eq__(self, other): return 1 / 0 bt = BadType('bt', (), {}) + bt2 = BadType('bt2', (), {}) # Comparison should fail and errors should propagate out for bad types. + union1 = int | bt + union2 = int | bt2 + with self.assertRaises(ZeroDivisionError): + union1 == union2 with self.assertRaises(ZeroDivisionError): - list[int] | list[bt] + bt | bt2 union_ga = (list[str] | int, collections.abc.Callable[..., str] | int, d | int) @@ -1060,6 +1084,14 @@ def test_or_type_operator_reference_cycle(self): self.assertLessEqual(sys.gettotalrefcount() - before, leeway, msg='Check for union reference leak.') + def test_instantiation(self): + with self.assertRaises(TypeError): + types.UnionType() + self.assertIs(int, types.UnionType[int]) + self.assertIs(int, types.UnionType[int, int]) + self.assertEqual(int | str, types.UnionType[int, str]) + self.assertEqual(int | typing.ForwardRef("str"), types.UnionType[int, "str"]) + class MappingProxyTests(unittest.TestCase): mappingproxy = types.MappingProxyType diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 2f1f9e86a0bce4..561c901b5a24ec 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -511,7 +511,7 @@ def test_cannot_instantiate_vars(self): def test_bound_errors(self): with self.assertRaises(TypeError): - TypeVar('X', bound=Union) + TypeVar('X', bound=Optional) with self.assertRaises(TypeError): TypeVar('X', str, float, bound=Employee) with self.assertRaisesRegex(TypeError, @@ -551,7 +551,7 @@ def test_var_substitution(self): def test_bad_var_substitution(self): T = TypeVar('T') bad_args = ( - (), (int, str), Union, + (), (int, str), Optional, Generic, Generic[T], Protocol, Protocol[T], Final, Final[int], ClassVar, ClassVar[int], ) @@ -2010,10 +2010,6 @@ def test_basics(self): self.assertNotEqual(u, Union) def test_subclass_error(self): - with self.assertRaises(TypeError): - issubclass(int, Union) - with self.assertRaises(TypeError): - issubclass(Union, int) with self.assertRaises(TypeError): issubclass(Union[int, str], int) @@ -2066,41 +2062,40 @@ class B(metaclass=UnhashableMeta): ... self.assertEqual(Union[A, B].__args__, (A, B)) union1 = Union[A, B] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union1) union2 = Union[int, B] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union2) union3 = Union[A, int] - with self.assertRaises(TypeError): + with self.assertRaisesRegex(TypeError, "unhashable type: 'UnhashableMeta'"): hash(union3) def test_repr(self): - self.assertEqual(repr(Union), 'typing.Union') u = Union[Employee, int] - self.assertEqual(repr(u), 'typing.Union[%s.Employee, int]' % __name__) + self.assertEqual(repr(u), f'{__name__}.Employee | int') u = Union[int, Employee] - self.assertEqual(repr(u), 'typing.Union[int, %s.Employee]' % __name__) + self.assertEqual(repr(u), f'int | {__name__}.Employee') T = TypeVar('T') u = Union[T, int][int] self.assertEqual(repr(u), repr(int)) u = Union[List[int], int] - self.assertEqual(repr(u), 'typing.Union[typing.List[int], int]') + self.assertEqual(repr(u), 'typing.List[int] | int') u = Union[list[int], dict[str, float]] - self.assertEqual(repr(u), 'typing.Union[list[int], dict[str, float]]') + self.assertEqual(repr(u), 'list[int] | dict[str, float]') u = Union[int | float] - self.assertEqual(repr(u), 'typing.Union[int, float]') + self.assertEqual(repr(u), 'int | float') u = Union[None, str] - self.assertEqual(repr(u), 'typing.Optional[str]') + self.assertEqual(repr(u), 'None | str') u = Union[str, None] - self.assertEqual(repr(u), 'typing.Optional[str]') + self.assertEqual(repr(u), 'str | None') u = Union[None, str, int] - self.assertEqual(repr(u), 'typing.Union[NoneType, str, int]') + self.assertEqual(repr(u), 'None | str | int') u = Optional[str] - self.assertEqual(repr(u), 'typing.Optional[str]') + self.assertEqual(repr(u), 'str | None') def test_dir(self): dir_items = set(dir(Union[str, int])) @@ -2112,14 +2107,11 @@ def test_dir(self): def test_cannot_subclass(self): with self.assertRaisesRegex(TypeError, - r'Cannot subclass typing\.Union'): + r"type 'typing\.Union' is not an acceptable base type"): class C(Union): pass - with self.assertRaisesRegex(TypeError, CANNOT_SUBCLASS_TYPE): - class D(type(Union)): - pass with self.assertRaisesRegex(TypeError, - r'Cannot subclass typing\.Union\[int, str\]'): + r'Cannot subclass int \| str'): class E(Union[int, str]): pass @@ -2165,7 +2157,7 @@ def f(x: u): ... def test_function_repr_union(self): def fun() -> int: ... - self.assertEqual(repr(Union[fun, int]), 'typing.Union[fun, int]') + self.assertEqual(repr(Union[fun, int]), f'{__name__}.{fun.__qualname__} | int') def test_union_str_pattern(self): # Shouldn't crash; see http://bugs.python.org/issue25390 @@ -4835,11 +4827,11 @@ class Derived(Base): ... def test_extended_generic_rules_repr(self): T = TypeVar('T') self.assertEqual(repr(Union[Tuple, Callable]).replace('typing.', ''), - 'Union[Tuple, Callable]') + 'Tuple | Callable') self.assertEqual(repr(Union[Tuple, Tuple[int]]).replace('typing.', ''), - 'Union[Tuple, Tuple[int]]') + 'Tuple | Tuple[int]') self.assertEqual(repr(Callable[..., Optional[T]][int]).replace('typing.', ''), - 'Callable[..., Optional[int]]') + 'Callable[..., int | None]') self.assertEqual(repr(Callable[[], List[T]][int]).replace('typing.', ''), 'Callable[[], List[int]]') @@ -5019,9 +5011,9 @@ def __contains__(self, item): with self.assertRaises(TypeError): issubclass(Tuple[int, ...], typing.Iterable) - def test_fail_with_bare_union(self): + def test_fail_with_special_forms(self): with self.assertRaises(TypeError): - List[Union] + List[Final] with self.assertRaises(TypeError): Tuple[Optional] with self.assertRaises(TypeError): @@ -5551,8 +5543,6 @@ def test_subclass_special_form(self): for obj in ( ClassVar[int], Final[int], - Union[int, float], - Optional[int], Literal[1, 2], Concatenate[int, ParamSpec("P")], TypeGuard[int], @@ -5584,7 +5574,7 @@ class A: __parameters__ = (T,) # Bare classes should be skipped for a in (List, list): - for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, types.UnionType): + for b in (A, int, TypeVar, TypeVarTuple, ParamSpec, types.GenericAlias, Union): with self.subTest(generic=a, sub=b): with self.assertRaisesRegex(TypeError, '.* is not a generic class'): a[b][str] @@ -5603,7 +5593,7 @@ class A: for s in (int, G, A, List, list, TypeVar, TypeVarTuple, ParamSpec, - types.GenericAlias, types.UnionType): + types.GenericAlias, Union): for t in Tuple, tuple: with self.subTest(tuple=t, sub=s): @@ -7085,7 +7075,7 @@ class C(Generic[T]): pass self.assertIs(get_origin(Callable), collections.abc.Callable) self.assertIs(get_origin(list[int]), list) self.assertIs(get_origin(list), None) - self.assertIs(get_origin(list | str), types.UnionType) + self.assertIs(get_origin(list | str), Union) self.assertIs(get_origin(P.args), P) self.assertIs(get_origin(P.kwargs), P) self.assertIs(get_origin(Required[int]), Required) @@ -10270,7 +10260,6 @@ def test_special_attrs(self): typing.TypeGuard: 'TypeGuard', typing.TypeIs: 'TypeIs', typing.TypeVar: 'TypeVar', - typing.Union: 'Union', typing.Self: 'Self', # Subscripted special forms typing.Annotated[Any, "Annotation"]: 'Annotated', @@ -10281,7 +10270,7 @@ def test_special_attrs(self): typing.Literal[Any]: 'Literal', typing.Literal[1, 2]: 'Literal', typing.Literal[True, 2]: 'Literal', - typing.Optional[Any]: 'Optional', + typing.Optional[Any]: 'Union', typing.TypeGuard[Any]: 'TypeGuard', typing.TypeIs[Any]: 'TypeIs', typing.Union[Any]: 'Any', @@ -10300,7 +10289,10 @@ def test_special_attrs(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): s = pickle.dumps(cls, proto) loaded = pickle.loads(s) - self.assertIs(cls, loaded) + if isinstance(cls, Union): + self.assertEqual(cls, loaded) + else: + self.assertIs(cls, loaded) TypeName = typing.NewType('SpecialAttrsTests.TypeName', Any) @@ -10575,6 +10567,34 @@ def test_is_not_instance_of_iterable(self): self.assertNotIsInstance(type_to_test, collections.abc.Iterable) +class UnionGenericAliasTests(BaseTestCase): + def test_constructor(self): + # Used e.g. in typer, pydantic + with self.assertWarns(DeprecationWarning): + inst = typing._UnionGenericAlias(typing.Union, (int, str)) + self.assertEqual(inst, int | str) + with self.assertWarns(DeprecationWarning): + # name is accepted but ignored + inst = typing._UnionGenericAlias(typing.Union, (int, None), name="Optional") + self.assertEqual(inst, int | None) + + def test_isinstance(self): + # Used e.g. in pydantic + with self.assertWarns(DeprecationWarning): + self.assertTrue(isinstance(Union[int, str], typing._UnionGenericAlias)) + with self.assertWarns(DeprecationWarning): + self.assertFalse(isinstance(int, typing._UnionGenericAlias)) + + def test_eq(self): + # type(t) == _UnionGenericAlias is used in vyos + with self.assertWarns(DeprecationWarning): + self.assertEqual(Union, typing._UnionGenericAlias) + with self.assertWarns(DeprecationWarning): + self.assertEqual(typing._UnionGenericAlias, typing._UnionGenericAlias) + with self.assertWarns(DeprecationWarning): + self.assertNotEqual(int, typing._UnionGenericAlias) + + def load_tests(loader, tests, pattern): import doctest tests.addTests(doctest.DocTestSuite(typing)) diff --git a/Lib/typing.py b/Lib/typing.py index c924c767042552..7281a625e1e31f 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -29,7 +29,13 @@ import operator import sys import types -from types import GenericAlias +from types import ( + WrapperDescriptorType, + MethodWrapperType, + MethodDescriptorType, + GenericAlias, +) +import warnings from _typing import ( _idfunc, @@ -40,6 +46,7 @@ ParamSpecKwargs, TypeAliasType, Generic, + Union, NoDefault, ) @@ -367,21 +374,6 @@ def _compare_args_orderless(first_args, second_args): return False return not t -def _remove_dups_flatten(parameters): - """Internal helper for Union creation and substitution. - - Flatten Unions among parameters, then remove duplicates. - """ - # Flatten out Union[Union[...], ...]. - params = [] - for p in parameters: - if isinstance(p, (_UnionGenericAlias, types.UnionType)): - params.extend(p.__args__) - else: - params.append(p) - - return tuple(_deduplicate(params, unhashable_fallback=True)) - def _flatten_literal_params(parameters): """Internal helper for Literal creation: flatten Literals among parameters.""" @@ -470,7 +462,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f return evaluate_forward_ref(t, globals=globalns, locals=localns, type_params=type_params, owner=owner, _recursive_guard=recursive_guard, format=format) - if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)): + if isinstance(t, (_GenericAlias, GenericAlias, Union)): if isinstance(t, GenericAlias): args = tuple( _make_forward_ref(arg) if isinstance(arg, str) else arg @@ -495,7 +487,7 @@ def _eval_type(t, globalns, localns, type_params=_sentinel, *, recursive_guard=f return t if isinstance(t, GenericAlias): return GenericAlias(t.__origin__, ev_args) - if isinstance(t, types.UnionType): + if isinstance(t, Union): return functools.reduce(operator.or_, ev_args) else: return t.copy_with(ev_args) @@ -749,59 +741,6 @@ class FastConnector(Connection): item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) return _GenericAlias(self, (item,)) -@_SpecialForm -def Union(self, parameters): - """Union type; Union[X, Y] means either X or Y. - - On Python 3.10 and higher, the | operator - can also be used to denote unions; - X | Y means the same thing to the type checker as Union[X, Y]. - - To define a union, use e.g. Union[int, str]. Details: - - The arguments must be types and there must be at least one. - - None as an argument is a special case and is replaced by - type(None). - - Unions of unions are flattened, e.g.:: - - assert Union[Union[int, str], float] == Union[int, str, float] - - - Unions of a single argument vanish, e.g.:: - - assert Union[int] == int # The constructor actually returns int - - - Redundant arguments are skipped, e.g.:: - - assert Union[int, str, int] == Union[int, str] - - - When comparing unions, the argument order is ignored, e.g.:: - - assert Union[int, str] == Union[str, int] - - - You cannot subclass or instantiate a union. - - You can use Optional[X] as a shorthand for Union[X, None]. - """ - if parameters == (): - raise TypeError("Cannot take a Union of no types.") - if not isinstance(parameters, tuple): - parameters = (parameters,) - msg = "Union[arg, ...]: each arg must be a type." - parameters = tuple(_type_check(p, msg) for p in parameters) - parameters = _remove_dups_flatten(parameters) - if len(parameters) == 1: - return parameters[0] - if len(parameters) == 2 and type(None) in parameters: - return _UnionGenericAlias(self, parameters, name="Optional") - return _UnionGenericAlias(self, parameters) - -def _make_union(left, right): - """Used from the C implementation of TypeVar. - - TypeVar.__or__ calls this instead of returning types.UnionType - because we want to allow unions between TypeVars and strings - (forward references). - """ - return Union[left, right] - @_SpecialForm def Optional(self, parameters): """Optional[X] is equivalent to Union[X, None].""" @@ -1708,41 +1647,30 @@ def __getitem__(self, params): return self.copy_with(params) -class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True): - def copy_with(self, params): - return Union[params] +class _UnionGenericAliasMeta(type): + def __instancecheck__(self, inst: type) -> bool: + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + return isinstance(inst, Union) def __eq__(self, other): - if not isinstance(other, (_UnionGenericAlias, types.UnionType)): - return NotImplemented - try: # fast path - return set(self.__args__) == set(other.__args__) - except TypeError: # not hashable, slow path - return _compare_args_orderless(self.__args__, other.__args__) - - def __hash__(self): - return hash(frozenset(self.__args__)) + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + if other is _UnionGenericAlias or other is Union: + return True + return NotImplemented - def __repr__(self): - args = self.__args__ - if len(args) == 2: - if args[0] is type(None): - return f'typing.Optional[{_type_repr(args[1])}]' - elif args[1] is type(None): - return f'typing.Optional[{_type_repr(args[0])}]' - return super().__repr__() - def __instancecheck__(self, obj): - return self.__subclasscheck__(type(obj)) +class _UnionGenericAlias(metaclass=_UnionGenericAliasMeta): + """Compatibility hack. - def __subclasscheck__(self, cls): - for arg in self.__args__: - if issubclass(cls, arg): - return True + A class named _UnionGenericAlias used to be used to implement + typing.Union. This class exists to serve as a shim to preserve + the meaning of some code that used to use _UnionGenericAlias + directly. - def __reduce__(self): - func, (origin, args) = super().__reduce__() - return func, (Union, args) + """ + def __new__(cls, self_cls, parameters, /, *, name=None): + warnings._deprecated("_UnionGenericAlias", remove=(3, 17)) + return Union[parameters] def _value_and_type_iter(parameters): @@ -2466,7 +2394,7 @@ def _strip_annotations(t): if stripped_args == t.__args__: return t return GenericAlias(t.__origin__, stripped_args) - if isinstance(t, types.UnionType): + if isinstance(t, Union): stripped_args = tuple(_strip_annotations(a) for a in t.__args__) if stripped_args == t.__args__: return t @@ -2500,8 +2428,8 @@ def get_origin(tp): return tp.__origin__ if tp is Generic: return Generic - if isinstance(tp, types.UnionType): - return types.UnionType + if isinstance(tp, Union): + return Union return None @@ -2526,7 +2454,7 @@ def get_args(tp): if _should_unflatten_callable_args(tp, res): res = (list(res[:-1]), res[-1]) return res - if isinstance(tp, types.UnionType): + if isinstance(tp, Union): return tp.__args__ return () diff --git a/Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst b/Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst new file mode 100644 index 00000000000000..5240f4aa7d1e9c --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-06-08-07-56-05.gh-issue-105499.7jV6cP.rst @@ -0,0 +1,3 @@ +Make :class:`types.UnionType` an alias for :class:`typing.Union`. Both +``int | str`` and ``Union[int, str]`` now create instances of the same +type. Patch by Jelle Zijlstra. diff --git a/Modules/_typingmodule.c b/Modules/_typingmodule.c index 09fbb3c5e8b91d..e51279c808a2e1 100644 --- a/Modules/_typingmodule.c +++ b/Modules/_typingmodule.c @@ -5,9 +5,10 @@ #endif #include "Python.h" -#include "pycore_interp.h" +#include "internal/pycore_interp.h" +#include "internal/pycore_typevarobject.h" +#include "internal/pycore_unionobject.h" // _PyUnion_Type #include "pycore_pystate.h" // _PyInterpreterState_GET() -#include "pycore_typevarobject.h" #include "clinic/_typingmodule.c.h" /*[clinic input] @@ -63,6 +64,9 @@ _typing_exec(PyObject *m) if (PyModule_AddObjectRef(m, "TypeAliasType", (PyObject *)&_PyTypeAlias_Type) < 0) { return -1; } + if (PyModule_AddObjectRef(m, "Union", (PyObject *)&_PyUnion_Type) < 0) { + return -1; + } if (PyModule_AddObjectRef(m, "NoDefault", (PyObject *)&_Py_NoDefaultStruct) < 0) { return -1; } diff --git a/Objects/typevarobject.c b/Objects/typevarobject.c index 91cc37c9a72636..1a70c41071ffb1 100644 --- a/Objects/typevarobject.c +++ b/Objects/typevarobject.c @@ -2,8 +2,8 @@ #include "Python.h" #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK #include "pycore_typevarobject.h" -#include "pycore_unionobject.h" // _Py_union_type_or - +#include "pycore_unionobject.h" // _Py_union_type_or, _Py_union_from_tuple +#include "structmember.h" /*[clinic input] class typevar "typevarobject *" "&_PyTypeVar_Type" @@ -361,9 +361,13 @@ type_check(PyObject *arg, const char *msg) static PyObject * make_union(PyObject *self, PyObject *other) { - PyObject *args[2] = {self, other}; - PyObject *result = call_typing_func_object("_make_union", args, 2); - return result; + PyObject *args = PyTuple_Pack(2, self, other); + if (args == NULL) { + return NULL; + } + PyObject *u = _Py_union_from_tuple(args); + Py_DECREF(args); + return u; } static PyObject * diff --git a/Objects/unionobject.c b/Objects/unionobject.c index 6e65a653a95c46..065b0b8539775c 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -1,17 +1,17 @@ -// types.UnionType -- used to represent e.g. Union[int, str], int | str +// typing.Union -- used to represent e.g. Union[int, str], int | str #include "Python.h" #include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK #include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr #include "pycore_unionobject.h" -static PyObject *make_union(PyObject *); - - typedef struct { PyObject_HEAD - PyObject *args; + PyObject *args; // all args (tuple) + PyObject *hashable_args; // frozenset or NULL + PyObject *unhashable_args; // tuple or NULL PyObject *parameters; + PyObject *weakreflist; } unionobject; static void @@ -20,8 +20,13 @@ unionobject_dealloc(PyObject *self) unionobject *alias = (unionobject *)self; _PyObject_GC_UNTRACK(self); + if (alias->weakreflist != NULL) { + PyObject_ClearWeakRefs((PyObject *)alias); + } Py_XDECREF(alias->args); + Py_XDECREF(alias->hashable_args); + Py_XDECREF(alias->unhashable_args); Py_XDECREF(alias->parameters); Py_TYPE(self)->tp_free(self); } @@ -31,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg) { unionobject *alias = (unionobject *)self; Py_VISIT(alias->args); + Py_VISIT(alias->hashable_args); + Py_VISIT(alias->unhashable_args); Py_VISIT(alias->parameters); return 0; } @@ -39,13 +46,67 @@ static Py_hash_t union_hash(PyObject *self) { unionobject *alias = (unionobject *)self; - PyObject *args = PyFrozenSet_New(alias->args); - if (args == NULL) { - return (Py_hash_t)-1; + // If there are any unhashable args, treat this union as unhashable. + // Otherwise, two unions might compare equal but have different hashes. + if (alias->unhashable_args) { + // Attempt to get an error from one of the values. + assert(PyTuple_CheckExact(alias->unhashable_args)); + Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args); + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i); + Py_hash_t hash = PyObject_Hash(arg); + if (hash == -1) { + return -1; + } + } + // The unhashable values somehow became hashable again. Still raise + // an error. + PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n); + return -1; } - Py_hash_t hash = PyObject_Hash(args); - Py_DECREF(args); - return hash; + return PyObject_Hash(alias->hashable_args); +} + +static int +unions_equal(unionobject *a, unionobject *b) +{ + int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ); + if (result == -1) { + return -1; + } + if (result == 0) { + return 0; + } + if (a->unhashable_args && b->unhashable_args) { + Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args); + if (n != PyTuple_GET_SIZE(b->unhashable_args)) { + return 0; + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i); + int result = PySequence_Contains(b->unhashable_args, arg_a); + if (result == -1) { + return -1; + } + if (!result) { + return 0; + } + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i); + int result = PySequence_Contains(a->unhashable_args, arg_b); + if (result == -1) { + return -1; + } + if (!result) { + return 0; + } + } + } + else if (a->unhashable_args || b->unhashable_args) { + return 0; + } + return 1; } static PyObject * @@ -55,93 +116,128 @@ union_richcompare(PyObject *a, PyObject *b, int op) Py_RETURN_NOTIMPLEMENTED; } - PyObject *a_set = PySet_New(((unionobject*)a)->args); - if (a_set == NULL) { + int equal = unions_equal((unionobject*)a, (unionobject*)b); + if (equal == -1) { return NULL; } - PyObject *b_set = PySet_New(((unionobject*)b)->args); - if (b_set == NULL) { - Py_DECREF(a_set); - return NULL; + if (op == Py_EQ) { + return PyBool_FromLong(equal); + } + else { + return PyBool_FromLong(!equal); } - PyObject *result = PyObject_RichCompare(a_set, b_set, op); - Py_DECREF(b_set); - Py_DECREF(a_set); - return result; } -static int -is_same(PyObject *left, PyObject *right) +typedef struct { + PyObject *args; // list + PyObject *hashable_args; // set + PyObject *unhashable_args; // list or NULL + bool is_checked; // whether to call type_check() +} unionbuilder; + +static bool unionbuilder_add_tuple(unionbuilder *, PyObject *); +static PyObject *make_union(unionbuilder *); +static PyObject *type_check(PyObject *, const char *); + +static bool +unionbuilder_init(unionbuilder *ub, bool is_checked) { - int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right); - return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right; + ub->args = PyList_New(0); + if (ub->args == NULL) { + return false; + } + ub->hashable_args = PySet_New(NULL); + if (ub->hashable_args == NULL) { + Py_DECREF(ub->args); + return false; + } + ub->unhashable_args = NULL; + ub->is_checked = is_checked; + return true; } -static int -contains(PyObject **items, Py_ssize_t size, PyObject *obj) +static void +unionbuilder_finalize(unionbuilder *ub) { - for (Py_ssize_t i = 0; i < size; i++) { - int is_duplicate = is_same(items[i], obj); - if (is_duplicate) { // -1 or 1 - return is_duplicate; - } - } - return 0; + Py_DECREF(ub->args); + Py_DECREF(ub->hashable_args); + Py_XDECREF(ub->unhashable_args); } -static PyObject * -merge(PyObject **items1, Py_ssize_t size1, - PyObject **items2, Py_ssize_t size2) +static bool +unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg) { - PyObject *tuple = NULL; - Py_ssize_t pos = 0; - - for (Py_ssize_t i = 0; i < size2; i++) { - PyObject *arg = items2[i]; - int is_duplicate = contains(items1, size1, arg); - if (is_duplicate < 0) { - Py_XDECREF(tuple); - return NULL; - } - if (is_duplicate) { - continue; + Py_hash_t hash = PyObject_Hash(arg); + if (hash == -1) { + PyErr_Clear(); + if (ub->unhashable_args == NULL) { + ub->unhashable_args = PyList_New(0); + if (ub->unhashable_args == NULL) { + return false; + } } - - if (tuple == NULL) { - tuple = PyTuple_New(size1 + size2 - i); - if (tuple == NULL) { - return NULL; + else { + int contains = PySequence_Contains(ub->unhashable_args, arg); + if (contains < 0) { + return false; } - for (; pos < size1; pos++) { - PyObject *a = items1[pos]; - PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a)); + if (contains == 1) { + return true; } } - PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg)); - pos++; + if (PyList_Append(ub->unhashable_args, arg) < 0) { + return false; + } } - - if (tuple) { - (void) _PyTuple_Resize(&tuple, pos); + else { + int contains = PySet_Contains(ub->hashable_args, arg); + if (contains < 0) { + return false; + } + if (contains == 1) { + return true; + } + if (PySet_Add(ub->hashable_args, arg) < 0) { + return false; + } } - return tuple; + return PyList_Append(ub->args, arg) == 0; } -static PyObject ** -get_types(PyObject **obj, Py_ssize_t *size) +static bool +unionbuilder_add_single(unionbuilder *ub, PyObject *arg) { - if (*obj == Py_None) { - *obj = (PyObject *)&_PyNone_Type; + if (Py_IsNone(arg)) { + arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed } - if (_PyUnion_Check(*obj)) { - PyObject *args = ((unionobject *) *obj)->args; - *size = PyTuple_GET_SIZE(args); - return &PyTuple_GET_ITEM(args, 0); + else if (_PyUnion_Check(arg)) { + PyObject *args = ((unionobject *)arg)->args; + return unionbuilder_add_tuple(ub, args); + } + if (ub->is_checked) { + PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type."); + if (type == NULL) { + return false; + } + bool result = unionbuilder_add_single_unchecked(ub, type); + Py_DECREF(type); + return result; } else { - *size = 1; - return obj; + return unionbuilder_add_single_unchecked(ub, arg); + } +} + +static bool +unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple) +{ + Py_ssize_t n = PyTuple_GET_SIZE(tuple); + for (Py_ssize_t i = 0; i < n; i++) { + if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) { + return false; + } } + return true; } static int @@ -164,19 +260,18 @@ _Py_union_type_or(PyObject* self, PyObject* other) Py_RETURN_NOTIMPLEMENTED; } - Py_ssize_t size1, size2; - PyObject **items1 = get_types(&self, &size1); - PyObject **items2 = get_types(&other, &size2); - PyObject *tuple = merge(items1, size1, items2, size2); - if (tuple == NULL) { - if (PyErr_Occurred()) { - return NULL; - } - return Py_NewRef(self); + unionbuilder ub; + // unchecked because we already checked is_unionable() + if (!unionbuilder_init(&ub, false)) { + return NULL; + } + if (!unionbuilder_add_single(&ub, self) || + !unionbuilder_add_single(&ub, other)) { + unionbuilder_finalize(&ub); + return NULL; } - PyObject *new_union = make_union(tuple); - Py_DECREF(tuple); + PyObject *new_union = make_union(&ub); return new_union; } @@ -202,6 +297,18 @@ union_repr(PyObject *self) goto error; } } + +#if 0 + PyUnicodeWriter_WriteUTF8(writer, "|args=", 6); + PyUnicodeWriter_WriteRepr(writer, alias->args); + PyUnicodeWriter_WriteUTF8(writer, "|h=", 3); + PyUnicodeWriter_WriteRepr(writer, alias->hashable_args); + if (alias->unhashable_args) { + PyUnicodeWriter_WriteUTF8(writer, "|u=", 3); + PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args); + } +#endif + return PyUnicodeWriter_Finish(writer); error: @@ -231,21 +338,7 @@ union_getitem(PyObject *self, PyObject *item) return NULL; } - PyObject *res; - Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); - if (nargs == 0) { - res = make_union(newargs); - } - else { - res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0)); - for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { - PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); - Py_SETREF(res, PyNumber_Or(res, arg)); - if (res == NULL) { - break; - } - } - } + PyObject *res = _Py_union_from_tuple(newargs); Py_DECREF(newargs); return res; } @@ -267,7 +360,25 @@ union_parameters(PyObject *self, void *Py_UNUSED(unused)) return Py_NewRef(alias->parameters); } +static PyObject * +union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored)) +{ + return PyUnicode_FromString("Union"); +} + +static PyObject * +union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored)) +{ + return Py_NewRef(&_PyUnion_Type); +} + static PyGetSetDef union_properties[] = { + {"__name__", union_name, NULL, + PyDoc_STR("Name of the type"), NULL}, + {"__qualname__", union_name, NULL, + PyDoc_STR("Qualified name of the type"), NULL}, + {"__origin__", union_origin, NULL, + PyDoc_STR("Always returns the type"), NULL}, {"__parameters__", union_parameters, (setter)NULL, PyDoc_STR("Type variables in the types.UnionType."), NULL}, {0} @@ -306,10 +417,88 @@ _Py_union_args(PyObject *self) return ((unionobject *) self)->args; } +static PyObject * +call_typing_func_object(const char *name, PyObject **args, size_t nargs) +{ + PyObject *typing = PyImport_ImportModule("typing"); + if (typing == NULL) { + return NULL; + } + PyObject *func = PyObject_GetAttrString(typing, name); + if (func == NULL) { + Py_DECREF(typing); + return NULL; + } + PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL); + Py_DECREF(func); + Py_DECREF(typing); + return result; +} + +static PyObject * +type_check(PyObject *arg, const char *msg) +{ + if (Py_IsNone(arg)) { + // NoneType is immortal, so don't need an INCREF + return (PyObject *)Py_TYPE(arg); + } + // Fast path to avoid calling into typing.py + if (is_unionable(arg)) { + return Py_NewRef(arg); + } + PyObject *message_str = PyUnicode_FromString(msg); + if (message_str == NULL) { + return NULL; + } + PyObject *args[2] = {arg, message_str}; + PyObject *result = call_typing_func_object("_type_check", args, 2); + Py_DECREF(message_str); + return result; +} + +PyObject * +_Py_union_from_tuple(PyObject *args) +{ + unionbuilder ub; + if (!unionbuilder_init(&ub, true)) { + return NULL; + } + if (PyTuple_CheckExact(args)) { + if (!unionbuilder_add_tuple(&ub, args)) { + return NULL; + } + } + else { + if (!unionbuilder_add_single(&ub, args)) { + return NULL; + } + } + return make_union(&ub); +} + +static PyObject * +union_class_getitem(PyObject *cls, PyObject *args) +{ + return _Py_union_from_tuple(args); +} + +static PyObject * +union_mro_entries(PyObject *self, PyObject *args) +{ + return PyErr_Format(PyExc_TypeError, + "Cannot subclass %R", self); +} + +static PyMethodDef union_methods[] = { + {"__mro_entries__", union_mro_entries, METH_O}, + {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")}, + {0} +}; + PyTypeObject _PyUnion_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - .tp_name = "types.UnionType", - .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n" + .tp_name = "typing.Union", + .tp_doc = PyDoc_STR("Represent a union type\n" "\n" "E.g. for int | str"), .tp_basicsize = sizeof(unionobject), @@ -321,25 +510,64 @@ PyTypeObject _PyUnion_Type = { .tp_hash = union_hash, .tp_getattro = union_getattro, .tp_members = union_members, + .tp_methods = union_methods, .tp_richcompare = union_richcompare, .tp_as_mapping = &union_as_mapping, .tp_as_number = &union_as_number, .tp_repr = union_repr, .tp_getset = union_properties, + .tp_weaklistoffset = offsetof(unionobject, weakreflist), }; static PyObject * -make_union(PyObject *args) +make_union(unionbuilder *ub) { - assert(PyTuple_CheckExact(args)); + Py_ssize_t n = PyList_GET_SIZE(ub->args); + if (n == 0) { + PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types."); + unionbuilder_finalize(ub); + return NULL; + } + if (n == 1) { + PyObject *result = PyList_GET_ITEM(ub->args, 0); + Py_INCREF(result); + unionbuilder_finalize(ub); + return result; + } + + PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL; + args = PyList_AsTuple(ub->args); + if (args == NULL) { + goto error; + } + hashable_args = PyFrozenSet_New(ub->hashable_args); + if (hashable_args == NULL) { + goto error; + } + if (ub->unhashable_args != NULL) { + unhashable_args = PyList_AsTuple(ub->unhashable_args); + if (unhashable_args == NULL) { + goto error; + } + } unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type); if (result == NULL) { - return NULL; + goto error; } + unionbuilder_finalize(ub); result->parameters = NULL; - result->args = Py_NewRef(args); + result->args = args; + result->hashable_args = hashable_args; + result->unhashable_args = unhashable_args; + result->weakreflist = NULL; _PyObject_GC_TRACK(result); return (PyObject*)result; +error: + Py_XDECREF(args); + Py_XDECREF(hashable_args); + Py_XDECREF(unhashable_args); + unionbuilder_finalize(ub); + return NULL; }