diff --git a/astroid/inference.py b/astroid/inference.py index bc3e1f9701..b297d3eef0 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -368,6 +368,10 @@ def infer_subscript(self, context=None): if value is util.Uninferable: yield util.Uninferable return None + if isinstance(value, nodes.ClassDef): + if value.is_subtype_of('typing.Generic'): + yield value + return None for index in self.slice.infer(context): if index is util.Uninferable: yield util.Uninferable diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index e267f97022..816b011b2a 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -374,6 +374,30 @@ class A(B): #@ self.assertIs(a2_ancestors[0], b) self.assertIs(a2_ancestors[1], a1) + @pytest.mark.skipif(sys.version_info < (3, 5), reason="Needs 'typing' module") + def test_ancestors_generic(self): + code = """ + from typing import Generic, TypeVar + + T = TypeVar('T') + + class A(Generic[T]): #@ + pass + + class B(A[T]): #@ + pass + + class C(B[int]): #@ + pass + """ + a, b, c = extract_node(code, __name__) + ancestors = list(c.ancestors()) + self.assertEqual(len(ancestors), 4) + self.assertIs(ancestors[0], b) + self.assertIs(ancestors[1], a) + self.assertEqual(ancestors[2].name, 'Generic') + self.assertEqual(ancestors[3].name, 'object') + def test_f_arg_f(self): code = """ def f(f=1):