Skip to content

Commit

Permalink
Update tests for the BlockingTrioPortal transition
Browse files Browse the repository at this point in the history
  • Loading branch information
njsmith committed Sep 15, 2017
1 parent 5a37021 commit fa66630
Showing 1 changed file with 42 additions and 21 deletions.
63 changes: 42 additions & 21 deletions trio/tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,53 +41,63 @@ def threadfn():
("f", trio_thread), expected
]

run_in_trio_thread = current_run_in_trio_thread()
portal = BlockingTrioPortal()

def f(record):
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
return 2

await check_case(run_in_trio_thread, f, ("got", 2))
await check_case(portal.run_sync, f, ("got", 2))

def f(record):
assert not _core.currently_ki_protected()
record.append(("f", threading.current_thread()))
raise ValueError

await check_case(run_in_trio_thread, f, ("error", ValueError))

await_in_trio_thread = current_await_in_trio_thread()
await check_case(portal.run_sync, f, ("error", ValueError))

async def f(record):
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
return 3

await check_case(await_in_trio_thread, f, ("got", 3))
await check_case(portal.run, f, ("got", 3))

async def f(record):
assert not _core.currently_ki_protected()
await _core.checkpoint()
record.append(("f", threading.current_thread()))
raise KeyError

await check_case(await_in_trio_thread, f, ("error", KeyError))
await check_case(portal.run, f, ("error", KeyError))


async def test_do_in_trio_thread_from_trio_thread():
run_in_trio_thread = current_run_in_trio_thread()
await_in_trio_thread = current_await_in_trio_thread()
portal = BlockingTrioPortal()

with pytest.raises(RuntimeError):
run_in_trio_thread(lambda: None) # pragma: no branch
portal.run_sync(lambda: None) # pragma: no branch

async def foo(): # pragma: no cover
pass

with pytest.raises(RuntimeError):
await_in_trio_thread(foo)
portal.run(foo)


async def test_BlockingTrioPortal_with_explicit_TrioToken():
token = _core.current_trio_token()

def worker_thread(token):
with pytest.raises(RuntimeError):
BlockingTrioPortal()
portal = BlockingTrioPortal(token)
return portal.run_sync(threading.current_thread)

t = await run_sync_in_worker_thread(worker_thread, token)
assert t == threading.current_thread()


def test_run_in_trio_thread_ki():
Expand All @@ -96,8 +106,7 @@ def test_run_in_trio_thread_ki():
record = set()

async def check_run_in_trio_thread():
run_in_trio_thread = current_run_in_trio_thread()
await_in_trio_thread = current_await_in_trio_thread()
portal = BlockingTrioPortal()

def trio_thread_fn():
print("in trio thread")
Expand All @@ -115,12 +124,12 @@ async def trio_thread_afn():
def external_thread_fn():
try:
print("running")
run_in_trio_thread(trio_thread_fn)
portal.run_sync(trio_thread_fn)
except KeyboardInterrupt:
print("ok1")
record.add("ok1")
try:
await_in_trio_thread(trio_thread_afn)
portal.run(trio_thread_afn)
except KeyboardInterrupt:
print("ok2")
record.add("ok2")
Expand All @@ -147,15 +156,15 @@ async def trio_fn():
ev.set()
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)

def thread_fn(await_in_trio_thread):
def thread_fn(portal):
try:
await_in_trio_thread(trio_fn)
portal.run(trio_fn)
except _core.Cancelled:
record.append("cancelled")

async def main():
aitt = current_await_in_trio_thread()
thread = threading.Thread(target=thread_fn, args=(aitt,))
portal = BlockingTrioPortal()
thread = threading.Thread(target=thread_fn, args=(portal,))
thread.start()
await ev.wait()
assert record == ["sleeping"]
Expand Down Expand Up @@ -319,11 +328,11 @@ class state:
state.running = 0
state.parked = 0

run_in_trio_thread = current_run_in_trio_thread()
portal = BlockingTrioPortal()

def thread_fn(cancel_scope):
print("thread_fn start")
run_in_trio_thread(cancel_scope.cancel)
portal.run_sync(cancel_scope.cancel)
with lock:
state.ran += 1
state.running += 1
Expand Down Expand Up @@ -454,3 +463,15 @@ def bad_start(self):
assert "engines" in str(excinfo.value)

assert limiter.borrowed_tokens == 0


# can remove after deleting 0.2.0 deprecations
async def test_deprecated_portal_API():
trio_thread = threading.current_thread()

async def async_current_thread():
return threading.current_thread()

def worker_thread(portal):
assert portal.run_sync(threading.current_thread) == trio_thread
assert portal.run(async_current_thread) == trio_thread

0 comments on commit fa66630

Please sign in to comment.