From db91eca22ef696998110d2bbf2ab7becf64204d1 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Wed, 26 Aug 2020 21:41:17 -0400 Subject: [PATCH] Preliminary configurable trio.run for trio_test() --- trio/testing/_trio_test.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/trio/testing/_trio_test.py b/trio/testing/_trio_test.py index 4fcaeae372..14358a3615 100644 --- a/trio/testing/_trio_test.py +++ b/trio/testing/_trio_test.py @@ -12,18 +12,24 @@ # # Also: if a pytest fixture is passed in that subclasses the Clock abc, then # that clock is passed to trio.run(). -def trio_test(fn): - @wraps(fn) - def wrapper(**kwargs): - __tracebackhide__ = True - clocks = [c for c in kwargs.values() if isinstance(c, Clock)] - if not clocks: - clock = None - elif len(clocks) == 1: - clock = clocks[0] - else: - raise ValueError("too many clocks spoil the broth!") - instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] - return _core.run(partial(fn, **kwargs), clock=clock, instruments=instruments) +def trio_test(fn=None, *, run=_core.run): + def decorator(fn): + @wraps(fn) + def wrapper(**kwargs): + __tracebackhide__ = True + clocks = [c for c in kwargs.values() if isinstance(c, Clock)] + if not clocks: + clock = None + elif len(clocks) == 1: + clock = clocks[0] + else: + raise ValueError("too many clocks spoil the broth!") + instruments = [i for i in kwargs.values() if isinstance(i, Instrument)] + return run(partial(fn, **kwargs), clock=clock, instruments=instruments) - return wrapper + return wrapper + + if fn is None: + return decorator + + return decorator(fn)