diff --git a/src/fixit/rules/deprecated_abc_import.py b/src/fixit/rules/deprecated_abc_import.py index 84ef6cf1..7946eadb 100644 --- a/src/fixit/rules/deprecated_abc_import.py +++ b/src/fixit/rules/deprecated_abc_import.py @@ -167,33 +167,16 @@ def get_import_from( Iterate over a Statement Sequence and return a Statement if it is a `cst.ImportFrom` statement. """ - if m.matches( + imp = m.findall( node, - m.SimpleStatementLine( - body=[ - m.ZeroOrMore(), - m.ImportFrom( - module=m.Name("collections"), - names=m.OneOf( - [m.ImportAlias(name=m.Name(n)) for n in self.imports_names] - ), - ), - m.ZeroOrMore(), - ] - ), - ): - imp = m.findall( - node, - m.ImportFrom( - module=m.Name("collections"), - names=m.OneOf( - [m.ImportAlias(name=m.Name(n)) for n in self.imports_names] - ), + m.ImportFrom( + module=m.Name("collections"), + names=m.OneOf( + [m.ImportAlias(name=m.Name(n)) for n in self.imports_names] ), - )[0] - return imp if isinstance(imp, cst.ImportFrom) else None - - return None + ), + ) + return imp[0] if len(imp) > 0 and isinstance(imp[0], cst.ImportFrom) else None def leave_Module(self, node: cst.Module) -> None: """