Skip to content

Commit

Permalink
combine remove_imports and remove_duplicated_imports
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Jul 3, 2022
1 parent 1557a71 commit b10eb67
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 37 deletions.
28 changes: 8 additions & 20 deletions reorder_python_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,6 @@ def add_imports(
]


def remove_imports(
partitions: Iterable[CodePartition],
to_remove: set[tuple[str, str | None] | tuple[str, str, str | None]],
) -> list[CodePartition]:
return [
partition for partition in partitions
if (
partition.code_type is not CodeType.IMPORT or
import_obj_from_str(partition.src).key not in to_remove
)
]


class Replacements(NamedTuple):
# (orig_mod, attr) => new_mod
exact: dict[tuple[str, str], str]
Expand Down Expand Up @@ -289,16 +276,18 @@ def _module_to_base_modules(s: str) -> Generator[str, None, None]:

def remove_duplicated_imports(
partitions: Iterable[CodePartition],
*,
to_remove: set[tuple[str, ...]],
) -> list[CodePartition]:
seen: set[Import | ImportFrom] = set()
seen = set(to_remove)
seen_module_names: set[str] = set()
without_exact_duplicates = []

for partition in partitions:
if partition.code_type is CodeType.IMPORT:
import_obj = import_obj_from_str(partition.src)
if import_obj not in seen:
seen.add(import_obj)
if import_obj.key not in seen:
seen.add(import_obj.key)
if (
isinstance(import_obj, Import) and
not import_obj.key.asname
Expand Down Expand Up @@ -385,7 +374,7 @@ def fix_file_contents(
contents: str,
*,
to_add: tuple[str, ...] = (),
to_remove: set[tuple[str, str | None] | tuple[str, str, str | None]],
to_remove: set[tuple[str, ...]],
to_replace: Replacements,
settings: Settings = Settings(),
) -> str:
Expand All @@ -401,9 +390,8 @@ def fix_file_contents(
partitioned = combine_trailing_code_chunks(partitioned)
partitioned = add_imports(partitioned, to_add=to_add)
partitioned = separate_comma_imports(partitioned)
partitioned = remove_imports(partitioned, to_remove=to_remove)
partitioned = replace_imports(partitioned, to_replace=to_replace)
partitioned = remove_duplicated_imports(partitioned)
partitioned = remove_duplicated_imports(partitioned, to_remove=to_remove)
partitioned = apply_import_sorting(partitioned, settings=settings)

return _partitions_to_src(partitioned).replace('\n', nl)
Expand All @@ -413,7 +401,7 @@ def _fix_file(
filename: str,
args: argparse.Namespace,
*,
to_remove: set[tuple[str, str | None] | tuple[str, str, str | None]],
to_remove: set[tuple[str, ...]],
to_replace: Replacements,
settings: Settings = Settings(),
) -> int:
Expand Down
31 changes: 14 additions & 17 deletions tests/reorder_python_imports_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,56 +209,53 @@ def test_separate_comma_imports_does_not_remove_comments_when_not_splitting():


def test_remove_duplicated_imports_trivial():
assert remove_duplicated_imports([]) == []
assert remove_duplicated_imports([], to_remove=set()) == []


def test_remove_duplicated_imports_no_dupes_no_removals():
input_partitions = [
partitions = [
CodePartition(CodeType.IMPORT, 'import sys\n'),
CodePartition(CodeType.NON_CODE, '\n'),
CodePartition(CodeType.IMPORT, 'from six import text_type\n'),
]
assert remove_duplicated_imports(input_partitions) == input_partitions
assert remove_duplicated_imports(partitions, to_remove=set()) == partitions


def test_remove_duplicated_imports_removes_duplicated():
assert remove_duplicated_imports([
partitions = [
CodePartition(CodeType.IMPORT, 'import sys\n'),
CodePartition(CodeType.IMPORT, 'import sys\n'),
]) == [
CodePartition(CodeType.IMPORT, 'import sys\n'),
]
expected = [CodePartition(CodeType.IMPORT, 'import sys\n')]
assert remove_duplicated_imports(partitions, to_remove=set()) == expected


def test_remove_duplicate_redundant_import_imports():
assert remove_duplicated_imports([
CodePartition(CodeType.IMPORT, 'import os\n'),
CodePartition(CodeType.IMPORT, 'import os.path\n'),
]) == [
CodePartition(CodeType.IMPORT, 'import os.path\n'),
]
assert remove_duplicated_imports([
CodePartition(CodeType.IMPORT, 'import os.path\n'),
partitions = [
CodePartition(CodeType.IMPORT, 'import os\n'),
]) == [
CodePartition(CodeType.IMPORT, 'import os.path\n'),
]
expected = [CodePartition(CodeType.IMPORT, 'import os.path\n')]

assert remove_duplicated_imports(partitions, to_remove=set()) == expected
partitions.reverse()
assert remove_duplicated_imports(partitions, to_remove=set()) == expected


def test_aliased_imports_not_considered_redundant():
partitions = [
CodePartition(CodeType.IMPORT, 'import os\n'),
CodePartition(CodeType.IMPORT, 'import os.path as os_path\n'),
]
assert remove_duplicated_imports(partitions) == partitions
assert remove_duplicated_imports(partitions, to_remove=set()) == partitions


def test_aliased_imports_not_considered_redundant_v2():
partitions = [
CodePartition(CodeType.IMPORT, 'import os as osmod\n'),
CodePartition(CodeType.IMPORT, 'import os.path\n'),
]
assert remove_duplicated_imports(partitions) == partitions
assert remove_duplicated_imports(partitions, to_remove=set()) == partitions


def test_apply_import_sorting_trivial():
Expand Down

0 comments on commit b10eb67

Please sign in to comment.