Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add run_sync and ensure_async functions #315

Merged
merged 2 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions jupyter_core/tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Tests for async helper functions"""

# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio

from jupyter_core.utils import ensure_async, run_sync


async def afunc():
return "afunc"


def func():
return "func"


sync_afunc = run_sync(afunc)


def test_ensure_async():
async def main():
assert await ensure_async(afunc()) == "afunc"
assert await ensure_async(func()) == "func"

asyncio.run(main())


def test_run_sync():
async def main():
assert sync_afunc() == "afunc"

asyncio.run(main())
97 changes: 97 additions & 0 deletions jupyter_core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import atexit
import errno
import inspect
import os
import sys
import threading
import warnings
from pathlib import Path
from typing import Any, Awaitable, Callable, Optional, TypeVar, Union


def ensure_dir_exists(path, mode=0o777):
Expand Down Expand Up @@ -81,3 +85,96 @@ def deprecation(message, internal="jupyter_core/"):

# The call to .warn adds one frame, so bump the stacklevel up by one
warnings.warn(message, DeprecationWarning, stacklevel=stacklevel + 1)


T = TypeVar("T")


class _TaskRunner:
"""A task runner that runs an asyncio event loop on a background thread."""

def __init__(self):
self.__io_loop: Optional[asyncio.AbstractEventLoop] = None
self.__runner_thread: Optional[threading.Thread] = None
self.__lock = threading.Lock()
atexit.register(self._close)

def _close(self):
if self.__io_loop:
self.__io_loop.stop()

def _runner(self):
loop = self.__io_loop
assert loop is not None
try:
loop.run_forever()
finally:
loop.close()

def run(self, coro):
"""Synchronously run a coroutine on a background thread."""
with self.__lock:
name = f"{threading.current_thread().name} - runner"
if self.__io_loop is None:
self.__io_loop = asyncio.new_event_loop()
self.__runner_thread = threading.Thread(target=self._runner, daemon=True, name=name)
self.__runner_thread.start()
fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop)
return fut.result(None)


_runner_map = {}
_loop_map = {}


def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
"""Runs a coroutine and blocks until it has executed.

Parameters
----------
coro : coroutine
The coroutine to be executed.
Returns
-------
result :
Whatever the coroutine returns.
"""

def wrapped(*args, **kwargs):
name = threading.current_thread().name
inner = coro(*args, **kwargs)
try:
# If a loop is currently running in this thread,
# use a task runner.
asyncio.get_running_loop()
if name not in _runner_map:
_runner_map[name] = _TaskRunner()
return _runner_map[name].run(inner)
except RuntimeError:
pass

# Run the loop for this thread.
if name not in _loop_map:
_loop_map[name] = asyncio.new_event_loop()
loop = _loop_map[name]
return loop.run_until_complete(inner)

wrapped.__doc__ = coro.__doc__
return wrapped


async def ensure_async(obj: Union[Awaitable[Any], Any]) -> Any:
"""Convert a non-awaitable object to a coroutine if needed,
and await it if it was not already awaited.
"""
if inspect.isawaitable(obj):
try:
result = await obj
except RuntimeError as e:
if str(e) == "cannot reuse already awaited coroutine":
# obj is already the coroutine's result
return obj
raise
return result
# obj doesn't need to be awaited
return obj