Skip to content

Commit

Permalink
Fix(#83): type for Iter[Iter].flatten() (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Jan 23, 2024
2 parents ce7591d + bebc22c commit c2ea2d3
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 8 deletions.
13 changes: 7 additions & 6 deletions iterpy/_generate_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class TypeMarker:
base_message_template = """
# TYPE[S] # noqa: ERA001
@overload
def flatten(self: Seq[TYPE[S]]) -> Seq[S]: ...
def flatten(self: Iter[TYPE[S]]) -> Iter[S]: ...
@overload
def flatten(self: Seq[TYPE[S] | S]) -> Seq[S]: ...
def flatten(self: Iter[TYPE[S] | S]) -> Iter[S]: ...
"""

heterogenous_overload = """@overload
def flatten(self: Seq[TYPE[S] | T]) -> Seq[S]: ..."""
def flatten(self: Iter[TYPE[S] | T]) -> Iter[S]: ..."""

combined_interface = f" # Code for generating the following is in {Path(__file__).name}"
for mark in [
Expand All @@ -28,6 +28,7 @@ def flatten(self: Seq[TYPE[S] | T]) -> Seq[S]: ..."""
TypeMarker("list[S]"),
TypeMarker("set[S]"),
TypeMarker("frozenset[S]"),
TypeMarker("Iter[S]"),
]:
message = (
base_message_template.replace(
Expand All @@ -40,15 +41,15 @@ def flatten(self: Seq[TYPE[S] | T]) -> Seq[S]: ..."""
combined_interface += """
# str
@overload
def flatten(self: Seq[str]) -> Seq[str]: ...
def flatten(self: Iter[str]) -> Iter[str]: ...
@overload
def flatten(self: Seq[str | S]) -> Seq[S]: ...
def flatten(self: Iter[str | S]) -> Iter[S]: ...
"""

combined_interface += """
# Generic
@overload
def flatten(self: Seq[S]) -> Seq[S]: ...
def flatten(self: Iter[S]) -> Iter[S]: ...
"""

print(combined_interface)
6 changes: 6 additions & 0 deletions iterpy/_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def _iterator(self) -> Iterator[T]:

return deepcopy(self.__consumable_iterator)

def __iter__(self) -> "Iter[T]":
return self

def __next__(self) -> T:
return next(self._iterator)

def __getitem__(self, index: int | slice) -> T | "Iter[T]":
if isinstance(index, int) and index >= 0:
try:
Expand Down
13 changes: 11 additions & 2 deletions iterpy/_iter.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ from collections.abc import (
Iterator,
Sequence,
)
from types import UnionType
from typing import Generic, TypeVar, overload

S = TypeVar("S")
T = TypeVar("T")
U = TypeVar("U")

iterpyables: TypeVar = (
Iterable | Iterator | tuple | list | set | frozenset # type: ignore
iterpyables: UnionType = (
Iterable | Iterator | "Iter" | tuple | list | set | frozenset
)

class Iter(Generic[T]):
Expand All @@ -24,6 +25,8 @@ class Iter(Generic[T]):
def __getitem__(self, index: int) -> T: ...
@overload
def __getitem__(self, index: slice) -> Iter[T]: ...
def __iter__(self) -> Iter[T]: ...
def __next__(self) -> T: ...
def count(self) -> int: ...
def to_list(self) -> list[T]: ...
def to_tuple(self) -> tuple[T, ...]: ...
Expand Down Expand Up @@ -80,6 +83,12 @@ class Iter(Generic[T]):
@overload
def flatten(self: Iter[frozenset[S] | S]) -> Iter[S]: ...

# Iter[S] # noqa: ERA001
@overload
def flatten(self: Iter[Iter[S]]) -> Iter[S]: ...
@overload
def flatten(self: Iter[Iter[S] | S]) -> Iter[S]: ...

# str
@overload
def flatten(self: Iter[str]) -> Iter[str]: ...
Expand Down
5 changes: 5 additions & 0 deletions iterpy/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def test_flatten_iterator(self):
result: Iter[int] = Iter(test_input).flatten()
assert result.to_list() == [1, 2, 3, 4]

def test_flatten_iter_iter(self):
iterator: Iter[int] = Iter([1, 2])
nested_iter: Iter[Iter[int]] = Iter([iterator])
unnested_iter: Iter[int] = nested_iter.flatten() # noqa: F841, RUF100

def test_flatten_str(self):
test_input: list[str] = ["abcd"]
iterator = Iter(test_input)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ ignore = [
"RET504",
"COM812",
"COM819",
"RUF100",
"W191",
]
ignore-init-module-imports = true
Expand Down

0 comments on commit c2ea2d3

Please sign in to comment.