Skip to content

Commit

Permalink
Implement B037 check for yielding or returning values in __init__() (#…
Browse files Browse the repository at this point in the history
…442)

* Implement B037 check for yielding or returning values in __init__()

* move return-in-init check to bugbearvisitor
  • Loading branch information
r-downing authored Jan 11, 2024
1 parent b4c661b commit 12c2dc4
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ second usage. Save the result to a list if the result is needed multiple times.

**B036**: Found ``except BaseException:`` without re-raising (no ``raise`` in the top-level of the ``except`` block). This catches all kinds of things (Exception, SystemExit, KeyboardInterrupt...) and may prevent a program from exiting as expected.

**B037**: Found ``return <value>``, ``yield``, ``yield <value>``, or ``yield from <value>`` in class ``__init__()`` method. No values should be returned or yielded, only bare ``return``s are ok.
Opinionated warnings
~~~~~~~~~~~~~~~~~~~~
Expand Down
32 changes: 31 additions & 1 deletion bugbear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import ast
import builtins
import itertools
Expand Down Expand Up @@ -379,6 +381,30 @@ def node_stack(self):
context, stack = self.contexts[-1]
return stack

def in_class_init(self) -> bool:
return (
len(self.contexts) >= 2
and isinstance(self.contexts[-2].node, ast.ClassDef)
and isinstance(self.contexts[-1].node, ast.FunctionDef)
and self.contexts[-1].node.name == "__init__"
)

def visit_Return(self, node: ast.Return) -> None:
if self.in_class_init():
if node.value is not None:
self.errors.append(B037(node.lineno, node.col_offset))
self.generic_visit(node)

def visit_Yield(self, node: ast.Yield) -> None:
if self.in_class_init():
self.errors.append(B037(node.lineno, node.col_offset))
self.generic_visit(node)

def visit_YieldFrom(self, node: ast.YieldFrom) -> None:
if self.in_class_init():
self.errors.append(B037(node.lineno, node.col_offset))
self.generic_visit(node)

def visit(self, node):
is_contextful = isinstance(node, CONTEXTFUL_NODES)

Expand Down Expand Up @@ -540,7 +566,7 @@ def visit_FunctionDef(self, node):
self.check_for_b906(node)
self.generic_visit(node)

def visit_ClassDef(self, node):
def visit_ClassDef(self, node: ast.ClassDef):
self.check_for_b903(node)
self.check_for_b021(node)
self.check_for_b024_and_b027(node)
Expand Down Expand Up @@ -1986,6 +2012,10 @@ def visit_Lambda(self, node):
message="B036 Don't except `BaseException` unless you plan to re-raise it."
)

B037 = Error(
message="B037 Class `__init__` methods must not return or yield and any values."
)

# Warnings disabled by default.
B901 = Error(
message=(
Expand Down
33 changes: 33 additions & 0 deletions tests/b037.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

class A:
def __init__(self) -> None:
return 1 # bad

class B:
def __init__(self, x) -> None:
if x:
return # ok
else:
return [] # bad

class BNested:
def __init__(self) -> None:
yield # bad


class C:
def func(self):
pass

def __init__(self, k="") -> None:
yield from [] # bad


class D(C):
def __init__(self, k="") -> None:
super().__init__(k)
return None # bad

class E:
def __init__(self) -> None:
yield "a"
15 changes: 15 additions & 0 deletions tests/test_bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
B034,
B035,
B036,
B037,
B901,
B902,
B903,
Expand Down Expand Up @@ -619,6 +620,20 @@ def test_b036(self) -> None:
)
self.assertEqual(errors, expected)

def test_b037(self) -> None:
filename = Path(__file__).absolute().parent / "b037.py"
bbc = BugBearChecker(filename=str(filename))
errors = list(bbc.run())
expected = self.errors(
B037(4, 8),
B037(11, 12),
B037(15, 12),
B037(23, 8),
B037(29, 8),
B037(33, 8),
)
self.assertEqual(errors, expected)

def test_b908(self):
filename = Path(__file__).absolute().parent / "b908.py"
bbc = BugBearChecker(filename=str(filename))
Expand Down

0 comments on commit 12c2dc4

Please sign in to comment.