Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) #23335

Merged
merged 1 commit into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def test_repr(self):
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
self.assertEqual(repr(Literal), "typing.Literal")
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")

def test_cannot_init(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -563,6 +564,30 @@ def test_no_multiple_subscripts(self):
with self.assertRaises(TypeError):
Literal[1][1]

def test_equal(self):
self.assertNotEqual(Literal[0], Literal[False])
self.assertNotEqual(Literal[True], Literal[1])
self.assertNotEqual(Literal[1], Literal[2])
self.assertNotEqual(Literal[1, True], Literal[1])
self.assertEqual(Literal[1], Literal[1])
self.assertEqual(Literal[1, 2], Literal[2, 1])
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])

def test_args(self):
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
# Mutable arguments will not be deduplicated
self.assertEqual(Literal[[], []].__args__, ([], []))

def test_flatten(self):
l1 = Literal[Literal[1], Literal[2], Literal[3]]
l2 = Literal[Literal[1, 2], 3]
l3 = Literal[Literal[1, 2, 3]]
for l in l1, l2, l3:
self.assertEqual(l, Literal[1, 2, 3])
self.assertEqual(l.__args__, (1, 2, 3))


XK = TypeVar('XK', str, bytes)
XV = TypeVar('XV')
Expand Down
100 changes: 77 additions & 23 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
f" actual {alen}, expected {elen}")


def _deduplicate(params):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params


def _remove_dups_flatten(parameters):
"""An internal helper for Union creation and substitution: flatten Unions
among parameters, then remove duplicates.
Expand All @@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
params.extend(p[1:])
else:
params.append(p)
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params

return tuple(_deduplicate(params))


def _flatten_literal_params(parameters):
"""An internal helper for Literal creation: flatten Literals among parameters"""
params = []
for p in parameters:
if isinstance(p, _LiteralGenericAlias):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)


_cleanups = []


def _tp_cache(func):
def _tp_cache(func=None, /, *, typed=False):
"""Internal wrapper caching __getitem__ of generic types with a fallback to
original function for non-hashable arguments.
"""
cached = functools.lru_cache()(func)
_cleanups.append(cached.cache_clear)
def decorator(func):
cached = functools.lru_cache(typed=typed)(func)
_cleanups.append(cached.cache_clear)

@functools.wraps(func)
def inner(*args, **kwds):
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner
@functools.wraps(func)
def inner(*args, **kwds):
try:
return cached(*args, **kwds)
except TypeError:
pass # All real errors (not unhashable args) are raised below.
return func(*args, **kwds)
return inner

if func is not None:
return decorator(func)

return decorator

def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
Expand Down Expand Up @@ -317,6 +338,13 @@ def __subclasscheck__(self, cls):
def __getitem__(self, parameters):
return self._getitem(self, parameters)


class _LiteralSpecialForm(_SpecialForm, _root=True):
@_tp_cache(typed=True)
def __getitem__(self, parameters):
return self._getitem(self, parameters)


@_SpecialForm
def Any(self, parameters):
"""Special type indicating an unconstrained type.
Expand Down Expand Up @@ -434,7 +462,7 @@ def Optional(self, parameters):
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)]

@_SpecialForm
@_LiteralSpecialForm
def Literal(self, parameters):
"""Special typing form to define literal types (a.k.a. value types).

Expand All @@ -458,7 +486,17 @@ def open_helper(file: str, mode: MODE) -> str:
"""
# There is no '_type_check' call because arguments to Literal[...] are
# values, not types.
return _GenericAlias(self, parameters)
if not isinstance(parameters, tuple):
parameters = (parameters,)

parameters = _flatten_literal_params(parameters)

try:
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
except TypeError: # unhashable parameters
pass

return _LiteralGenericAlias(self, parameters)


class ForwardRef(_Final, _root=True):
Expand Down Expand Up @@ -881,6 +919,22 @@ def __repr__(self):
return super().__repr__()


def _value_and_type_iter(parameters):
return ((p, type(p)) for p in parameters)


class _LiteralGenericAlias(_GenericAlias, _root=True):

def __eq__(self, other):
if not isinstance(other, _LiteralGenericAlias):
return NotImplemented

return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))

def __hash__(self):
return hash(tuple(_value_and_type_iter(self.__args__)))


class Generic:
"""Abstract base class for generic types.

Expand Down
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ Jan Kanis
Rafe Kaplan
Jacob Kaplan-Moss
Allison Kaptur
Yurii Karabas
Janne Karila
Per Øyvind Karlsen
Anton Kasyanov
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix various issues with ``typing.Literal`` parameter handling (flatten,
deduplicate, use type to cache key). Patch provided by Yurii Karabas.