diff --git a/CHANGELOG.md b/CHANGELOG.md index db6719c..81b4b95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ subscripted objects) had wrong parameters if they were directly subscripted with an `Unpack` object. Patch by [Daraan](https://github.com/Daraan). +- Fix backport of `get_type_hints` to reflect Python 3.11+ behavior which does not add + `Union[..., NoneType]` to annotations that have a `None` default value anymore. + This fixes wrapping of `Annotated` in an unwanted `Optional` in such cases. + Patch by [Daraan](https://github.com/Daraan). # Release 4.12.2 (June 7, 2024) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 8c2726f..3552dae 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -1645,6 +1645,70 @@ def test_final_forward_ref(self): self.assertNotEqual(gth(Loop, globals())['attr'], Final[int]) self.assertNotEqual(gth(Loop, globals())['attr'], Final) + def test_annotation_and_optional_default(self): + annotation = Annotated[Union[int, None], "data"] + optional_annotation = Optional[annotation] + + cases = { + # (annotation, skip_as_str): expected_type_hints + # Should skip_as_str if contains a ForwardRef. + ((), True): {}, + (int, True): {"x": int}, + ("int", True): {"x": int}, + (Optional[int], False): {"x": Optional[int]}, + (optional_annotation, False): {"x": optional_annotation}, + (annotation, False): {"x": annotation}, + (Union[annotation, T], True): {"x": Union[annotation, T]}, + ("Union[Annotated[Union[int, None], 'data'], T]", True): { + "x": Union[annotation, T] + }, + (Union[str, None, str], False): {"x": Optional[str]}, + (Union[str, None, "str"], True): {"x": Optional[str]}, + (Union[str, "str"], True): {"x": str}, + (List["str"], True): {"x": List[str]}, + (Optional[List[str]], False): {"x": Optional[List[str]]}, + (Tuple[Unpack[Tuple[int, str]]], False): { + "x": Tuple[Unpack[Tuple[int, str]]] + }, + } + for ((annot, skip_as_str), expected), none_default, as_str, wrap_optional in itertools.product( + cases.items(), (False, True), (False, True), (False, True) + ): + if wrap_optional: + if annot == (): + continue + if (get_origin(annot) is not Optional + or (sys.version_info[:2] == (3, 8) and annot._name != "Optional") + ): + annot = Optional[annot] + expected = {"x": Optional[expected['x']]} + if as_str: + if skip_as_str or annot == (): + continue + annot = str(annot) + with self.subTest( + annotation=annot, + as_str=as_str, + none_default=none_default, + expected_type_hints=expected, + wrap_optional=wrap_optional, + ): + # Create function to check + if annot == (): + if none_default: + def func(x=None): pass + else: + def func(x): pass + elif none_default: + def func(x: annot = None): pass + else: + def func(x: annot): pass + type_hints = get_type_hints(func, include_extras=True) + self.assertEqual(type_hints, expected) + for k in type_hints.keys(): + self.assertEqual(hash(type_hints[k]), hash(expected[k])) + self.assertEqual(str(type_hints)+repr(type_hints), str(expected)+repr(type_hints)) + class GetUtilitiesTestCase(TestCase): def test_get_origin(self): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 5bf4f2d..483066a 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -1236,10 +1236,83 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): ) else: # 3.8 hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) + if sys.version_info < (3, 11): + _clean_optional(obj, hint, globalns, localns) + if sys.version_info < (3, 9): + # In 3.8 eval_type does not handle all Optional[ForwardRef] correctly + # this also returns cached versions of Union + hint = { + k: (t + if get_origin(t) != Union + else Union[t.__args__]) + for k, t in hint.items() + } if include_extras: return hint return {k: _strip_extras(t) for k, t in hint.items()} + _NoneType = type(None) + + def _could_be_inserted_optional(t): + """detects Union[..., None] pattern""" + # 3.8+ compatible checking before _UnionGenericAlias + if get_origin(t) is not Union: + return False + # Assume if last argument is not None they are user defined + if t.__args__[-1] is not _NoneType: + return False + return True + + # < 3.11 + def _clean_optional(obj, hints, globalns=None, localns=None): + # reverts injected Union[..., None] cases from typing.get_type_hints + # when a None default value is used. + # see https://github.com/python/typing_extensions/issues/310 + if not hints or isinstance(obj, type): + return + defaults = typing._get_defaults(obj) # avoid accessing __annotations___ + if not defaults: + return + original_hints = obj.__annotations__ + for name, value in hints.items(): + # Not a Union[..., None] or replacement conditions not fullfilled + if (not _could_be_inserted_optional(value) + or name not in defaults + or defaults[name] is not None + ): + continue + original_value = original_hints[name] + if original_value is None: # should not happen + original_value = _NoneType + # Forward reference + if isinstance(original_value, str): + if globalns is None: + if isinstance(obj, _types.ModuleType): + globalns = obj.__dict__ + else: + nsobj = obj + # Find globalns for the unwrapped object. + while hasattr(nsobj, '__wrapped__'): + nsobj = nsobj.__wrapped__ + globalns = getattr(nsobj, '__globals__', {}) + if localns is None: + localns = globalns + elif localns is None: + localns = globalns + if sys.version_info < (3, 9): + original_value = ForwardRef(original_value) + else: + original_value = ForwardRef( + original_value, + is_argument=not isinstance(obj, _types.ModuleType) + ) + original_evaluated = typing._eval_type(original_value, globalns, localns) + if sys.version_info < (3, 9) and get_origin(original_evaluated) is Union: + # Union[str, None, "str"] is not reduced to Union[str, None] + original_evaluated = Union[original_evaluated.__args__] + # Compare if values differ + if original_evaluated != value: + hints[name] = original_evaluated # Python 3.9+ has PEP 593 (Annotated) if hasattr(typing, 'Annotated'):