diff --git a/futures-channel/src/mpsc/mod.rs b/futures-channel/src/mpsc/mod.rs index e276c745e0..cd34f3569c 100644 --- a/futures-channel/src/mpsc/mod.rs +++ b/futures-channel/src/mpsc/mod.rs @@ -94,11 +94,8 @@ use crate::mpsc::queue::Queue; mod queue; -/// The transmission end of a bounded mpsc channel. -/// -/// This value is created by the [`channel`](channel) function. #[derive(Debug)] -pub struct Sender { +struct SenderInner { // Channel state shared between the sender and receiver. inner: Arc>, @@ -112,14 +109,20 @@ pub struct Sender { maybe_parked: bool, } -// We never project Pin<&mut Sender> to `Pin<&mut T>` -impl Unpin for Sender {} +// We never project Pin<&mut SenderInner> to `Pin<&mut T>` +impl Unpin for SenderInner {} + +/// The transmission end of a bounded mpsc channel. +/// +/// This value is created by the [`channel`](channel) function. +#[derive(Debug)] +pub struct Sender(Option>); /// The transmission end of an unbounded mpsc channel. /// /// This value is created by the [`unbounded`](unbounded) function. #[derive(Debug)] -pub struct UnboundedSender(Sender); +pub struct UnboundedSender(Option>); trait AssertKinds: Send + Sync + Clone {} impl AssertKinds for UnboundedSender {} @@ -357,7 +360,8 @@ pub fn channel(buffer: usize) -> (Sender, Receiver) { // Check that the requested buffer size does not exceed the maximum buffer // size permitted by the system. assert!(buffer < MAX_BUFFER, "requested buffer size too large"); - channel2(Some(buffer)) + let (tx, rx) = channel2(Some(buffer)); + (Sender(Some(tx)), rx) } /// Creates an unbounded mpsc channel for communicating between asynchronous @@ -372,10 +376,10 @@ pub fn channel(buffer: usize) -> (Sender, Receiver) { /// process to run out of memory. In this case, the process will be aborted. pub fn unbounded() -> (UnboundedSender, UnboundedReceiver) { let (tx, rx) = channel2(None); - (UnboundedSender(tx), UnboundedReceiver(rx)) + (UnboundedSender(Some(tx)), UnboundedReceiver(rx)) } -fn channel2(buffer: Option) -> (Sender, Receiver) { +fn channel2(buffer: Option) -> (SenderInner, Receiver) { let inner = Arc::new(Inner { buffer, state: AtomicUsize::new(INIT_STATE), @@ -385,7 +389,7 @@ fn channel2(buffer: Option) -> (Sender, Receiver) { recv_task: AtomicWaker::new(), }); - let tx = Sender { + let tx = SenderInner { inner: inner.clone(), sender_task: Arc::new(Mutex::new(SenderTask::new())), maybe_parked: false, @@ -404,10 +408,10 @@ fn channel2(buffer: Option) -> (Sender, Receiver) { * */ -impl Sender { +impl SenderInner { /// Attempts to send a message on this `Sender`, returning the message /// if there was an error. - pub fn try_send(&mut self, msg: T) -> Result<(), TrySendError> { + fn try_send(&mut self, msg: T) -> Result<(), TrySendError> { // If the sender is currently blocked, reject the message if !self.poll_unparked(None).is_ready() { return Err(TrySendError { @@ -422,16 +426,6 @@ impl Sender { self.do_send_b(msg) } - /// Send a message on the channel. - /// - /// This function should only be called after - /// [`poll_ready`](Sender::poll_ready) has reported that the channel is - /// ready to receive a message. - pub fn start_send(&mut self, msg: T) -> Result<(), SendError> { - self.try_send(msg) - .map_err(|e| e.err) - } - // Do the send without failing. // Can be called only by bounded sender. fn do_send_b(&mut self, msg: T) @@ -484,7 +478,7 @@ impl Sender { Poll::Ready(Ok(())) } else { Poll::Ready(Err(SendError { - kind: SendErrorKind::Full, + kind: SendErrorKind::Disconnected, })) } } @@ -559,7 +553,7 @@ impl Sender { /// capacity, in which case the current task is queued to be notified once /// capacity is available; /// - `Err(SendError)` if the receiver has been dropped. - pub fn poll_ready( + fn poll_ready( &mut self, lw: &LocalWaker ) -> Poll> { @@ -574,12 +568,12 @@ impl Sender { } /// Returns whether this channel is closed without needing a context. - pub fn is_closed(&self) -> bool { + fn is_closed(&self) -> bool { !decode_state(self.inner.state.load(SeqCst)).is_open } /// Closes this channel from the sender side, preventing any new messages. - pub fn close_channel(&mut self) { + fn close_channel(&self) { // There's no need to park this sender, its dropping, // and we don't want to check for capacity, so skip // that stuff from `do_send`. @@ -615,43 +609,116 @@ impl Sender { } } +impl Sender { + /// Attempts to send a message on this `Sender`, returning the message + /// if there was an error. + pub fn try_send(&mut self, msg: T) -> Result<(), TrySendError> { + if let Some(inner) = &mut self.0 { + inner.try_send(msg) + } else { + Err(TrySendError { + err: SendError { + kind: SendErrorKind::Disconnected, + }, + val: msg, + }) + } + } + + /// Send a message on the channel. + /// + /// This function should only be called after + /// [`poll_ready`](Sender::poll_ready) has reported that the channel is + /// ready to receive a message. + pub fn start_send(&mut self, msg: T) -> Result<(), SendError> { + self.try_send(msg) + .map_err(|e| e.err) + } + + /// Polls the channel to determine if there is guaranteed capacity to send + /// at least one item without waiting. + /// + /// # Return value + /// + /// This method returns: + /// + /// - `Ok(Async::Ready(_))` if there is sufficient capacity; + /// - `Ok(Async::Pending)` if the channel may not have + /// capacity, in which case the current task is queued to be notified once + /// capacity is available; + /// - `Err(SendError)` if the receiver has been dropped. + pub fn poll_ready( + &mut self, + lw: &LocalWaker + ) -> Poll> { + let inner = self.0.as_mut().ok_or(SendError { + kind: SendErrorKind::Disconnected, + })?; + inner.poll_ready(lw) + } + + /// Returns whether this channel is closed without needing a context. + pub fn is_closed(&self) -> bool { + self.0.as_ref().map(SenderInner::is_closed).unwrap_or(true) + } + + /// Closes this channel from the sender side, preventing any new messages. + pub fn close_channel(&mut self) { + if let Some(inner) = &mut self.0 { + inner.close_channel(); + } + } + + /// Disconnects this sender from the channel, closing it if there are no more senders left. + pub fn disconnect(&mut self) { + self.0 = None; + } +} + impl UnboundedSender { /// Check if the channel is ready to receive a message. pub fn poll_ready( &self, _: &LocalWaker, ) -> Poll> { - self.0.poll_ready_nb() + let inner = self.0.as_ref().ok_or(SendError { + kind: SendErrorKind::Disconnected, + })?; + inner.poll_ready_nb() } /// Returns whether this channel is closed without needing a context. pub fn is_closed(&self) -> bool { - self.0.is_closed() + self.0.as_ref().map(SenderInner::is_closed).unwrap_or(true) } /// Closes this channel from the sender side, preventing any new messages. pub fn close_channel(&self) { - self.0.inner.set_closed(); - self.0.inner.recv_task.wake(); + if let Some(inner) = &self.0 { + inner.close_channel(); + } + } + + /// Disconnects this sender from the channel, closing it if there are no more senders left. + pub fn disconnect(&mut self) { + self.0 = None; } // Do the send without parking current task. fn do_send_nb(&self, msg: T) -> Result<(), TrySendError> { - match self.0.inc_num_messages() { - Some(_num_messages) => {} - None => { - return Err(TrySendError { - err: SendError { - kind: SendErrorKind::Disconnected, - }, - val: msg, - }); - }, - }; - - self.0.queue_push_and_signal(msg); + if let Some(inner) = &self.0 { + if inner.inc_num_messages().is_some() { + inner.queue_push_and_signal(msg); + return Ok(()); + } + } - Ok(()) + Err(TrySendError { + err: SendError { + kind: SendErrorKind::Disconnected, + }, + val: msg, + }) } /// Send a message on the channel. @@ -673,15 +740,20 @@ impl UnboundedSender { } } +impl Clone for Sender { + fn clone(&self) -> Sender { + Sender(self.0.clone()) + } +} + impl Clone for UnboundedSender { fn clone(&self) -> UnboundedSender { UnboundedSender(self.0.clone()) } } - -impl Clone for Sender { - fn clone(&self) -> Sender { +impl Clone for SenderInner { + fn clone(&self) -> SenderInner { // Since this atomic op isn't actually guarding any memory and we don't // care about any orderings besides the ordering on the single atomic // variable, a relaxed ordering is acceptable. @@ -701,7 +773,7 @@ impl Clone for Sender { // The ABA problem doesn't matter here. We only care that the // number of senders never exceeds the maximum. if actual == curr { - return Sender { + return SenderInner { inner: self.inner.clone(), sender_task: Arc::new(Mutex::new(SenderTask::new())), maybe_parked: false, @@ -713,7 +785,7 @@ impl Clone for Sender { } } -impl Drop for Sender { +impl Drop for SenderInner { fn drop(&mut self) { // Ordering between variables don't matter here let prev = self.inner.num_senders.fetch_sub(1, SeqCst); diff --git a/futures-channel/tests/mpsc-close.rs b/futures-channel/tests/mpsc-close.rs index b8ab5fe6e3..a0fdf88eac 100644 --- a/futures-channel/tests/mpsc-close.rs +++ b/futures-channel/tests/mpsc-close.rs @@ -17,3 +17,85 @@ fn smoke() { t.join().unwrap() } + +#[test] +fn multiple_senders_disconnect() { + { + let (mut tx1, mut rx) = mpsc::channel(1); + let (tx2, mut tx3, mut tx4) = (tx1.clone(), tx1.clone(), tx1.clone()); + + // disconnect, dropping and Sink::poll_close should all close this sender but leave the + // channel open for other senders + tx1.disconnect(); + drop(tx2); + block_on(tx3.close()).unwrap(); + + assert!(tx1.is_closed()); + assert!(tx3.is_closed()); + assert!(!tx4.is_closed()); + + block_on(tx4.send(5)).unwrap(); + assert_eq!(block_on(rx.next()), Some(5)); + + // dropping the final sender will close the channel + drop(tx4); + assert_eq!(block_on(rx.next()), None); + } + + { + let (mut tx1, mut rx) = mpsc::unbounded(); + let (tx2, mut tx3, mut tx4) = (tx1.clone(), tx1.clone(), tx1.clone()); + + // disconnect, dropping and Sink::poll_close should all close this sender but leave the + // channel open for other senders + tx1.disconnect(); + drop(tx2); + block_on(tx3.close()).unwrap(); + + assert!(tx1.is_closed()); + assert!(tx3.is_closed()); + assert!(!tx4.is_closed()); + + block_on(tx4.send(5)).unwrap(); + assert_eq!(block_on(rx.next()), Some(5)); + + // dropping the final sender will close the channel + drop(tx4); + assert_eq!(block_on(rx.next()), None); + } +} + +#[test] +fn multiple_senders_close_channel() { + { + let (mut tx1, mut rx) = mpsc::channel(1); + let mut tx2 = tx1.clone(); + + // close_channel should shut down the whole channel + tx1.close_channel(); + + assert!(tx1.is_closed()); + assert!(tx2.is_closed()); + + let err = block_on(tx2.send(5)).unwrap_err(); + assert!(err.is_disconnected()); + + assert_eq!(block_on(rx.next()), None); + } + + { + let (tx1, mut rx) = mpsc::unbounded(); + let mut tx2 = tx1.clone(); + + // close_channel should shut down the whole channel + tx1.close_channel(); + + assert!(tx1.is_closed()); + assert!(tx2.is_closed()); + + let err = block_on(tx2.send(5)).unwrap_err(); + assert!(err.is_disconnected()); + + assert_eq!(block_on(rx.next()), None); + } +} diff --git a/futures-sink/src/channel_impls.rs b/futures-sink/src/channel_impls.rs index edd3f6fc82..16a655296e 100644 --- a/futures-sink/src/channel_impls.rs +++ b/futures-sink/src/channel_impls.rs @@ -20,7 +20,7 @@ impl Sink for Sender { } fn poll_close(mut self: Pin<&mut Self>, _: &LocalWaker) -> Poll> { - self.close_channel(); + self.disconnect(); Poll::Ready(Ok(())) } } @@ -41,8 +41,8 @@ impl Sink for UnboundedSender { Poll::Ready(Ok(())) } - fn poll_close(self: Pin<&mut Self>, _: &LocalWaker) -> Poll> { - self.close_channel(); + fn poll_close(mut self: Pin<&mut Self>, _: &LocalWaker) -> Poll> { + self.disconnect(); Poll::Ready(Ok(())) } }