Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport CPython PR 26067 #132

Merged
merged 12 commits into from
Apr 12, 2023
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
- Add `typing_extensions.Buffer`, a marker class for buffer types, as proposed
by PEP 688. Equivalent to `collections.abc.Buffer` in Python 3.12. Patch by
Jelle Zijlstra.
- Backport [CPython PR 26067](https://github.com/python/cpython/pull/26067)
(originally by Yurii Karabas), ensuring that `isinstance()` calls on
protocols raise `TypeError` when the protocol is not decorated with
`@runtime_checkable`. Patch by Alex Waygood.

# Release 4.5.0 (February 14, 2023)

Expand Down
31 changes: 25 additions & 6 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,22 @@ class E(C, BP): pass
self.assertNotIsInstance(D(), E)
self.assertNotIsInstance(E(), D)

@skipUnless(
hasattr(typing, "Protocol"),
"Test is only relevant if typing.Protocol exists"
)
def test_runtimecheckable_on_typing_dot_Protocol(self):
@runtime_checkable
class Foo(typing.Protocol):
x: int

class Bar:
def __init__(self):
self.x = 42

self.assertIsInstance(Bar(), Foo)
self.assertNotIsInstance(object(), Foo)

def test_no_instantiation(self):
class P(Protocol): pass
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -1829,11 +1845,7 @@ def meth(self):
self.assertTrue(P._is_protocol)
self.assertTrue(PR._is_protocol)
self.assertTrue(PG._is_protocol)
if hasattr(typing, 'Protocol'):
self.assertFalse(P._is_runtime_protocol)
else:
with self.assertRaises(AttributeError):
self.assertFalse(P._is_runtime_protocol)
self.assertFalse(P._is_runtime_protocol)
self.assertTrue(PR._is_runtime_protocol)
self.assertTrue(PG[int]._is_protocol)
self.assertEqual(typing_extensions._get_protocol_attrs(P), {'meth'})
Expand Down Expand Up @@ -1929,6 +1941,13 @@ class CustomProtocol(TestCase, Protocol):
class CustomContextManager(typing.ContextManager, Protocol):
pass

def test_non_runtime_protocol_isinstance_check(self):
class P(Protocol):
x: int

with self.assertRaisesRegex(TypeError, "@runtime_checkable"):
isinstance(1, P)

def test_no_init_same_for_different_protocol_implementations(self):
class CustomProtocolWithoutInitA(Protocol):
pass
Expand Down Expand Up @@ -3314,7 +3333,7 @@ def test_typing_extensions_defers_when_possible(self):
'is_typeddict',
}
if sys.version_info < (3, 10):
exclude |= {'get_args', 'get_origin'}
exclude |= {'get_args', 'get_origin', 'Protocol', 'runtime_checkable'}
if sys.version_info < (3, 11):
exclude |= {'final', 'NamedTuple', 'Any'}
for item in typing_extensions.__all__:
Expand Down
69 changes: 47 additions & 22 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,21 +398,33 @@ def clear_overloads():
}


_EXCLUDED_ATTRS = {
"__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol",
"_is_runtime_protocol", "__dict__", "__slots__", "__parameters__",
"__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__",
"__subclasshook__", "__orig_class__", "__init__", "__new__",
}

if sys.version_info < (3, 8):
_EXCLUDED_ATTRS |= {
"_gorg", "__next_in_mro__", "__extra__", "__tree_hash__", "__args__",
"__origin__"
}

if sys.version_info >= (3, 9):
_EXCLUDED_ATTRS.add("__class_getitem__")

_EXCLUDED_ATTRS = frozenset(_EXCLUDED_ATTRS)


def _get_protocol_attrs(cls):
attrs = set()
for base in cls.__mro__[:-1]: # without object
if base.__name__ in ('Protocol', 'Generic'):
continue
annotations = getattr(base, '__annotations__', {})
for attr in list(base.__dict__.keys()) + list(annotations.keys()):
if (not attr.startswith('_abc_') and attr not in (
'__abstractmethods__', '__annotations__', '__weakref__',
'_is_protocol', '_is_runtime_protocol', '__dict__',
'__args__', '__slots__',
'__next_in_mro__', '__parameters__', '__origin__',
'__orig_bases__', '__extra__', '__tree_hash__',
'__doc__', '__subclasshook__', '__init__', '__new__',
'__module__', '_MutableMapping__marker', '_gorg')):
if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS):
attrs.add(attr)
return attrs

Expand Down Expand Up @@ -468,11 +480,18 @@ def _caller(depth=2):
return None


# 3.8+
if hasattr(typing, 'Protocol'):
# A bug in runtime-checkable protocols was fixed in 3.10+,
# but we backport it to all versions
if sys.version_info >= (3, 10):
Protocol = typing.Protocol
# 3.7
runtime_checkable = typing.runtime_checkable
else:
def _allow_reckless_class_checks(depth=4):
"""Allow instance and class checks for special stdlib modules.
The abc and functools modules indiscriminately call isinstance() and
issubclass() on the whole MRO of a user class, which may contain protocols.
"""
return _caller(depth) in {'abc', 'functools', None}

def _no_init(self, *args, **kwargs):
if type(self)._is_protocol:
Expand All @@ -484,11 +503,19 @@ class _ProtocolMeta(abc.ABCMeta):
def __instancecheck__(cls, instance):
# We need this method for situations where attributes are
# assigned in __init__.
if ((not getattr(cls, '_is_protocol', False) or
is_protocol_cls = getattr(cls, "_is_protocol", False)
if (
is_protocol_cls and
not getattr(cls, '_is_runtime_protocol', False) and
not _allow_reckless_class_checks(depth=2)
):
raise TypeError("Instance and class checks can only be used with"
" @runtime_checkable protocols")
if ((not is_protocol_cls or
_is_callable_members_only(cls)) and
issubclass(instance.__class__, cls)):
return True
if cls._is_protocol:
if is_protocol_cls:
if all(hasattr(instance, attr) and
(not callable(getattr(cls, attr, None)) or
getattr(instance, attr) is not None)
Expand Down Expand Up @@ -530,6 +557,7 @@ def meth(self) -> T:
"""
__slots__ = ()
_is_protocol = True
_is_runtime_protocol = False

def __new__(cls, *args, **kwds):
if cls is Protocol:
Expand Down Expand Up @@ -581,12 +609,12 @@ def _proto_hook(other):
if not cls.__dict__.get('_is_protocol', None):
return NotImplemented
if not getattr(cls, '_is_runtime_protocol', False):
if _caller(depth=3) in {'abc', 'functools'}:
if _allow_reckless_class_checks():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that the depth is changed from 3 to 4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because by adding a new function that the code has to pass through before it gets to the sys._getframe call, the call stack becomes "another frame deep"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could just have looked at _caller and figured that out myself. 🤦

return NotImplemented
raise TypeError("Instance and class checks can only be used with"
" @runtime protocols")
if not _is_callable_members_only(cls):
if _caller(depth=3) in {'abc', 'functools'}:
if _allow_reckless_class_checks():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

return NotImplemented
raise TypeError("Protocols with non-method members"
" don't support issubclass()")
Expand Down Expand Up @@ -625,12 +653,6 @@ def _proto_hook(other):
f' protocols, got {repr(base)}')
cls.__init__ = _no_init


# 3.8+
if hasattr(typing, 'runtime_checkable'):
runtime_checkable = typing.runtime_checkable
# 3.7
else:
def runtime_checkable(cls):
"""Mark a protocol class as a runtime protocol, so that it
can be used with isinstance() and issubclass(). Raise TypeError
Expand All @@ -639,7 +661,10 @@ def runtime_checkable(cls):
This allows a simple-minded structural check very similar to the
one-offs in collections.abc such as Hashable.
"""
if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol:
if not (
(isinstance(cls, _ProtocolMeta) or issubclass(cls, typing.Generic))
and getattr(cls, "_is_protocol", False)
):
raise TypeError('@runtime_checkable can be only applied to protocol classes,'
f' got {cls!r}')
cls._is_runtime_protocol = True
Expand Down