Skip to content

Commit

Permalink
move_on_ and fail_ functions accepts shield kwarg (#3051)
Browse files Browse the repository at this point in the history
* move_on_ and fail_ context managers accepts shield arg

* make it a kwarg

* news rst

* news rst

* better docstring and parametrize test

* undo

* black

* new line

* update news rst to issue number

* no need to explicitly link to docs

---------

Co-authored-by: EXPLOSION <git@helvetica.moe>
  • Loading branch information
agnesnatasya and A5rocks authored Aug 26, 2024
1 parent 7c08af7 commit cd19652
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
3 changes: 1 addition & 2 deletions docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,7 @@ attribute to :data:`True`:
try:
await conn.send_hello_msg()
finally:
with trio.move_on_after(CLEANUP_TIMEOUT) as cleanup_scope:
cleanup_scope.shield = True
with trio.move_on_after(CLEANUP_TIMEOUT, shield=True) as cleanup_scope:
await conn.send_goodbye_msg()
So long as you're inside a scope with ``shield = True`` set, then
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3052.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`trio.move_on_at`, `trio.move_on_after`, `trio.fail_at` and `trio.fail_after` now accept *shield* as a keyword argument. If specified, it provides an initial value for the `~trio.CancelScope.shield` attribute of the `trio.CancelScope` object created by the context manager.
45 changes: 44 additions & 1 deletion src/trio/_tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Awaitable, Callable, TypeVar
from typing import Awaitable, Callable, Protocol, TypeVar

import outcome
import pytest
Expand Down Expand Up @@ -75,6 +75,49 @@ async def sleep_3() -> None:
await check_takes_about(sleep_3, TARGET)


class TimeoutScope(Protocol):
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...


@pytest.mark.parametrize("scope", [move_on_after, fail_after])
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
outer.cancel()
try:
await trio.lowlevel.checkpoint()
except trio.Cancelled:
pytest.fail("shield didn't work")
inner.shield = False
with pytest.raises(trio.Cancelled):
await trio.lowlevel.checkpoint()


@slow
async def test_move_on_after_moves_on_even_if_shielded() -> None:
async def task() -> None:
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()

await check_takes_about(task, TARGET)


@slow
async def test_fail_after_fails_even_if_shielded() -> None:
async def task() -> None:
with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after(
TARGET, shield=True
):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()

await check_takes_about(task, TARGET)


@slow
async def test_fail() -> None:
async def sleep_4() -> None:
Expand Down
26 changes: 18 additions & 8 deletions src/trio/_timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,40 @@
import trio


def move_on_at(deadline: float) -> trio.CancelScope:
def move_on_at(deadline: float, *, shield: bool = False) -> trio.CancelScope:
"""Use as a context manager to create a cancel scope with the given
absolute deadline.
Args:
deadline (float): The deadline.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.
Raises:
ValueError: if deadline is NaN.
"""
if math.isnan(deadline):
raise ValueError("deadline must not be NaN")
return trio.CancelScope(deadline=deadline)
return trio.CancelScope(deadline=deadline, shield=shield)


def move_on_after(seconds: float) -> trio.CancelScope:
def move_on_after(seconds: float, *, shield: bool = False) -> trio.CancelScope:
"""Use as a context manager to create a cancel scope whose deadline is
set to now + *seconds*.
Args:
seconds (float): The timeout.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.
Raises:
ValueError: if timeout is less than zero or NaN.
"""
if seconds < 0:
raise ValueError("timeout must be non-negative")
return move_on_at(trio.current_time() + seconds)
return move_on_at(trio.current_time() + seconds, shield=shield)


async def sleep_forever() -> None:
Expand Down Expand Up @@ -96,7 +100,7 @@ class TooSlowError(Exception):

# workaround for PyCharm not being able to infer return type from @contextmanager
# see https://youtrack.jetbrains.com/issue/PY-36444/PyCharm-doesnt-infer-types-when-using-contextlib.contextmanager-decorator
def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc]
def fail_at(deadline: float, *, shield: bool = False) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc]
"""Creates a cancel scope with the given deadline, and raises an error if it
is actually cancelled.
Expand All @@ -110,14 +114,16 @@ def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # typ
Args:
deadline (float): The deadline.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.
Raises:
TooSlowError: if a :exc:`Cancelled` exception is raised in this scope
and caught by the context manager.
ValueError: if deadline is NaN.
"""
with move_on_at(deadline) as scope:
with move_on_at(deadline, shield=shield) as scope:
yield scope
if scope.cancelled_caught:
raise TooSlowError
Expand All @@ -127,7 +133,9 @@ def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # typ
fail_at = contextmanager(fail_at)


def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]:
def fail_after(
seconds: float, *, shield: bool = False
) -> AbstractContextManager[trio.CancelScope]:
"""Creates a cancel scope with the given timeout, and raises an error if
it is actually cancelled.
Expand All @@ -140,6 +148,8 @@ def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]:
Args:
seconds (float): The timeout.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.
Raises:
TooSlowError: if a :exc:`Cancelled` exception is raised in this scope
Expand All @@ -149,4 +159,4 @@ def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]:
"""
if seconds < 0:
raise ValueError("timeout must be non-negative")
return fail_at(trio.current_time() + seconds)
return fail_at(trio.current_time() + seconds, shield=shield)

0 comments on commit cd19652

Please sign in to comment.