Skip to content

Commit

Permalink
disallow untyped defs
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed May 12, 2021
1 parent e32999b commit a14d442
Show file tree
Hide file tree
Showing 30 changed files with 342 additions and 280 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ pytest11 =

[mypy]
ignore_missing_imports = true
disallow_untyped_defs = true
109 changes: 55 additions & 54 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,11 @@ def _maybe_set_event_loop_policy(policy: Optional[asyncio.AbstractEventLoopPolic
asyncio.set_event_loop_policy(policy)


def run(func: Callable[..., T_Retval], *args, debug: bool = False, use_uvloop: bool = True,
def run(func: Callable[..., Awaitable[T_Retval]], *args: object, debug: bool = False, use_uvloop: bool = True,
policy: Optional[asyncio.AbstractEventLoopPolicy] = None) -> T_Retval:
@wraps(func)
async def wrapper():
task = current_task()
async def wrapper() -> T_Retval:
task = cast(asyncio.Task, current_task())
task_state = TaskState(None, get_callable_name(func), None)
_task_states[task] = task_state
if _native_task_names:
Expand Down Expand Up @@ -247,7 +247,7 @@ async def wrapper():


class CancelScope(BaseCancelScope):
def __new__(cls, *, deadline: float = math.inf, shield: bool = False):
def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> "CancelScope":
return object.__new__(cls)

def __init__(self, deadline: float = math.inf, shield: bool = False):
Expand All @@ -262,20 +262,20 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._host_task: Optional[asyncio.Task] = None
self._timeout_expired = False

def __enter__(self):
def __enter__(self) -> "CancelScope":
if self._active:
raise RuntimeError(
"Each CancelScope may only be used for a single 'with' block"
)

self._host_task = current_task()
self._tasks.add(self._host_task)
self._host_task = host_task = cast(asyncio.Task, current_task())
self._tasks.add(host_task)
try:
task_state = _task_states[self._host_task]
task_state = _task_states[host_task]
except KeyError:
task_name = self._host_task.get_name() if _native_task_names else None
task_name = host_task.get_name() if _native_task_names else None
task_state = TaskState(None, task_name, self)
_task_states[self._host_task] = task_state
_task_states[host_task] = task_state
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
Expand Down Expand Up @@ -326,7 +326,7 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba

return None

def _timeout(self):
def _timeout(self) -> None:
if self._deadline != math.inf:
loop = get_running_loop()
if loop.time() >= self._deadline:
Expand Down Expand Up @@ -460,9 +460,9 @@ async def cancel_shielded_checkpoint() -> None:
await sleep(0)


def current_effective_deadline():
def current_effective_deadline() -> float:
try:
cancel_scope = _task_states[current_task()].cancel_scope
cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index]
except KeyError:
return math.inf

Expand All @@ -477,7 +477,7 @@ def current_effective_deadline():
return deadline


def current_time():
def current_time() -> float:
return get_running_loop().time()


Expand Down Expand Up @@ -517,17 +517,17 @@ class _AsyncioTaskStatus(abc.TaskStatus):
def __init__(self, future: asyncio.Future):
self._future = future

def started(self, value=None) -> None:
def started(self, value: object = None) -> None:
self._future.set_result(value)


class TaskGroup(abc.TaskGroup):
def __init__(self):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: List[BaseException] = []

async def __aenter__(self):
async def __aenter__(self) -> "TaskGroup":
self.cancel_scope.__enter__()
self._active = True
return self
Expand Down Expand Up @@ -613,7 +613,7 @@ async def _run_wrapped_task(
self.cancel_scope._tasks.remove(task)
del _task_states[task]

def _spawn(self, func: Callable[..., Coroutine], args: tuple, name,
def _spawn(self, func: Callable[..., Coroutine], args: tuple, name: Optional[str],
task_status_future: Optional[asyncio.Future] = None) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
# This is the code path for Python 3.8+
Expand Down Expand Up @@ -669,10 +669,10 @@ def task_done(_task: asyncio.Task) -> None:
self.cancel_scope._tasks.add(task)
return task

def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None:
def start_soon(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> None:
self._spawn(func, args, name)

async def start(self, func: Callable[..., Coroutine], *args, name=None) -> None:
async def start(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> None:
future: asyncio.Future = asyncio.Future()
task = self._spawn(func, args, name, future)

Expand Down Expand Up @@ -745,7 +745,7 @@ def stop(self, f: Optional[asyncio.Task] = None) -> None:


async def run_sync_in_worker_thread(
func: Callable[..., T_Retval], *args, cancellable: bool = False,
func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional['CapacityLimiter'] = None) -> T_Retval:
await checkpoint()

Expand Down Expand Up @@ -789,10 +789,10 @@ async def run_sync_in_worker_thread(
idle_workers.append(worker)


def run_sync_from_thread(func: Callable[..., T_Retval], *args,
def run_sync_from_thread(func: Callable[..., T_Retval], *args: object,
loop: Optional[asyncio.AbstractEventLoop] = None) -> T_Retval:
@wraps(func)
def wrapper():
def wrapper() -> None:
try:
f.set_result(func(*args))
except BaseException as exc:
Expand All @@ -806,22 +806,22 @@ def wrapper():
return f.result()


def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval:
def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval:
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
func(*args), threadlocals.loop)
return f.result()


class BlockingPortal(abc.BlockingPortal):
def __new__(cls):
def __new__(cls) -> "BlockingPortal":
return object.__new__(cls)

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._loop = get_running_loop()

def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
name, future: Future) -> None:
name: Optional[str], future: Future) -> None:
run_sync_from_thread(
partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs,
future, loop=self._loop)
Expand Down Expand Up @@ -908,12 +908,12 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:
return self._stderr


async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int,
async def open_process(command: Union[str, Sequence[str]], *, shell: bool, stdin: int, stdout: int, stderr: int,
cwd: Union[str, bytes, PathLike, None] = None,
env: Optional[Mapping[str, str]] = None) -> Process:
await checkpoint()
if shell:
process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout,
process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, # type: ignore[arg-type]
stderr=stderr, cwd=cwd, env=env)
else:
process = await asyncio.create_subprocess_exec(*command, stdin=stdin, stdout=stdout,
Expand All @@ -925,7 +925,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr:
return Process(process, stdin_stream, stdout_stream, stderr_stream)


def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task) -> None:
def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task: object) -> None:
"""
Forcibly shuts down worker processes belonging to this event loop."""
child_watcher: Optional[asyncio.AbstractChildWatcher]
Expand Down Expand Up @@ -1142,7 +1142,7 @@ def _raw_socket(self) -> SocketType:
return self.__raw_socket

def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
def callback(f):
def callback(f: object) -> None:
del self._receive_future
loop.remove_reader(self.__raw_socket)

Expand All @@ -1152,7 +1152,7 @@ def callback(f):
return f

def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
def callback(f):
def callback(f: object) -> None:
del self._send_future
loop.remove_writer(self.__raw_socket)

Expand Down Expand Up @@ -1594,10 +1594,10 @@ async def wait_socket_writable(sock: socket.SocketType) -> None:
#

class Event(BaseEvent):
def __new__(cls):
def __new__(cls) -> "Event":
return object.__new__(cls)

def __init__(self):
def __init__(self) -> None:
self._event = asyncio.Event()

def set(self) -> DeprecatedAwaitable:
Expand All @@ -1607,26 +1607,26 @@ def set(self) -> DeprecatedAwaitable:
def is_set(self) -> bool:
return self._event.is_set()

async def wait(self):
async def wait(self) -> None:
if await self._event.wait():
await checkpoint()

def statistics(self) -> EventStatistics:
return EventStatistics(len(self._event._waiters))
return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]


class CapacityLimiter(BaseCapacityLimiter):
_total_tokens: float = 0

def __new__(cls, total_tokens: float):
def __new__(cls, total_tokens: float) -> "CapacityLimiter":
return object.__new__(cls)

def __init__(self, total_tokens: float):
self._borrowers: Set[Any] = set()
self._wait_queue: Dict[Any, asyncio.Event] = OrderedDict()
self.total_tokens = total_tokens

async def __aenter__(self):
async def __aenter__(self) -> None:
await self.acquire()

async def __aexit__(self, exc_type: Optional[Type[BaseException]],
Expand Down Expand Up @@ -1671,7 +1671,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable:
self.acquire_on_behalf_of_nowait(current_task())
return DeprecatedAwaitable(self.acquire_nowait)

def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable:
def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
if borrower in self._borrowers:
raise RuntimeError("this borrower is already holding one of this CapacityLimiter's "
"tokens")
Expand All @@ -1685,7 +1685,7 @@ def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable:
async def acquire(self) -> None:
return await self.acquire_on_behalf_of(current_task())

async def acquire_on_behalf_of(self, borrower) -> None:
async def acquire_on_behalf_of(self, borrower: object) -> None:
await checkpoint_if_cancelled()
try:
self.acquire_on_behalf_of_nowait(borrower)
Expand All @@ -1705,7 +1705,7 @@ async def acquire_on_behalf_of(self, borrower) -> None:
def release(self) -> None:
self.release_on_behalf_of(current_task())

def release_on_behalf_of(self, borrower) -> None:
def release_on_behalf_of(self, borrower: object) -> None:
try:
self._borrowers.remove(borrower)
except KeyError:
Expand All @@ -1725,7 +1725,7 @@ def statistics(self) -> CapacityLimiterStatistics:
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar('_default_thread_limiter')


def current_default_thread_limiter():
def current_default_thread_limiter() -> CapacityLimiter:
try:
return _default_thread_limiter.get()
except LookupError:
Expand All @@ -1751,18 +1751,19 @@ def _deliver(self, signum: int) -> None:
if not self._future.done():
self._future.set_result(None)

def __enter__(self):
def __enter__(self) -> "_SignalReceiver":
for sig in set(self._signals):
self._loop.add_signal_handler(sig, self._deliver, sig)
self._handled_signals.add(sig)

return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]:
for sig in self._handled_signals:
self._loop.remove_signal_handler(sig)
return None

def __aiter__(self):
def __aiter__(self) -> "_SignalReceiver":
return self

async def __anext__(self) -> int:
Expand Down Expand Up @@ -1825,7 +1826,7 @@ def __init__(self, debug: bool = False, use_uvloop: bool = True,
self._loop.set_debug(debug)
asyncio.set_event_loop(self._loop)

def _cancel_all_tasks(self):
def _cancel_all_tasks(self) -> None:
to_cancel = all_tasks(self._loop)
if not to_cancel:
return
Expand All @@ -1840,7 +1841,7 @@ def _cancel_all_tasks(self):
if task.cancelled():
continue
if task.exception() is not None:
raise task.exception()
raise cast(BaseException, task.exception())

def close(self) -> None:
try:
Expand All @@ -1850,23 +1851,23 @@ def close(self) -> None:
asyncio.set_event_loop(None)
self._loop.close()

def call(self, func: Callable[..., Awaitable], *args, **kwargs):
def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object) -> T_Retval:
def exception_handler(loop: asyncio.AbstractEventLoop, context: Dict[str, Any]) -> None:
exceptions.append(context['exception'])

exceptions: List[Exception] = []
self._loop.set_exception_handler(exception_handler)
try:
retval = self._loop.run_until_complete(func(*args, **kwargs))
retval: T_Retval = self._loop.run_until_complete(func(*args, **kwargs))
except Exception as exc:
retval = None
retval = None # type: ignore[assignment]
exceptions.append(exc)
finally:
self._loop.set_exception_handler(None)

if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
raise ExceptionGroup(exceptions)
if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
raise ExceptionGroup(exceptions)

return retval
Loading

0 comments on commit a14d442

Please sign in to comment.