Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure the function to rename is imported before renaming #54

Merged
merged 2 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion django_codemod/visitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def leave_ImportFrom(
obj=self.new_name,
asname=as_name,
)
self.context.scratch[self.rename_from] = not import_alias.asname
else:
new_names.append(import_alias)
if not new_names:
Expand All @@ -85,8 +86,14 @@ def leave_ImportFrom(
return updated_node.with_changes(names=new_names)
return super().leave_ImportFrom(original_node, updated_node)

@property
def _is_context_right(self):
return self.context.scratch.get(self.rename_from, False)

def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
if m.matches(updated_node, m.Call(func=m.Name(self.old_name))):
if self._is_context_right and m.matches(
updated_node, m.Call(func=m.Name(self.old_name))
):
updated_args = self.update_call_args(updated_node)
return Call(args=updated_args, func=Name(self.new_name))
return super().leave_Call(original_node, updated_node)
Expand Down
34 changes: 34 additions & 0 deletions tests/visitors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_simple_substitution(self) -> None:
self.assertCodemod(before, after)

def test_already_imported(self) -> None:
"""Function to modify is already imported with an alias."""
before = """
from django.dummy.module import func, better_func

Expand All @@ -41,6 +42,7 @@ def test_already_imported(self) -> None:
self.assertCodemod(before, after)

def test_import_with_alias(self) -> None:
"""Function to modify is imported with an alias."""
before = """
from django.dummy.module import func as aliased_func

Expand All @@ -53,6 +55,38 @@ def test_import_with_alias(self) -> None:
"""
self.assertCodemod(before, after)

def test_same_name_function(self) -> None:
"""Should not be fooled by a function bearing the same name."""
before = """
from utils.helpers import func

result = func()
"""
after = """
from utils.helpers import func

result = func()
"""
self.assertCodemod(before, after)

def test_same_name_with_alias_import_function(self) -> None:
"""Imported with alias and other function with the same name."""
before = """
from django.dummy.module import func as aliased_func
from utils.helpers import func

result = func()
aliased_func()
"""
after = """
from utils.helpers import func
from django.dummy.module import better_func as aliased_func

result = func()
aliased_func()
"""
self.assertCodemod(before, after)

def test_extra_trailing_comma_when_last(self) -> None:
"""Extra trailing comma when removed import is the last one."""
before = """
Expand Down