Skip to content

Commit

Permalink
Add cap argument to testing mode (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
hynek authored Nov 3, 2024
1 parent 0e1010a commit a8154bf
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 6 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/

## [Unreleased](https://github.com/hynek/stamina/compare/24.3.0...HEAD)

### Added

- *cap* argument to `stamina.set_testing()`.
By default, the value passed as *attempts* is used strictly.
When `cap=True`, it is used as an upper cap; that means that if the original attempts number is lower, it's not changed.
[#80](https://github.com/hynek/stamina/pull/80)


## [24.3.0](https://github.com/hynek/stamina/compare/24.2.0...24.3.0) - 2024-08-27

Expand Down
32 changes: 28 additions & 4 deletions src/stamina/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,29 @@ class _Testing:
Strictly private.
"""

__slots__ = ("attempts",)
__slots__ = ("attempts", "cap")

attempts: int
cap: bool

def __init__(self, attempts: int) -> None:
def __init__(self, attempts: int, cap: bool) -> None:
self.attempts = attempts
self.cap = cap

def get_attempts(self, non_testing_attempts: int | None) -> int:
"""
Get the number of attempts to use.
Args:
non_testing_attempts: The number of attempts specified by the user.
Returns:
The number of attempts to use.
"""
if self.cap:
return min(self.attempts, non_testing_attempts or self.attempts)

return self.attempts


class _Config:
Expand Down Expand Up @@ -137,14 +154,21 @@ def is_testing() -> bool:
return CONFIG.testing is not None


def set_testing(testing: bool, *, attempts: int = 1) -> None:
def set_testing(
testing: bool, *, attempts: int = 1, cap: bool = False
) -> None:
"""
Activate or deactivate test mode.
In testing mode, backoffs are disabled, and attempts are set to *attempts*.
If *cap* is True, the number of attempts is not set but capped at
*attempts*. This means that if *attempts* is greater than the number of
attempts specified by the user, the user's value is used.
Is idempotent and can be called repeatedly with the same values.
.. versionadded:: 24.3.0
.. versionadded:: 24.4.0 *cap*
"""
CONFIG.testing = _Testing(attempts) if testing else None
CONFIG.testing = _Testing(attempts, cap) if testing else None
7 changes: 6 additions & 1 deletion src/stamina/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ class _RetryContextIterator:
"_name",
"_args",
"_kw",
"_attempts",
"_wait_jitter",
"_wait_initial",
"_wait_max",
Expand All @@ -414,6 +415,7 @@ class _RetryContextIterator:
_args: tuple[object, ...]
_kw: dict[str, object]

_attempts: int | None
_wait_jitter: float
_wait_initial: float
_wait_max: float
Expand Down Expand Up @@ -455,6 +457,7 @@ def from_params(
_name=name,
_args=args,
_kw=kw,
_attempts=attempts,
_wait_jitter=wait_jitter,
_wait_initial=wait_initial,
_wait_max=wait_max,
Expand Down Expand Up @@ -494,7 +497,9 @@ def _apply_maybe_test_mode_to_tenacity_kw(

t_kw = self._t_kw.copy()

t_kw["stop"] = _t.stop_after_attempt(testing.attempts)
t_kw["stop"] = _t.stop_after_attempt(
testing.get_attempts(self._attempts)
)

return t_kw

Expand Down
31 changes: 30 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from threading import Lock

from stamina import is_active, set_active
from stamina._config import _Config
from stamina._config import _Config, _Testing


def test_activate_deactivate():
Expand Down Expand Up @@ -38,3 +38,32 @@ def fake_on_retry(self):

assert (1, 2) == cfg._init_on_first_retry()
assert fake_on_retry is cfg._get_on_retry


class TestTesting:
def test_cap_true(self):
"""
If cap is True, get_attempts returns the lower of the two values.
"""
t = _Testing(2, True)

assert 1 == t.get_attempts(1)
assert 2 == t.get_attempts(3)

def test_cap_false(self):
"""
If cap is False, get_attempts always returns the testing value.
"""
t = _Testing(2, False)

assert 2 == t.get_attempts(1)
assert 2 == t.get_attempts(3)

def test_cap_true_with_none(self):
"""
If cap is True and attempts is None, get_attempts returns the
testing value.
"""
t = _Testing(100, True)

assert 100 == t.get_attempts(None)

0 comments on commit a8154bf

Please sign in to comment.