Skip to content

Commit

Permalink
Prohibit referring to class within its definition
Browse files Browse the repository at this point in the history
This partially, but not completely, resolves
python#3088

Previously, code of the following form was accepted by mypy:

    class A:
        def foo(self) -> A:
            ...

This results in an error at runtime because 'A' is not yet defined.
This pull request modifies the parsing process to report an error when
a type is being incorrectly used within its class definition.

However, this commit does *not* attempt to handle other cases where
the user is using a type that is undefined at runtime. For example, mypy
will continue to accept the following code without an error:

    def f() -> A:
        ...

    class A:
        ...

I decided it would probably be best to address this in a future pull
request.
  • Loading branch information
Michael0x2a committed Jul 5, 2017
1 parent 0d045cd commit 8279c20
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 53 deletions.
66 changes: 53 additions & 13 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ def parse_type_comment(type_comment: str, line: int, errors: Optional[Errors]) -
raise
else:
assert isinstance(typ, ast3.Expression)
return TypeConverter(errors, line=line).visit(typ.body)

# parse_type_comments() is meant to be used on types within strings or comments, so
# there's no need to check if the class is currently being defined or not. It also
# doesn't matter if we're using stub files or not.
return TypeConverter(errors, set(), line=line).visit(typ.body)


def with_line(f: Callable[['ASTConverter', T], U]) -> Callable[['ASTConverter', T], U]:
Expand Down Expand Up @@ -142,7 +146,7 @@ def __init__(self,
options: Options,
is_stub: bool,
errors: Errors) -> None:
self.class_nesting = 0
self.classes_being_defined = [set()] # type: List[Set[str]]
self.imports = [] # type: List[ImportBase]

self.options = options
Expand All @@ -152,6 +156,14 @@ def __init__(self,
def fail(self, msg: str, line: int, column: int) -> None:
self.errors.report(line, column, msg)

def convert_to_type(self, node: ast3.AST, lineno: int, skip_class_check: bool = False) -> Type:
if skip_class_check or self.is_stub:
classes = set() # type: Set[str]
else:
classes = self.classes_being_defined[-1]

return TypeConverter(self.errors, classes, line=lineno).visit(node)

def generic_visit(self, node: ast3.AST) -> None:
raise RuntimeError('AST node not implemented: ' + str(type(node)))

Expand Down Expand Up @@ -254,7 +266,21 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
return ret

def in_class(self) -> bool:
return self.class_nesting > 0
return len(self.classes_being_defined[-1]) > 0

def enter_function_body(self) -> None:
# When defining a method, the body is not processed until
# after the containing class is fully defined, so we reset
# the set of classes being defined since to record that we
# can refer to our parent class directly, without needing
# forward references.
#
# If this is a regular function, not a method, pushing an
# empty set is a harmless no-op.
self.classes_being_defined.append(set())

def leave_function_body(self) -> None:
self.classes_being_defined.pop()

def translate_module_id(self, id: str) -> str:
"""Return the actual, internal module id for a source text id.
Expand Down Expand Up @@ -326,12 +352,12 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
# PEP 484 disallows both type annotations and type comments
if n.returns or any(a.type_annotation is not None for a in args):
self.fail(messages.DUPLICATE_TYPE_SIGNATURES, n.lineno, n.col_offset)
translated_args = (TypeConverter(self.errors, line=n.lineno)
translated_args = (TypeConverter(self.errors, set(), line=n.lineno)
.translate_expr_list(func_type_ast.argtypes))
arg_types = [a if a is not None else AnyType()
for a in translated_args]
return_type = TypeConverter(self.errors,
line=n.lineno).visit(func_type_ast.returns)
return_type = TypeConverter(
self.errors, set(), line=n.lineno).visit(func_type_ast.returns)

# add implicit self type
if self.in_class() and len(arg_types) < len(args):
Expand All @@ -342,8 +368,9 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
return_type = AnyType()
else:
arg_types = [a.type_annotation for a in args]
return_type = TypeConverter(self.errors, line=n.returns.lineno
if n.returns else n.lineno).visit(n.returns)
return_type = self.convert_to_type(
n.returns,
n.returns.lineno if n.returns else n.lineno)

for arg, arg_type in zip(args, arg_types):
self.set_type_optional(arg_type, arg.initializer)
Expand All @@ -366,10 +393,13 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
AnyType(implicit=True),
None)

self.enter_function_body()
func_def = FuncDef(n.name,
args,
self.as_block(n.body, n.lineno),
func_type)
self.leave_function_body()

if is_coroutine:
# A coroutine is also a generator, mostly for internal reasons.
func_def.is_generator = func_def.is_coroutine = True
Expand Down Expand Up @@ -410,7 +440,7 @@ def make_argument(arg: ast3.arg, default: Optional[ast3.expr], kind: int) -> Arg
self.fail(messages.DUPLICATE_TYPE_SIGNATURES, arg.lineno, arg.col_offset)
arg_type = None
if arg.annotation is not None:
arg_type = TypeConverter(self.errors, line=arg.lineno).visit(arg.annotation)
arg_type = self.convert_to_type(arg.annotation, arg.lineno)
elif arg.type_comment is not None:
arg_type = parse_type_comment(arg.type_comment, arg.lineno, self.errors)
return Argument(Var(arg.arg), arg_type, self.visit(default), kind)
Expand Down Expand Up @@ -460,7 +490,7 @@ def fail_arg(msg: str, arg: ast3.arg) -> None:
# expr* decorator_list)
@with_line
def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
self.class_nesting += 1
self.classes_being_defined[-1].add(n.name)
metaclass_arg = find(lambda x: x.arg == 'metaclass', n.keywords)
metaclass = None
if metaclass_arg:
Expand All @@ -477,7 +507,7 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
metaclass=metaclass,
keywords=keywords)
cdef.decorators = self.translate_expr_list(n.decorator_list)
self.class_nesting -= 1
self.classes_being_defined[-1].remove(n.name)
return cdef

# Return(expr? value)
Expand Down Expand Up @@ -513,7 +543,7 @@ def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt:
rvalue = TempNode(AnyType()) # type: Expression
else:
rvalue = self.visit(n.value)
typ = TypeConverter(self.errors, line=n.lineno).visit(n.annotation)
typ = self.convert_to_type(n.annotation, n.lineno)
typ.column = n.annotation.col_offset
return AssignmentStmt([self.visit(n.target)], rvalue, type=typ, new_syntax=True)

Expand Down Expand Up @@ -961,11 +991,18 @@ def visit_Index(self, n: ast3.Index) -> Node:


class TypeConverter(ast3.NodeTransformer): # type: ignore # typeshed PR #931
def __init__(self, errors: Errors, line: int = -1) -> None:
def __init__(self,
errors: Errors,
classes_being_defined: Set[str],
line: int = -1) -> None:
self.errors = errors
self.classes_being_defined = classes_being_defined
self.line = line
self.node_stack = [] # type: List[ast3.AST]

def _definition_is_incomplete(self, name: str) -> bool:
return name in self.classes_being_defined

def visit(self, node: ast3.AST) -> Type:
"""Modified visit -- keep track of the stack of nodes"""
self.node_stack.append(node)
Expand Down Expand Up @@ -1049,6 +1086,9 @@ def _extract_argument_name(self, n: ast3.expr) -> str:
return None

def visit_Name(self, n: ast3.Name) -> Type:
if self._definition_is_incomplete(n.id):
self.fail("class '{}' is not fully defined; use a forward reference".format(n.id),
n.lineno, n.col_offset)
return UnboundType(n.id, line=self.line)

def visit_NameConstant(self, n: ast3.NameConstant) -> Type:
Expand Down
4 changes: 2 additions & 2 deletions mypy/fastparse2.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def visit_Module(self, mod: ast27.Module) -> MypyFile:
# arg? kwarg, expr* defaults)
@with_line
def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
converter = TypeConverter(self.errors, line=n.lineno)
converter = TypeConverter(self.errors, set(), line=n.lineno)
args, decompose_stmts = self.transform_args(n.args, n.lineno)

arg_kinds = [arg.kind for arg in args]
Expand Down Expand Up @@ -378,7 +378,7 @@ def transform_args(self,
# TODO: remove the cast once https://github.com/python/typeshed/pull/522
# is accepted and synced
type_comments = cast(List[str], n.type_comments) # type: ignore
converter = TypeConverter(self.errors, line=line)
converter = TypeConverter(self.errors, set(), line=line)
decompose_stmts = [] # type: List[Statement]

def extract_names(arg: ast27.expr) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-class-namedtuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class XRepr(NamedTuple):
y: int = 1
def __str__(self) -> str:
return 'string'
def __add__(self, other: XRepr) -> int:
def __add__(self, other: 'XRepr') -> int:
return 0

reveal_type(XMeth(1).double()) # E: Revealed type is 'builtins.int'
Expand Down
28 changes: 14 additions & 14 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1477,24 +1477,24 @@ from typing import Any
def deco(f: Any) -> Any: return f
class C:
@deco
def __add__(self, other: C) -> C: return C()
def __radd__(self, other: C) -> C: return C()
def __add__(self, other: 'C') -> 'C': return C()
def __radd__(self, other: 'C') -> 'C': return C()
[out]

[case testReverseOperatorMethodForwardIsAny2]
from typing import Any
def deco(f: Any) -> Any: return f
class C:
__add__ = None # type: Any
def __radd__(self, other: C) -> C: return C()
def __radd__(self, other: 'C') -> 'C': return C()
[out]

[case testReverseOperatorMethodForwardIsAny3]
from typing import Any
def deco(f: Any) -> Any: return f
class C:
__add__ = 42
def __radd__(self, other: C) -> C: return C()
def __radd__(self, other: 'C') -> 'C': return C()
[out]
main:5: error: Forward operator "__add__" is not callable

Expand Down Expand Up @@ -1631,7 +1631,7 @@ main:8: error: Signatures of "__iadd__" and "__add__" are incompatible

a, b = None, None # type: A, B
class A:
def __getattribute__(self, x: str) -> A:
def __getattribute__(self, x: str) -> 'A':
return A()
class B: pass

Expand All @@ -1642,11 +1642,11 @@ main:9: error: Incompatible types in assignment (expression has type "A", variab

[case testGetAttributeSignature]
class A:
def __getattribute__(self, x: str) -> A: pass
def __getattribute__(self, x: str) -> 'A': pass
class B:
def __getattribute__(self, x: A) -> B: pass
def __getattribute__(self, x: A) -> 'B': pass
class C:
def __getattribute__(self, x: str, y: str) -> C: pass
def __getattribute__(self, x: str, y: str) -> 'C': pass
class D:
def __getattribute__(self, x: str) -> None: pass
[out]
Expand All @@ -1657,7 +1657,7 @@ main:6: error: Invalid signature "def (__main__.C, builtins.str, builtins.str) -

a, b = None, None # type: A, B
class A:
def __getattr__(self, x: str) -> A:
def __getattr__(self, x: str) -> 'A':
return A()
class B: pass

Expand All @@ -1668,11 +1668,11 @@ main:9: error: Incompatible types in assignment (expression has type "A", variab

[case testGetAttrSignature]
class A:
def __getattr__(self, x: str) -> A: pass
def __getattr__(self, x: str) -> 'A': pass
class B:
def __getattr__(self, x: A) -> B: pass
def __getattr__(self, x: A) -> 'B': pass
class C:
def __getattr__(self, x: str, y: str) -> C: pass
def __getattr__(self, x: str, y: str) -> 'C': pass
class D:
def __getattr__(self, x: str) -> None: pass
[out]
Expand Down Expand Up @@ -1776,7 +1776,7 @@ a = a(b) # E: Argument 1 to "__call__" of "A" has incompatible type "B"; expect
b = a(a) # E: Incompatible types in assignment (expression has type "A", variable has type "B")

class A:
def __call__(self, x: A) -> A:
def __call__(self, x: 'A') -> 'A':
pass
class B: pass

Expand Down Expand Up @@ -3280,7 +3280,7 @@ def r(ta: Type[TA], tta: TTA) -> None:

class Class(metaclass=M):
@classmethod
def f1(cls: Type[Class]) -> None: pass
def f1(cls: Type['Class']) -> None: pass
@classmethod
def f2(cls: M) -> None: pass
cl: Type[Class] = m # E: Incompatible types in assignment (expression has type "M", variable has type Type[Class])
Expand Down
16 changes: 8 additions & 8 deletions test-data/unit/check-selftype.test
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class A:
pass

class C(A):
def copy(self: C) -> C:
def copy(self: 'C') -> 'C':
pass

class D(A):
Expand Down Expand Up @@ -276,10 +276,10 @@ class B:
return cls()

class C:
def foo(self: C) -> C: return self
def foo(self: 'C') -> 'C': return self

@classmethod
def cfoo(cls: Type[C]) -> C:
def cfoo(cls: Type['C']) -> 'C':
return cls()

class D:
Expand Down Expand Up @@ -330,21 +330,21 @@ class B:
pass

class C:
def __new__(cls: Type[C]) -> C:
def __new__(cls: Type['C']) -> 'C':
return cls()

def __init_subclass__(cls: Type[C]) -> None:
def __init_subclass__(cls: Type['C']) -> None:
pass

class D:
def __new__(cls: D) -> D: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
def __new__(cls: 'D') -> 'D': # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
return cls

def __init_subclass__(cls: D) -> None: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
def __init_subclass__(cls: 'D') -> None: # E: The erased type of self '__main__.D' is not a supertype of its class 'Type[__main__.D]'
pass

class E:
def __new__(cls) -> E:
def __new__(cls) -> 'E':
reveal_type(cls) # E: Revealed type is 'def () -> __main__.E'
return cls()

Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class A(object):
self.a = 0

def __iadd__(self, a):
# type: (int) -> A
# type: (int) -> 'A'
self.a += 1
return self

Expand Down
18 changes: 9 additions & 9 deletions test-data/unit/check-typevar-values.test
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ f(S())
[case testCheckGenericFunctionBodyWithTypeVarValues]
from typing import TypeVar
class A:
def f(self, x: int) -> A: return self
def f(self, x: int) -> 'A': return self
class B:
def f(self, x: int) -> B: return self
def f(self, x: int) -> 'B': return self
AB = TypeVar('AB', A, B)
def f(x: AB) -> AB:
x = x.f(1)
Expand All @@ -58,11 +58,11 @@ def f(x: AB) -> AB:
[case testCheckGenericFunctionBodyWithTypeVarValues2]
from typing import TypeVar
class A:
def f(self) -> A: return A()
def g(self) -> B: return B()
def f(self) -> 'A': return A()
def g(self) -> 'B': return B()
class B:
def f(self) -> A: return A()
def g(self) -> B: return B()
def g(self) -> 'B': return B()
AB = TypeVar('AB', A, B)
def f(x: AB) -> AB:
return x.f() # Error
Expand All @@ -75,11 +75,11 @@ main:12: error: Incompatible return value type (got "B", expected "A")
[case testTypeInferenceAndTypeVarValues]
from typing import TypeVar
class A:
def f(self) -> A: return self
def g(self) -> B: return B()
def f(self) -> 'A': return self
def g(self) -> 'B': return B()
class B:
def f(self) -> B: return self
def g(self) -> B: return B()
def f(self) -> 'B': return self
def g(self) -> 'B': return B()
AB = TypeVar('AB', A, B)
def f(x: AB) -> AB:
y = x
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fine-grained.test
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class A:
def g(self) -> None: pass
[file m.py.2]
class A:
def g(self, a: A) -> None: pass
def g(self, a: 'A') -> None: pass
[out]
==
main:4: error: Too few arguments for "g" of "A"
Expand Down
Loading

0 comments on commit 8279c20

Please sign in to comment.