Skip to content

Commit

Permalink
Updated to use ParentNodeProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
surge119 committed Jul 29, 2024
1 parent e3dfe56 commit 5f2ea83
Showing 1 changed file with 20 additions and 46 deletions.
66 changes: 20 additions & 46 deletions src/fixit/rules/deprecated_abc_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Set, Union
from typing import List, Optional, Union

import libcst as cst
import libcst.matchers as m

from libcst.metadata import ParentNodeProvider

from fixit import Invalid, LintRule, Valid


Expand Down Expand Up @@ -53,6 +55,7 @@ class DeprecatedABCImport(LintRule):

MESSAGE = "ABCs must be imported from collections.abc"
PYTHON_VERSION = ">= 3.3"
METADATA_DEPENDENCIES = (ParentNodeProvider,)

VALID = [
Valid("from collections.abc import Container"),
Expand Down Expand Up @@ -152,59 +155,30 @@ def __init__(self) -> None:
self.update_module: bool = False
# The original imports
self.imports_names: List[str] = []
# Nodes to ignore
self.ignore_nodes: Set[cst.ImportFrom] = set()

def visit_Try(self, node: cst.Try) -> None:
def is_except_block(self, node: cst.CSTNode) -> bool:
"""
Catch instances where a correct import is in a try block with an except block
that fails over to the deprecated import.
Check if the node is in an except block - if it is, we know ti ignore it, as it
may be a fallback import
"""
# If a try block imports the correct import, check the except block
if m.findall(
node,
m.ImportFrom(
module=m.Attribute(value=m.Name("collections"), attr=m.Name("abc")),
names=[
m.AtLeastN(
n=1,
matcher=m.OneOf(*[m.ImportAlias(name=m.Name(n)) for n in ABCS]),
)
],
),
):
# For each handler, ensure it is a ImportError and check that it contains
# the deprecated import
for handler in node.handlers:
if (
import_nodes := m.findall(
handler,
m.ImportFrom(
module=m.Name("collections"),
names=[
m.AtLeastN(
n=1,
matcher=m.OneOf(
*[m.ImportAlias(name=m.Name(n)) for n in ABCS]
),
)
],
),
)
) and m.matches(
node=cst.ensure_type(handler.type, cst.Name),
matcher=m.Name("ImportError"),
):
self.ignore_nodes |= {
cst.ensure_type(import_node, cst.ImportFrom)
for import_node in import_nodes
}
try:
return isinstance(
self.get_metadata(
ParentNodeProvider,
self.get_metadata(
ParentNodeProvider, self.get_metadata(ParentNodeProvider, node)
),
),
cst.ExceptHandler,
)
except KeyError:
return False

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""
This catches the `from collections import <ABC>` cases
"""
if node in self.ignore_nodes:
if self.is_except_block(node):
return

# Get imports in this statement
Expand Down

0 comments on commit 5f2ea83

Please sign in to comment.