From b4908da80436db2c957d0a56c2ef750208dd67d1 Mon Sep 17 00:00:00 2001 From: Sergio Ly Date: Wed, 5 Jun 2024 11:54:41 -0700 Subject: [PATCH] Added new testcase --- src/fixit/rules/deprecated_abc_import.py | 66 ++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/fixit/rules/deprecated_abc_import.py b/src/fixit/rules/deprecated_abc_import.py index 59e89786..84ef6cf1 100644 --- a/src/fixit/rules/deprecated_abc_import.py +++ b/src/fixit/rules/deprecated_abc_import.py @@ -64,6 +64,13 @@ class DeprecatedABCImport(LintRule): Valid("import collections"), Valid("import collections.abc"), Valid("import collections.abc.Container"), + Valid( + """ + class MyTest(collections.Something): + def test(self): + pass + """ + ), ] INVALID = [ Invalid( @@ -94,6 +101,18 @@ class DeprecatedABCImport(LintRule): "from collections import defaultdict\nfrom collections import Container", expected_replacement="from collections import defaultdict\nfrom collections.abc import Container", ), + Invalid( + """ + class MyTest(collections.Container): + def test(self): + pass + """, + expected_replacement=""" + class MyTest(collections.abc.Container): + def test(self): + pass + """, + ), ] def __init__(self) -> None: @@ -263,3 +282,50 @@ def visit_ImportAlias(self, node: cst.ImportAlias) -> None: ) ), ) + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Iterate over inherited Classes and search for `collections.` + for base in node.bases: + if m.matches( + base, + m.Arg( + value=m.Attribute( + value=m.Name("collections"), + attr=m.OneOf(*[m.Name(abc) for abc in ABCS]), + ) + ), + ): + # Report + replace `collections.` with `collections.abc.` + # while keeping the remaining classes. + self.report( + node, + replacement=node.with_changes( + bases=[ + ( + cst.Arg( + value=cst.Attribute( + value=cst.Attribute( + value=cst.Name("collections"), + attr=cst.Name("abc"), + ), + attr=base.value.attr, + ), + ) + if m.matches( + base, + m.Arg( + value=m.Attribute( + value=m.Name("collections"), + attr=m.OneOf( + *[m.Name(abc) for abc in ABCS] + ), + ) + ), + ) + and isinstance(base.value, cst.Attribute) + else base + ) + for base in node.bases + ] + ), + )