diff --git a/Cargo.toml b/Cargo.toml index d522d4eff3c..0d88dd62c38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ futures-executor = "0.3.29" hyper = "0.14.27" tokio = "1.33" +tokio-util = "0.7" tower = { version = "0.4.13", features = ["util"] } tracing = "0.1.40" diff --git a/crates/transport-ipc/Cargo.toml b/crates/transport-ipc/Cargo.toml index 8a98337e51e..0ee14eb4aba 100644 --- a/crates/transport-ipc/Cargo.toml +++ b/crates/transport-ipc/Cargo.toml @@ -20,6 +20,7 @@ futures.workspace = true pin-project.workspace = true serde_json.workspace = true tokio.workspace = true +tokio-util = { workspace = true, features = ["io", "compat"]} tracing.workspace = true bytes = "1.5.0" diff --git a/crates/transport-ipc/src/lib.rs b/crates/transport-ipc/src/lib.rs index fe550ca1938..fa94430b20a 100644 --- a/crates/transport-ipc/src/lib.rs +++ b/crates/transport-ipc/src/lib.rs @@ -23,13 +23,12 @@ pub mod mock; #[cfg(feature = "mock")] pub use mock::MockIpcServer; -use std::task::Poll::{Pending, Ready}; - -use alloy_json_rpc::PubSubItem; use bytes::{Buf, BytesMut}; -use futures::{io::BufReader, ready, AsyncBufRead, AsyncRead, AsyncWriteExt, StreamExt}; +use futures::{ready, AsyncRead, AsyncWriteExt, StreamExt}; use interprocess::local_socket::{tokio::LocalSocketStream, ToLocalSocketName}; +use std::task::Poll::Ready; use tokio::select; +use tokio_util::compat::FuturesAsyncReadCompatExt; type Result = std::result::Result; @@ -113,11 +112,11 @@ impl IpcBackend { pub struct ReadJsonStream { /// The underlying reader. #[pin] - reader: BufReader, - /// A buffer of bytes read from the reader. + reader: tokio_util::compat::Compat, + /// A buffer for reading data from the reader. buf: BytesMut, - /// A buffer of items deserialized from the reader. - items: Vec, + /// Whether the buffer has been drained. + drained: bool, } impl ReadJsonStream @@ -125,7 +124,7 @@ where T: AsyncRead, { fn new(reader: T) -> Self { - Self { reader: BufReader::new(reader), buf: BytesMut::with_capacity(4096), items: vec![] } + Self { reader: reader.compat(), buf: BytesMut::with_capacity(4096), drained: true } } } @@ -148,57 +147,53 @@ where self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); + use tokio_util::io::poll_read_buf; - // Deserialize any buffered items. - if !this.buf.is_empty() { - this.reader.consume(this.buf.len()); + let mut this = self.project(); - tracing::debug!(buf_len = this.buf.len(), "Deserializing buffered IPC data"); - let mut de = serde_json::Deserializer::from_slice(this.buf.as_ref()).into_iter(); + loop { + // try decoding from the buffer, but only if we have new data + if !*this.drained { + tracing::debug!(buf_len = this.buf.len(), "Deserializing buffered IPC data"); + let mut de = serde_json::Deserializer::from_slice(this.buf.as_ref()).into_iter(); - let item = de.next(); - match item { - Some(Ok(response)) => { - this.items.push(response); - } - Some(Err(e)) => { - tracing::error!(%e, "IPC response contained invalid JSON. Buffer contents will be logged at trace level"); - tracing::trace!( - buffer = %String::from_utf8_lossy(this.buf.as_ref()), - "IPC response contained invalid JSON. NOTE: Buffer contents do not include invalid utf8.", - ); + let item = de.next(); - return Ready(None); + // advance the buffer + this.buf.advance(de.byte_offset()); + + match item { + Some(Ok(response)) => { + return Ready(Some(response)); + } + Some(Err(e)) => { + tracing::error!(%e, "IPC response contained invalid JSON. Buffer contents will be logged at trace level"); + tracing::trace!( + buffer = %String::from_utf8_lossy(this.buf.as_ref()), + "IPC response contained invalid JSON. NOTE: Buffer contents do not include invalid utf8.", + ); + + return Ready(None); + } + None => { + // nothing decoded + *this.drained = true; + } } - None => {} } - this.buf.advance(de.byte_offset()); - cx.waker().wake_by_ref(); - return Pending; - } - - // Return any buffered items, rewaking. - if !this.items.is_empty() { - // may have more work! - cx.waker().wake_by_ref(); - return Ready(this.items.pop()); - } - tracing::debug!(buf_len = this.buf.len(), "Polling IPC socket for data"); + // read more data into the buffer + match ready!(poll_read_buf(this.reader.as_mut(), cx, &mut this.buf)) { + Ok(data_len) => { + tracing::debug!(%data_len, "Read data from IPC socket"); - let data = ready!(this.reader.poll_fill_buf(cx)); - match data { - Err(e) => { - tracing::error!(%e, "Failed to read from IPC socket, shutting down"); - Ready(None) - } - Ok(data) => { - tracing::debug!(data_len = data.len(), "Read data from IPC socket"); - this.buf.extend_from_slice(data); - // wake task to run deserialization - cx.waker().wake_by_ref(); - Pending + // can try decoding again + *this.drained = false; + } + Err(e) => { + tracing::error!(%e, "Failed to read from IPC socket, shutting down"); + return Ready(None); + } } } }