diff --git a/src/fixit/rules/deprecated_abc_import.py b/src/fixit/rules/deprecated_abc_import.py index b82a3ad3..0b0560d5 100644 --- a/src/fixit/rules/deprecated_abc_import.py +++ b/src/fixit/rules/deprecated_abc_import.py @@ -3,11 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Set, Union +from typing import List, Optional, Union import libcst as cst import libcst.matchers as m +from libcst.metadata import ParentNodeProvider + from fixit import Invalid, LintRule, Valid @@ -53,6 +55,7 @@ class DeprecatedABCImport(LintRule): MESSAGE = "ABCs must be imported from collections.abc" PYTHON_VERSION = ">= 3.3" + METADATA_DEPENDENCIES = (ParentNodeProvider,) VALID = [ Valid("from collections.abc import Container"), @@ -152,59 +155,30 @@ def __init__(self) -> None: self.update_module: bool = False # The original imports self.imports_names: List[str] = [] - # Nodes to ignore - self.ignore_nodes: Set[cst.ImportFrom] = set() - def visit_Try(self, node: cst.Try) -> None: + def is_except_block(self, node: cst.CSTNode) -> bool: """ - Catch instances where a correct import is in a try block with an except block - that fails over to the deprecated import. + Check if the node is in an except block - if it is, we know ti ignore it, as it + may be a fallback import """ - # If a try block imports the correct import, check the except block - if m.findall( - node, - m.ImportFrom( - module=m.Attribute(value=m.Name("collections"), attr=m.Name("abc")), - names=[ - m.AtLeastN( - n=1, - matcher=m.OneOf(*[m.ImportAlias(name=m.Name(n)) for n in ABCS]), - ) - ], - ), - ): - # For each handler, ensure it is a ImportError and check that it contains - # the deprecated import - for handler in node.handlers: - if ( - import_nodes := m.findall( - handler, - m.ImportFrom( - module=m.Name("collections"), - names=[ - m.AtLeastN( - n=1, - matcher=m.OneOf( - *[m.ImportAlias(name=m.Name(n)) for n in ABCS] - ), - ) - ], - ), - ) - ) and m.matches( - node=cst.ensure_type(handler.type, cst.Name), - matcher=m.Name("ImportError"), - ): - self.ignore_nodes |= { - cst.ensure_type(import_node, cst.ImportFrom) - for import_node in import_nodes - } + try: + return isinstance( + self.get_metadata( + ParentNodeProvider, + self.get_metadata( + ParentNodeProvider, self.get_metadata(ParentNodeProvider, node) + ), + ), + cst.ExceptHandler, + ) + except KeyError: + return False def visit_ImportFrom(self, node: cst.ImportFrom) -> None: """ This catches the `from collections import ` cases """ - if node in self.ignore_nodes: + if self.is_except_block(node): return # Get imports in this statement