Skip to content

Commit

Permalink
Refactor UnLock -> ConflictDetector
Browse files Browse the repository at this point in the history
Just rename it and make it always do ResourceBusyError, since that's
how it's always used.

Fixes python-triogh-197
  • Loading branch information
njsmith committed Aug 19, 2017
1 parent 38b49fa commit 675555d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 64 deletions.
11 changes: 5 additions & 6 deletions trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import _core
from . import socket as tsocket
from ._socket import real_socket_type
from ._util import UnLock
from ._util import ConflictDetector
from .abc import HalfCloseableStream, Listener
from ._highlevel_generic import (
ClosedStreamError, BrokenStreamError, ClosedListenerError
Expand Down Expand Up @@ -71,8 +71,7 @@ def __init__(self, socket):
raise err from None

self.socket = socket
self._send_lock = UnLock(
_core.ResourceBusyError,
self._send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SocketStream"
)

Expand Down Expand Up @@ -105,19 +104,19 @@ async def send_all(self, data):
if self.socket.did_shutdown_SHUT_WR:
await _core.yield_briefly()
raise ClosedStreamError("can't send data after sending EOF")
with self._send_lock.sync:
with self._send_conflict_detector.sync:
with _translate_socket_errors_to_stream_errors():
await self.socket.sendall(data)

async def wait_send_all_might_not_block(self):
async with self._send_lock:
async with self._send_conflict_detector:
if self.socket.fileno() == -1:
raise ClosedStreamError
with _translate_socket_errors_to_stream_errors():
await self.socket.wait_writable()

async def send_eof(self):
async with self._send_lock:
async with self._send_conflict_detector:
# On MacOS, calling shutdown a second time raises ENOTCONN, but
# send_eof needs to be idempotent.
if self.socket.did_shutdown_SHUT_WR:
Expand Down
17 changes: 8 additions & 9 deletions trio/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
BrokenStreamError, ClosedStreamError, aclose_forcefully
)
from . import _sync
from ._util import UnLock
from ._util import ConflictDetector

__all__ = ["SSLStream", "SSLListener"]

Expand Down Expand Up @@ -368,12 +368,10 @@ def __init__(
# These are used to make sure that our caller doesn't attempt to make
# multiple concurrent calls to send_all/wait_send_all_might_not_block
# or to receive_some.
self._outer_send_lock = UnLock(
_core.ResourceBusyError,
self._outer_send_conflict_detector = ConflictDetector(
"another task is currently sending data on this SSLStream"
)
self._outer_recv_lock = UnLock(
_core.ResourceBusyError,
self._outer_recv_conflict_detector = ConflictDetector(
"another task is currently receiving data on this SSLStream"
)

Expand Down Expand Up @@ -624,7 +622,7 @@ async def receive_some(self, max_bytes):
:exc:`trio.BrokenStreamError`.
"""
async with self._outer_recv_lock:
async with self._outer_recv_conflict_detector:
self._check_status()
try:
await self._handshook.ensure(checkpoint=False)
Expand Down Expand Up @@ -666,7 +664,7 @@ async def send_all(self, data):
:exc:`trio.BrokenStreamError`.
"""
async with self._outer_send_lock:
async with self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
# SSLObject interprets write(b"") as an EOF for some reason, which
Expand All @@ -693,7 +691,8 @@ async def unwrap(self):
``transport_stream.receive_some(...)``.
"""
async with self._outer_recv_lock, self._outer_send_lock:
async with self._outer_recv_conflict_detector, \
self._outer_send_conflict_detector:
self._check_status()
await self._handshook.ensure(checkpoint=False)
await self._retry(self._ssl_object.unwrap)
Expand Down Expand Up @@ -797,7 +796,7 @@ async def wait_send_all_might_not_block(self):
# semantics that wait_send_all_might_not_block and send_all
# conflict. This also takes care of providing correct checkpoint
# semantics before we potentially error out from _check_status().
async with self._outer_send_lock:
async with self._outer_send_conflict_detector:
self._check_status()
# Then we take the inner send lock. We know that no other tasks
# are calling self.send_all or self.wait_send_all_might_not_block,
Expand Down
29 changes: 16 additions & 13 deletions trio/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

# There's a dependency loop here... _core is allowed to use this file (in fact
# it's the *only* file in the main trio/ package it's allowed to use), but
# UnLock needs yield_briefly so it also has to import _core. Possibly we
# should split this file into two: one for true generic low-level utility
# code, and one for higher level helpers?
# ConflictDetector needs yield_briefly so it also has to import
# _core. Possibly we should split this file into two: one for true generic
# low-level utility code, and one for higher level helpers?
from . import _core

__all__ = [
"signal_raise",
"aiter_compat",
"acontextmanager",
"UnLock",
"ConflictDetector",
"fixup_module_metadata",
]

Expand Down Expand Up @@ -176,24 +176,24 @@ def helper(*args, **kwds):
return helper


class _UnLockSync:
def __init__(self, exc, *args):
self._exc = exc
self._args = args
class _ConflictDetectorSync:
def __init__(self, msg):
self._msg = msg
self._held = False

def __enter__(self):
if self._held:
raise self._exc(*self._args)
raise _core.ResourceBusyError(self._msg)
else:
self._held = True

def __exit__(self, *args):
self._held = False


class UnLock:
"""An unnecessary lock.
class ConflictDetector:
"""Detect when two tasks are about to perform operations that would
conflict.
Use as an async context manager; if two tasks enter it at the same
time then the second one raises an error. You can use it when there are
Expand All @@ -205,10 +205,13 @@ class UnLock:
This executes a checkpoint on entry. That's the only reason it's async.
To use from sync code, do ``with cd.sync``; this is just like ``async with
cd`` except that it doesn't execute a checkpoint.
"""

def __init__(self, exc, *args):
self.sync = _UnLockSync(exc, *args)
def __init__(self, msg):
self.sync = _ConflictDetectorSync(msg)

async def __aenter__(self):
await _core.yield_briefly()
Expand Down
32 changes: 16 additions & 16 deletions trio/testing/_memory_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self):
self._data = bytearray()
self._closed = False
self._lot = _core.ParkingLot()
self._fetch_lock = _util.UnLock(
_core.ResourceBusyError, "another task is already fetching data"
self._fetch_lock = _util.ConflictDetector(
"another task is already fetching data"
)

def close(self):
Expand Down Expand Up @@ -102,8 +102,8 @@ def __init__(
wait_send_all_might_not_block_hook=None,
close_hook=None
):
self._lock = _util.UnLock(
_core.ResourceBusyError, "another task is using this stream"
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream"
)
self._outgoing = _UnboundedByteQueue()
self.send_all_hook = send_all_hook
Expand All @@ -118,7 +118,7 @@ async def send_all(self, data):
# The lock itself is a checkpoint, but then we also yield inside the
# lock to give ourselves a chance to detect buggy user code that calls
# this twice at the same time.
async with self._lock:
async with self._conflict_detector:
await _core.yield_briefly()
self._outgoing.put(data)
if self.send_all_hook is not None:
Expand All @@ -132,7 +132,7 @@ async def wait_send_all_might_not_block(self):
# The lock itself is a checkpoint, but then we also yield inside the
# lock to give ourselves a chance to detect buggy user code that calls
# this twice at the same time.
async with self._lock:
async with self._conflict_detector:
await _core.yield_briefly()
# check for being closed:
self._outgoing.put(b"")
Expand Down Expand Up @@ -201,8 +201,8 @@ class MemoryReceiveStream(ReceiveStream):
"""

def __init__(self, receive_some_hook=None, close_hook=None):
self._lock = _util.UnLock(
_core.ResourceBusyError, "another task is using this stream"
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream"
)
self._incoming = _UnboundedByteQueue()
self._closed = False
Expand All @@ -217,7 +217,7 @@ async def receive_some(self, max_bytes):
# The lock itself is a checkpoint, but then we also yield inside the
# lock to give ourselves a chance to detect buggy user code that calls
# this twice at the same time.
async with self._lock:
async with self._conflict_detector:
await _core.yield_briefly()
if max_bytes is None:
raise TypeError("max_bytes must not be None")
Expand Down Expand Up @@ -435,11 +435,11 @@ def __init__(self):
self._receiver_closed = False
self._receiver_waiting = False
self._waiters = _core.ParkingLot()
self._send_lock = _util.UnLock(
_core.ResourceBusyError, "another task is already sending"
self._send_conflict_detector = _util.ConflictDetector(
"another task is already sending"
)
self._receive_lock = _util.UnLock(
_core.ResourceBusyError, "another task is already receiving"
self._receive_conflict_detector = _util.ConflictDetector(
"another task is already receiving"
)

def _something_happened(self):
Expand All @@ -459,7 +459,7 @@ def close_receiver(self):
self._something_happened()

async def send_all(self, data):
async with self._send_lock:
async with self._send_conflict_detector:
if self._sender_closed:
raise ClosedStreamError
if self._receiver_closed:
Expand All @@ -476,7 +476,7 @@ async def send_all(self, data):
return

async def wait_send_all_might_not_block(self):
async with self._send_lock:
async with self._send_conflict_detector:
if self._sender_closed:
raise ClosedStreamError
if self._receiver_closed:
Expand All @@ -486,7 +486,7 @@ async def wait_send_all_might_not_block(self):
)

async def receive_some(self, max_bytes):
async with self._receive_lock:
async with self._receive_conflict_detector:
# Argument validation
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
Expand Down
14 changes: 6 additions & 8 deletions trio/tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .._highlevel_open_tcp_stream import open_tcp_stream
from .. import ssl as tssl
from .. import socket as tsocket
from .._util import UnLock, acontextmanager
from .._util import ConflictDetector, acontextmanager

from .._core.tests.tutil import slow

Expand Down Expand Up @@ -175,12 +175,10 @@ def __init__(self, sleeper=None):
self._lot = _core.ParkingLot()
self._pending_cleartext = bytearray()

self._send_all_mutex = UnLock(
_core.ResourceBusyError,
self._send_all_conflict_detector = ConflictDetector(
"simultaneous calls to PyOpenSSLEchoStream.send_all"
)
self._receive_some_mutex = UnLock(
_core.ResourceBusyError,
self._receive_some_conflict_detector = ConflictDetector(
"simultaneous calls to PyOpenSSLEchoStream.receive_some"
)

Expand All @@ -205,13 +203,13 @@ def renegotiate(self):
assert self._conn.renegotiate()

async def wait_send_all_might_not_block(self):
async with self._send_all_mutex:
async with self._send_all_conflict_detector:
await _core.yield_briefly()
await self.sleeper("wait_send_all_might_not_block")

async def send_all(self, data):
print(" --> transport_stream.send_all")
async with self._send_all_mutex:
async with self._send_all_conflict_detector:
await _core.yield_briefly()
await self.sleeper("send_all")
self._conn.bio_write(data)
Expand All @@ -233,7 +231,7 @@ async def send_all(self, data):

async def receive_some(self, nbytes):
print(" --> transport_stream.receive_some")
async with self._receive_some_mutex:
async with self._receive_some_conflict_detector:
try:
await _core.yield_briefly()
while True:
Expand Down
18 changes: 6 additions & 12 deletions trio/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,34 @@ def handler(signum, _):
assert record == [signal.SIGFPE]


async def test_UnLock():
ul1 = UnLock(RuntimeError, "ul1")
ul2 = UnLock(ValueError)
async def test_ConflictDetector():
ul1 = ConflictDetector("ul1")
ul2 = ConflictDetector("ul2")

async with ul1:
with assert_yields():
async with ul2:
print("ok")

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(_core.ResourceBusyError) as excinfo:
async with ul1:
with assert_yields():
async with ul1:
pass # pragma: no cover
assert "ul1" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
async with ul2:
with assert_yields():
async with ul2:
pass # pragma: no cover

async def wait_with_ul1():
async with ul1:
await wait_all_tasks_blocked()

with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(_core.ResourceBusyError) as excinfo:
async with _core.open_nursery() as nursery:
nursery.spawn(wait_with_ul1)
nursery.spawn(wait_with_ul1)
assert "ul1" in str(excinfo.value)

# mixing sync and async entry
with pytest.raises(RuntimeError) as excinfo:
with pytest.raises(_core.ResourceBusyError) as excinfo:
with ul1.sync:
with assert_yields():
async with ul1:
Expand Down

0 comments on commit 675555d

Please sign in to comment.