Skip to content

Commit

Permalink
Remove Optional inside Union (#20)
Browse files Browse the repository at this point in the history
* Remove `Optional` inside `Union`

* Complete
  • Loading branch information
Kludex authored Jun 25, 2022
1 parent 7c0fc72 commit 9082aa0
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 40 deletions.
117 changes: 84 additions & 33 deletions no_optional/command.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List, Set

import libcst as cst
from libcst import ensure_type
from libcst import matchers as m
from libcst.codemod import VisitorBasedCodemodCommand

from no_optional.utils import is_typing_node


class NoOptionalCommand(VisitorBasedCodemodCommand):
@m.leave(
Expand All @@ -14,46 +19,38 @@ def replace_optional(
self, original_node: cst.Subscript, updated_node: cst.Subscript
) -> cst.Subscript:
if original_node.value.value == "Optional":
return updated_node.with_changes(
value=cst.Name("Union"),
slice=[
*updated_node.slice,
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("None")),
comma=cst.MaybeSentinel.DEFAULT,
),
],
)
union_type = cst.Name("Union")
else:
return updated_node.with_changes(
value=cst.Attribute(value=cst.Name("typing"), attr=cst.Name("Union")),
slice=[
*updated_node.slice,
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("None")),
comma=cst.MaybeSentinel.DEFAULT,
),
],
)
union_type = cst.Attribute(value=cst.Name("typing"), attr=cst.Name("Union"))

@m.call_if_inside(
m.Annotation(
annotation=m.Subscript(
value=m.Name(value="Optional")
| m.Attribute(value=m.Name("typing"), attr=m.Name("Optional")),
slice=(
m.SubscriptElement(
slice=m.Index(value=m.Subscript(value=m.Name(value="Union"))),
),
m.ZeroOrMore(),
return updated_node.with_changes(
value=union_type,
slice=[
*updated_node.slice,
cst.SubscriptElement(
slice=cst.Index(value=cst.Name("None")),
comma=cst.MaybeSentinel.DEFAULT,
),
)
],
)
)

@m.leave(
m.Subscript(
value=m.Name(value="Optional")
| m.Attribute(value=m.Name("typing"), attr=m.Name("Optional"))
| m.Attribute(value=m.Name("typing"), attr=m.Name("Optional")),
slice=(
m.SubscriptElement(
slice=m.Index(
value=m.Subscript(
value=m.Name(value="Union")
| m.Attribute(
value=m.Name(value="typing"), attr=m.Name(value="Union")
)
)
),
),
m.ZeroOrMore(),
),
)
)
def remove_union_redundancy(
Expand All @@ -63,6 +60,60 @@ def remove_union_redundancy(
slice=[*updated_node.slice[0].slice.value.slice, *updated_node.slice[1:]]
)

@m.leave(
m.Subscript(
value=m.Name(value="Union")
| m.Attribute(value=m.Name("typing"), attr=m.Name("Union")),
slice=(
m.ZeroOrMore(),
m.SubscriptElement(
slice=m.Index(
value=m.Subscript(
value=m.Name(value="Optional")
| m.Attribute(
value=m.Name(value="typing"),
attr=m.Name(value="Optional"),
)
)
),
),
m.ZeroOrMore(),
),
)
)
def remove_internal_optional(
self, original_node: cst.Subscript, updated_node: cst.Subscript
) -> cst.Subscript:
slices: List[cst.SubscriptElement] = []
# 1. Iterate over slices, remove optional, and hold the inner slices
for slice in updated_node.slice:
if isinstance(slice.slice, cst.Index):
value = slice.slice.value
if isinstance(value, cst.Name):
slices.append(slice)
elif isinstance(value, cst.Subscript):
slices.extend(value.slice)

# 2. Compute unique slices
unique_slices = []
unique_names: Set[str] = set()
for slice in slices:
index = ensure_type(slice.slice, cst.Index)
if isinstance(index.value, cst.Name):
name = index.value.value
elif isinstance(index.value, cst.Attribute):
name = index.value.attr.value
else:
raise ValueError(f"Unexpected index type: {type(index.value)}")
if name not in unique_names:
unique_slices.append(slice)
unique_names.add(name)

# 3. Send `None` to the end
unique_slices.sort(key=lambda x: is_typing_node(x.slice.value, "None"))

return updated_node.with_changes(slice=unique_slices)

@m.call_if_inside(m.ImportAlias(name=m.Name(value="Optional")))
@m.leave(m.Name(value="Optional"))
def replace_import(
Expand Down
10 changes: 10 additions & 0 deletions no_optional/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Union

import libcst as cst


def is_typing_node(node: Union[cst.Name, cst.Attribute], name: str) -> bool:
if isinstance(node, cst.Name):
return node.value == name
else:
return node.value == "typing" and node.attr.value == name
53 changes: 46 additions & 7 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@pytest.mark.parametrize(
"input,expected",
(
(
pytest.param(
textwrap.dedent(
"""
from typing import Optional
Expand All @@ -28,7 +28,7 @@ def function(a: Union[int, None] = None) -> Union[int, None]:
"""
),
),
(
pytest.param(
textwrap.dedent(
"""
import typing
Expand All @@ -46,7 +46,7 @@ async def function(a: typing.Union[int, None]) -> typing.Union[int, None]:
"""
),
),
(
pytest.param(
textwrap.dedent(
"""
from typing import Optional
Expand All @@ -64,7 +64,25 @@ class Potato:
"""
),
),
(
pytest.param(
textwrap.dedent(
"""
import typing
class Potato:
a: typing.Optional[typing.Union[int, str]]
"""
),
textwrap.dedent(
"""
import typing
class Potato:
a: typing.Union[int, str, None]
"""
),
),
pytest.param(
textwrap.dedent(
"""
a: int = 2
Expand All @@ -76,7 +94,7 @@ class Potato:
"""
),
),
(
pytest.param(
textwrap.dedent(
"""
from typing import List, Optional
Expand All @@ -94,7 +112,7 @@ def function(a: List[Union[int, None]]) -> Union[int, None]:
"""
),
),
(
pytest.param(
textwrap.dedent(
"""
from typing import Dict, Optional
Expand All @@ -112,11 +130,32 @@ def function(a: Dict[str, Union[int, None]]) -> Union[int, None]:
"""
),
),
pytest.param(
textwrap.dedent(
"""
import typing
from typing import Optional, Union
def function(a: Union[A, B, Optional[D], E, typing.Optional[F]] = None):
...
"""
),
textwrap.dedent(
"""
import typing
from typing import Union, Union
def function(a: Union[A, B, D, E, F, None] = None):
...
"""
),
),
),
)
def test_transformer(input: str, expected: str) -> None:
source_tree = cst.parse_module(input)
print(source_tree)
transformer = NoOptionalCommand(CodemodContext())
modified_tree = source_tree.visit(transformer)
assert modified_tree.code == expected

0 comments on commit 9082aa0

Please sign in to comment.