Skip to content

Commit

Permalink
If a generic class is subscripted, infer that class itself
Browse files Browse the repository at this point in the history
  • Loading branch information
mthuurne committed Apr 27, 2020
1 parent c44fa4e commit cbffd45
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
4 changes: 4 additions & 0 deletions astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,10 @@ def infer_subscript(self, context=None):
if value is util.Uninferable:
yield util.Uninferable
return None
if isinstance(value, nodes.ClassDef):
if value.is_subtype_of('typing.Generic'):
yield value
return None
for index in self.slice.infer(context):
if index is util.Uninferable:
yield util.Uninferable
Expand Down
24 changes: 24 additions & 0 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,30 @@ class A(B): #@
self.assertIs(a2_ancestors[0], b)
self.assertIs(a2_ancestors[1], a1)

@pytest.mark.skipif(sys.version_info < (3, 5), reason="Needs 'typing' module")
def test_ancestors_generic(self):
code = """
from typing import Generic, TypeVar
T = TypeVar('T')
class A(Generic[T]): #@
pass
class B(A[T]): #@
pass
class C(B[int]): #@
pass
"""
a, b, c = extract_node(code, __name__)
ancestors = list(c.ancestors())
self.assertEqual(len(ancestors), 4)
self.assertIs(ancestors[0], b)
self.assertIs(ancestors[1], a)
self.assertEqual(ancestors[2].name, 'Generic')
self.assertEqual(ancestors[3].name, 'object')

def test_f_arg_f(self):
code = """
def f(f=1):
Expand Down

0 comments on commit cbffd45

Please sign in to comment.