Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial typing of imports.py #6982

Merged
merged 2 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pylint/checkers/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def deprecated_classes(self, module: str) -> Iterable[str]:
# pylint: disable=unused-argument
return ()

def check_deprecated_module(self, node: nodes.Import, mod_path: str) -> None:
def check_deprecated_module(self, node: nodes.Import, mod_path: str | None) -> None:
"""Checks if the module is deprecated."""
for mod_name in self.deprecated_modules():
if mod_path == mod_name or mod_path.startswith(mod_name + "."):
if mod_path == mod_name or mod_path and mod_path.startswith(mod_name + "."):
self.add_message("deprecated-module", node=node, args=mod_path)

def check_deprecated_method(self, node: nodes.Call, inferred: nodes.NodeNG) -> None:
Expand Down
108 changes: 75 additions & 33 deletions pylint/checkers/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import copy
import os
import sys
from collections import defaultdict
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import astroid
from astroid import nodes
from astroid.nodes._base_nodes import ImportNode

from pylint.checkers import BaseChecker, DeprecatedMixin
from pylint.checkers.utils import (
Expand All @@ -28,6 +31,7 @@
from pylint.reporters.ureports.nodes import Paragraph, Section, VerbatimText
from pylint.typing import MessageDefinitionTuple
from pylint.utils import IsortDriver
from pylint.utils.linterstats import LinterStats

if TYPE_CHECKING:
from pylint.lint import PyLinter
Expand Down Expand Up @@ -69,19 +73,26 @@
}


def _qualified_names(modname):
def _qualified_names(modname: str | None) -> list[str]:
"""Split the names of the given module into subparts.

For example,
_qualified_names('pylint.checkers.ImportsChecker')
returns
['pylint', 'pylint.checkers', 'pylint.checkers.ImportsChecker']
"""
names = modname.split(".")
names = modname.split(".") if modname is not None else ""
return [".".join(names[0 : i + 1]) for i in range(len(names))]


def _get_first_import(node, context, name, base, level, alias):
def _get_first_import(
node: ImportNode,
context: nodes.LocalsDictNodeNG,
name: str,
base: str | None,
level: int | None,
alias: str | None,
) -> nodes.Import | nodes.ImportFrom | None:
"""Return the node where [base.]<name> is imported or None if not found."""
fullname = f"{base}.{name}" if base else name

Expand Down Expand Up @@ -116,7 +127,11 @@ def _get_first_import(node, context, name, base, level, alias):
return None


def _ignore_import_failure(node, modname, ignored_modules):
def _ignore_import_failure(
node: ImportNode,
modname: str | None,
ignored_modules: Sequence[str],
) -> bool:
for submodule in _qualified_names(modname):
if submodule in ignored_modules:
return True
Expand Down Expand Up @@ -186,7 +201,7 @@ def _dependencies_graph(filename: str, dep_info: dict[str, set[str]]) -> str:

def _make_graph(
filename: str, dep_info: dict[str, set[str]], sect: Section, gtype: str
):
) -> None:
"""Generate a dependencies graph and add some information about it in the
report's section.
"""
Expand Down Expand Up @@ -403,7 +418,7 @@ class ImportsChecker(DeprecatedMixin, BaseChecker):

def __init__(self, linter: PyLinter) -> None:
BaseChecker.__init__(self, linter)
self.import_graph: collections.defaultdict = collections.defaultdict(set)
self.import_graph: defaultdict[str, set[str]] = defaultdict(set)
self._imports_stack: list[tuple[Any, Any]] = []
self._first_non_import_node = None
self._module_pkg: dict[
Expand All @@ -415,14 +430,14 @@ def __init__(self, linter: PyLinter) -> None:
("RP0402", "Modules dependencies graph", self._report_dependencies_graph),
)

def open(self):
def open(self) -> None:
"""Called before visiting project (i.e set of modules)."""
self.linter.stats.dependencies = {}
self.linter.stats = self.linter.stats
self.import_graph = collections.defaultdict(set)
self.import_graph = defaultdict(set)
self._module_pkg = {} # mapping of modules to the pkg they belong in
self._excluded_edges = collections.defaultdict(set)
self._ignored_modules = self.linter.config.ignored_modules
self._excluded_edges: defaultdict[str, set[str]] = defaultdict(set)
self._ignored_modules: Sequence[str] = self.linter.config.ignored_modules
# Build a mapping {'module': 'preferred-module'}
self.preferred_modules = dict(
module.split(":")
Expand All @@ -431,13 +446,13 @@ def open(self):
)
self._allow_any_import_level = set(self.linter.config.allow_any_import_level)

def _import_graph_without_ignored_edges(self):
def _import_graph_without_ignored_edges(self) -> defaultdict[str, set[str]]:
filtered_graph = copy.deepcopy(self.import_graph)
for node in filtered_graph:
filtered_graph[node].difference_update(self._excluded_edges[node])
return filtered_graph

def close(self):
def close(self) -> None:
"""Called before visiting project (i.e set of modules)."""
if self.linter.is_message_enabled("cyclic-import"):
graph = self._import_graph_without_ignored_edges()
Expand Down Expand Up @@ -536,7 +551,17 @@ def leave_module(self, node: nodes.Module) -> None:
self._imports_stack = []
self._first_non_import_node = None

def compute_first_non_import_node(self, node):
def compute_first_non_import_node(
self,
node: nodes.If
| nodes.Expr
| nodes.Comprehension
| nodes.IfExp
| nodes.Assign
| nodes.AssignAttr
| nodes.TryExcept
| nodes.TryFinally,
Comment on lines +556 to +563
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this can be everything except an import node it could also be a classDef or a lot of things, right ? Maybe NodeNG would make sense ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See L577. I thought I would be explicit just like we normally are for visit_ methods.

That said L622 shows that L602 is incorrect.

I'm drafting as I'd like to fix that in this PR.

) -> None:
# if the node does not contain an import instruction, and if it is the
# first node of the module, keep a track of it (all the import positions
# of the module will be compared to the position of this first
Expand Down Expand Up @@ -576,7 +601,9 @@ def compute_first_non_import_node(self, node):
visit_ifexp
) = visit_comprehension = visit_expr = visit_if = compute_first_non_import_node

def visit_functiondef(self, node: nodes.FunctionDef) -> None:
def visit_functiondef(
self, node: nodes.FunctionDef | nodes.While | nodes.For | nodes.ClassDef
) -> None:
# If it is the first non import instruction of the module, record it.
if self._first_non_import_node:
return
Expand All @@ -598,7 +625,7 @@ def visit_functiondef(self, node: nodes.FunctionDef) -> None:

visit_classdef = visit_for = visit_while = visit_functiondef

def _check_misplaced_future(self, node):
def _check_misplaced_future(self, node: nodes.ImportFrom) -> None:
basename = node.modname
if basename == "__future__":
# check if this is the first non-docstring statement in the module
Expand All @@ -611,15 +638,15 @@ def _check_misplaced_future(self, node):
self.add_message("misplaced-future", node=node)
return

def _check_same_line_imports(self, node):
def _check_same_line_imports(self, node: nodes.ImportFrom) -> None:
# Detect duplicate imports on the same line.
names = (name for name, _ in node.names)
counter = collections.Counter(names)
for name, count in counter.items():
if count > 1:
self.add_message("reimported", node=node, args=(name, node.fromlineno))

def _check_position(self, node):
def _check_position(self, node: ImportNode) -> None:
"""Check `node` import or importfrom node position is correct.

Send a message if `node` comes before another instruction
Expand All @@ -638,7 +665,11 @@ def _check_position(self, node):
"wrong-import-position", node.fromlineno, node
)

def _record_import(self, node, importedmodnode):
def _record_import(
self,
node: ImportNode,
importedmodnode: nodes.Module | None,
) -> None:
"""Record the package `node` imports from."""
if isinstance(node, nodes.ImportFrom):
importedname = node.modname
Expand Down Expand Up @@ -759,7 +790,9 @@ def _check_imports_order(self, _module_node):
)
return std_imports, external_imports, local_imports

def _get_imported_module(self, importnode, modname):
def _get_imported_module(
self, importnode: ImportNode, modname: str | None
) -> nodes.Module | None:
try:
return importnode.do_import_module(modname)
except astroid.TooManyLevelsError:
Expand Down Expand Up @@ -789,9 +822,7 @@ def _get_imported_module(self, importnode, modname):
raise astroid.AstroidError from e
return None

def _add_imported_module(
self, node: nodes.Import | nodes.ImportFrom, importedmodname: str
) -> None:
def _add_imported_module(self, node: ImportNode, importedmodname: str) -> None:
"""Notify an imported module, used to analyze dependencies."""
module_file = node.root().file
context_name = node.root().name
Expand Down Expand Up @@ -841,7 +872,7 @@ def _check_preferred_module(self, node, mod_path):
args=(self.preferred_modules[mod_path], mod_path),
)

def _check_import_as_rename(self, node: nodes.Import | nodes.ImportFrom) -> None:
def _check_import_as_rename(self, node: ImportNode) -> None:
names = node.names
for name in names:
if not all(name):
Expand All @@ -862,7 +893,12 @@ def _check_import_as_rename(self, node: nodes.Import | nodes.ImportFrom) -> None
args=(splitted_packages[0], import_name),
)

def _check_reimport(self, node, basename=None, level=None):
def _check_reimport(
self,
node: ImportNode,
basename: str | None = None,
level: int | None = None,
) -> None:
"""Check if the import is necessary (i.e. not already done)."""
if not self.linter.is_message_enabled("reimported"):
return
Expand All @@ -883,15 +919,19 @@ def _check_reimport(self, node, basename=None, level=None):
"reimported", node=node, args=(name, first.fromlineno)
)

def _report_external_dependencies(self, sect, _, _dummy):
def _report_external_dependencies(
self, sect: Section, _: LinterStats, _dummy: LinterStats | None
) -> None:
"""Return a verbatim layout for displaying dependencies."""
dep_info = _make_tree_defs(self._external_dependencies_info().items())
if not dep_info:
raise EmptyReportError()
tree_str = _repr_tree_defs(dep_info)
sect.append(VerbatimText(tree_str))

def _report_dependencies_graph(self, sect, _, _dummy):
def _report_dependencies_graph(
self, sect: Section, _: LinterStats, _dummy: LinterStats | None
) -> None:
"""Write dependencies as a dot (graphviz) file."""
dep_info = self.linter.stats.dependencies
if not dep_info or not (
Expand All @@ -910,9 +950,9 @@ def _report_dependencies_graph(self, sect, _, _dummy):
if filename:
_make_graph(filename, self._internal_dependencies_info(), sect, "internal ")

def _filter_dependencies_graph(self, internal):
def _filter_dependencies_graph(self, internal: bool) -> defaultdict[str, set[str]]:
"""Build the internal or the external dependency graph."""
graph = collections.defaultdict(set)
graph: defaultdict[str, set[str]] = defaultdict(set)
for importee, importers in self.linter.stats.dependencies.items():
for importer in importers:
package = self._module_pkg.get(importer, importer)
Expand All @@ -922,20 +962,22 @@ def _filter_dependencies_graph(self, internal):
return graph

@astroid.decorators.cached
def _external_dependencies_info(self):
def _external_dependencies_info(self) -> defaultdict[str, set[str]]:
"""Return cached external dependencies information or build and
cache them.
"""
return self._filter_dependencies_graph(internal=False)

@astroid.decorators.cached
def _internal_dependencies_info(self):
def _internal_dependencies_info(self) -> defaultdict[str, set[str]]:
"""Return cached internal dependencies information or build and
cache them.
"""
return self._filter_dependencies_graph(internal=True)

def _check_wildcard_imports(self, node, imported_module):
def _check_wildcard_imports(
self, node: nodes.ImportFrom, imported_module: nodes.Module | None
) -> None:
if node.root().package:
# Skip the check if in __init__.py issue #2026
return
Expand All @@ -945,14 +987,14 @@ def _check_wildcard_imports(self, node, imported_module):
if name == "*" and not wildcard_import_is_allowed:
self.add_message("wildcard-import", args=node.modname, node=node)

def _wildcard_import_is_allowed(self, imported_module):
def _wildcard_import_is_allowed(self, imported_module: nodes.Module | None) -> bool:
return (
self.linter.config.allow_wildcard_with_all
and imported_module is not None
and "__all__" in imported_module.locals
)

def _check_toplevel(self, node):
def _check_toplevel(self, node: ImportNode) -> None:
"""Check whether the import is made outside the module toplevel."""
# If the scope of the import is a module, then obviously it is
# not outside the module toplevel.
Expand Down
3 changes: 2 additions & 1 deletion pylint/checkers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from astroid import TooManyLevelsError, nodes
from astroid.context import InferenceContext
from astroid.exceptions import AstroidError
from astroid.nodes._base_nodes import ImportNode

if TYPE_CHECKING:
from pylint.checkers import BaseChecker
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def get_subscript_const_value(node: nodes.Subscript) -> nodes.Const:
return inferred


def get_import_name(importnode: nodes.Import | nodes.ImportFrom, modname: str) -> str:
def get_import_name(importnode: ImportNode, modname: str | None) -> str | None:
"""Get a prepared module name from the given import node.

In the case of relative imports, this will return the
Expand Down