From 5ccce79096f1e42a2cc391c4b84c47d17caea735 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 2 Jul 2024 06:23:56 -0700 Subject: [PATCH] More types (#483) --- bugbear.py | 180 ++++++++++++++++++++++++++--------------------------- 1 file changed, 89 insertions(+), 91 deletions(-) diff --git a/bugbear.py b/bugbear.py index cff5551..96052aa 100644 --- a/bugbear.py +++ b/bugbear.py @@ -12,7 +12,7 @@ from contextlib import suppress from functools import lru_cache, partial from keyword import iskeyword -from typing import Dict, Iterable, Iterator, List, Set, Union +from typing import Dict, Iterable, Iterator, List, Sequence, Set, Union import attr import pycodestyle # type: ignore[import-untyped] @@ -128,7 +128,7 @@ def adapt_error(cls, e): """Adapts the extended error namedtuple to be compatible with Flake8.""" return e._replace(message=e.message.format(*e.vars))[:4] - def load_file(self): + def load_file(self) -> None: """Loads the file in a way that auto-detects source encoding and deals with broken terminal encodings for stdin. @@ -145,7 +145,7 @@ def load_file(self): self.tree = ast.parse("".join(self.lines)) @staticmethod - def add_options(optmanager): + def add_options(optmanager) -> None: """Informs flake8 to ignore B9xx by default.""" optmanager.extend_default_ignore(disabled_by_default) optmanager.add_option( @@ -170,7 +170,7 @@ def add_options(optmanager): ) @lru_cache # noqa: B019 - def should_warn(self, code): + def should_warn(self, code) -> bool: """Returns `True` if Bugbear should emit a particular warning. flake8 overrides default ignores when the user specifies @@ -217,7 +217,7 @@ def should_warn(self, code): return False -def _is_identifier(arg): +def _is_identifier(arg) -> bool: # Return True if arg is a valid identifier, per # https://docs.python.org/2/reference/lexical_analysis.html#identifiers @@ -243,7 +243,7 @@ def _flatten_excepthandler(node: ast.expr | None) -> Iterator[ast.expr | None]: yield expr -def _check_redundant_excepthandlers(names, node): +def _check_redundant_excepthandlers(names: Sequence[str], node): # See if any of the given exception names could be removed, e.g. from: # (MyError, MyError) # duplicate names # (MyError, BaseException) # everything derives from the Base @@ -380,7 +380,7 @@ class BugBearVisitor(ast.NodeVisitor): if False: # Useful for tracing what the hell is going on. - def __getattr__(self, name): + def __getattr__(self, name: str): print(name) return self.__getattribute__(name) @@ -416,7 +416,7 @@ def visit_YieldFrom(self, node: ast.YieldFrom) -> None: self.errors.append(B037(node.lineno, node.col_offset)) self.generic_visit(node) - def visit(self, node): + def visit(self, node) -> None: is_contextful = isinstance(node, CONTEXTFUL_NODES) if is_contextful: @@ -463,14 +463,14 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None: self.errors.append(B040(node.lineno, node.col_offset)) self.b040_caught_exception = old_b040_caught_exception - def visit_UAdd(self, node): + def visit_UAdd(self, node) -> None: trailing_nodes = list(map(type, self.node_window[-4:])) if trailing_nodes == [ast.UnaryOp, ast.UAdd, ast.UnaryOp, ast.UAdd]: originator = self.node_window[-4] self.errors.append(B002(originator.lineno, originator.col_offset)) self.generic_visit(node) - def visit_Call(self, node): + def visit_Call(self, node) -> None: is_b040_add_note = False if isinstance(node.func, ast.Attribute): self.check_for_b005(node) @@ -517,7 +517,7 @@ def visit_Call(self, node): # e.g. `e.add_note(str(e))` self.b040_caught_exception = current_b040_caught_exception - def visit_Module(self, node): + def visit_Module(self, node) -> None: self.generic_visit(node) def visit_Assign(self, node: ast.Assign) -> None: @@ -530,7 +530,7 @@ def visit_Assign(self, node: ast.Assign) -> None: self.errors.append(B003(node.lineno, node.col_offset)) self.generic_visit(node) - def visit_For(self, node): + def visit_For(self, node) -> None: self.check_for_b007(node) self.check_for_b020(node) self.check_for_b023(node) @@ -538,41 +538,41 @@ def visit_For(self, node): self.check_for_b909(node) self.generic_visit(node) - def visit_AsyncFor(self, node): + def visit_AsyncFor(self, node) -> None: self.check_for_b023(node) self.generic_visit(node) - def visit_While(self, node): + def visit_While(self, node) -> None: self.check_for_b023(node) self.generic_visit(node) - def visit_ListComp(self, node): + def visit_ListComp(self, node) -> None: self.check_for_b023(node) self.generic_visit(node) - def visit_SetComp(self, node): + def visit_SetComp(self, node) -> None: self.check_for_b023(node) self.generic_visit(node) - def visit_DictComp(self, node): + def visit_DictComp(self, node) -> None: self.check_for_b023(node) self.check_for_b035(node) self.generic_visit(node) - def visit_GeneratorExp(self, node): + def visit_GeneratorExp(self, node) -> None: self.check_for_b023(node) self.generic_visit(node) - def visit_Assert(self, node): + def visit_Assert(self, node) -> None: self.check_for_b011(node) self.generic_visit(node) - def visit_AsyncFunctionDef(self, node): + def visit_AsyncFunctionDef(self, node) -> None: self.check_for_b902(node) self.check_for_b006_and_b008(node) self.generic_visit(node) - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node) -> None: self.check_for_b901(node) self.check_for_b902(node) self.check_for_b006_and_b008(node) @@ -581,22 +581,22 @@ def visit_FunctionDef(self, node): self.check_for_b906(node) self.generic_visit(node) - def visit_ClassDef(self, node: ast.ClassDef): + def visit_ClassDef(self, node: ast.ClassDef) -> None: self.check_for_b903(node) self.check_for_b021(node) self.check_for_b024_and_b027(node) self.generic_visit(node) - def visit_Try(self, node): + def visit_Try(self, node) -> None: self.check_for_b012(node) self.check_for_b025(node) self.generic_visit(node) - def visit_Compare(self, node): + def visit_Compare(self, node) -> None: self.check_for_b015(node) self.generic_visit(node) - def visit_Raise(self, node: ast.Raise): + def visit_Raise(self, node: ast.Raise) -> None: if node.exc is None: self.b040_caught_exception = None else: @@ -606,40 +606,40 @@ def visit_Raise(self, node: ast.Raise): self.check_for_b904(node) self.generic_visit(node) - def visit_With(self, node): + def visit_With(self, node) -> None: self.check_for_b017(node) self.check_for_b022(node) self.check_for_b908(node) self.generic_visit(node) - def visit_JoinedStr(self, node): + def visit_JoinedStr(self, node) -> None: self.check_for_b907(node) self.generic_visit(node) - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node) -> None: self.check_for_b032(node) self.check_for_b040_usage(node.value) self.generic_visit(node) - def visit_Import(self, node): + def visit_Import(self, node) -> None: self.check_for_b005(node) self.generic_visit(node) - def visit_ImportFrom(self, node): + def visit_ImportFrom(self, node) -> None: self.visit_Import(node) - def visit_Set(self, node): + def visit_Set(self, node) -> None: self.check_for_b033(node) self.generic_visit(node) - def check_for_b005(self, node): + def check_for_b005(self, node) -> None: if isinstance(node, ast.Import): for name in node.names: self._b005_imports.add(name.asname or name.name) elif isinstance(node, ast.ImportFrom): for name in node.names: self._b005_imports.add(f"{node.module}.{name.name or name.asname}") - elif isinstance(node, ast.Call): + elif isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): if node.func.attr not in B005_METHODS: return # method name doesn't match @@ -665,14 +665,14 @@ def check_for_b005(self, node): self.errors.append(B005(node.lineno, node.col_offset)) - def check_for_b006_and_b008(self, node): + def check_for_b006_and_b008(self, node) -> None: visitor = FunctionDefDefaultsVisitor( B006, B008, self.b008_b039_extend_immutable_calls ) visitor.visit(node.args.defaults + node.args.kw_defaults) self.errors.extend(visitor.errors) - def check_for_b039(self, node: ast.Call): + def check_for_b039(self, node: ast.Call) -> None: if not ( (isinstance(node.func, ast.Name) and node.func.id == "ContextVar") or ( @@ -696,7 +696,7 @@ def check_for_b039(self, node: ast.Call): visitor.visit(kw.value) self.errors.extend(visitor.errors) - def check_for_b007(self, node): + def check_for_b007(self, node) -> None: targets = NameFinder() targets.visit(node.target) ctrl_names = set(filter(lambda s: not s.startswith("_"), targets.names)) @@ -708,12 +708,12 @@ def check_for_b007(self, node): n = targets.names[name][0] self.errors.append(B007(n.lineno, n.col_offset, vars=(name,))) - def check_for_b011(self, node): + def check_for_b011(self, node) -> None: if isinstance(node.test, ast.Constant) and node.test.value is False: self.errors.append(B011(node.lineno, node.col_offset)) - def check_for_b012(self, node): - def _loop(node, bad_node_types): + def check_for_b012(self, node) -> None: + def _loop(node, bad_node_types) -> None: if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): return @@ -763,11 +763,11 @@ def check_for_b013_b029_b030(self, node: ast.ExceptHandler) -> list[str]: self.errors.append(maybe_error) return names - def check_for_b015(self, node): + def check_for_b015(self, node) -> None: if isinstance(self.node_stack[-2], ast.Expr): self.errors.append(B015(node.lineno, node.col_offset)) - def check_for_b016(self, node): + def check_for_b016(self, node) -> None: if isinstance(node.exc, ast.JoinedStr) or ( isinstance(node.exc, ast.Constant) and ( @@ -777,7 +777,7 @@ def check_for_b016(self, node): ): self.errors.append(B016(node.lineno, node.col_offset)) - def check_for_b017(self, node): + def check_for_b017(self, node) -> None: """Checks for use of the evil syntax 'with assertRaises(Exception):' or 'with pytest.raises(Exception)'. @@ -821,7 +821,7 @@ def check_for_b017(self, node): ): self.errors.append(B017(node.lineno, node.col_offset)) - def check_for_b019(self, node): + def check_for_b019(self, node) -> None: if ( len(node.decorator_list) == 0 or len(self.contexts) < 2 @@ -847,7 +847,7 @@ def check_for_b019(self, node): ) return - def check_for_b020(self, node): + def check_for_b020(self, node) -> None: targets = NameFinder() targets.visit(node.target) ctrl_names = set(targets.names) @@ -861,7 +861,7 @@ def check_for_b020(self, node): n = targets.names[name][0] self.errors.append(B020(n.lineno, n.col_offset, vars=(name,))) - def check_for_b023(self, loop_node): # noqa: C901 + def check_for_b023(self, loop_node) -> None: # noqa: C901 """Check that functions (including lambdas) do not use loop variables. https://docs.python-guide.org/writing/gotchas/#late-binding-closures from @@ -940,11 +940,11 @@ def check_for_b023(self, loop_node): # noqa: C901 if reassigned_in_loop.issuperset(err.vars): self.errors.append(err) - def check_for_b024_and_b027(self, node: ast.ClassDef): # noqa: C901 + def check_for_b024_and_b027(self, node: ast.ClassDef) -> None: # noqa: C901 """Check for inheritance from abstract classes in abc and lack of any methods decorated with abstract*""" - def is_abc_class(value, name="ABC"): + def is_abc_class(value, name: str = "ABC"): # class foo(metaclass = [abc.]ABCMeta) if isinstance(value, ast.keyword): return value.arg == "metaclass" and is_abc_class(value.value, "ABCMeta") @@ -1022,7 +1022,7 @@ def is_str_or_ellipsis(node): if has_method and not has_abstract_method: self.errors.append(B024(node.lineno, node.col_offset, vars=(node.name,))) - def check_for_b026(self, call: ast.Call): + def check_for_b026(self, call: ast.Call) -> None: if not call.keywords: return @@ -1038,7 +1038,7 @@ def check_for_b026(self, call: ast.Call): ): self.errors.append(B026(starred.lineno, starred.col_offset)) - def check_for_b031(self, loop_node): # noqa: C901 + def check_for_b031(self, loop_node) -> None: # noqa: C901 """Check that `itertools.groupby` isn't iterated over more than once. We emit a warning when the generator returned by `groupby()` is used @@ -1107,7 +1107,7 @@ def _get_dict_comp_loop_and_named_expr_var_names(self, node: ast.DictComp): yield from finder.names.keys() - def check_for_b035(self, node: ast.DictComp): + def check_for_b035(self, node: ast.DictComp) -> None: """Check that a static key isn't used in a dict comprehension. Emit a warning if a likely unchanging key is used - either a constant, @@ -1161,7 +1161,7 @@ def _get_assigned_names(self, loop_node): if isinstance(node, loop_targets + (ast.AnnAssign, ast.AugAssign)): yield from names_from_assignments(node.target) - def check_for_b904(self, node): + def check_for_b904(self, node) -> None: """Checks `raise` without `from` inside an `except` clause. In these cases, you should use explicit exception chaining from the @@ -1308,7 +1308,7 @@ def is_classmethod(decorators: Set[str]) -> bool: B902(lineno, col, vars=(actual_first_arg, kind, expected_first_args[0])) ) - def check_for_b903(self, node): + def check_for_b903(self, node) -> None: body = node.body if ( body @@ -1339,7 +1339,7 @@ def check_for_b903(self, node): self.errors.append(B903(node.lineno, node.col_offset)) - def check_for_b018(self, node): + def check_for_b018(self, node) -> None: if not isinstance(node, ast.Expr): return if isinstance( @@ -1365,7 +1365,7 @@ def check_for_b018(self, node): ) ) - def check_for_b021(self, node): + def check_for_b021(self, node) -> None: if ( node.body and isinstance(node.body[0], ast.Expr) @@ -1375,7 +1375,7 @@ def check_for_b021(self, node): B021(node.body[0].value.lineno, node.body[0].value.col_offset) ) - def check_for_b022(self, node): + def check_for_b022(self, node) -> None: item = node.items[0] item_context = item.context_expr if ( @@ -1415,14 +1415,14 @@ def _is_assertRaises_like(node: ast.withitem) -> bool: else: return False - def check_for_b908(self, node: ast.With): + def check_for_b908(self, node: ast.With) -> None: if len(node.body) < 2: return for node_item in node.items: if self._is_assertRaises_like(node_item): self.errors.append(B908(node.lineno, node.col_offset)) - def check_for_b025(self, node): + def check_for_b025(self, node) -> None: seen = [] for handler in node.handlers: if isinstance(handler.type, (ast.Name, ast.Attribute)): @@ -1473,7 +1473,7 @@ def _is_infinite_iterator(node: ast.expr) -> bool: return False - def check_for_b905(self, node): + def check_for_b905(self, node) -> None: if not (isinstance(node.func, ast.Name) and node.func.id == "zip"): return for arg in node.args: @@ -1482,7 +1482,7 @@ def check_for_b905(self, node): if not any(kw.arg == "strict" for kw in node.keywords): self.errors.append(B905(node.lineno, node.col_offset)) - def check_for_b906(self, node: ast.FunctionDef): + def check_for_b906(self, node: ast.FunctionDef) -> None: if not node.name.startswith("visit_"): return @@ -1530,7 +1530,7 @@ def check_for_b906(self, node: ast.FunctionDef): else: self.errors.append(B906(node.lineno, node.col_offset)) - def check_for_b907(self, node: ast.JoinedStr): # noqa: C901 + def check_for_b907(self, node: ast.JoinedStr) -> None: # noqa: C901 def myunparse(node: ast.AST) -> str: # pragma: no cover if sys.version_info >= (3, 9): return ast.unparse(node) @@ -1633,7 +1633,7 @@ def myunparse(node: ast.AST) -> str: # pragma: no cover # if no pre-mark or variable detected, reset state current_mark = variable = None - def check_for_b028(self, node): + def check_for_b028(self, node) -> None: if ( isinstance(node.func, ast.Attribute) and node.func.attr == "warn" @@ -1644,7 +1644,7 @@ def check_for_b028(self, node): ): self.errors.append(B028(node.lineno, node.col_offset)) - def check_for_b032(self, node): + def check_for_b032(self, node) -> None: if ( node.value is None and hasattr(node.target, "value") @@ -1659,7 +1659,7 @@ def check_for_b032(self, node): ): self.errors.append(B032(node.lineno, node.col_offset)) - def check_for_b033(self, node): + def check_for_b033(self, node) -> None: seen = set() for elt in node.elts: if not isinstance(elt, ast.Constant): @@ -1671,28 +1671,26 @@ def check_for_b033(self, node): else: seen.add(elt.value) - def check_for_b034(self, node: ast.Call): + def check_for_b034(self, node: ast.Call) -> None: if not isinstance(node.func, ast.Attribute): return - if not isinstance(node.func.value, ast.Name) or node.func.value.id != "re": + func = node.func + if not isinstance(func.value, ast.Name) or func.value.id != "re": return - def check(num_args, param_name): + def check(num_args: int, param_name: str) -> None: if len(node.args) > num_args: + arg = node.args[num_args] self.errors.append( - B034( - node.args[num_args].lineno, - node.args[num_args].col_offset, - vars=(node.func.attr, param_name), - ) + B034(arg.lineno, arg.col_offset, vars=(func.attr, param_name)) ) - if node.func.attr in ("sub", "subn"): + if func.attr in ("sub", "subn"): check(3, "count") - elif node.func.attr == "split": + elif func.attr == "split": check(2, "maxsplit") - def check_for_b909(self, node: ast.For): + def check_for_b909(self, node: ast.For) -> None: if isinstance(node.iter, ast.Name): name = _to_name_str(node.iter) key = _to_name_str(node.target) @@ -1806,12 +1804,12 @@ def visit_Assign(self, node: ast.Assign) -> None: self.mutations[self._conditional_block].append(node) self.generic_visit(node) - def visit_AugAssign(self, node: ast.AugAssign): + def visit_AugAssign(self, node: ast.AugAssign) -> None: if _to_name_str(node.target) == self.name: self.mutations[self._conditional_block].append(node) self.generic_visit(node) - def visit_Delete(self, node: ast.Delete): + def visit_Delete(self, node: ast.Delete) -> None: for target in node.targets: if isinstance(target, ast.Subscript): name = _to_name_str(target.value) @@ -1824,7 +1822,7 @@ def visit_Delete(self, node: ast.Delete): if name == self.name: self.mutations[self._conditional_block].append(node) - def visit_Call(self, node: ast.Call): + def visit_Call(self, node: ast.Call) -> None: if isinstance(node.func, ast.Attribute): name = _to_name_str(node.func.value) function_object = name @@ -1838,7 +1836,7 @@ def visit_Call(self, node: ast.Call): self.generic_visit(node) - def visit_If(self, node: ast.If): + def visit_If(self, node: ast.If) -> None: self._conditional_block += 1 self.visit(node.body) self._conditional_block += 1 @@ -1867,7 +1865,7 @@ class NameFinder(ast.NodeVisitor): def visit_Name( # noqa: B906 # names don't contain other names self, node: ast.Name - ): + ) -> None: self.names.setdefault(node.id, []).append(node) def visit(self, node): @@ -1890,7 +1888,7 @@ class NamedExprFinder(ast.NodeVisitor): names: Dict[str, List[ast.Name]] = attr.ib(factory=dict) - def visit_NamedExpr(self, node: ast.NamedExpr): + def visit_NamedExpr(self, node: ast.NamedExpr) -> None: self.names.setdefault(node.target.id, []).append(node.target) self.generic_visit(node) @@ -1912,7 +1910,7 @@ def __init__( error_code_calls, # B006 or B039 error_code_literals, # B008 or B039 b008_b039_extend_immutable_calls=None, - ): + ) -> None: self.b008_b039_extend_immutable_calls = ( b008_b039_extend_immutable_calls or set() ) @@ -1920,11 +1918,11 @@ def __init__( self.error_code_literals = error_code_literals for node in B006_MUTABLE_LITERALS + B006_MUTABLE_COMPREHENSIONS: setattr(self, f"visit_{node}", self.visit_mutable_literal_or_comprehension) - self.errors = [] + self.errors: list[error] = [] self.arg_depth = 0 super().__init__() - def visit_mutable_literal_or_comprehension(self, node): + def visit_mutable_literal_or_comprehension(self, node) -> None: # Flag B006 iff mutable literal/comprehension is not nested. # We only flag these at the top level of the expression as we # cannot easily guarantee that nested mutable structures are not @@ -1940,7 +1938,7 @@ def visit_mutable_literal_or_comprehension(self, node): # Check for nested functions. self.generic_visit(node) - def visit_Call(self, node): + def visit_Call(self, node) -> None: call_path = ".".join(compose_call_path(node.func)) if call_path in B006_MUTABLE_CALLS: self.errors.append(self.error_code_calls(node.lineno, node.col_offset)) @@ -1968,12 +1966,12 @@ def visit_Call(self, node): # Check for nested functions. self.generic_visit(node) - def visit_Lambda(self, node): # noqa: B906 + def visit_Lambda(self, node) -> None: # noqa: B906 # Don't recurse into lambda expressions # as they are evaluated at call time. pass - def visit(self, node): + def visit(self, node) -> None: """Like super-visit but supports iteration over lists.""" self.arg_depth += 1 if isinstance(node, list): @@ -1988,19 +1986,19 @@ def visit(self, node): class B020NameFinder(NameFinder): """Ignore names defined within the local scope of a comprehension.""" - def visit_GeneratorExp(self, node): + def visit_GeneratorExp(self, node) -> None: self.visit(node.generators) - def visit_ListComp(self, node): + def visit_ListComp(self, node) -> None: self.visit(node.generators) - def visit_DictComp(self, node): + def visit_DictComp(self, node) -> None: self.visit(node.generators) - def visit_comprehension(self, node): + def visit_comprehension(self, node) -> None: self.visit(node.iter) - def visit_Lambda(self, node): + def visit_Lambda(self, node) -> None: self.visit(node.body) for lambda_arg in node.args.args: self.names.pop(lambda_arg.arg, None)