Skip to content

Commit

Permalink
Don't type check most function bodies if ignoring errors (#14150)
Browse files Browse the repository at this point in the history
If errors are ignored, type checking function bodies often can have no
effect. Remove function bodies after parsing to speed up type checking.

Methods that define attributes have an externally visible effect even if
errors are ignored. The body of any method that assigns to any attribute
is preserved to deal with this (even if it doesn't actually define a new
attribute). Most methods don't assign to an attribute, so stripping
bodies is still effective for methods.

There are a couple of additional interesting things in the
implementation:

1. We need to know whether an abstract method has a trivial body (e.g.
just `...`) to check `super()` method calls. The approach here is to
preserve such trivial bodies and treat them differently from no body at
all.
2. Stubgen analyzes the bodies of functions to e.g. infer some return
types. As a workaround, explicitly preserve full ASTs when using
stubgen.

The main benefit is faster type checking when using installed packages
with inline type information (PEP 561). Errors are ignored in this case,
and it's common to have a large number of third-party code to type
check. For example, a self check (code under `mypy/`) is now about **20%
faster**, with a compiled mypy on Python 3.11.

Another, more subtle benefit is improved reliability. A third-party
library may have some code that triggers a mypy crash or an invalid
blocking error. If bodies are stripped, often the error will no longer
be triggered, since the amount code to type check is much lower.

---------

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
  • Loading branch information
JukkaL and hauntsaninja authored Apr 24, 2023
1 parent bdac4bc commit aee983e
Show file tree
Hide file tree
Showing 9 changed files with 692 additions and 26 deletions.
2 changes: 2 additions & 0 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ def parse_file(
Raise CompileError if there is a parse error.
"""
t0 = time.time()
if ignore_errors:
self.errors.ignored_files.add(path)
tree = parse(source, path, id, self.errors, options=options)
tree._fullname = id
self.add_stats(
Expand Down
5 changes: 1 addition & 4 deletions mypy/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,7 @@ def split_directive(s: str) -> tuple[list[str], list[str]]:


def mypy_comments_to_config_map(line: str, template: Options) -> tuple[dict[str, str], list[str]]:
"""Rewrite the mypy comment syntax into ini file syntax.
Returns
"""
"""Rewrite the mypy comment syntax into ini file syntax."""
options = {}
entries, errors = split_directive(line)
for entry in entries:
Expand Down
179 changes: 166 additions & 13 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
)
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
from mypy.sharedparse import argument_elide_name, special_function_elide_names
from mypy.traverser import TraverserVisitor
from mypy.types import (
AnyType,
CallableArgument,
Expand Down Expand Up @@ -260,6 +261,11 @@ def parse(
Return the parse tree. If errors is not provided, raise ParseError
on failure. Otherwise, use the errors object to report parse errors.
"""
ignore_errors = (options is not None and options.ignore_errors) or (
errors is not None and fnam in errors.ignored_files
)
# If errors are ignored, we can drop many function bodies to speed up type checking.
strip_function_bodies = ignore_errors and (options is None or not options.preserve_asts)
raise_on_error = False
if options is None:
options = Options()
Expand All @@ -281,7 +287,13 @@ def parse(
warnings.filterwarnings("ignore", category=DeprecationWarning)
ast = ast3_parse(source, fnam, "exec", feature_version=feature_version)

tree = ASTConverter(options=options, is_stub=is_stub_file, errors=errors).visit(ast)
tree = ASTConverter(
options=options,
is_stub=is_stub_file,
errors=errors,
ignore_errors=ignore_errors,
strip_function_bodies=strip_function_bodies,
).visit(ast)
tree.path = fnam
tree.is_stub = is_stub_file
except SyntaxError as e:
Expand Down Expand Up @@ -400,14 +412,24 @@ def is_no_type_check_decorator(expr: ast3.expr) -> bool:


class ASTConverter:
def __init__(self, options: Options, is_stub: bool, errors: Errors) -> None:
# 'C' for class, 'F' for function
self.class_and_function_stack: list[Literal["C", "F"]] = []
def __init__(
self,
options: Options,
is_stub: bool,
errors: Errors,
*,
ignore_errors: bool,
strip_function_bodies: bool,
) -> None:
# 'C' for class, 'D' for function signature, 'F' for function, 'L' for lambda
self.class_and_function_stack: list[Literal["C", "D", "F", "L"]] = []
self.imports: list[ImportBase] = []

self.options = options
self.is_stub = is_stub
self.errors = errors
self.ignore_errors = ignore_errors
self.strip_function_bodies = strip_function_bodies

self.type_ignores: dict[int, list[str]] = {}

Expand Down Expand Up @@ -475,7 +497,12 @@ def get_lineno(self, node: ast3.expr | ast3.stmt) -> int:
return node.lineno

def translate_stmt_list(
self, stmts: Sequence[ast3.stmt], ismodule: bool = False
self,
stmts: Sequence[ast3.stmt],
*,
ismodule: bool = False,
can_strip: bool = False,
is_coroutine: bool = False,
) -> list[Statement]:
# A "# type: ignore" comment before the first statement of a module
# ignores the whole module:
Expand Down Expand Up @@ -504,11 +531,41 @@ def translate_stmt_list(
mark_block_unreachable(block)
return [block]

stack = self.class_and_function_stack
if self.strip_function_bodies and len(stack) == 1 and stack[0] == "F":
return []

res: list[Statement] = []
for stmt in stmts:
node = self.visit(stmt)
res.append(node)

if (
self.strip_function_bodies
and can_strip
and stack[-2:] == ["C", "F"]
and not is_possible_trivial_body(res)
):
# We only strip method bodies if they don't assign to an attribute, as
# this may define an attribute which has an externally visible effect.
visitor = FindAttributeAssign()
for s in res:
s.accept(visitor)
if visitor.found:
break
else:
if is_coroutine:
# Yields inside an async function affect the return type and should not
# be stripped.
yield_visitor = FindYield()
for s in res:
s.accept(yield_visitor)
if yield_visitor.found:
break
else:
return []
else:
return []
return res

def translate_type_comment(
Expand Down Expand Up @@ -573,9 +630,20 @@ def as_block(self, stmts: list[ast3.stmt], lineno: int) -> Block | None:
b.set_line(lineno)
return b

def as_required_block(self, stmts: list[ast3.stmt], lineno: int) -> Block:
def as_required_block(
self,
stmts: list[ast3.stmt],
lineno: int,
*,
can_strip: bool = False,
is_coroutine: bool = False,
) -> Block:
assert stmts # must be non-empty
b = Block(self.fix_function_overloads(self.translate_stmt_list(stmts)))
b = Block(
self.fix_function_overloads(
self.translate_stmt_list(stmts, can_strip=can_strip, is_coroutine=is_coroutine)
)
)
# TODO: in most call sites line is wrong (includes first line of enclosing statement)
# TODO: also we need to set the column, and the end position here.
b.set_line(lineno)
Expand Down Expand Up @@ -831,9 +899,6 @@ def _is_stripped_if_stmt(self, stmt: Statement) -> bool:
# For elif, IfStmt are stored recursively in else_body
return self._is_stripped_if_stmt(stmt.else_body.body[0])

def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ["C", "F"]

def translate_module_id(self, id: str) -> str:
"""Return the actual, internal module id for a source text id."""
if id == self.options.custom_typing_module:
Expand Down Expand Up @@ -868,7 +933,7 @@ def do_func_def(
self, n: ast3.FunctionDef | ast3.AsyncFunctionDef, is_coroutine: bool = False
) -> FuncDef | Decorator:
"""Helper shared between visit_FunctionDef and visit_AsyncFunctionDef."""
self.class_and_function_stack.append("F")
self.class_and_function_stack.append("D")
no_type_check = bool(
n.decorator_list and any(is_no_type_check_decorator(d) for d in n.decorator_list)
)
Expand Down Expand Up @@ -915,7 +980,8 @@ def do_func_def(
return_type = TypeConverter(self.errors, line=lineno).visit(func_type_ast.returns)

# add implicit self type
if self.in_method_scope() and len(arg_types) < len(args):
in_method_scope = self.class_and_function_stack[-2:] == ["C", "D"]
if in_method_scope and len(arg_types) < len(args):
arg_types.insert(0, AnyType(TypeOfAny.special_form))
except SyntaxError:
stripped_type = n.type_comment.split("#", 2)[0].strip()
Expand Down Expand Up @@ -965,7 +1031,10 @@ def do_func_def(
end_line = getattr(n, "end_lineno", None)
end_column = getattr(n, "end_col_offset", None)

func_def = FuncDef(n.name, args, self.as_required_block(n.body, lineno), func_type)
self.class_and_function_stack.pop()
self.class_and_function_stack.append("F")
body = self.as_required_block(n.body, lineno, can_strip=True, is_coroutine=is_coroutine)
func_def = FuncDef(n.name, args, body, func_type)
if isinstance(func_def.type, CallableType):
# semanal.py does some in-place modifications we want to avoid
func_def.unanalyzed_type = func_def.type.copy_modified()
Expand Down Expand Up @@ -1409,9 +1478,11 @@ def visit_Lambda(self, n: ast3.Lambda) -> LambdaExpr:
body.lineno = n.body.lineno
body.col_offset = n.body.col_offset

self.class_and_function_stack.append("L")
e = LambdaExpr(
self.transform_args(n.args, n.lineno), self.as_required_block([body], n.lineno)
)
self.class_and_function_stack.pop()
e.set_line(n.lineno, n.col_offset) # Overrides set_line -- can't use self.set_line
return e

Expand Down Expand Up @@ -2081,3 +2152,85 @@ def stringify_name(n: AST) -> str | None:
if sv is not None:
return f"{sv}.{n.attr}"
return None # Can't do it.


class FindAttributeAssign(TraverserVisitor):
"""Check if an AST contains attribute assignments (e.g. self.x = 0)."""

def __init__(self) -> None:
self.lvalue = False
self.found = False

def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
self.lvalue = True
for lv in s.lvalues:
lv.accept(self)
self.lvalue = False

def visit_with_stmt(self, s: WithStmt) -> None:
self.lvalue = True
for lv in s.target:
if lv is not None:
lv.accept(self)
self.lvalue = False
s.body.accept(self)

def visit_for_stmt(self, s: ForStmt) -> None:
self.lvalue = True
s.index.accept(self)
self.lvalue = False
s.body.accept(self)
if s.else_body:
s.else_body.accept(self)

def visit_expression_stmt(self, s: ExpressionStmt) -> None:
# No need to look inside these
pass

def visit_call_expr(self, e: CallExpr) -> None:
# No need to look inside these
pass

def visit_index_expr(self, e: IndexExpr) -> None:
# No need to look inside these
pass

def visit_member_expr(self, e: MemberExpr) -> None:
if self.lvalue:
self.found = True


class FindYield(TraverserVisitor):
"""Check if an AST contains yields or yield froms."""

def __init__(self) -> None:
self.found = False

def visit_yield_expr(self, e: YieldExpr) -> None:
self.found = True

def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
self.found = True


def is_possible_trivial_body(s: list[Statement]) -> bool:
"""Could the statements form a "trivial" function body, such as 'pass'?
This mimics mypy.semanal.is_trivial_body, but this runs before
semantic analysis so some checks must be conservative.
"""
l = len(s)
if l == 0:
return False
i = 0
if isinstance(s[0], ExpressionStmt) and isinstance(s[0].expr, StrExpr):
# Skip docstring
i += 1
if i == l:
return True
if l > i + 1:
return False
stmt = s[i]
return isinstance(stmt, (PassStmt, RaiseStmt)) or (
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
)
11 changes: 9 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6684,7 +6684,7 @@ def is_trivial_body(block: Block) -> bool:
"..." (ellipsis), or "raise NotImplementedError()". A trivial body may also
start with a statement containing just a string (e.g. a docstring).
Note: functions that raise other kinds of exceptions do not count as
Note: Functions that raise other kinds of exceptions do not count as
"trivial". We use this function to help us determine when it's ok to
relax certain checks on body, but functions that raise arbitrary exceptions
are more likely to do non-trivial work. For example:
Expand All @@ -6694,11 +6694,18 @@ def halt(self, reason: str = ...) -> NoReturn:
A function that raises just NotImplementedError is much less likely to be
this complex.
Note: If you update this, you may also need to update
mypy.fastparse.is_possible_trivial_body!
"""
body = block.body
if not body:
# Functions have empty bodies only if the body is stripped or the function is
# generated or deserialized. In these cases the body is unknown.
return False

# Skip a docstring
if body and isinstance(body[0], ExpressionStmt) and isinstance(body[0].expr, StrExpr):
if isinstance(body[0], ExpressionStmt) and isinstance(body[0].expr, StrExpr):
body = block.body[1:]

if len(body) == 0:
Expand Down
1 change: 1 addition & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,7 @@ def mypy_options(stubgen_options: Options) -> MypyOptions:
options.python_version = stubgen_options.pyversion
options.show_traceback = True
options.transform_source = remove_misplaced_type_comments
options.preserve_asts = True
return options


Expand Down
15 changes: 10 additions & 5 deletions mypy/test/testparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from pytest import skip

from mypy import defaults
from mypy.config_parser import parse_mypy_comments
from mypy.errors import CompileError
from mypy.options import Options
from mypy.parse import parse
from mypy.test.data import DataDrivenTestCase, DataSuite
from mypy.test.helpers import assert_string_arrays_equal, find_test_files, parse_options
from mypy.util import get_mypy_comments


class ParserSuite(DataSuite):
Expand Down Expand Up @@ -40,13 +42,16 @@ def test_parser(testcase: DataDrivenTestCase) -> None:
else:
options.python_version = defaults.PYTHON3_VERSION

source = "\n".join(testcase.input)

# Apply mypy: comments to options.
comments = get_mypy_comments(source)
changes, _ = parse_mypy_comments(comments, options)
options = options.apply_changes(changes)

try:
n = parse(
bytes("\n".join(testcase.input), "ascii"),
fnam="main",
module="__main__",
errors=None,
options=options,
bytes(source, "ascii"), fnam="main", module="__main__", errors=None, options=options
)
a = n.str_with_options(options).split("\n")
except CompileError as e:
Expand Down
Loading

0 comments on commit aee983e

Please sign in to comment.