Skip to content

Commit

Permalink
Pass MyPy in strict mode
Browse files Browse the repository at this point in the history
  • Loading branch information
aiudirog committed Jun 23, 2021
1 parent e7670ab commit e2351fc
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 174 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,26 @@ on:
push:

jobs:
flake8:
name: Flake8
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set Up Python
uses: actions/setup-python@v1
with:
python-version: '3.8'
python-version: '3.9'
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install flake8
- name: Check
python -m pip install flake8 mypy pytest
- name: Flake8
run: python -m flake8
- name: MyPy
run: python -m mypy aiuti --strict
test:
name: Test
needs: flake8
needs: lint
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down
3 changes: 1 addition & 2 deletions aiuti/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from ._version import get_versions
__version__ = get_versions()['version']
__version__ = get_versions()['version'] # type: ignore
del get_versions
129 changes: 78 additions & 51 deletions aiuti/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@
from functools import partial, wraps
from concurrent.futures import ThreadPoolExecutor
from typing import (
Iterator, Iterable, AsyncIterable, Callable, Set, Awaitable, Optional,
Type, Generic,
Iterator, Iterable, AsyncIterable, Callable, Set, Awaitable, Optional, Any,
Type, Generic, Dict, DefaultDict, Union, TypeVar, FrozenSet, overload,
)

from .typing import T, Yields, AYields

E = TypeVar('E', bound=BaseException)
X = TypeVar('X')
Loop = aio.AbstractEventLoop

logger = logging.getLogger(__name__)

_DONE = object()


async def gather_excs(aws: Iterable[Awaitable],
only: Type[T] = BaseException) -> AYields[T]:
async def gather_excs(
aws: Iterable[Awaitable[Any]],
only: Type[E] = BaseException, # type: ignore # MyPy bug
) -> AYields[E]:
"""
Gather the given awaitables and yield any exceptions they raise.
Expand Down Expand Up @@ -86,8 +90,8 @@ async def gather_excs(aws: Iterable[Awaitable],
yield res


async def raise_first_exc(aws: Iterable[Awaitable],
only: Type[T] = BaseException):
async def raise_first_exc(aws: Iterable[Awaitable[Any]],
only: Type[BaseException] = BaseException) -> None:
"""
Gather the given awaitables using :func:`gather_excs` and raise the
first exception encountered.
Expand Down Expand Up @@ -168,14 +172,14 @@ async def to_async_iter(iterable: Iterable[T]) -> AYields[T]:
yield x
return

def _queue_elements():
def _queue_elements() -> None:
try:
for x in iterable:
put(x)
finally:
put(_DONE)

q = aio.Queue()
q: 'aio.Queue[Union[T, object]]' = aio.Queue()
loop = aio.get_event_loop()
put = partial(loop.call_soon_threadsafe, q.put_nowait)
with ThreadPoolExecutor(1) as pool:
Expand All @@ -184,11 +188,13 @@ def _queue_elements():
i = await q.get()
if i is _DONE:
break
yield i
yield i # type: ignore
await future # Bubble any errors


def to_sync_iter(iterable: AsyncIterable[T], *, loop: Loop = None) -> Yields[T]:
def to_sync_iter(iterable: AsyncIterable[T],
*,
loop: Optional[Loop] = None) -> Yields[T]:
"""
Convert the given iterable from asynchronous to synchrounous by
by using a background thread running a new event loop to iterate it.
Expand Down Expand Up @@ -224,25 +230,24 @@ def to_sync_iter(iterable: AsyncIterable[T], *, loop: Loop = None) -> Yields[T]:
:param iterable: Asynchonrous iterable to process
:param loop: Optional specific loop to use to process the iterable
"""
if loop is None:
loop = aio.new_event_loop()

def _set_loop_and_queue_elements():
aio.set_event_loop(loop)
return loop.run_until_complete(_queue_elements())
def _set_loop_and_queue_elements(_loop: Loop) -> None:
aio.set_event_loop(_loop)
_loop.run_until_complete(_queue_elements())

async def _queue_elements():
async def _queue_elements() -> None:
try:
async for x in iterable:
put(x)
finally:
put(_DONE)
put(_DONE) # type: ignore

if loop is None:
loop = aio.new_event_loop()

q = queue.Queue()
q: 'queue.Queue[T]' = queue.Queue()
put = q.put_nowait
with ThreadPoolExecutor(1) as pool:
future = pool.submit(_set_loop_and_queue_elements)
future = pool.submit(_set_loop_and_queue_elements, loop)
try:
while True:
i = q.get()
Expand All @@ -253,7 +258,9 @@ async def _queue_elements():
future.result()


def threadsafe_async_cache(func: T) -> T:
def threadsafe_async_cache(
func: Callable[..., Awaitable[T]],
) -> Callable[..., Awaitable[T]]:
"""
Simple thread-safe asynchronous cache decorator which ensures that
the decorated/wrapped function is only ever called once for a given
Expand Down Expand Up @@ -300,12 +307,12 @@ def threadsafe_async_cache(func: T) -> T:
per-instance basis during intilaztion so that the cache only
lives as long as the object.
"""
cache = {}

locks = defaultdict(Lock) # 1 lock per input key
cache: Dict[FrozenSet[Any], T] = {}
# 1 lock per input key
locks: DefaultDict[FrozenSet[Any], Lock] = defaultdict(Lock)

@wraps(func)
async def _wrapper(*args, **kwargs):
async def _wrapper(*args: Any, **kwargs: Any) -> T:
# Get the key from the input arguments
key = frozenset(args + tuple(kwargs.values()))
while True:
Expand Down Expand Up @@ -336,9 +343,32 @@ async def _wrapper(*args, **kwargs):
return _wrapper


def buffer_until_timeout(func: Callable[[Set[T]], Awaitable[None]] = None,
*,
timeout: float = 1) -> 'BufferAsyncCalls[T]':
_BufferFunc = Callable[[Set[T]], Awaitable[None]]


@overload
def buffer_until_timeout(
*,
timeout: float = 1,
) -> 'Callable[[_BufferFunc[T]], BufferAsyncCalls[T]]':
...


@overload
def buffer_until_timeout(
func: _BufferFunc[T],
*,
timeout: float = 1,
) -> 'BufferAsyncCalls[T]':
...


def buffer_until_timeout(
func: Optional[_BufferFunc[T]] = None,
*,
timeout: float = 1,
) -> Union['BufferAsyncCalls[T]',
'Callable[[_BufferFunc[T]], BufferAsyncCalls[T]]']:
"""
Async function decorator/wrapper which buffers the arg passed in
each call. After a given timeout has passed since the last call,
Expand Down Expand Up @@ -376,7 +406,7 @@ def buffer_until_timeout(func: Callable[[Set[T]], Awaitable[None]] = None,
:return: Non-async function which will buffer the given argument
"""
if func is None:
return partial(buffer_until_timeout, timeout=timeout)
return partial(buffer_until_timeout, timeout=timeout) # type: ignore
return wraps(func)(BufferAsyncCalls(func, timeout=timeout))


Expand All @@ -392,10 +422,7 @@ class BufferAsyncCalls(Generic[T]):
for it to work.
"""

def __init__(self,
func: Callable[[Set[T]], Awaitable[None]],
*,
timeout: float = 1):
def __init__(self, func: _BufferFunc[T], *, timeout: float = 1):
#: Wrapped function which should take a set and will be called
#: once inputs are buffered and timeout is reached
self.func = func
Expand Down Expand Up @@ -424,16 +451,16 @@ def __init__(self,
name=f"Buffering {self.func!r}",
)
#: Current task that is waiting for a new element from the queue
self._getting: Optional[aio.Task] = None
self._getting: Optional['aio.Task[AsyncIterable[T]]'] = None

def __call__(self, _arg: T):
def __call__(self, _arg: T) -> None:
"""
Place the given argument on the queue to be processed on the
next execution of the function.
"""
self._put(_obj_to_aiter(_arg))

def await_(self, _arg: Awaitable[T]):
def await_(self, _arg: Awaitable[T]) -> None:
"""
Schedule the given awaitable to be be put onto the queue after
it has been awaited.
Expand All @@ -454,7 +481,7 @@ def await_(self, _arg: Awaitable[T]):
"""
self._put(_awaitable_to_aiter(_arg))

def map(self, _args: Iterable[T]):
def map(self, _args: Iterable[T]) -> None:
"""
Place an iterable of args onto the queue to be processed.
Expand All @@ -470,7 +497,7 @@ def map(self, _args: Iterable[T]):
"""
self._put(to_async_iter(_args))

def amap(self, _args: AsyncIterable[T]):
def amap(self, _args: AsyncIterable[T]) -> None:
"""
Schedule an async iterable of args to be put onto the queue.
Expand All @@ -486,7 +513,7 @@ def amap(self, _args: AsyncIterable[T]):
"""
self._put(_args)

async def wait(self, *, cancel: bool = True):
async def wait(self, *, cancel: bool = True) -> None:
"""
Wait for the event to be set indicating that all arguments
currently in the queue have been processed.
Expand All @@ -513,29 +540,29 @@ async def wait(self, *, cancel: bool = True):
# Wait for the function to finish processing
await self.event.wait()

async def wait_from_anywhere(self, *, cancel: bool = True):
async def wait_from_anywhere(self, *, cancel: bool = True) -> None:
"""
Wrapper around :meth:`wait` which uses :func:`ensure_aw` to
handle waiting from a possibly different event loop.
"""
return await ensure_aw(self.wait(cancel=cancel), self.loop)

async def _waiter(self):
async def _waiter(self) -> None:
"""
Simple loop which tries to process the queue infinitely using
:meth:`_process_queue`. This is spawned automatically upon
"""
while True:
await self._process_queue()

async def _process_queue(self):
async def _process_queue(self) -> None:
"""
Internal method to retrieve all current elements from queue and
execute the function on timeout or cancellation.
"""
inputs: Set[T] = set()

async def _load_inputs(iterable: AsyncIterable[T]):
async def _load_inputs(iterable: AsyncIterable[T]) -> None:
try:
async for i in iterable:
inputs.add(i)
Expand Down Expand Up @@ -570,7 +597,7 @@ async def _load_inputs(iterable: AsyncIterable[T]):
else:
self.q.task_done()

async def _run_func(self, inputs: Set[T]):
async def _run_func(self, inputs: Set[T]) -> None:
"""
Run :attr:`func` with the given set of inputs and set
:attr:`event` once it has finished successfully.
Expand All @@ -587,14 +614,14 @@ async def _run_func(self, inputs: Set[T]):
else:
self.event.set()

def _schedule_with_timeout(self, coro: Awaitable) -> aio.Task:
def _schedule_with_timeout(self, coro: Awaitable[X]) -> 'aio.Task[X]':
"""
Helper method to create a task for the given coroutine with the
configured :attr:`timeout`.
"""
return self.loop.create_task(aio.wait_for(coro, self.timeout))

def _put(self, iterable: AsyncIterable[T]):
def _put(self, iterable: AsyncIterable[T]) -> None:
"""
Helper method to put an async iterable onto the queue and clear
the event.
Expand Down Expand Up @@ -713,13 +740,13 @@ def loop_in_thread(loop: Loop) -> Callable[[], None]:
"""
stop = aio.Event()

async def _spin():
async def _spin() -> None:
aio.set_event_loop(loop)
await stop.wait()

Thread(target=loop.run_until_complete, args=(_spin(),), daemon=True).start()

def _stopper(): loop.call_soon_threadsafe(stop.set)
def _stopper() -> None: loop.call_soon_threadsafe(stop.set)

return _stopper

Expand Down Expand Up @@ -758,18 +785,18 @@ async def _awaitable_to_aiter(o: Awaitable[T]) -> AsyncIterable[T]:
yield await o


class DaemonTask(aio.Task):
class DaemonTask(aio.Task): # type: ignore
"""
Custom :class:`asyncio.Task` which is meant to run forever and
therefore doesn't warn when it is still pending at loop shutdown.
"""

if sys.version_info <= (3, 8):
if sys.version_info < (3, 8):

# Ignore name arg when it isn't available for compatibility
def __init__(self, coro, *, loop=None, name=None):
super().__init__(coro, loop=loop)

# Skip the __del__ defined by aio.Task which does the logging and
# then calls super()
__del__ = aio.Task.__base__.__del__ # noqa
__del__ = aio.Task.__base__.__del__ # type: ignore
Loading

0 comments on commit e2351fc

Please sign in to comment.