From 945ea214d584e749f25bdf31bb0539b0bdebe811 Mon Sep 17 00:00:00 2001 From: Bruno Alla Date: Thu, 7 May 2020 20:29:47 +0100 Subject: [PATCH 1/2] Refactor a check with matchers --- django_codemod/codemods/django_40.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/django_codemod/codemods/django_40.py b/django_codemod/codemods/django_40.py index 5affb6e8..1850c0b6 100644 --- a/django_codemod/codemods/django_40.py +++ b/django_codemod/codemods/django_40.py @@ -1,7 +1,7 @@ """Main module.""" from typing import Union -from libcst import RemovalSentinel, Call, BaseExpression, Name +from libcst import matchers as m, RemovalSentinel, Call, BaseExpression, Name from libcst._nodes.statement import Import, ImportFrom, BaseSmallStatement, ImportAlias from libcst.codemod import VisitorBasedCodemodCommand @@ -44,6 +44,6 @@ def leave_ImportFrom( return super().leave_ImportFrom(original_node, updated_node) def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: - if getattr(updated_node.func, "value", None) == "force_text": + if m.matches(updated_node, m.Call(func=m.Name("force_text"))): return Call(args=updated_node.args, func=Name("force_str")) return super().leave_Call(original_node, updated_node) From 56847917b0868b3bbcc63960aecd3c040c5694aa Mon Sep 17 00:00:00 2001 From: Bruno Alla Date: Thu, 7 May 2020 20:52:26 +0100 Subject: [PATCH 2/2] Refactor & simplify check with matcher --- django_codemod/codemods/django_40.py | 49 +++++++++++++++------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/django_codemod/codemods/django_40.py b/django_codemod/codemods/django_40.py index 1850c0b6..d82313b4 100644 --- a/django_codemod/codemods/django_40.py +++ b/django_codemod/codemods/django_40.py @@ -18,29 +18,32 @@ def leave_Import( def leave_ImportFrom( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Union[BaseSmallStatement, RemovalSentinel]: - if updated_node.module and len(updated_node.module.children) == 3: - tops, _, last = updated_node.module.children - if len(tops.children) == 3: - top, _, middle = tops.children - if ( - top.value == "django" - and middle.value == "utils" - and last.value == "encoding" - ): - new_names = [] - new_import_missing = True - new_import_alias = None - for import_alias in original_node.names: - if import_alias.evaluated_name == "force_text": - new_import_alias = ImportAlias(name=Name("force_str")) - else: - if import_alias.evaluated_name == "force_str": - new_import_missing = False - new_names.append(import_alias) - if new_import_missing and new_import_alias is not None: - new_names.append(new_import_alias) - new_names = list(sorted(new_names, key=lambda n: n.evaluated_name)) - return ImportFrom(module=updated_node.module, names=new_names) + import_matches = m.matches( + updated_node, + m.ImportFrom( + module=m.Attribute( + attr=m.Name("encoding"), + value=m.Attribute( + value=m.Name("django"), attr=m.Name(value="utils") + ), + ), + ), + ) + if import_matches: + new_names = [] + new_import_missing = True + new_import_alias = None + for import_alias in original_node.names: + if import_alias.evaluated_name == "force_text": + new_import_alias = ImportAlias(name=Name("force_str")) + else: + if import_alias.evaluated_name == "force_str": + new_import_missing = False + new_names.append(import_alias) + if new_import_missing and new_import_alias is not None: + new_names.append(new_import_alias) + new_names = list(sorted(new_names, key=lambda n: n.evaluated_name)) + return ImportFrom(module=updated_node.module, names=new_names) return super().leave_ImportFrom(original_node, updated_node) def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: