From f9e0f775a581cf1fc77ccb1f6b2cbe1e8b514734 Mon Sep 17 00:00:00 2001 From: Joe Young <80432516+jpy-git@users.noreply.github.com> Date: Wed, 23 Mar 2022 21:24:09 +0000 Subject: [PATCH] B006 and B008: Cover additional test cases (#239) * B006 and B008: Cover additional test cases * Add change log entry * Account for inconsistent ast between python versions * Use ast.literal_eval to simplify infinity float detection --- .pre-commit-config.yaml | 2 +- README.rst | 5 ++ bugbear.py | 143 +++++++++++++++++++++++----------------- tests/b006_b008.py | 105 +++++++++++++++++++++-------- tests/test_bugbear.py | 38 +++++++---- 5 files changed, 193 insertions(+), 100 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81197f1..381e186 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 21.10b0 + rev: 22.1.0 hooks: - id: black args: diff --git a/README.rst b/README.rst index 961324a..db2bb20 100644 --- a/README.rst +++ b/README.rst @@ -279,6 +279,11 @@ MIT Change Log ---------- + +~~~~~~~~~~ + +* B006 and B008: Detect function calls at any level of the default expression. + 22.3.20 ~~~~~~~~~~ diff --git a/bugbear.py b/bugbear.py index d1f5cd8..40803e7 100644 --- a/bugbear.py +++ b/bugbear.py @@ -2,8 +2,8 @@ import builtins import itertools import logging +import math import re -import sys from collections import namedtuple from contextlib import suppress from functools import lru_cache, partial @@ -354,13 +354,13 @@ def visit_Assert(self, node): def visit_AsyncFunctionDef(self, node): self.check_for_b902(node) - self.check_for_b006(node) + self.check_for_b006_and_b008(node) self.generic_visit(node) def visit_FunctionDef(self, node): self.check_for_b901(node) self.check_for_b902(node) - self.check_for_b006(node) + self.check_for_b006_and_b008(node) self.check_for_b018(node) self.check_for_b019(node) self.check_for_b021(node) @@ -390,15 +390,6 @@ def visit_With(self, node): self.check_for_b022(node) self.generic_visit(node) - def compose_call_path(self, node): - if isinstance(node, ast.Attribute): - yield from self.compose_call_path(node.value) - yield node.attr - elif isinstance(node, ast.Call): - yield from self.compose_call_path(node.func) - elif isinstance(node, ast.Name): - yield node.id - def check_for_b005(self, node): if node.func.attr not in B005.methods: return # method name doesn't match @@ -406,7 +397,7 @@ def check_for_b005(self, node): if len(node.args) != 1 or not isinstance(node.args[0], ast.Str): return # used arguments don't match the builtin strip - call_path = ".".join(self.compose_call_path(node.func.value)) + call_path = ".".join(compose_call_path(node.func.value)) if call_path in B005.valid_paths: return # path is exempt @@ -419,48 +410,10 @@ def check_for_b005(self, node): self.errors.append(B005(node.lineno, node.col_offset)) - def check_for_b006(self, node): - for default in node.args.defaults + node.args.kw_defaults: - if isinstance( - default, (*B006.mutable_literals, *B006.mutable_comprehensions) - ): - self.errors.append(B006(default.lineno, default.col_offset)) - elif isinstance(default, ast.Call): - call_path = ".".join(self.compose_call_path(default.func)) - if call_path in B006.mutable_calls: - self.errors.append(B006(default.lineno, default.col_offset)) - elif ( - call_path - not in B008.immutable_calls | self.b008_extend_immutable_calls - ): - # Check if function call is actually a float infinity/NaN literal - if call_path == "float" and len(default.args) == 1: - float_arg = default.args[0] - if sys.version_info < (3, 8, 0): - # NOTE: pre-3.8, string literals are represented with ast.Str - if isinstance(float_arg, ast.Str): - str_val = float_arg.s - else: - str_val = "" - else: - # NOTE: post-3.8, string literals are represented with ast.Constant - if isinstance(float_arg, ast.Constant): - str_val = float_arg.value - if not isinstance(str_val, str): - str_val = "" - else: - str_val = "" - - # NOTE: regex derived from documentation at: - # https://docs.python.org/3/library/functions.html#float - inf_nan_regex = r"^[+-]?(inf|infinity|nan)$" - re_result = re.search(inf_nan_regex, str_val.lower()) - is_float_literal = re_result is not None - else: - is_float_literal = False - - if not is_float_literal: - self.errors.append(B008(default.lineno, default.col_offset)) + def check_for_b006_and_b008(self, node): + visitor = FuntionDefDefaultsVisitor(self.b008_extend_immutable_calls) + visitor.visit(node.args.defaults + node.args.kw_defaults) + self.errors.extend(visitor.errors) def check_for_b007(self, node): targets = NameFinder() @@ -536,8 +489,7 @@ def check_for_b019(self, node): # Preserve decorator order so we can get the lineno from the decorator node # rather than the function node (this location definition changes in Python 3.8) resolved_decorators = ( - ".".join(self.compose_call_path(decorator)) - for decorator in node.decorator_list + ".".join(compose_call_path(decorator)) for decorator in node.decorator_list ) for idx, decorator in enumerate(resolved_decorators): if decorator in {"classmethod", "staticmethod"}: @@ -755,6 +707,16 @@ def check_for_b022(self, node): self.errors.append(B022(node.lineno, node.col_offset)) +def compose_call_path(node): + if isinstance(node, ast.Attribute): + yield from compose_call_path(node.value) + yield node.attr + elif isinstance(node, ast.Call): + yield from compose_call_path(node.func) + elif isinstance(node, ast.Name): + yield node.id + + @attr.s class NameFinder(ast.NodeVisitor): """Finds a name within a tree of nodes. @@ -778,6 +740,69 @@ def visit(self, node): return node +class FuntionDefDefaultsVisitor(ast.NodeVisitor): + def __init__(self, b008_extend_immutable_calls=None): + self.b008_extend_immutable_calls = b008_extend_immutable_calls or set() + for node in B006.mutable_literals + B006.mutable_comprehensions: + setattr(self, f"visit_{node}", self.visit_mutable_literal_or_comprehension) + self.errors = [] + self.arg_depth = 0 + super().__init__() + + def visit_mutable_literal_or_comprehension(self, node): + # Flag B006 iff mutable literal/comprehension is not nested. + # We only flag these at the top level of the expression as we + # cannot easily guarantee that nested mutable structures are not + # made immutable by outer operations, so we prefer no false positives. + # e.g. + # >>> def this_is_fine(a=frozenset({"a", "b", "c"})): ... + # + # >>> def this_is_not_fine_but_hard_to_detect(a=(lambda x: x)([1, 2, 3])) + # + # We do still search for cases of B008 within mutable structures though. + if self.arg_depth == 1: + self.errors.append(B006(node.lineno, node.col_offset)) + # Check for nested functions. + self.generic_visit(node) + + def visit_Call(self, node): + call_path = ".".join(compose_call_path(node.func)) + if call_path in B006.mutable_calls: + self.errors.append(B006(node.lineno, node.col_offset)) + self.generic_visit(node) + return + + if call_path in B008.immutable_calls | self.b008_extend_immutable_calls: + self.generic_visit(node) + return + + # Check if function call is actually a float infinity/NaN literal + if call_path == "float" and len(node.args) == 1: + try: + value = float(ast.literal_eval(node.args[0])) + except Exception: + pass + else: + if math.isfinite(value): + self.errors.append(B008(node.lineno, node.col_offset)) + else: + self.errors.append(B008(node.lineno, node.col_offset)) + + # Check for nested functions. + self.generic_visit(node) + + def visit(self, node): + """Like super-visit but supports iteration over lists.""" + self.arg_depth += 1 + if isinstance(node, list): + for elem in node: + if elem is not None: + super().visit(elem) + else: + super().visit(node) + self.arg_depth -= 1 + + class B020NameFinder(NameFinder): """Ignore names defined within the local scope of a comprehension.""" @@ -851,8 +876,8 @@ def visit_comprehension(self, node): "between them." ) ) -B006.mutable_literals = (ast.Dict, ast.List, ast.Set) -B006.mutable_comprehensions = (ast.ListComp, ast.DictComp, ast.SetComp) +B006.mutable_literals = ("Dict", "List", "Set") +B006.mutable_comprehensions = ("ListComp", "DictComp", "SetComp") B006.mutable_calls = { "Counter", "OrderedDict", diff --git a/tests/b006_b008.py b/tests/b006_b008.py index e095323..781af9f 100644 --- a/tests/b006_b008.py +++ b/tests/b006_b008.py @@ -1,6 +1,8 @@ import collections +import datetime as dt import logging import operator +import random import re import time import types @@ -8,11 +10,13 @@ from types import MappingProxyType +# B006 +# Allow immutable literals/calls/comprehensions def this_is_okay(value=(1, 2, 3)): ... -def and_this_also(value=tuple()): +async def and_this_also(value=tuple()): pass @@ -26,6 +30,33 @@ def mappingproxytype_okay( pass +def re_compile_ok(value=re.compile("foo")): + pass + + +def operators_ok( + v=operator.attrgetter("foo"), + v2=operator.itemgetter("foo"), + v3=operator.methodcaller("foo"), +): + pass + + +def operators_ok_unqualified( + v=attrgetter("foo"), + v2=itemgetter("foo"), + v3=methodcaller("foo"), +): + pass + + +def kwonlyargs_immutable(*, value=()): + ... + + +# Flag mutable literals/comprehensions + + def this_is_wrong(value=[1, 2, 3]): ... @@ -42,35 +73,61 @@ def this_too(value=collections.OrderedDict()): ... -async def async_this_too(value=collections.OrderedDict()): +async def async_this_too(value=collections.defaultdict()): + ... + + +def dont_forget_me(value=collections.deque()): ... +# N.B. we're also flagging the function call in the comprehension +def list_comprehension_also_not_okay(default=[i**2 for i in range(3)]): + pass + + +def dict_comprehension_also_not_okay(default={i: i**2 for i in range(3)}): + pass + + +def set_comprehension_also_not_okay(default={i**2 for i in range(3)}): + pass + + +def kwonlyargs_mutable(*, value=[]): + ... + + +# Recommended approach for mutable defaults def do_this_instead(value=None): if value is None: value = set() +# B008 +# Flag function calls as default args (including if they are part of a sub-expression) def in_fact_all_calls_are_wrong(value=time.time()): ... -LOGGER = logging.getLogger(__name__) +def f(when=dt.datetime.now() + dt.timedelta(days=7)): + pass -def do_this_instead_of_calls_in_defaults(logger=LOGGER): - # That makes it more obvious that this one value is reused. +def can_even_catch_lambdas(a=(lambda x: x)()): ... -def kwonlyargs_immutable(*, value=()): - ... +# Recommended approach for function calls as default args +LOGGER = logging.getLogger(__name__) -def kwonlyargs_mutable(*, value=[]): +def do_this_instead_of_calls_in_defaults(logger=LOGGER): + # That makes it more obvious that this one value is reused. ... +# Handle inf/infinity/nan special case def float_inf_okay(value=float("inf")): pass @@ -95,39 +152,31 @@ def float_minus_NaN_okay(value=float("-NaN")): pass -def float_int_is_wrong(value=float(3)): +def float_infinity_literal(value=float("1e999")): pass -def float_str_not_inf_or_nan_is_wrong(value=float("3.14")): - pass - - -def re_compile_ok(value=re.compile("foo")): - pass - - -def operators_ok( - v=operator.attrgetter("foo"), - v2=operator.itemgetter("foo"), - v3=operator.methodcaller("foo"), -): +# But don't allow standard floats +def float_int_is_wrong(value=float(3)): pass -def operators_ok_unqualified( - v=attrgetter("foo"), v2=itemgetter("foo"), v3=methodcaller("foo") -): +def float_str_not_inf_or_nan_is_wrong(value=float("3.14")): pass -def list_comprehension_also_not_okay(default=[i ** 2 for i in range(3)]): +# B006 and B008 +# We should handle arbitrary nesting of these B008. +def nested_combo(a=[float(3), dt.datetime.now()]): pass -def dict_comprehension_also_not_okay(default={i: i ** 2 for i in range(3)}): +# Don't flag nested B006 since we can't guarantee that +# it isn't made mutable by the outer operation. +def no_nested_b006(a=map(lambda s: s.upper(), ["a", "b", "c"])): pass -def set_comprehension_also_not_okay(default={i ** 2 for i in range(3)}): +# B008-ception. +def nested_b008(a=random.randint(0, dt.datetime.now().year)): pass diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 7954b5d..9a61cef 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -101,18 +101,32 @@ def test_b006_b008(self): self.assertEqual( errors, self.errors( - B006(29, 24), - B006(33, 29), - B006(37, 19), - B006(41, 19), - B006(45, 31), - B008(54, 38), - B006(70, 32), - B008(98, 29), - B008(102, 44), - B006(124, 45 if sys.version_info >= (3, 8) else 46), - B006(128, 45), - B006(132, 44), + B006(60, 24), + B006(64, 29), + B006(68, 19), + B006(72, 19), + B006(76, 31), + B006(80, 25), + B006(85, 45 if sys.version_info >= (3, 8) else 46), + B008(85, 60), + B006(89, 45), + B008(89, 63), + B006(93, 44), + B008(93, 59), + B006(97, 32), + B008(109, 38), + B008(113, 11), + B008(113, 31), + B008(117, 29 if sys.version_info >= (3, 8) else 30), + B008(160, 29), + B008(164, 44), + B006(170, 19), + B008(170, 20), + B008(170, 30), + B008(176, 21), + B008(176, 35), + B008(181, 18), + B008(181, 36), ), )