Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move_on_ and fail_ functions accepts shield kwarg #3051

Merged
merged 11 commits into from
Aug 26, 2024
3 changes: 1 addition & 2 deletions docs/source/reference-core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,7 @@ attribute to :data:`True`:
try:
await conn.send_hello_msg()
finally:
with trio.move_on_after(CLEANUP_TIMEOUT) as cleanup_scope:
cleanup_scope.shield = True
with trio.move_on_after(CLEANUP_TIMEOUT, shield=True) as cleanup_scope:
await conn.send_goodbye_msg()

So long as you're inside a scope with ``shield = True`` set, then
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3052.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`trio.move_on_at`, `trio.move_on_after`, `trio.fail_at` and `trio.fail_after` now accept *shield* as a keyword argument. If specified, it provides an initial value for the `~trio.CancelScope.shield` attribute of the `trio.CancelScope` object created by the context manager.
45 changes: 44 additions & 1 deletion src/trio/_tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Awaitable, Callable, TypeVar
from typing import Awaitable, Callable, Protocol, TypeVar

import outcome
import pytest
Expand Down Expand Up @@ -75,6 +75,49 @@
await check_takes_about(sleep_3, TARGET)


class TimeoutScope(Protocol):
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...


@pytest.mark.parametrize("scope", [move_on_after, fail_after])
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
outer.cancel()
try:
await trio.lowlevel.checkpoint()
except trio.Cancelled:
pytest.fail("shield didn't work")

Check warning on line 89 in src/trio/_tests/test_timeouts.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_tests/test_timeouts.py#L88-L89

Added lines #L88 - L89 were not covered by tests
inner.shield = False
with pytest.raises(trio.Cancelled):
await trio.lowlevel.checkpoint()


@slow
async def test_move_on_after_moves_on_even_if_shielded() -> None:
async def task() -> None:
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()

await check_takes_about(task, TARGET)


@slow
async def test_fail_after_fails_even_if_shielded() -> None:
async def task() -> None:
with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after(
TARGET, shield=True
):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()

await check_takes_about(task, TARGET)


@slow
async def test_fail() -> None:
async def sleep_4() -> None:
Expand Down
26 changes: 18 additions & 8 deletions src/trio/_timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,40 @@
import trio


def move_on_at(deadline: float) -> trio.CancelScope:
def move_on_at(deadline: float, *, shield: bool = False) -> trio.CancelScope:
"""Use as a context manager to create a cancel scope with the given
absolute deadline.

Args:
deadline (float): The deadline.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.

Raises:
ValueError: if deadline is NaN.

"""
if math.isnan(deadline):
raise ValueError("deadline must not be NaN")
return trio.CancelScope(deadline=deadline)
return trio.CancelScope(deadline=deadline, shield=shield)


def move_on_after(seconds: float) -> trio.CancelScope:
def move_on_after(seconds: float, *, shield: bool = False) -> trio.CancelScope:
"""Use as a context manager to create a cancel scope whose deadline is
set to now + *seconds*.

Args:
seconds (float): The timeout.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.

Raises:
ValueError: if timeout is less than zero or NaN.

"""
if seconds < 0:
raise ValueError("timeout must be non-negative")
return move_on_at(trio.current_time() + seconds)
return move_on_at(trio.current_time() + seconds, shield=shield)


async def sleep_forever() -> None:
Expand Down Expand Up @@ -96,7 +100,7 @@ class TooSlowError(Exception):

# workaround for PyCharm not being able to infer return type from @contextmanager
# see https://youtrack.jetbrains.com/issue/PY-36444/PyCharm-doesnt-infer-types-when-using-contextlib.contextmanager-decorator
def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc]
def fail_at(deadline: float, *, shield: bool = False) -> AbstractContextManager[trio.CancelScope]: # type: ignore[misc]
"""Creates a cancel scope with the given deadline, and raises an error if it
is actually cancelled.

Expand All @@ -110,14 +114,16 @@ def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # typ

Args:
deadline (float): The deadline.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.

Raises:
TooSlowError: if a :exc:`Cancelled` exception is raised in this scope
and caught by the context manager.
ValueError: if deadline is NaN.

"""
with move_on_at(deadline) as scope:
with move_on_at(deadline, shield=shield) as scope:
yield scope
if scope.cancelled_caught:
raise TooSlowError
Expand All @@ -127,7 +133,9 @@ def fail_at(deadline: float) -> AbstractContextManager[trio.CancelScope]: # typ
fail_at = contextmanager(fail_at)


def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]:
def fail_after(
seconds: float, *, shield: bool = False
) -> AbstractContextManager[trio.CancelScope]:
"""Creates a cancel scope with the given timeout, and raises an error if
it is actually cancelled.

Expand All @@ -140,6 +148,8 @@ def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]:

Args:
seconds (float): The timeout.
shield (bool): Initial value for the `~trio.CancelScope.shield` attribute
of the newly created cancel scope.

Raises:
TooSlowError: if a :exc:`Cancelled` exception is raised in this scope
Expand All @@ -149,4 +159,4 @@ def fail_after(seconds: float) -> AbstractContextManager[trio.CancelScope]:
"""
if seconds < 0:
raise ValueError("timeout must be non-negative")
return fail_at(trio.current_time() + seconds)
return fail_at(trio.current_time() + seconds, shield=shield)
Loading