diff --git a/CHANGELOG.md b/CHANGELOG.md index df1110e..1f08203 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,13 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/ ## [Unreleased](https://github.com/hynek/stamina/compare/24.3.0...HEAD) +### Added + +- *cap* argument to `stamina.set_testing()`. + By default, the value passed as *attempts* is used strictly. + When `cap=True`, it is used as an upper cap; that means that if the original attempts number is lower, it's not changed. + [#80](https://github.com/hynek/stamina/pull/80) + ## [24.3.0](https://github.com/hynek/stamina/compare/24.2.0...24.3.0) - 2024-08-27 diff --git a/src/stamina/_config.py b/src/stamina/_config.py index c89673a..1d0bf2d 100644 --- a/src/stamina/_config.py +++ b/src/stamina/_config.py @@ -19,12 +19,29 @@ class _Testing: Strictly private. """ - __slots__ = ("attempts",) + __slots__ = ("attempts", "cap") attempts: int + cap: bool - def __init__(self, attempts: int) -> None: + def __init__(self, attempts: int, cap: bool) -> None: self.attempts = attempts + self.cap = cap + + def get_attempts(self, non_testing_attempts: int | None) -> int: + """ + Get the number of attempts to use. + + Args: + non_testing_attempts: The number of attempts specified by the user. + + Returns: + The number of attempts to use. + """ + if self.cap: + return min(self.attempts, non_testing_attempts or self.attempts) + + return self.attempts class _Config: @@ -137,14 +154,21 @@ def is_testing() -> bool: return CONFIG.testing is not None -def set_testing(testing: bool, *, attempts: int = 1) -> None: +def set_testing( + testing: bool, *, attempts: int = 1, cap: bool = False +) -> None: """ Activate or deactivate test mode. In testing mode, backoffs are disabled, and attempts are set to *attempts*. + If *cap* is True, the number of attempts is not set but capped at + *attempts*. This means that if *attempts* is greater than the number of + attempts specified by the user, the user's value is used. + Is idempotent and can be called repeatedly with the same values. .. versionadded:: 24.3.0 + .. versionadded:: 24.4.0 *cap* """ - CONFIG.testing = _Testing(attempts) if testing else None + CONFIG.testing = _Testing(attempts, cap) if testing else None diff --git a/src/stamina/_core.py b/src/stamina/_core.py index a1df0e0..ec0efc8 100644 --- a/src/stamina/_core.py +++ b/src/stamina/_core.py @@ -403,6 +403,7 @@ class _RetryContextIterator: "_name", "_args", "_kw", + "_attempts", "_wait_jitter", "_wait_initial", "_wait_max", @@ -414,6 +415,7 @@ class _RetryContextIterator: _args: tuple[object, ...] _kw: dict[str, object] + _attempts: int | None _wait_jitter: float _wait_initial: float _wait_max: float @@ -455,6 +457,7 @@ def from_params( _name=name, _args=args, _kw=kw, + _attempts=attempts, _wait_jitter=wait_jitter, _wait_initial=wait_initial, _wait_max=wait_max, @@ -494,7 +497,9 @@ def _apply_maybe_test_mode_to_tenacity_kw( t_kw = self._t_kw.copy() - t_kw["stop"] = _t.stop_after_attempt(testing.attempts) + t_kw["stop"] = _t.stop_after_attempt( + testing.get_attempts(self._attempts) + ) return t_kw diff --git a/tests/test_config.py b/tests/test_config.py index ca16734..528421f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ from threading import Lock from stamina import is_active, set_active -from stamina._config import _Config +from stamina._config import _Config, _Testing def test_activate_deactivate(): @@ -38,3 +38,32 @@ def fake_on_retry(self): assert (1, 2) == cfg._init_on_first_retry() assert fake_on_retry is cfg._get_on_retry + + +class TestTesting: + def test_cap_true(self): + """ + If cap is True, get_attempts returns the lower of the two values. + """ + t = _Testing(2, True) + + assert 1 == t.get_attempts(1) + assert 2 == t.get_attempts(3) + + def test_cap_false(self): + """ + If cap is False, get_attempts always returns the testing value. + """ + t = _Testing(2, False) + + assert 2 == t.get_attempts(1) + assert 2 == t.get_attempts(3) + + def test_cap_true_with_none(self): + """ + If cap is True and attempts is None, get_attempts returns the + testing value. + """ + t = _Testing(100, True) + + assert 100 == t.get_attempts(None)