Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't close the whole mpsc channel when one sender is closed #1443

Merged
merged 1 commit into from
Feb 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 122 additions & 50 deletions futures-channel/src/mpsc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,8 @@ use crate::mpsc::queue::Queue;

mod queue;

/// The transmission end of a bounded mpsc channel.
///
/// This value is created by the [`channel`](channel) function.
#[derive(Debug)]
pub struct Sender<T> {
struct SenderInner<T> {
// Channel state shared between the sender and receiver.
inner: Arc<Inner<T>>,

Expand All @@ -112,14 +109,20 @@ pub struct Sender<T> {
maybe_parked: bool,
}

// We never project Pin<&mut Sender> to `Pin<&mut T>`
impl<T> Unpin for Sender<T> {}
// We never project Pin<&mut SenderInner> to `Pin<&mut T>`
impl<T> Unpin for SenderInner<T> {}

/// The transmission end of a bounded mpsc channel.
///
/// This value is created by the [`channel`](channel) function.
#[derive(Debug)]
pub struct Sender<T>(Option<SenderInner<T>>);

/// The transmission end of an unbounded mpsc channel.
///
/// This value is created by the [`unbounded`](unbounded) function.
#[derive(Debug)]
pub struct UnboundedSender<T>(Sender<T>);
pub struct UnboundedSender<T>(Option<SenderInner<T>>);

trait AssertKinds: Send + Sync + Clone {}
impl AssertKinds for UnboundedSender<u32> {}
Expand Down Expand Up @@ -357,7 +360,8 @@ pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
// Check that the requested buffer size does not exceed the maximum buffer
// size permitted by the system.
assert!(buffer < MAX_BUFFER, "requested buffer size too large");
channel2(Some(buffer))
let (tx, rx) = channel2(Some(buffer));
(Sender(Some(tx)), rx)
}

/// Creates an unbounded mpsc channel for communicating between asynchronous
Expand All @@ -372,10 +376,10 @@ pub fn channel<T>(buffer: usize) -> (Sender<T>, Receiver<T>) {
/// process to run out of memory. In this case, the process will be aborted.
pub fn unbounded<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
let (tx, rx) = channel2(None);
(UnboundedSender(tx), UnboundedReceiver(rx))
(UnboundedSender(Some(tx)), UnboundedReceiver(rx))
}

fn channel2<T>(buffer: Option<usize>) -> (Sender<T>, Receiver<T>) {
fn channel2<T>(buffer: Option<usize>) -> (SenderInner<T>, Receiver<T>) {
let inner = Arc::new(Inner {
buffer,
state: AtomicUsize::new(INIT_STATE),
Expand All @@ -385,7 +389,7 @@ fn channel2<T>(buffer: Option<usize>) -> (Sender<T>, Receiver<T>) {
recv_task: AtomicWaker::new(),
});

let tx = Sender {
let tx = SenderInner {
inner: inner.clone(),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: false,
Expand All @@ -404,10 +408,10 @@ fn channel2<T>(buffer: Option<usize>) -> (Sender<T>, Receiver<T>) {
*
*/

impl<T> Sender<T> {
impl<T> SenderInner<T> {
/// Attempts to send a message on this `Sender`, returning the message
/// if there was an error.
pub fn try_send(&mut self, msg: T) -> Result<(), TrySendError<T>> {
fn try_send(&mut self, msg: T) -> Result<(), TrySendError<T>> {
// If the sender is currently blocked, reject the message
if !self.poll_unparked(None).is_ready() {
return Err(TrySendError {
Expand All @@ -422,16 +426,6 @@ impl<T> Sender<T> {
self.do_send_b(msg)
}

/// Send a message on the channel.
///
/// This function should only be called after
/// [`poll_ready`](Sender::poll_ready) has reported that the channel is
/// ready to receive a message.
pub fn start_send(&mut self, msg: T) -> Result<(), SendError> {
self.try_send(msg)
.map_err(|e| e.err)
}

// Do the send without failing.
// Can be called only by bounded sender.
fn do_send_b(&mut self, msg: T)
Expand Down Expand Up @@ -484,7 +478,7 @@ impl<T> Sender<T> {
Poll::Ready(Ok(()))
} else {
Poll::Ready(Err(SendError {
kind: SendErrorKind::Full,
kind: SendErrorKind::Disconnected,
}))
}
}
Expand Down Expand Up @@ -559,7 +553,7 @@ impl<T> Sender<T> {
/// capacity, in which case the current task is queued to be notified once
/// capacity is available;
/// - `Err(SendError)` if the receiver has been dropped.
pub fn poll_ready(
fn poll_ready(
&mut self,
lw: &LocalWaker
) -> Poll<Result<(), SendError>> {
Expand All @@ -574,12 +568,12 @@ impl<T> Sender<T> {
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
fn is_closed(&self) -> bool {
!decode_state(self.inner.state.load(SeqCst)).is_open
}

/// Closes this channel from the sender side, preventing any new messages.
pub fn close_channel(&mut self) {
fn close_channel(&self) {
// There's no need to park this sender, its dropping,
// and we don't want to check for capacity, so skip
// that stuff from `do_send`.
Expand Down Expand Up @@ -615,43 +609,116 @@ impl<T> Sender<T> {
}
}

impl<T> Sender<T> {
/// Attempts to send a message on this `Sender`, returning the message
/// if there was an error.
pub fn try_send(&mut self, msg: T) -> Result<(), TrySendError<T>> {
if let Some(inner) = &mut self.0 {
inner.try_send(msg)
} else {
Err(TrySendError {
err: SendError {
kind: SendErrorKind::Disconnected,
},
val: msg,
})
}
}

/// Send a message on the channel.
///
/// This function should only be called after
/// [`poll_ready`](Sender::poll_ready) has reported that the channel is
/// ready to receive a message.
pub fn start_send(&mut self, msg: T) -> Result<(), SendError> {
self.try_send(msg)
.map_err(|e| e.err)
}

/// Polls the channel to determine if there is guaranteed capacity to send
/// at least one item without waiting.
///
/// # Return value
///
/// This method returns:
///
/// - `Ok(Async::Ready(_))` if there is sufficient capacity;
/// - `Ok(Async::Pending)` if the channel may not have
/// capacity, in which case the current task is queued to be notified once
/// capacity is available;
/// - `Err(SendError)` if the receiver has been dropped.
pub fn poll_ready(
&mut self,
lw: &LocalWaker
) -> Poll<Result<(), SendError>> {
let inner = self.0.as_mut().ok_or(SendError {
kind: SendErrorKind::Disconnected,
})?;
inner.poll_ready(lw)
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.0.as_ref().map(SenderInner::is_closed).unwrap_or(true)
}

/// Closes this channel from the sender side, preventing any new messages.
pub fn close_channel(&mut self) {
if let Some(inner) = &mut self.0 {
inner.close_channel();
}
}

/// Disconnects this sender from the channel, closing it if there are no more senders left.
pub fn disconnect(&mut self) {
self.0 = None;
}
}

impl<T> UnboundedSender<T> {
/// Check if the channel is ready to receive a message.
pub fn poll_ready(
&self,
_: &LocalWaker,
) -> Poll<Result<(), SendError>> {
self.0.poll_ready_nb()
let inner = self.0.as_ref().ok_or(SendError {
kind: SendErrorKind::Disconnected,
})?;
inner.poll_ready_nb()
}

/// Returns whether this channel is closed without needing a context.
pub fn is_closed(&self) -> bool {
self.0.is_closed()
self.0.as_ref().map(SenderInner::is_closed).unwrap_or(true)
}

/// Closes this channel from the sender side, preventing any new messages.
pub fn close_channel(&self) {
self.0.inner.set_closed();
self.0.inner.recv_task.wake();
if let Some(inner) = &self.0 {
inner.close_channel();
}
}

/// Disconnects this sender from the channel, closing it if there are no more senders left.
pub fn disconnect(&mut self) {
self.0 = None;
}

// Do the send without parking current task.
fn do_send_nb(&self, msg: T) -> Result<(), TrySendError<T>> {
match self.0.inc_num_messages() {
Some(_num_messages) => {}
None => {
return Err(TrySendError {
err: SendError {
kind: SendErrorKind::Disconnected,
},
val: msg,
});
},
};

self.0.queue_push_and_signal(msg);
if let Some(inner) = &self.0 {
if inner.inc_num_messages().is_some() {
inner.queue_push_and_signal(msg);
return Ok(());
}
}

Ok(())
Err(TrySendError {
err: SendError {
kind: SendErrorKind::Disconnected,
},
val: msg,
})
}

/// Send a message on the channel.
Expand All @@ -673,15 +740,20 @@ impl<T> UnboundedSender<T> {
}
}

impl<T> Clone for Sender<T> {
fn clone(&self) -> Sender<T> {
Sender(self.0.clone())
}
}

impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> UnboundedSender<T> {
UnboundedSender(self.0.clone())
}
}


impl<T> Clone for Sender<T> {
fn clone(&self) -> Sender<T> {
impl<T> Clone for SenderInner<T> {
fn clone(&self) -> SenderInner<T> {
// Since this atomic op isn't actually guarding any memory and we don't
// care about any orderings besides the ordering on the single atomic
// variable, a relaxed ordering is acceptable.
Expand All @@ -701,7 +773,7 @@ impl<T> Clone for Sender<T> {
// The ABA problem doesn't matter here. We only care that the
// number of senders never exceeds the maximum.
if actual == curr {
return Sender {
return SenderInner {
inner: self.inner.clone(),
sender_task: Arc::new(Mutex::new(SenderTask::new())),
maybe_parked: false,
Expand All @@ -713,7 +785,7 @@ impl<T> Clone for Sender<T> {
}
}

impl<T> Drop for Sender<T> {
impl<T> Drop for SenderInner<T> {
fn drop(&mut self) {
// Ordering between variables don't matter here
let prev = self.inner.num_senders.fetch_sub(1, SeqCst);
Expand Down
82 changes: 82 additions & 0 deletions futures-channel/tests/mpsc-close.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,85 @@ fn smoke() {

t.join().unwrap()
}

#[test]
fn multiple_senders_disconnect() {
{
let (mut tx1, mut rx) = mpsc::channel(1);
let (tx2, mut tx3, mut tx4) = (tx1.clone(), tx1.clone(), tx1.clone());

// disconnect, dropping and Sink::poll_close should all close this sender but leave the
// channel open for other senders
tx1.disconnect();
drop(tx2);
block_on(tx3.close()).unwrap();

assert!(tx1.is_closed());
assert!(tx3.is_closed());
assert!(!tx4.is_closed());

block_on(tx4.send(5)).unwrap();
assert_eq!(block_on(rx.next()), Some(5));

// dropping the final sender will close the channel
drop(tx4);
assert_eq!(block_on(rx.next()), None);
}

{
let (mut tx1, mut rx) = mpsc::unbounded();
let (tx2, mut tx3, mut tx4) = (tx1.clone(), tx1.clone(), tx1.clone());

// disconnect, dropping and Sink::poll_close should all close this sender but leave the
// channel open for other senders
tx1.disconnect();
drop(tx2);
block_on(tx3.close()).unwrap();

assert!(tx1.is_closed());
assert!(tx3.is_closed());
assert!(!tx4.is_closed());

block_on(tx4.send(5)).unwrap();
assert_eq!(block_on(rx.next()), Some(5));

// dropping the final sender will close the channel
drop(tx4);
assert_eq!(block_on(rx.next()), None);
}
}

#[test]
fn multiple_senders_close_channel() {
{
let (mut tx1, mut rx) = mpsc::channel(1);
let mut tx2 = tx1.clone();

// close_channel should shut down the whole channel
tx1.close_channel();

assert!(tx1.is_closed());
assert!(tx2.is_closed());

let err = block_on(tx2.send(5)).unwrap_err();
assert!(err.is_disconnected());

assert_eq!(block_on(rx.next()), None);
}

{
let (tx1, mut rx) = mpsc::unbounded();
let mut tx2 = tx1.clone();

// close_channel should shut down the whole channel
tx1.close_channel();

assert!(tx1.is_closed());
assert!(tx2.is_closed());

let err = block_on(tx2.send(5)).unwrap_err();
assert!(err.is_disconnected());

assert_eq!(block_on(rx.next()), None);
}
}
Loading