diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index 2889e23f13..28c8a7c08a 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -6,7 +6,7 @@ from . import _core from . import socket as tsocket from ._socket import real_socket_type -from ._util import UnLock +from ._util import ConflictDetector from .abc import HalfCloseableStream, Listener from ._highlevel_generic import ( ClosedStreamError, BrokenStreamError, ClosedListenerError @@ -71,8 +71,7 @@ def __init__(self, socket): raise err from None self.socket = socket - self._send_lock = UnLock( - _core.ResourceBusyError, + self._send_conflict_detector = ConflictDetector( "another task is currently sending data on this SocketStream" ) @@ -105,19 +104,19 @@ async def send_all(self, data): if self.socket.did_shutdown_SHUT_WR: await _core.yield_briefly() raise ClosedStreamError("can't send data after sending EOF") - with self._send_lock.sync: + with self._send_conflict_detector.sync: with _translate_socket_errors_to_stream_errors(): await self.socket.sendall(data) async def wait_send_all_might_not_block(self): - async with self._send_lock: + async with self._send_conflict_detector: if self.socket.fileno() == -1: raise ClosedStreamError with _translate_socket_errors_to_stream_errors(): await self.socket.wait_writable() async def send_eof(self): - async with self._send_lock: + async with self._send_conflict_detector: # On MacOS, calling shutdown a second time raises ENOTCONN, but # send_eof needs to be idempotent. if self.socket.did_shutdown_SHUT_WR: diff --git a/trio/_ssl.py b/trio/_ssl.py index 2564b30b1e..24f36817bb 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -159,7 +159,7 @@ BrokenStreamError, ClosedStreamError, aclose_forcefully ) from . import _sync -from ._util import UnLock +from ._util import ConflictDetector __all__ = ["SSLStream", "SSLListener"] @@ -368,12 +368,10 @@ def __init__( # These are used to make sure that our caller doesn't attempt to make # multiple concurrent calls to send_all/wait_send_all_might_not_block # or to receive_some. - self._outer_send_lock = UnLock( - _core.ResourceBusyError, + self._outer_send_conflict_detector = ConflictDetector( "another task is currently sending data on this SSLStream" ) - self._outer_recv_lock = UnLock( - _core.ResourceBusyError, + self._outer_recv_conflict_detector = ConflictDetector( "another task is currently receiving data on this SSLStream" ) @@ -624,7 +622,7 @@ async def receive_some(self, max_bytes): :exc:`trio.BrokenStreamError`. """ - async with self._outer_recv_lock: + async with self._outer_recv_conflict_detector: self._check_status() try: await self._handshook.ensure(checkpoint=False) @@ -666,7 +664,7 @@ async def send_all(self, data): :exc:`trio.BrokenStreamError`. """ - async with self._outer_send_lock: + async with self._outer_send_conflict_detector: self._check_status() await self._handshook.ensure(checkpoint=False) # SSLObject interprets write(b"") as an EOF for some reason, which @@ -693,7 +691,8 @@ async def unwrap(self): ``transport_stream.receive_some(...)``. """ - async with self._outer_recv_lock, self._outer_send_lock: + async with self._outer_recv_conflict_detector, \ + self._outer_send_conflict_detector: self._check_status() await self._handshook.ensure(checkpoint=False) await self._retry(self._ssl_object.unwrap) @@ -797,7 +796,7 @@ async def wait_send_all_might_not_block(self): # semantics that wait_send_all_might_not_block and send_all # conflict. This also takes care of providing correct checkpoint # semantics before we potentially error out from _check_status(). - async with self._outer_send_lock: + async with self._outer_send_conflict_detector: self._check_status() # Then we take the inner send lock. We know that no other tasks # are calling self.send_all or self.wait_send_all_might_not_block, diff --git a/trio/_util.py b/trio/_util.py index 32ee70d58b..8a39aa7765 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -8,16 +8,16 @@ # There's a dependency loop here... _core is allowed to use this file (in fact # it's the *only* file in the main trio/ package it's allowed to use), but -# UnLock needs yield_briefly so it also has to import _core. Possibly we -# should split this file into two: one for true generic low-level utility -# code, and one for higher level helpers? +# ConflictDetector needs yield_briefly so it also has to import +# _core. Possibly we should split this file into two: one for true generic +# low-level utility code, and one for higher level helpers? from . import _core __all__ = [ "signal_raise", "aiter_compat", "acontextmanager", - "UnLock", + "ConflictDetector", "fixup_module_metadata", ] @@ -176,15 +176,14 @@ def helper(*args, **kwds): return helper -class _UnLockSync: - def __init__(self, exc, *args): - self._exc = exc - self._args = args +class _ConflictDetectorSync: + def __init__(self, msg): + self._msg = msg self._held = False def __enter__(self): if self._held: - raise self._exc(*self._args) + raise _core.ResourceBusyError(self._msg) else: self._held = True @@ -192,8 +191,9 @@ def __exit__(self, *args): self._held = False -class UnLock: - """An unnecessary lock. +class ConflictDetector: + """Detect when two tasks are about to perform operations that would + conflict. Use as an async context manager; if two tasks enter it at the same time then the second one raises an error. You can use it when there are @@ -205,10 +205,13 @@ class UnLock: This executes a checkpoint on entry. That's the only reason it's async. + To use from sync code, do ``with cd.sync``; this is just like ``async with + cd`` except that it doesn't execute a checkpoint. + """ - def __init__(self, exc, *args): - self.sync = _UnLockSync(exc, *args) + def __init__(self, msg): + self.sync = _ConflictDetectorSync(msg) async def __aenter__(self): await _core.yield_briefly() diff --git a/trio/testing/_memory_streams.py b/trio/testing/_memory_streams.py index e36dcec411..8b7e31fe18 100644 --- a/trio/testing/_memory_streams.py +++ b/trio/testing/_memory_streams.py @@ -27,8 +27,8 @@ def __init__(self): self._data = bytearray() self._closed = False self._lot = _core.ParkingLot() - self._fetch_lock = _util.UnLock( - _core.ResourceBusyError, "another task is already fetching data" + self._fetch_lock = _util.ConflictDetector( + "another task is already fetching data" ) def close(self): @@ -102,8 +102,8 @@ def __init__( wait_send_all_might_not_block_hook=None, close_hook=None ): - self._lock = _util.UnLock( - _core.ResourceBusyError, "another task is using this stream" + self._conflict_detector = _util.ConflictDetector( + "another task is using this stream" ) self._outgoing = _UnboundedByteQueue() self.send_all_hook = send_all_hook @@ -118,7 +118,7 @@ async def send_all(self, data): # The lock itself is a checkpoint, but then we also yield inside the # lock to give ourselves a chance to detect buggy user code that calls # this twice at the same time. - async with self._lock: + async with self._conflict_detector: await _core.yield_briefly() self._outgoing.put(data) if self.send_all_hook is not None: @@ -132,7 +132,7 @@ async def wait_send_all_might_not_block(self): # The lock itself is a checkpoint, but then we also yield inside the # lock to give ourselves a chance to detect buggy user code that calls # this twice at the same time. - async with self._lock: + async with self._conflict_detector: await _core.yield_briefly() # check for being closed: self._outgoing.put(b"") @@ -201,8 +201,8 @@ class MemoryReceiveStream(ReceiveStream): """ def __init__(self, receive_some_hook=None, close_hook=None): - self._lock = _util.UnLock( - _core.ResourceBusyError, "another task is using this stream" + self._conflict_detector = _util.ConflictDetector( + "another task is using this stream" ) self._incoming = _UnboundedByteQueue() self._closed = False @@ -217,7 +217,7 @@ async def receive_some(self, max_bytes): # The lock itself is a checkpoint, but then we also yield inside the # lock to give ourselves a chance to detect buggy user code that calls # this twice at the same time. - async with self._lock: + async with self._conflict_detector: await _core.yield_briefly() if max_bytes is None: raise TypeError("max_bytes must not be None") @@ -435,11 +435,11 @@ def __init__(self): self._receiver_closed = False self._receiver_waiting = False self._waiters = _core.ParkingLot() - self._send_lock = _util.UnLock( - _core.ResourceBusyError, "another task is already sending" + self._send_conflict_detector = _util.ConflictDetector( + "another task is already sending" ) - self._receive_lock = _util.UnLock( - _core.ResourceBusyError, "another task is already receiving" + self._receive_conflict_detector = _util.ConflictDetector( + "another task is already receiving" ) def _something_happened(self): @@ -459,7 +459,7 @@ def close_receiver(self): self._something_happened() async def send_all(self, data): - async with self._send_lock: + async with self._send_conflict_detector: if self._sender_closed: raise ClosedStreamError if self._receiver_closed: @@ -476,7 +476,7 @@ async def send_all(self, data): return async def wait_send_all_might_not_block(self): - async with self._send_lock: + async with self._send_conflict_detector: if self._sender_closed: raise ClosedStreamError if self._receiver_closed: @@ -486,7 +486,7 @@ async def wait_send_all_might_not_block(self): ) async def receive_some(self, max_bytes): - async with self._receive_lock: + async with self._receive_conflict_detector: # Argument validation max_bytes = operator.index(max_bytes) if max_bytes < 1: diff --git a/trio/tests/test_ssl.py b/trio/tests/test_ssl.py index 629c59163d..84509d21f1 100644 --- a/trio/tests/test_ssl.py +++ b/trio/tests/test_ssl.py @@ -20,7 +20,7 @@ from .._highlevel_open_tcp_stream import open_tcp_stream from .. import ssl as tssl from .. import socket as tsocket -from .._util import UnLock, acontextmanager +from .._util import ConflictDetector, acontextmanager from .._core.tests.tutil import slow @@ -175,12 +175,10 @@ def __init__(self, sleeper=None): self._lot = _core.ParkingLot() self._pending_cleartext = bytearray() - self._send_all_mutex = UnLock( - _core.ResourceBusyError, + self._send_all_conflict_detector = ConflictDetector( "simultaneous calls to PyOpenSSLEchoStream.send_all" ) - self._receive_some_mutex = UnLock( - _core.ResourceBusyError, + self._receive_some_conflict_detector = ConflictDetector( "simultaneous calls to PyOpenSSLEchoStream.receive_some" ) @@ -205,13 +203,13 @@ def renegotiate(self): assert self._conn.renegotiate() async def wait_send_all_might_not_block(self): - async with self._send_all_mutex: + async with self._send_all_conflict_detector: await _core.yield_briefly() await self.sleeper("wait_send_all_might_not_block") async def send_all(self, data): print(" --> transport_stream.send_all") - async with self._send_all_mutex: + async with self._send_all_conflict_detector: await _core.yield_briefly() await self.sleeper("send_all") self._conn.bio_write(data) @@ -233,7 +231,7 @@ async def send_all(self, data): async def receive_some(self, nbytes): print(" --> transport_stream.receive_some") - async with self._receive_some_mutex: + async with self._receive_some_conflict_detector: try: await _core.yield_briefly() while True: diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index 7a29012dd7..4765ff969f 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -25,40 +25,34 @@ def handler(signum, _): assert record == [signal.SIGFPE] -async def test_UnLock(): - ul1 = UnLock(RuntimeError, "ul1") - ul2 = UnLock(ValueError) +async def test_ConflictDetector(): + ul1 = ConflictDetector("ul1") + ul2 = ConflictDetector("ul2") async with ul1: with assert_yields(): async with ul2: print("ok") - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(_core.ResourceBusyError) as excinfo: async with ul1: with assert_yields(): async with ul1: pass # pragma: no cover assert "ul1" in str(excinfo.value) - with pytest.raises(ValueError) as excinfo: - async with ul2: - with assert_yields(): - async with ul2: - pass # pragma: no cover - async def wait_with_ul1(): async with ul1: await wait_all_tasks_blocked() - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(_core.ResourceBusyError) as excinfo: async with _core.open_nursery() as nursery: nursery.spawn(wait_with_ul1) nursery.spawn(wait_with_ul1) assert "ul1" in str(excinfo.value) # mixing sync and async entry - with pytest.raises(RuntimeError) as excinfo: + with pytest.raises(_core.ResourceBusyError) as excinfo: with ul1.sync: with assert_yields(): async with ul1: