Skip to content

Commit

Permalink
Handle more edge cases with if and try block imports.
Browse files Browse the repository at this point in the history
* Test mypy against Python 3.8 not 3.6
  • Loading branch information
domdfcoding committed Apr 2, 2024
1 parent db801a4 commit 37d71bc
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 8 deletions.
59 changes: 51 additions & 8 deletions flake8_dunder_all/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# stdlib
import ast
import sys
from typing import Any, Generator, List, Set, Tuple, Type, Union
from typing import Any, Generator, Iterator, List, Set, Tuple, Type, Union

# 3rd party
from consolekit.terminal_colours import Fore
Expand Down Expand Up @@ -189,16 +189,59 @@ def visit_If(self, node: ast.If) -> None:
:param node: The node being visited.
"""
# Check if the condition is checking for TYPE_CHECKING
if isinstance(node.test, ast.Name) and node.test.id == "TYPE_CHECKING":
# Process the body of the if statement for any import statements
for body_node in node.body:
if isinstance(body_node, (ast.Import, ast.ImportFrom)):
self.handle_import(body_node)
# Continue visiting other nodes

if _is_type_checking(node.test):
if self.use_endlineno and node.end_lineno is not None:
self.last_import = max(self.last_import, node.end_lineno)
else:
self.last_import = max(self.last_import, max(_descend_node(node)))

self.generic_visit(node)

def visit_Try(self, node: ast.Try) -> None:
"""
Visit a Try statement.
:param node: The node being visited.
"""

if any(isinstance(n, (ast.Import, ast.ImportFrom)) for n in node.body):
if self.use_endlineno and node.end_lineno is not None and sys.implementation.name != "pypy": # pragma: no cover (pypy)
self.last_import = max(self.last_import, node.end_lineno)
else: # pragma: no cover (!pypy)
end_lineno = max(
*_descend_node(node),
*_descend_node(node, "handlers"),
*_descend_node(node, "orelse"),
*_descend_node(node, "finalbody"),
)
self.last_import = max(self.last_import, end_lineno)

self.generic_visit(node)


def _descend_node(node, attr: str = "body") -> Iterator[int]:
for child in getattr(node, attr, []):
yield child.lineno
yield from _descend_node(child)


_nameconstant = ast.Constant if sys.version_info >= (3, 8) else ast.NameConstant


def _is_type_checking(node: ast.AST):
"""

Check warning on line 233 in flake8_dunder_all/__init__.py

View workflow job for this annotation

GitHub Actions / Flake8

D400: First line should end with a period
Does the given ``if`` node indicate a `TYPE_CHECKING` block?
"""

if isinstance(node, ast.Name) and node.id == "TYPE_CHECKING":
return True
elif isinstance(node, _nameconstant) and node.value is False:
return True
elif isinstance(node, ast.BoolOp):
return any(_is_type_checking(value) for value in node.values)


class Plugin:
"""
A Flake8 plugin which checks to ensure modules have defined ``__all__``.
Expand Down
36 changes: 36 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,39 @@ def a_function():
def a_function(): ...
"""

if_type_checking_else_source = """
if False or sys.version_info < (3, 7):
import foo
else:
import bar
def a_function(): ...
"""

if_type_checking_try_source = """
from typing import TYPE_CHECKING
if TYPE_CHECKING:
try:
from x import y
except ImportError:
pass
def a_function(): ...
"""

if_type_checking_try_finally_source = """
from typing import TYPE_CHECKING
if TYPE_CHECKING:
try:
from x import y
except ImportError:
pass
finally:
pass
def a_function(): ...
"""
14 changes: 14 additions & 0 deletions tests/test_flake8_dunder_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import pytest
from coincidence.regressions import AdvancedFileRegressionFixture
from common import (
if_type_checking_else_source,
if_type_checking_source,
if_type_checking_try_finally_source,
if_type_checking_try_source,
mangled_source,
results,
testing_source_a,
Expand Down Expand Up @@ -125,6 +128,14 @@ def test_visitor(source: str, members: List[str], found_all: bool, last_import:
pytest.param(testing_source_j, ["a_function"], False, 14, id="multiline import"),
pytest.param(testing_source_m, ["a_function"], False, 7, id="if False"),
pytest.param(if_type_checking_source, ["a_function"], False, 5, id="if TYPE_CHECKING:"),
pytest.param(if_type_checking_else_source, ["a_function"], False, 5, id="if TYPE_CHECKING else"),
pytest.param(if_type_checking_try_source, ["a_function"], False, 8, id="if TYPE_CHECKING try"),
pytest.param(
if_type_checking_try_finally_source, ["a_function"],
False,
10,
id="if TYPE_CHECKING try finally"
),
]
)
def test_visitor_endlineno(source: str, members: List[str], found_all: bool, last_import: int):
Expand Down Expand Up @@ -160,6 +171,9 @@ def test_visitor_endlineno(source: str, members: List[str], found_all: bool, las
pytest.param(testing_source_l, [], 0, id="typing.overload"),
pytest.param(testing_source_m, [], 1, id="if False"),
pytest.param(testing_source_n, [], 1, id="if TYPE_CHECKING"),
pytest.param(if_type_checking_else_source, [], 1, id="if TYPE_CHECKING else"),
pytest.param(if_type_checking_try_source, [], 1, id="if TYPE_CHECKING try"),
pytest.param(if_type_checking_try_finally_source, [], 1, id="if TYPE_CHECKING try finally"),
]
)
def test_check_and_add_all(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

if False or sys.version_info < (3, 7):
import foo
else:
import bar

__all__ = ["a_function"]


def a_function(): ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

from typing import TYPE_CHECKING

if TYPE_CHECKING:
try:
from x import y
except ImportError:
pass

__all__ = ["a_function"]


def a_function(): ...
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

from typing import TYPE_CHECKING

if TYPE_CHECKING:
try:
from x import y
except ImportError:
pass
finally:
pass

__all__ = ["a_function"]


def a_function(): ...

0 comments on commit 37d71bc

Please sign in to comment.