Skip to content

Commit

Permalink
New APIs to run trio-based code in the subprocess (#20)
Browse files Browse the repository at this point in the history
New APIs to run trio-based code in the subprocess
  • Loading branch information
gsalgado authored May 13, 2020
1 parent 3ad3420 commit d59b5e7
Show file tree
Hide file tree
Showing 15 changed files with 290 additions and 96 deletions.
2 changes: 2 additions & 0 deletions asyncio_run_in_process/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
)
from .run_in_process import ( # noqa: F401
open_in_process,
open_in_process_with_trio,
run_in_process,
run_in_process_with_trio,
)
from .state import ( # noqa: F401
State,
Expand Down
78 changes: 36 additions & 42 deletions asyncio_run_in_process/_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import (
Any,
BinaryIO,
Callable,
Coroutine,
Sequence,
cast,
Expand All @@ -19,11 +18,18 @@
pickle_value,
receive_pickled_value,
)
from .abc import (
TAsyncFn,
TEngineRunner,
)
from .exceptions import (
ChildCancelled,
)
from .state import (
State,
update_state,
update_state_finished,
update_state_initialized,
)
from .typing import (
TReturn,
Expand All @@ -32,23 +38,6 @@
logger = logging.getLogger("asyncio_run_in_process")


def update_state(to_parent: BinaryIO, state: State) -> None:
to_parent.write(state.value.to_bytes(1, 'big'))
to_parent.flush()


def update_state_initialized(to_parent: BinaryIO) -> None:
payload = State.INITIALIZED.value.to_bytes(1, 'big') + os.getpid().to_bytes(4, 'big')
to_parent.write(payload)
to_parent.flush()


def update_state_finished(to_parent: BinaryIO, finished_payload: bytes) -> None:
payload = State.FINISHED.value.to_bytes(1, 'big') + finished_payload
to_parent.write(payload)
to_parent.flush()


SHUTDOWN_SIGNALS = {signal.SIGTERM}


Expand Down Expand Up @@ -119,7 +108,7 @@ async def _handle_coro(coro: Coroutine[Any, Any, TReturn], got_SIGINT: asyncio.E


async def _do_async_fn(
async_fn: Callable[..., Coroutine[Any, Any, TReturn]],
async_fn: TAsyncFn,
args: Sequence[Any],
to_parent: BinaryIO,
loop: asyncio.AbstractEventLoop,
Expand Down Expand Up @@ -178,7 +167,28 @@ async def _do_async_fn(
raise Exception("unreachable")


def _run_process(parent_pid: int, fd_read: int, fd_write: int) -> None:
def _run_on_asyncio(async_fn: TAsyncFn, args: Sequence[Any], to_parent: BinaryIO) -> None:
loop = asyncio.get_event_loop()
try:
result: Any = loop.run_until_complete(_do_async_fn(async_fn, args, to_parent, loop))
except BaseException:
exc_type, exc_value, exc_tb = sys.exc_info()
# `mypy` thinks that `exc_value` and `exc_tb` are `Optional[..]` types
if exc_type is asyncio.CancelledError:
exc_value = ChildCancelled(*exc_value.args) # type: ignore
remote_exc = RemoteException(exc_value, exc_tb) # type: ignore
finished_payload = pickle_value(remote_exc)
raise
else:
finished_payload = pickle_value(result)
finally:
# XXX: The STOPPING state seems useless as nothing happens between that and the FINISHED
# state.
update_state(to_parent, State.STOPPING)
update_state_finished(to_parent, finished_payload)


def run_process(runner: TEngineRunner, parent_pid: int, fd_read: int, fd_write: int) -> None:
"""
Run the child process.
Expand All @@ -199,36 +209,17 @@ def _run_process(parent_pid: int, fd_read: int, fd_write: int) -> None:
# state: BOOTING
update_state(to_parent, State.BOOTING)

loop = asyncio.get_event_loop()

try:
try:
result = loop.run_until_complete(
_do_async_fn(async_fn, args, to_parent, loop),
)
except BaseException:
exc_type, exc_value, exc_tb = sys.exc_info()
# `mypy` thinks that `exc_value` and `exc_tb` are `Optional[..]` types
if exc_type is asyncio.CancelledError:
exc_value = ChildCancelled(*exc_value.args) # type: ignore
remote_exc = RemoteException(exc_value, exc_tb) # type: ignore
finished_payload = pickle_value(remote_exc)
raise
finally:
# state: STOPPING
update_state(to_parent, State.STOPPING)
runner(async_fn, args, to_parent)
except KeyboardInterrupt:
code = 2
except SystemExit as err:
code = err.args[0]
except BaseException:
code = 1
else:
finished_payload = pickle_value(result)
code = 0
finally:
# state: FINISHED
update_state_finished(to_parent, finished_payload)
sys.exit(code)


Expand Down Expand Up @@ -261,6 +252,9 @@ def _run_process(parent_pid: int, fd_read: int, fd_write: int) -> None:

if __name__ == "__main__":
args = parser.parse_args()
_run_process(
parent_pid=args.parent_pid, fd_read=args.fd_read, fd_write=args.fd_write
run_process(
runner=_run_on_asyncio,
parent_pid=args.parent_pid,
fd_read=args.fd_read,
fd_write=args.fd_write,
)
82 changes: 82 additions & 0 deletions asyncio_run_in_process/_child_trio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import signal
from typing import (
Any,
AsyncIterator,
Awaitable,
BinaryIO,
Callable,
Sequence,
)

import trio
import trio_typing

from ._utils import (
pickle_value,
)
from .abc import (
TAsyncFn,
)
from .state import (
State,
update_state,
update_state_finished,
)
from .typing import (
TReturn,
)

SHUTDOWN_SIGNALS = {signal.SIGTERM}


async def _do_monitor_signals(signal_aiter: AsyncIterator[int]) -> None:
async for signum in signal_aiter:
raise SystemExit(signum)


@trio_typing.takes_callable_and_args
async def _do_async_fn(
async_fn: Callable[..., Awaitable[TReturn]],
args: Sequence[Any],
to_parent: BinaryIO,
) -> TReturn:
with trio.open_signal_receiver(*SHUTDOWN_SIGNALS) as signal_aiter:
# state: STARTED
update_state(to_parent, State.STARTED)

async with trio.open_nursery() as nursery:
nursery.start_soon(_do_monitor_signals, signal_aiter)

# state: EXECUTING
update_state(to_parent, State.EXECUTING)

result = await async_fn(*args)

nursery.cancel_scope.cancel()
return result


def _run_on_trio(async_fn: TAsyncFn, args: Sequence[Any], to_parent: BinaryIO) -> None:
try:
result = trio.run(_do_async_fn, async_fn, args, to_parent)
except BaseException as err:
finished_payload = pickle_value(err)
raise
else:
finished_payload = pickle_value(result)
finally:
# XXX: The STOPPING state seems useless as nothing happens between that and the FINISHED
# state.
update_state(to_parent, State.STOPPING)
update_state_finished(to_parent, finished_payload)


if __name__ == "__main__":
from asyncio_run_in_process._child import parser, run_process
args = parser.parse_args()
run_process(
runner=_run_on_trio,
parent_pid=args.parent_pid,
fd_read=args.fd_read,
fd_write=args.fd_write,
)
18 changes: 8 additions & 10 deletions asyncio_run_in_process/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
AsyncIterator,
BinaryIO,
Tuple,
cast,
)

from async_generator import (
Expand All @@ -22,14 +21,17 @@


def get_subprocess_command(
child_r: int, child_w: int, parent_pid: int
child_r: int, child_w: int, parent_pid: int, use_trio: bool,
) -> Tuple[str, ...]:
from . import _child
if use_trio:
from . import _child_trio as child_runner
else:
from . import _child as child_runner # type: ignore

return (
sys.executable,
"-m",
_child.__name__,
child_runner.__name__,
"--parent-pid",
str(parent_pid),
"--fd-read",
Expand Down Expand Up @@ -105,14 +107,10 @@ def cleanup_tasks(*tasks: 'asyncio.Future[Any]') -> AsyncContextManager[None]:
This function **must** be called with at least one task.
"""
return cast(
AsyncContextManager[None],
_cleanup_tasks(*tasks),
)
return _cleanup_tasks(*tasks)


# mypy recognizes this decorator as being untyped.
@asynccontextmanager # type: ignore
@asynccontextmanager
async def _cleanup_tasks(task: 'asyncio.Future[Any]',
*tasks: 'asyncio.Future[Any]',
) -> AsyncIterator[None]:
Expand Down
9 changes: 9 additions & 0 deletions asyncio_run_in_process/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
)
import signal
from typing import (
Any,
BinaryIO,
Callable,
Coroutine,
Generic,
Optional,
Sequence,
TypeVar,
)

from .state import (
Expand All @@ -15,6 +21,9 @@
TReturn,
)

TAsyncFn = TypeVar("TAsyncFn", bound=Callable[..., Coroutine[Any, Any, TReturn]])
TEngineRunner = TypeVar("TEngineRunner", bound=Callable[[TAsyncFn, Sequence[Any], BinaryIO], None])


class ProcessAPI(ABC, Generic[TReturn]):
sub_proc_payload: bytes
Expand Down
35 changes: 27 additions & 8 deletions asyncio_run_in_process/run_in_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AsyncIterator,
Callable,
Optional,
cast,
)

from async_generator import (
Expand Down Expand Up @@ -185,10 +184,17 @@ def open_in_process(
loop: asyncio.AbstractEventLoop = None,
subprocess_kwargs: 'SubprocessKwargs' = None,
) -> AsyncContextManager[ProcessAPI[TReturn]]:
return cast(
AsyncContextManager[ProcessAPI[TReturn]],
_open_in_process(async_fn, *args, loop=loop, subprocess_kwargs=subprocess_kwargs),
)
return _open_in_process(
async_fn, *args, loop=loop, subprocess_kwargs=subprocess_kwargs, use_trio=False)


def open_in_process_with_trio(
async_fn: Callable[..., TReturn],
*args: Any,
subprocess_kwargs: 'SubprocessKwargs' = None,
) -> AsyncContextManager[ProcessAPI[TReturn]]:
return _open_in_process(
async_fn, *args, loop=None, subprocess_kwargs=subprocess_kwargs, use_trio=True)


def _update_subprocess_kwargs(subprocess_kwargs: Optional['SubprocessKwargs'],
Expand All @@ -211,21 +217,24 @@ def _update_subprocess_kwargs(subprocess_kwargs: Optional['SubprocessKwargs'],
return updated_kwargs


# mypy recognizes this decorator as being untyped.
@asynccontextmanager # type: ignore
@asynccontextmanager
async def _open_in_process(
async_fn: Callable[..., TReturn],
*args: Any,
loop: asyncio.AbstractEventLoop = None,
subprocess_kwargs: 'SubprocessKwargs' = None,
use_trio: bool = False,
) -> AsyncIterator[ProcessAPI[TReturn]]:
if use_trio and loop is not None:
raise ValueError("If using trio, cannot specify a loop")

proc: Process[TReturn] = Process(async_fn, args)

parent_r, child_w = os.pipe()
child_r, parent_w = os.pipe()
parent_pid = os.getpid()

command = get_subprocess_command(child_r, child_w, parent_pid)
command = get_subprocess_command(child_r, child_w, parent_pid, use_trio)

sub_proc = await asyncio.create_subprocess_exec(
*command,
Expand Down Expand Up @@ -346,3 +355,13 @@ async def run_in_process(async_fn: Callable[..., TReturn],
async with proc_ctx as proc:
await proc.wait()
return proc.get_result_or_raise()


async def run_in_process_with_trio(async_fn: Callable[..., TReturn],
*args: Any,
subprocess_kwargs: 'SubprocessKwargs' = None) -> TReturn:
proc_ctx = open_in_process_with_trio(
async_fn, *args, subprocess_kwargs=subprocess_kwargs)
async with proc_ctx as proc:
await proc.wait()
return proc.get_result_or_raise()
Loading

0 comments on commit d59b5e7

Please sign in to comment.