diff --git a/src/copy_bidirectional.rs b/src/copy_bidirectional.rs index cdf12ee..71fee38 100644 --- a/src/copy_bidirectional.rs +++ b/src/copy_bidirectional.rs @@ -17,6 +17,7 @@ use std::task::{Context, Poll}; #[derive(Debug)] struct CopyBuffer { read_done: bool, + need_flush: bool, start_index: usize, cache_length: usize, size: usize, @@ -31,6 +32,7 @@ impl CopyBuffer { } Self { read_done: false, + need_flush: false, start_index: 0, cache_length: 0, size, @@ -50,6 +52,7 @@ impl CopyBuffer { { loop { let mut read_pending = false; + let mut write_pending = false; // If our buffer has some space, let's read up! if !self.read_done && self.cache_length < self.size { @@ -74,6 +77,8 @@ impl CopyBuffer { } } Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. read_pending = true; } } @@ -107,25 +112,30 @@ impl CopyBuffer { } else { self.start_index = (self.start_index + written) % self.size; } + self.need_flush = true; } } Poll::Pending => { - // Previously we set a write_pending flag, and kept going until - // both read and write were pending, but this might starve other - // tasks. - return Poll::Pending; + write_pending = true; + break; } } } + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + // If we've written all the data and we've seen EOF, flush out the // data and finish the transfer. if self.read_done && self.cache_length == 0 { - ready!(writer.as_mut().poll_flush(cx))?; return Poll::Ready(Ok(())); } - if read_pending { + // Previously we kept going until both read and write were pending, but + // this might starve other tasks. + if read_pending || write_pending { // If we got here, // 1) we hit read_pending on the current iteration. // 2) all data has been written successfully