Skip to content

Commit

Permalink
Add pyright to pre-commit (PR #99)
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel authored May 3, 2024
2 parents 1749428 + f3f047c commit 4a2eff6
Show file tree
Hide file tree
Showing 20 changed files with 421 additions and 374 deletions.
14 changes: 11 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,19 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
# We can't use the latest version due to a regression in mypy 1.7.0
# See https://github.com/python/mypy/issues/17191 for more information
rev: v1.6.1
hooks:
- id: mypy
files: ^(?!tests)
types: [python]
- id: mypy
files: ^(?!tests)
types: [python]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.361
hooks:
- id: pyright
additional_dependencies: [pytest, typing-extensions]
types: [python]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.272
Expand Down
10 changes: 8 additions & 2 deletions aiostream/aiter_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for asynchronous iteration."""
from __future__ import annotations

import sys
from types import TracebackType

import warnings
Expand Down Expand Up @@ -103,7 +105,7 @@ def assert_async_iterator(obj: object) -> None:

# Async iterator context

T = TypeVar("T")
T = TypeVar("T", covariant=True)
Self = TypeVar("Self", bound="AsyncIteratorContext[Any]")


Expand Down Expand Up @@ -195,7 +197,11 @@ async def __aexit__(
# Throw
try:
assert isinstance(self._aiterator, AsyncGenerator)
await self._aiterator.athrow(typ, value, traceback)
if sys.version_info >= (3, 12):
assert value is not None
await self._aiterator.athrow(value)
else:
await self._aiterator.athrow(typ, value, traceback)
raise RuntimeError("Async iterator didn't stop after athrow()")

# Exception has been (most probably) silenced
Expand Down
13 changes: 1 addition & 12 deletions aiostream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class StreamEmpty(Exception):

# Helpers

T = TypeVar("T")
T = TypeVar("T", covariant=True)
X = TypeVar("X")
A = TypeVar("A", contravariant=True)
P = ParamSpec("P")
Expand Down Expand Up @@ -349,14 +349,6 @@ async def random(offset=0., width=1.):
"since the decorated function becomes an operator class"
)

# Look for "more_sources"
for i, p in enumerate(parameters):
if p.name == "more_sources" and p.kind == inspect.Parameter.VAR_POSITIONAL:
more_sources_index = i
break
else:
more_sources_index = None

# Injected parameters
self_parameter = inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
inspect.Parameter("cls", inspect.Parameter.POSITIONAL_OR_KEYWORD)
Expand All @@ -371,9 +363,6 @@ async def random(offset=0., width=1.):

# Init method
def init(self: BaseStream[T], *args: P.args, **kwargs: P.kwargs) -> None:
if more_sources_index is not None:
for source in args[more_sources_index:]:
assert_async_iterable(source)
factory = functools.partial(raw, *args, **kwargs)
return BaseStream.__init__(self, factory)

Expand Down
9 changes: 5 additions & 4 deletions aiostream/stream/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
TypeVar,
AsyncIterator,
cast,
Type,
)
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Never

from ..stream import time
from ..core import operator, streamcontext
Expand Down Expand Up @@ -121,22 +122,22 @@ async def call(


@operator
async def throw(exc: Exception) -> AsyncIterator[None]:
async def throw(exc: Exception | Type[Exception]) -> AsyncIterator[Never]:
"""Throw an exception without generating any value."""
if False:
yield
raise exc


@operator
async def empty() -> AsyncIterator[None]:
async def empty() -> AsyncIterator[Never]:
"""Terminate without generating any value."""
if False:
yield


@operator
async def never() -> AsyncIterator[None]:
async def never() -> AsyncIterator[Never]:
"""Hang forever without generating any value."""
if False:
yield
Expand Down
2 changes: 1 addition & 1 deletion aiostream/stream/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def filterindex(


@pipable_operator
def slice(source: AsyncIterable[T], *args: int) -> AsyncIterator[T]:
def slice(source: AsyncIterable[T], *args: int | None) -> AsyncIterator[T]:
"""Slice an asynchronous sequence.
The arguments are the same as the builtin type slice.
Expand Down
151 changes: 87 additions & 64 deletions aiostream/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,41 @@
from __future__ import annotations

import asyncio
from unittest.mock import Mock
from contextlib import contextmanager
from unittest.mock import Mock
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
TypeVar,
AsyncIterable,
AsyncIterator,
ContextManager,
Iterator,
)

import pytest

from .core import StreamEmpty, streamcontext, pipable_operator
from typing import TYPE_CHECKING, Any, Callable, List

if TYPE_CHECKING:
from _pytest.fixtures import SubRequest
from aiostream.core import Stream

__all__ = ["add_resource", "assert_run", "event_loop"]
__all__ = ["add_resource", "assert_run", "event_loop_policy", "assert_cleanup"]


T = TypeVar("T")


@pipable_operator
async def add_resource(source, cleanup_time):
async def add_resource(
source: AsyncIterable[T], cleanup_time: float
) -> AsyncIterator[T]:
"""Simulate an open resource in a stream operator."""
try:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
loop.open_resources += 1
loop.resources += 1
async with streamcontext(source) as streamer:
Expand Down Expand Up @@ -89,7 +104,7 @@ def assert_run(request: SubRequest) -> Callable:


@pytest.fixture
def event_loop():
def event_loop_policy() -> TimeTrackingTestLoopPolicy:
"""Fixture providing a test event loop.
The event loop simulate and records the sleep operation,
Expand All @@ -98,76 +113,84 @@ def event_loop():
It also tracks simulated resources and make sure they are
all released before the loop is closed.
"""
return TimeTrackingTestLoopPolicy()

class TimeTrackingTestLoop(asyncio.BaseEventLoop):
stuck_threshold = 100

def __init__(self):
super().__init__()
self._time = 0
self._timers = []
self._selector = Mock()
self.clear()
@pytest.fixture
def assert_cleanup(
event_loop: TimeTrackingTestLoop,
) -> Callable[[], ContextManager[TimeTrackingTestLoop]]:
"""Fixture to assert cleanup of resources."""
return event_loop.assert_cleanup


class TimeTrackingTestLoop(asyncio.BaseEventLoop):
stuck_threshold: int = 100

def __init__(self):
super().__init__()
self._time: float = 0.0
self._timers: list[float] = []
self._selector = Mock()
self.clear()

# Loop internals

def _run_once(self) -> None:
super()._run_once()
# Update internals
self.busy_count += 1
self._timers = sorted(when for when in self._timers if when > self.time())
# Time advance
if self.time_to_go:
when = self._timers.pop(0)
step = when - self.time()
self.steps.append(step)
self.advance_time(step)
self.busy_count = 0

# Loop internals
def _process_events(self, event_list) -> None:
return

def _run_once(self):
super()._run_once()
# Update internals
self.busy_count += 1
self._timers = sorted(when for when in self._timers if when > loop.time())
# Time advance
if self.time_to_go:
when = self._timers.pop(0)
step = when - loop.time()
self.steps.append(step)
self.advance_time(step)
self.busy_count = 0
def _write_to_self(self) -> None:
return

def _process_events(self, event_list):
return
# Time management

def _write_to_self(self):
return
def time(self) -> float:
return self._time

# Time management
def advance_time(self, advance: float) -> None:
if advance:
self._time += advance

def time(self):
return self._time
def call_at(self, when, callback, *args, **kwargs):
self._timers.append(when)
return super().call_at(when, callback, *args, **kwargs)

def advance_time(self, advance):
if advance:
self._time += advance
@property
def stuck(self) -> bool:
return self.busy_count > self.stuck_threshold

def call_at(self, when, callback, *args, **kwargs):
self._timers.append(when)
return super().call_at(when, callback, *args, **kwargs)
@property
def time_to_go(self) -> bool:
return self._timers and (self.stuck or not self._ready)

@property
def stuck(self):
return self.busy_count > self.stuck_threshold
# Resource management

@property
def time_to_go(self):
return self._timers and (self.stuck or not self._ready)
def clear(self) -> None:
self.steps = []
self.open_resources = 0
self.resources = 0
self.busy_count = 0

# Resource management
@contextmanager
def assert_cleanup(self) -> Iterator[TimeTrackingTestLoop]:
self.clear()
yield self
assert self.open_resources == 0
self.clear()

def clear(self):
self.steps = []
self.open_resources = 0
self.resources = 0
self.busy_count = 0

@contextmanager
def assert_cleanup(self):
self.clear()
yield self
assert self.open_resources == 0
self.clear()

loop = TimeTrackingTestLoop()
asyncio.set_event_loop(loop)
with loop.assert_cleanup():
yield loop
loop.close()
class TimeTrackingTestLoopPolicy(asyncio.DefaultEventLoopPolicy):
_loop_factory = TimeTrackingTestLoop
5 changes: 4 additions & 1 deletion examples/norm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def strip(x: bytes, *_: object) -> str:
def nonempty(x: str) -> bool:
return x != ""

def to_float(x: str, *_: object) -> float:
return float(x)

def square(x: float, *_: object) -> float:
return x**2

Expand All @@ -66,7 +69,7 @@ def square_root(x: float, *_: object) -> float:
| pipe.print("string: {}")
| pipe.map(strip)
| pipe.takewhile(nonempty)
| pipe.map(float)
| pipe.map(to_float)
| pipe.map(square)
| pipe.print("square: {:.2f}")
| pipe.action(write_cursor)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.pyright]
ignore = ["aiostream/test_utils.py"]
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from aiostream.test_utils import (
add_resource,
assert_run,
assert_cleanup,
event_loop_policy,
)

__all__ = [
"add_resource",
"assert_run",
"assert_cleanup",
"event_loop_policy",
]
Loading

0 comments on commit 4a2eff6

Please sign in to comment.