From fa66630e6d8e1e8d9038d19181a3abab5205620b Mon Sep 17 00:00:00 2001 From: "Nathaniel J. Smith" Date: Fri, 15 Sep 2017 02:25:19 -0700 Subject: [PATCH] Update tests for the BlockingTrioPortal transition --- trio/tests/test_threads.py | 63 +++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/trio/tests/test_threads.py b/trio/tests/test_threads.py index 47e84847d4..854c874fc5 100644 --- a/trio/tests/test_threads.py +++ b/trio/tests/test_threads.py @@ -41,23 +41,21 @@ def threadfn(): ("f", trio_thread), expected ] - run_in_trio_thread = current_run_in_trio_thread() + portal = BlockingTrioPortal() def f(record): assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) return 2 - await check_case(run_in_trio_thread, f, ("got", 2)) + await check_case(portal.run_sync, f, ("got", 2)) def f(record): assert not _core.currently_ki_protected() record.append(("f", threading.current_thread())) raise ValueError - await check_case(run_in_trio_thread, f, ("error", ValueError)) - - await_in_trio_thread = current_await_in_trio_thread() + await check_case(portal.run_sync, f, ("error", ValueError)) async def f(record): assert not _core.currently_ki_protected() @@ -65,7 +63,7 @@ async def f(record): record.append(("f", threading.current_thread())) return 3 - await check_case(await_in_trio_thread, f, ("got", 3)) + await check_case(portal.run, f, ("got", 3)) async def f(record): assert not _core.currently_ki_protected() @@ -73,21 +71,33 @@ async def f(record): record.append(("f", threading.current_thread())) raise KeyError - await check_case(await_in_trio_thread, f, ("error", KeyError)) + await check_case(portal.run, f, ("error", KeyError)) async def test_do_in_trio_thread_from_trio_thread(): - run_in_trio_thread = current_run_in_trio_thread() - await_in_trio_thread = current_await_in_trio_thread() + portal = BlockingTrioPortal() with pytest.raises(RuntimeError): - run_in_trio_thread(lambda: None) # pragma: no branch + portal.run_sync(lambda: None) # pragma: no branch async def foo(): # pragma: no cover pass with pytest.raises(RuntimeError): - await_in_trio_thread(foo) + portal.run(foo) + + +async def test_BlockingTrioPortal_with_explicit_TrioToken(): + token = _core.current_trio_token() + + def worker_thread(token): + with pytest.raises(RuntimeError): + BlockingTrioPortal() + portal = BlockingTrioPortal(token) + return portal.run_sync(threading.current_thread) + + t = await run_sync_in_worker_thread(worker_thread, token) + assert t == threading.current_thread() def test_run_in_trio_thread_ki(): @@ -96,8 +106,7 @@ def test_run_in_trio_thread_ki(): record = set() async def check_run_in_trio_thread(): - run_in_trio_thread = current_run_in_trio_thread() - await_in_trio_thread = current_await_in_trio_thread() + portal = BlockingTrioPortal() def trio_thread_fn(): print("in trio thread") @@ -115,12 +124,12 @@ async def trio_thread_afn(): def external_thread_fn(): try: print("running") - run_in_trio_thread(trio_thread_fn) + portal.run_sync(trio_thread_fn) except KeyboardInterrupt: print("ok1") record.add("ok1") try: - await_in_trio_thread(trio_thread_afn) + portal.run(trio_thread_afn) except KeyboardInterrupt: print("ok2") record.add("ok2") @@ -147,15 +156,15 @@ async def trio_fn(): ev.set() await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED) - def thread_fn(await_in_trio_thread): + def thread_fn(portal): try: - await_in_trio_thread(trio_fn) + portal.run(trio_fn) except _core.Cancelled: record.append("cancelled") async def main(): - aitt = current_await_in_trio_thread() - thread = threading.Thread(target=thread_fn, args=(aitt,)) + portal = BlockingTrioPortal() + thread = threading.Thread(target=thread_fn, args=(portal,)) thread.start() await ev.wait() assert record == ["sleeping"] @@ -319,11 +328,11 @@ class state: state.running = 0 state.parked = 0 - run_in_trio_thread = current_run_in_trio_thread() + portal = BlockingTrioPortal() def thread_fn(cancel_scope): print("thread_fn start") - run_in_trio_thread(cancel_scope.cancel) + portal.run_sync(cancel_scope.cancel) with lock: state.ran += 1 state.running += 1 @@ -454,3 +463,15 @@ def bad_start(self): assert "engines" in str(excinfo.value) assert limiter.borrowed_tokens == 0 + + +# can remove after deleting 0.2.0 deprecations +async def test_deprecated_portal_API(): + trio_thread = threading.current_thread() + + async def async_current_thread(): + return threading.current_thread() + + def worker_thread(portal): + assert portal.run_sync(threading.current_thread) == trio_thread + assert portal.run(async_current_thread) == trio_thread