diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ddc83934..6e877e35 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -61,7 +61,6 @@ jobs: run: | cd netlink-proto cargo test - cargo test --features workaround-audit-bug - name: test (rtnetlink) env: diff --git a/audit/Cargo.toml b/audit/Cargo.toml index 562f337a..d84969c1 100644 --- a/audit/Cargo.toml +++ b/audit/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "audit" -version = "0.4.0" +version = "0.5.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Corentin Henry "] edition = "2018" @@ -14,13 +14,13 @@ description = "linux audit via netlink" [dependencies] futures = "0.3.11" thiserror = "1" -netlink-packet-audit = "0.2" -netlink-proto = { default-features = false, version = "0.7" } +netlink-packet-audit = "0.3" +netlink-proto = { default-features = false, version = "0.8" } [features] default = ["tokio_socket"] -tokio_socket = ["netlink-proto/tokio_socket", "netlink-proto/workaround-audit-bug"] -smol_socket = ["netlink-proto/smol_socket", "netlink-proto/workaround-audit-bug"] +tokio_socket = ["netlink-proto/tokio_socket"] +smol_socket = ["netlink-proto/smol_socket"] [dev-dependencies] tokio = { version = "1.0.1", default-features = false, features = ["macros", "rt-multi-thread"] } diff --git a/audit/src/lib.rs b/audit/src/lib.rs index e38cc9f0..eadb16d7 100644 --- a/audit/src/lib.rs +++ b/audit/src/lib.rs @@ -15,14 +15,31 @@ use std::io; use futures::channel::mpsc::UnboundedReceiver; #[allow(clippy::type_complexity)] +#[cfg(feature = "tokio_socket")] pub fn new_connection() -> io::Result<( - proto::Connection, + proto::Connection, Handle, UnboundedReceiver<( packet::NetlinkMessage, sys::SocketAddr, )>, )> { - let (conn, handle, messages) = netlink_proto::new_connection(sys::protocols::NETLINK_AUDIT)?; + new_connection_with_socket() +} + +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + proto::Connection, + Handle, + UnboundedReceiver<( + packet::NetlinkMessage, + sys::SocketAddr, + )>, +)> +where + S: sys::AsyncSocket, +{ + let (conn, handle, messages) = + netlink_proto::new_connection_with_codec(sys::protocols::NETLINK_AUDIT)?; Ok((conn, Handle::new(handle), messages)) } diff --git a/ethtool/Cargo.toml b/ethtool/Cargo.toml index 675a4c86..f6a92ad5 100644 --- a/ethtool/Cargo.toml +++ b/ethtool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ethtool" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Gris Ge "] license = "MIT" edition = "2018" @@ -24,13 +24,13 @@ anyhow = "1.0.44" async-std = { version = "1.9.0", optional = true} byteorder = "1.4.3" futures = "0.3.17" -genetlink = { default-features = false, version = "0.1.0"} +genetlink = { default-features = false, version = "0.2.0"} log = "0.4.14" -netlink-packet-core = "0.2.4" -netlink-packet-generic = "0.1.0" +netlink-packet-core = "0.3.0" +netlink-packet-generic = "0.2.0" netlink-packet-utils = "0.4.1" -netlink-proto = { default-features = false, version = "0.7.0" } -netlink-sys = "0.7.0" +netlink-proto = { default-features = false, version = "0.8.0" } +netlink-sys = "0.8.0" thiserror = "1.0.29" tokio = { version = "1.0.1", features = ["rt"], optional = true} diff --git a/ethtool/src/connection.rs b/ethtool/src/connection.rs index c9d9ef2c..165f31ea 100644 --- a/ethtool/src/connection.rs +++ b/ethtool/src/connection.rs @@ -4,16 +4,29 @@ use futures::channel::mpsc::UnboundedReceiver; use genetlink::message::RawGenlMessage; use netlink_packet_core::NetlinkMessage; use netlink_proto::Connection; -use netlink_sys::SocketAddr; +use netlink_sys::{AsyncSocket, SocketAddr}; use crate::EthtoolHandle; +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( Connection, EthtoolHandle, UnboundedReceiver<(NetlinkMessage, SocketAddr)>, )> { - let (conn, handle, messages) = genetlink::new_connection()?; + new_connection_with_socket() +} + +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + Connection, + EthtoolHandle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> +where + S: AsyncSocket, +{ + let (conn, handle, messages) = genetlink::new_connection_with_socket()?; Ok((conn, EthtoolHandle::new(handle), messages)) } diff --git a/ethtool/src/lib.rs b/ethtool/src/lib.rs index cf986aab..f1a43c2f 100644 --- a/ethtool/src/lib.rs +++ b/ethtool/src/lib.rs @@ -11,7 +11,9 @@ mod pause; mod ring; pub use coalesce::{EthtoolCoalesceAttr, EthtoolCoalesceGetRequest, EthtoolCoalesceHandle}; +#[cfg(feature = "tokio_socket")] pub use connection::new_connection; +pub use connection::new_connection_with_socket; pub use error::EthtoolError; pub use feature::{ EthtoolFeatureAttr, diff --git a/genetlink/Cargo.toml b/genetlink/Cargo.toml index 2e15b8c2..461a108a 100644 --- a/genetlink/Cargo.toml +++ b/genetlink/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "genetlink" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Leo "] edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -12,17 +12,17 @@ description = "communicate with generic netlink" [features] default = ["tokio_socket"] -tokio_socket = ["netlink-proto/tokio_socket","netlink-proto/workaround-audit-bug", "tokio"] -smol_socket = ["netlink-proto/smol_socket","netlink-proto/workaround-audit-bug","async-std"] +tokio_socket = ["netlink-proto/tokio_socket", "tokio"] +smol_socket = ["netlink-proto/smol_socket","async-std"] [dependencies] futures = "0.3.16" -netlink-packet-generic = "0.1.0" -netlink-proto = { default-features = false, version = "0.7.0" } +netlink-packet-generic = "0.2.0" +netlink-proto = { default-features = false, version = "0.8.0" } tokio = { version = "1.9.0", features = ["rt"], optional = true } async-std = { version = "1.9.0", optional = true } netlink-packet-utils = "0.4.1" -netlink-packet-core = "0.2.4" +netlink-packet-core = "0.3.0" thiserror = "1.0.26" [dev-dependencies] diff --git a/genetlink/src/connection.rs b/genetlink/src/connection.rs index d82af5e7..34eb2623 100644 --- a/genetlink/src/connection.rs +++ b/genetlink/src/connection.rs @@ -3,7 +3,7 @@ use futures::channel::mpsc::UnboundedReceiver; use netlink_packet_core::NetlinkMessage; use netlink_proto::{ self, - sys::{protocols::NETLINK_GENERIC, SocketAddr}, + sys::{protocols::NETLINK_GENERIC, AsyncSocket, SocketAddr}, Connection, }; use std::io; @@ -21,12 +21,26 @@ use std::io; /// /// The [`GenetlinkHandle`] can send and receive any type of generic netlink message. /// And it can automatic resolve the generic family id before sending. +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( Connection, GenetlinkHandle, UnboundedReceiver<(NetlinkMessage, SocketAddr)>, )> { - let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_GENERIC)?; + new_connection_with_socket() +} + +/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + Connection, + GenetlinkHandle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> +where + S: AsyncSocket, +{ + let (conn, handle, messages) = netlink_proto::new_connection_with_socket(NETLINK_GENERIC)?; Ok((conn, GenetlinkHandle::new(handle), messages)) } diff --git a/genetlink/src/handle.rs b/genetlink/src/handle.rs index 1a312913..935b81a8 100644 --- a/genetlink/src/handle.rs +++ b/genetlink/src/handle.rs @@ -78,13 +78,7 @@ impl GenetlinkHandle { GenetlinkError, > where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { self.resolve_message_family_id(&mut message).await?; self.send_request(message) @@ -102,13 +96,7 @@ impl GenetlinkHandle { GenetlinkError, > where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { let raw_msg = map_to_rawgenlmsg(message); @@ -122,13 +110,7 @@ impl GenetlinkHandle { mut message: NetlinkMessage>, ) -> Result<(), GenetlinkError> where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { self.resolve_message_family_id(&mut message).await?; self.send_notify(message) @@ -140,13 +122,7 @@ impl GenetlinkHandle { message: NetlinkMessage>, ) -> Result<(), GenetlinkError> where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { let raw_msg = map_to_rawgenlmsg(message); @@ -159,7 +135,7 @@ impl GenetlinkHandle { message: &mut NetlinkMessage>, ) -> Result<(), GenetlinkError> where - F: GenlFamily + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Debug, { if let NetlinkPayload::InnerMessage(genlmsg) = &mut message.payload { if genlmsg.family_id() == 0 { diff --git a/genetlink/src/lib.rs b/genetlink/src/lib.rs index e145b43f..a3b1e9e1 100644 --- a/genetlink/src/lib.rs +++ b/genetlink/src/lib.rs @@ -7,6 +7,8 @@ mod handle; pub mod message; mod resolver; +#[cfg(feature = "tokio_socket")] pub use connection::new_connection; +pub use connection::new_connection_with_socket; pub use error::GenetlinkError; pub use handle::GenetlinkHandle; diff --git a/genetlink/src/message.rs b/genetlink/src/message.rs index 7cc0fcc2..0b5ffabb 100644 --- a/genetlink/src/message.rs +++ b/genetlink/src/message.rs @@ -57,7 +57,7 @@ impl RawGenlMessage { /// Serialize the generic netlink payload into raw bytes pub fn from_genlmsg(genlmsg: GenlMessage) -> Self where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { let mut payload_buf = vec![0u8; genlmsg.payload.buffer_len()]; genlmsg.payload.emit(&mut payload_buf); @@ -72,7 +72,7 @@ impl RawGenlMessage { /// Try to deserialize the generic netlink payload from raw bytes pub fn parse_into_genlmsg(&self) -> Result, DecodeError> where - F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Debug, { let inner = F::parse_with_param(&self.payload, self.header)?; Ok(GenlMessage::new(self.header, inner, self.family_id)) @@ -107,7 +107,7 @@ where } } -impl NetlinkSerializable for RawGenlMessage { +impl NetlinkSerializable for RawGenlMessage { fn message_type(&self) -> u16 { self.family_id } @@ -121,7 +121,7 @@ impl NetlinkSerializable for RawGenlMessage { } } -impl NetlinkDeserializable for RawGenlMessage { +impl NetlinkDeserializable for RawGenlMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { let buffer = GenlBuffer::new_checked(payload)?; @@ -141,7 +141,7 @@ pub fn map_to_rawgenlmsg( message: NetlinkMessage>, ) -> NetlinkMessage where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { let raw_payload = match message.payload { NetlinkPayload::InnerMessage(genlmsg) => { @@ -162,7 +162,7 @@ pub fn map_from_rawgenlmsg( raw_msg: NetlinkMessage, ) -> Result>, DecodeError> where - F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Debug, { let payload = match raw_msg.payload { NetlinkPayload::InnerMessage(raw_genlmsg) => { diff --git a/netlink-packet-audit/Cargo.toml b/netlink-packet-audit/Cargo.toml index 245c50e3..fabd09d8 100644 --- a/netlink-packet-audit/Cargo.toml +++ b/netlink-packet-audit/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-packet-audit" -version = "0.2.2" +version = "0.3.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -13,9 +13,12 @@ description = "netlink packet types" [dependencies] anyhow = "1.0.31" +bytes = "1.0" byteorder = "1.3.2" -netlink-packet-core = "0.2" +log = "0.4.8" +netlink-packet-core = "0.3" netlink-packet-utils = ">= 0.3, <0.5" +netlink-proto = { default-features = false, version = "0.8" } [dev-dependencies] lazy_static = "1.4.0" diff --git a/netlink-packet-audit/fuzz/Cargo.toml b/netlink-packet-audit/fuzz/Cargo.toml index 04067f92..5d2573d9 100644 --- a/netlink-packet-audit/fuzz/Cargo.toml +++ b/netlink-packet-audit/fuzz/Cargo.toml @@ -9,8 +9,8 @@ edition = "2018" cargo-fuzz = true [dependencies] -netlink-packet-audit = "0.2" -netlink-packet-core = "0.2" +netlink-packet-audit = "0.3" +netlink-packet-core = "0.3" libfuzzer-sys = { git = "https://github.com/rust-fuzz/libfuzzer-sys.git" } [[bin]] diff --git a/netlink-packet-audit/src/codec.rs b/netlink-packet-audit/src/codec.rs new file mode 100644 index 00000000..36f226af --- /dev/null +++ b/netlink-packet-audit/src/codec.rs @@ -0,0 +1,135 @@ +use std::{fmt::Debug, io}; + +use bytes::BytesMut; +use netlink_packet_core::{ + NetlinkBuffer, + NetlinkDeserializable, + NetlinkMessage, + NetlinkSerializable, +}; +pub(crate) use netlink_proto::{NetlinkCodec, NetlinkMessageCodec}; + +/// audit specific implementation of [`NetlinkMessageCodec`] due to the +/// protocol violations in messages generated by kernal audit. +/// +/// Among the known bugs in kernel audit messages: +/// - `nlmsg_len` sometimes contains the padding too (it shouldn't) +/// - `nlmsg_len` sometimes doesn't contain the header (it really should) +/// +/// See also: +/// - https://blog.des.no/2020/08/netlink-auditing-and-counting-bytes/ +/// - https://github.com/torvalds/linux/blob/b5013d084e03e82ceeab4db8ae8ceeaebe76b0eb/kernel/audit.c#L2386 +/// - https://github.com/mozilla/libaudit-go/issues/24 +/// - https://github.com/linux-audit/audit-userspace/issues/78 +pub struct NetlinkAuditCodec { + // we don't need an instance of this, just the type + _private: (), +} + +impl NetlinkMessageCodec for NetlinkAuditCodec { + fn decode(src: &mut BytesMut) -> io::Result>> + where + T: NetlinkDeserializable + Debug, + { + debug!("NetlinkAuditCodec: decoding next message"); + + loop { + // If there's nothing to read, return Ok(None) + if src.as_ref().is_empty() { + trace!("buffer is empty"); + src.clear(); + return Ok(None); + } + + // This is a bit hacky because we don't want to keep `src` + // borrowed, since we need to mutate it later. + let len_res = match NetlinkBuffer::new_checked(src.as_ref()) { + Ok(buf) => { + if (src.as_ref().len() as isize - buf.length() as isize) <= 16 { + // The audit messages are sometimes truncated, + // because the length specified in the header, + // does not take the header itself into + // account. To workaround this, we tweak the + // length. We've noticed two occurences of + // truncated packets: + // + // - the length of the header is not included (see also: + // https://github.com/mozilla/libaudit-go/issues/24) + // - some rule message have some padding for alignment (see + // https://github.com/linux-audit/audit-userspace/issues/78) which is not + // taken into account in the buffer length. + warn!("found what looks like a truncated audit packet"); + Ok(src.as_ref().len()) + } else { + Ok(buf.length() as usize) + } + } + Err(e) => { + // We either received a truncated packet, or the + // packet if malformed (invalid length field). In + // both case, we can't decode the datagram, and we + // cannot find the start of the next one (if + // any). The only solution is to clear the buffer + // and potentially lose some datagrams. + error!("failed to decode datagram: {:?}: {:#x?}.", e, src.as_ref()); + Err(()) + } + }; + + if len_res.is_err() { + error!("clearing the whole socket buffer. Datagrams may have been lost"); + src.clear(); + return Ok(None); + } + + let len = len_res.unwrap(); + + let bytes = { + let mut bytes = src.split_to(len); + { + let mut buf = NetlinkBuffer::new(bytes.as_mut()); + // If the buffer contains more bytes than what the header says the length is, it + // means we ran into a malformed packet (see comment above), and we just set the + // "right" length ourself, so that parsing does not fail. + // + // How do we know that's the right length? Due to an implementation detail and to + // the fact that netlink is a datagram protocol. + // + // - our implementation of Stream always calls the codec with at most 1 message in + // the buffer, so we know the extra bytes do not belong to another message. + // - because netlink is a datagram protocol, we receive entire messages, so we know + // that if those extra bytes do not belong to another message, they belong to + // this one. + if len != buf.length() as usize { + warn!( + "setting packet length to {} instead of {}", + len, + buf.length() + ); + buf.set_length(len as u32); + } + } + bytes + }; + + let parsed = NetlinkMessage::::deserialize(&bytes); + match parsed { + Ok(packet) => { + trace!("<<< {:?}", packet); + return Ok(Some(packet)); + } + Err(e) => { + error!("failed to decode packet {:#x?}: {}", &bytes, e); + // continue looping, there may be more datagrams in the buffer + } + } + } + } + + fn encode(msg: NetlinkMessage, buf: &mut BytesMut) -> io::Result<()> + where + T: Debug + NetlinkSerializable, + { + NetlinkCodec::encode(msg, buf) + } +} diff --git a/netlink-packet-audit/src/lib.rs b/netlink-packet-audit/src/lib.rs index a80d4340..19bb06bb 100644 --- a/netlink-packet-audit/src/lib.rs +++ b/netlink-packet-audit/src/lib.rs @@ -1,3 +1,6 @@ +#[macro_use] +extern crate log; + pub(crate) extern crate netlink_packet_utils as utils; pub use self::utils::{traits, DecodeError}; pub use netlink_packet_core::{ @@ -13,6 +16,9 @@ use core::ops::Range; /// Represent a multi-bytes field with a fixed size in a packet pub(crate) type Field = Range; +mod codec; +pub use codec::NetlinkAuditCodec; + pub mod status; pub use self::status::*; diff --git a/netlink-packet-audit/src/message.rs b/netlink-packet-audit/src/message.rs index 73dd345b..4a13b921 100644 --- a/netlink-packet-audit/src/message.rs +++ b/netlink-packet-audit/src/message.rs @@ -106,7 +106,7 @@ impl Emitable for AuditMessage { } } -impl NetlinkSerializable for AuditMessage { +impl NetlinkSerializable for AuditMessage { fn message_type(&self) -> u16 { self.message_type() } @@ -120,7 +120,7 @@ impl NetlinkSerializable for AuditMessage { } } -impl NetlinkDeserializable for AuditMessage { +impl NetlinkDeserializable for AuditMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { match AuditBuffer::new_checked(payload) { diff --git a/netlink-packet-audit/src/rules/buffer.rs b/netlink-packet-audit/src/rules/buffer.rs index 16abc5ed..e0a9dbcc 100644 --- a/netlink-packet-audit/src/rules/buffer.rs +++ b/netlink-packet-audit/src/rules/buffer.rs @@ -175,14 +175,8 @@ impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for RuleMessage { let mut offset = 0; - let fields = buf - .fields() - .chunks(4) - .map(|chunk| NativeEndian::read_u32(chunk)); - let values = buf - .values() - .chunks(4) - .map(|chunk| NativeEndian::read_u32(chunk)); + let fields = buf.fields().chunks(4).map(NativeEndian::read_u32); + let values = buf.values().chunks(4).map(NativeEndian::read_u32); let field_flags = buf .field_flags() .chunks(4) diff --git a/netlink-packet-core/Cargo.toml b/netlink-packet-core/Cargo.toml index f46fb4fc..ccca7e14 100644 --- a/netlink-packet-core/Cargo.toml +++ b/netlink-packet-core/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-packet-core" -version = "0.2.4" +version = "0.3.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -18,4 +18,4 @@ libc = "0.2.66" netlink-packet-utils = ">=0.3, <0.5" [dev-dependencies] -netlink-packet-route = "0.8" +netlink-packet-route = "0.9" diff --git a/netlink-packet-core/examples/protocol.rs b/netlink-packet-core/examples/protocol.rs index 7c092e0b..59dc1a6d 100644 --- a/netlink-packet-core/examples/protocol.rs +++ b/netlink-packet-core/examples/protocol.rs @@ -44,7 +44,7 @@ impl fmt::Display for DeserializeError { } // NetlinkDeserializable implementation -impl NetlinkDeserializable for PingPongMessage { +impl NetlinkDeserializable for PingPongMessage { type Error = DeserializeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { @@ -59,7 +59,7 @@ impl NetlinkDeserializable for PingPongMessage { } // NetlinkSerializable implementation -impl NetlinkSerializable for PingPongMessage { +impl NetlinkSerializable for PingPongMessage { fn message_type(&self) -> u16 { match self { PingPongMessage::Ping(_) => PING_MESSAGE, diff --git a/netlink-packet-core/src/lib.rs b/netlink-packet-core/src/lib.rs index 7e724829..5faed295 100644 --- a/netlink-packet-core/src/lib.rs +++ b/netlink-packet-core/src/lib.rs @@ -156,7 +156,7 @@ //! } //! //! // NetlinkDeserializable implementation -//! impl NetlinkDeserializable for PingPongMessage { +//! impl NetlinkDeserializable for PingPongMessage { //! type Error = DeserializeError; //! //! fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { @@ -171,7 +171,7 @@ //! } //! //! // NetlinkSerializable implementation -//! impl NetlinkSerializable for PingPongMessage { +//! impl NetlinkSerializable for PingPongMessage { //! fn message_type(&self) -> u16 { //! match self { //! PingPongMessage::Ping(_) => PING_MESSAGE, diff --git a/netlink-packet-core/src/message.rs b/netlink-packet-core/src/message.rs index 497ca0c2..fb699116 100644 --- a/netlink-packet-core/src/message.rs +++ b/netlink-packet-core/src/message.rs @@ -18,20 +18,14 @@ use crate::{ /// Represent a netlink message. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct NetlinkMessage -where - I: Debug + PartialEq + Eq + Clone, -{ +pub struct NetlinkMessage { /// Message header (this is common to all the netlink protocols) pub header: NetlinkHeader, /// Inner message, which depends on the netlink protocol being used. pub payload: NetlinkPayload, } -impl NetlinkMessage -where - I: Debug + PartialEq + Eq + Clone, -{ +impl NetlinkMessage { /// Create a new netlink message from the given header and payload pub fn new(header: NetlinkHeader, payload: NetlinkPayload) -> Self { NetlinkMessage { header, payload } @@ -45,7 +39,7 @@ where impl NetlinkMessage where - I: NetlinkDeserializable + Debug + PartialEq + Eq + Clone, + I: NetlinkDeserializable + Debug, { /// Parse the given buffer as a netlink message pub fn deserialize(buffer: &[u8]) -> Result { @@ -56,7 +50,7 @@ where impl NetlinkMessage where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug, { /// Return the length of this message in bytes pub fn buffer_len(&self) -> usize { @@ -92,7 +86,7 @@ where impl<'buffer, B, I> Parseable> for NetlinkMessage where B: AsRef<[u8]> + 'buffer, - I: Debug + PartialEq + Eq + Clone + NetlinkDeserializable, + I: NetlinkDeserializable + Debug, { fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result { use self::NetlinkPayload::*; @@ -129,7 +123,7 @@ where impl Emitable for NetlinkMessage where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug, { fn buffer_len(&self) -> usize { use self::NetlinkPayload::*; @@ -163,7 +157,7 @@ where impl From for NetlinkMessage where - T: Into> + Debug + Clone + Eq + PartialEq, + T: Into> + Debug, { fn from(inner_message: T) -> Self { NetlinkMessage { diff --git a/netlink-packet-core/src/payload.rs b/netlink-packet-core/src/payload.rs index 499488b8..d6c598f0 100644 --- a/netlink-packet-core/src/payload.rs +++ b/netlink-packet-core/src/payload.rs @@ -14,10 +14,7 @@ pub const NLMSG_OVERRUN: u16 = 4; pub const NLMSG_ALIGNTO: u16 = 4; #[derive(Debug, PartialEq, Eq, Clone)] -pub enum NetlinkPayload -where - I: Debug + PartialEq + Eq + Clone, -{ +pub enum NetlinkPayload { Done, Error(ErrorMessage), Ack(AckMessage), @@ -28,7 +25,7 @@ where impl NetlinkPayload where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug, { pub fn message_type(&self) -> u16 { match self { diff --git a/netlink-packet-core/src/traits.rs b/netlink-packet-core/src/traits.rs index 368ba18a..0b9d0a7e 100644 --- a/netlink-packet-core/src/traits.rs +++ b/netlink-packet-core/src/traits.rs @@ -1,16 +1,15 @@ use crate::NetlinkHeader; use std::error::Error; -/// A `NetlinkDeserializable` type can be used to deserialize a buffer -/// into the target type `T` for which it is implemented. -pub trait NetlinkDeserializable { +/// A `NetlinkDeserializable` type can be deserialized from a buffer +pub trait NetlinkDeserializable: Sized { type Error: Error + Send + Sync + 'static; - /// Deserialize the given buffer into `T`. - fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result; + /// Deserialize the given buffer into `Self`. + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result; } -pub trait NetlinkSerializable { +pub trait NetlinkSerializable { fn message_type(&self) -> u16; /// Return the length of the serialized data. diff --git a/netlink-packet-generic/Cargo.toml b/netlink-packet-generic/Cargo.toml index b92a5d95..bfbbce79 100644 --- a/netlink-packet-generic/Cargo.toml +++ b/netlink-packet-generic/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "netlink-packet-generic" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Leo "] edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -14,8 +14,8 @@ description = "generic netlink packet types" anyhow = "1.0.39" libc = "0.2.86" byteorder = "1.4.2" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = "0.4" [dev-dependencies] -netlink-sys = { path = "../netlink-sys", version = "0.7" } +netlink-sys = { path = "../netlink-sys", version = "0.8" } diff --git a/netlink-packet-generic/src/buffer.rs b/netlink-packet-generic/src/buffer.rs index 7bd29d24..f7f0f365 100644 --- a/netlink-packet-generic/src/buffer.rs +++ b/netlink-packet-generic/src/buffer.rs @@ -12,7 +12,7 @@ buffer!(GenlBuffer(GENL_HDRLEN) { impl ParseableParametrized<[u8], u16> for GenlMessage where - F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: ParseableParametrized<[u8], GenlHeader> + Debug, { fn parse_with_param(buf: &[u8], message_type: u16) -> Result { let buf = GenlBuffer::new_checked(buf)?; @@ -22,7 +22,7 @@ where impl<'a, F, T> ParseableParametrized, u16> for GenlMessage where - F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: ParseableParametrized<[u8], GenlHeader> + Debug, T: AsRef<[u8]> + ?Sized, { fn parse_with_param(buf: &GenlBuffer<&'a T>, message_type: u16) -> Result { diff --git a/netlink-packet-generic/src/message.rs b/netlink-packet-generic/src/message.rs index 3d0bfde6..5ca5a961 100644 --- a/netlink-packet-generic/src/message.rs +++ b/netlink-packet-generic/src/message.rs @@ -20,10 +20,7 @@ use netlink_packet_core::NetlinkMessage; /// The message can be serialize/deserialize if the type `F` implements [`GenlFamily`], /// [`Emitable`], and [`ParseableParametrized<[u8], GenlHeader>`](ParseableParametrized). #[derive(Clone, Debug, PartialEq, Eq)] -pub struct GenlMessage -where - F: Clone + Debug + PartialEq + Eq, -{ +pub struct GenlMessage { pub header: GenlHeader, pub payload: F, resolved_family_id: u16, @@ -31,7 +28,7 @@ where impl GenlMessage where - F: Clone + Debug + PartialEq + Eq, + F: Debug, { /// Construct the message pub fn new(header: GenlHeader, payload: F, family_id: u16) -> Self { @@ -85,7 +82,7 @@ where impl GenlMessage where - F: GenlFamily + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Debug, { /// Build the message from the payload /// @@ -132,7 +129,7 @@ where impl Emitable for GenlMessage where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { fn buffer_len(&self) -> usize { self.header.buffer_len() + self.payload.buffer_len() @@ -146,9 +143,9 @@ where } } -impl NetlinkSerializable> for GenlMessage +impl NetlinkSerializable for GenlMessage where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { fn message_type(&self) -> u16 { self.family_id() @@ -163,9 +160,9 @@ where } } -impl<'a, F> NetlinkDeserializable> for GenlMessage +impl NetlinkDeserializable for GenlMessage where - F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: ParseableParametrized<[u8], GenlHeader> + Debug, { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { @@ -176,7 +173,7 @@ where impl From> for NetlinkPayload> where - F: Clone + Debug + PartialEq + Eq, + F: Debug, { fn from(message: GenlMessage) -> Self { NetlinkPayload::InnerMessage(message) diff --git a/netlink-packet-generic/tests/query_family_id.rs b/netlink-packet-generic/tests/query_family_id.rs index a74eadf4..da639eae 100644 --- a/netlink-packet-generic/tests/query_family_id.rs +++ b/netlink-packet-generic/tests/query_family_id.rs @@ -26,8 +26,7 @@ fn query_family_id() { socket.send(&txbuf, 0).unwrap(); - let mut rxbuf = vec![0u8; 2048]; - socket.recv(&mut rxbuf, 0).unwrap(); + let (rxbuf, _addr) = socket.recv_from_full().unwrap(); let rx_packet = >>::deserialize(&rxbuf).unwrap(); if let NetlinkPayload::InnerMessage(genlmsg) = rx_packet.payload { diff --git a/netlink-packet-route/Cargo.toml b/netlink-packet-route/Cargo.toml index 4e45b2ac..d47c3020 100644 --- a/netlink-packet-route/Cargo.toml +++ b/netlink-packet-route/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-packet-route" -version = "0.8.0" +version = "0.9.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -15,7 +15,7 @@ description = "netlink packet types" anyhow = "1.0.31" byteorder = "1.3.2" libc = "0.2.66" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = "0.4" bitflags = "1.2.1" @@ -26,7 +26,7 @@ name = "dump_links" criterion = "0.3.0" pcap-file = "1.1.1" lazy_static = "1.4.0" -netlink-sys = "0.7" +netlink-sys = "0.8" [[bench]] name = "link_message" diff --git a/netlink-packet-route/examples/dump_links.rs b/netlink-packet-route/examples/dump_links.rs index cfcaa1d0..3781445f 100644 --- a/netlink-packet-route/examples/dump_links.rs +++ b/netlink-packet-route/examples/dump_links.rs @@ -37,7 +37,7 @@ fn main() { // we set the NLM_F_DUMP flag so we expect a multipart rx_packet in response. loop { - let size = socket.recv(&mut receive_buffer[..], 0).unwrap(); + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); loop { let bytes = &receive_buffer[offset..]; diff --git a/netlink-packet-route/examples/dump_neighbours.rs b/netlink-packet-route/examples/dump_neighbours.rs index 89c06b20..3816824a 100644 --- a/netlink-packet-route/examples/dump_neighbours.rs +++ b/netlink-packet-route/examples/dump_neighbours.rs @@ -38,7 +38,7 @@ fn main() { let mut offset = 0; 'outer: loop { - let size = socket.recv(&mut receive_buffer[..], 0).unwrap(); + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); loop { let bytes = &receive_buffer[offset..]; diff --git a/netlink-packet-route/examples/dump_rules.rs b/netlink-packet-route/examples/dump_rules.rs index 2b3566e5..a239542d 100644 --- a/netlink-packet-route/examples/dump_rules.rs +++ b/netlink-packet-route/examples/dump_rules.rs @@ -41,7 +41,7 @@ fn main() { let mut offset = 0; // we set the NLM_F_DUMP flag so we expect a multipart rx_packet in response. - while let Ok(size) = socket.recv(&mut receive_buffer[..], 0) { + while let Ok(size) = socket.recv(&mut &mut receive_buffer[..], 0) { loop { let bytes = &receive_buffer[offset..]; let rx_packet = >::deserialize(bytes).unwrap(); diff --git a/netlink-packet-route/fuzz/Cargo.toml b/netlink-packet-route/fuzz/Cargo.toml index 741e1626..99888b57 100644 --- a/netlink-packet-route/fuzz/Cargo.toml +++ b/netlink-packet-route/fuzz/Cargo.toml @@ -9,7 +9,7 @@ edition = "2018" cargo-fuzz = true [dependencies] -netlink-packet-route = "0.7" +netlink-packet-route = "0.9" libfuzzer-sys = { git = "https://github.com/rust-fuzz/libfuzzer-sys.git" } [[bin]] diff --git a/netlink-packet-route/src/rtnl/message.rs b/netlink-packet-route/src/rtnl/message.rs index e1627bc3..46baf2ce 100644 --- a/netlink-packet-route/src/rtnl/message.rs +++ b/netlink-packet-route/src/rtnl/message.rs @@ -356,7 +356,7 @@ impl Emitable for RtnlMessage { } } -impl NetlinkSerializable for RtnlMessage { +impl NetlinkSerializable for RtnlMessage { fn message_type(&self) -> u16 { self.message_type() } @@ -370,7 +370,7 @@ impl NetlinkSerializable for RtnlMessage { } } -impl NetlinkDeserializable for RtnlMessage { +impl NetlinkDeserializable for RtnlMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { let buf = RtnlMessageBuffer::new(payload); diff --git a/netlink-packet-sock-diag/Cargo.toml b/netlink-packet-sock-diag/Cargo.toml index e3d09cec..dfd52347 100644 --- a/netlink-packet-sock-diag/Cargo.toml +++ b/netlink-packet-sock-diag/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Flier Lu ", "Corentin Henry "] name = "netlink-packet-sock-diag" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -14,7 +14,7 @@ description = "netlink packet types for the sock_diag subprotocol" [dependencies] anyhow = "1.0.32" byteorder = "1.3.4" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = ">= 0.3, <0.5" bitflags = "1.2.1" libc = "0.2.77" @@ -22,4 +22,4 @@ smallvec = "1.4.2" [dev-dependencies] lazy_static = "1.4.0" -netlink-sys = "0.7" +netlink-sys = "0.8" diff --git a/netlink-packet-sock-diag/examples/dump_ipv4.rs b/netlink-packet-sock-diag/examples/dump_ipv4.rs index fc9342e8..d4f8287e 100644 --- a/netlink-packet-sock-diag/examples/dump_ipv4.rs +++ b/netlink-packet-sock-diag/examples/dump_ipv4.rs @@ -46,7 +46,7 @@ fn main() { let mut receive_buffer = vec![0; 4096]; let mut offset = 0; - while let Ok(size) = socket.recv(&mut receive_buffer[..], 0) { + while let Ok(size) = socket.recv(&mut &mut receive_buffer[..], 0) { loop { let bytes = &receive_buffer[offset..]; let rx_packet = >::deserialize(bytes).unwrap(); diff --git a/netlink-packet-sock-diag/src/message.rs b/netlink-packet-sock-diag/src/message.rs index 5a84b5aa..6ab52e86 100644 --- a/netlink-packet-sock-diag/src/message.rs +++ b/netlink-packet-sock-diag/src/message.rs @@ -64,7 +64,7 @@ impl Emitable for SockDiagMessage { } } -impl NetlinkSerializable for SockDiagMessage { +impl NetlinkSerializable for SockDiagMessage { fn message_type(&self) -> u16 { self.message_type() } @@ -78,7 +78,7 @@ impl NetlinkSerializable for SockDiagMessage { } } -impl NetlinkDeserializable for SockDiagMessage { +impl NetlinkDeserializable for SockDiagMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { let buffer = SockDiagBuffer::new_checked(&payload)?; diff --git a/netlink-proto/Cargo.toml b/netlink-proto/Cargo.toml index 830f6698..f8d4f3aa 100644 --- a/netlink-proto/Cargo.toml +++ b/netlink-proto/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-proto" -version = "0.7.0" +version = "0.8.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -16,21 +16,19 @@ bytes = "1.0" log = "0.4.8" futures = "0.3" tokio = { version = "1.0", default-features = false, features = ["io-util"] } -tokio-util = { version = "0.6", default-features = false, features = ["codec"] } -netlink-packet-core = "0.2" -netlink-sys = { default-features = false, version = "0.7" } +netlink-packet-core = "0.3" +netlink-sys = { default-features = false, version = "0.8" } [features] default = ["tokio_socket"] tokio_socket = ["netlink-sys/tokio_socket"] smol_socket = ["netlink-sys/smol_socket"] -workaround-audit-bug = [] [dev-dependencies] env_logger = "0.8.2" tokio = { version = "1.0.1", default-features = false, features = ["macros", "rt-multi-thread"] } -netlink-packet-route = "0.8" -netlink-packet-audit = "0.2" +netlink-packet-route = "0.9" +netlink-packet-audit = "0.3" async-std = {version = "1.9.0", features = ["attributes"]} [[example]] @@ -42,4 +40,3 @@ required-features = ["smol_socket"] [[example]] name = "audit_events" -required-features = ["workaround-audit-bug"] diff --git a/netlink-proto/examples/audit_events.rs b/netlink-proto/examples/audit_events.rs index 1c451f7c..de82ccab 100644 --- a/netlink-proto/examples/audit_events.rs +++ b/netlink-proto/examples/audit_events.rs @@ -7,10 +7,7 @@ // Compilation: // ------------ // -// cargo build --example audit_events --features="workaround-audit-bug" -// -// Note that the audit protocol has a bug that we have to workaround, -// hence the custom --features flag for that protocol +// cargo build --example audit_events // // Usage: // ------ diff --git a/netlink-proto/src/codecs.rs b/netlink-proto/src/codecs.rs index d5c1af48..1b1bd737 100644 --- a/netlink-proto/src/codecs.rs +++ b/netlink-proto/src/codecs.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, io, marker::PhantomData}; +use std::{fmt::Debug, io}; use bytes::{BufMut, BytesMut}; use netlink_packet_core::{ @@ -7,36 +7,48 @@ use netlink_packet_core::{ NetlinkMessage, NetlinkSerializable, }; -use tokio_util::codec::{Decoder, Encoder}; -pub struct NetlinkCodec { - phantom: PhantomData, -} +/// Protocol to serialize and deserialize messages to and from datagrams +/// +/// This is separate from `tokio_util::codec::{Decoder, Encoder}` as the implementations +/// rely on the buffer containing full datagrams; they won't work well with simple +/// bytestreams. +/// +/// Officially there should be exactly one implementation of this, but the audit +/// subsystem ignores way too many rules of the protocol, so they need a separate +/// implementation. +/// +/// Although one could make a tighter binding between `NetlinkMessageCodec` and +/// the message types (NetlinkDeserializable+NetlinkSerializable) it can handle, +/// this would put quite some overhead on subsystems that followed the spec - so +/// we simply default to the proper implementation (in `Connection`) and the +/// `audit` code needs to overwrite it. +pub trait NetlinkMessageCodec { + /// Decode message of given type from datagram payload + /// + /// There might be more than one message; this needs to be called until it + /// either returns `Ok(None)` or an error. + fn decode(src: &mut BytesMut) -> io::Result>> + where + T: NetlinkDeserializable + Debug; -impl Default for NetlinkCodec { - fn default() -> Self { - Self::new() - } + /// Encode message to (datagram) buffer + fn encode(msg: NetlinkMessage, buf: &mut BytesMut) -> io::Result<()> + where + T: NetlinkSerializable + Debug; } -impl NetlinkCodec { - pub fn new() -> Self { - NetlinkCodec { - phantom: PhantomData, - } - } +/// Standard implementation of `NetlinkMessageCodec` +pub struct NetlinkCodec { + // we don't need an instance of this, just the type + _private: (), } -// FIXME: it seems that for audit, we're receiving malformed packets. -// See https://github.com/mozilla/libaudit-go/issues/24 -impl Decoder for NetlinkCodec> -where - T: NetlinkDeserializable + Debug + Eq + PartialEq + Clone, -{ - type Item = NetlinkMessage; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { +impl NetlinkMessageCodec for NetlinkCodec { + fn decode(src: &mut BytesMut) -> io::Result>> + where + T: NetlinkDeserializable + Debug, + { debug!("NetlinkCodec: decoding next message"); loop { @@ -50,29 +62,7 @@ where // This is a bit hacky because we don't want to keep `src` // borrowed, since we need to mutate it later. let len_res = match NetlinkBuffer::new_checked(src.as_ref()) { - #[cfg(not(feature = "workaround-audit-bug"))] Ok(buf) => Ok(buf.length() as usize), - #[cfg(feature = "workaround-audit-bug")] - Ok(buf) => { - if (src.as_ref().len() as isize - buf.length() as isize) <= 16 { - // The audit messages are sometimes truncated, - // because the length specified in the header, - // does not take the header itself into - // account. To workaround this, we tweak the - // length. We've noticed two occurences of - // truncated packets: - // - // - the length of the header is not included (see also: - // https://github.com/mozilla/libaudit-go/issues/24) - // - some rule message have some padding for alignment (see - // https://github.com/linux-audit/audit-userspace/issues/78) which is not - // taken into account in the buffer length. - warn!("found what looks like a truncated audit packet"); - Ok(src.as_ref().len()) - } else { - Ok(buf.length() as usize) - } - } Err(e) => { // We either received a truncated packet, or the // packet if malformed (invalid length field). In @@ -93,35 +83,6 @@ where let len = len_res.unwrap(); - #[cfg(feature = "workaround-audit-bug")] - let bytes = { - let mut bytes = src.split_to(len); - { - let mut buf = NetlinkBuffer::new(bytes.as_mut()); - // If the buffer contains more bytes than what the header says the length is, it - // means we ran into a malformed packet (see comment above), and we just set the - // "right" length ourself, so that parsing does not fail. - // - // How do we know that's the right length? Due to an implementation detail and to - // the fact that netlink is a datagram protocol. - // - // - our implementation of Stream always calls the codec with at most 1 message in - // the buffer, so we know the extra bytes do not belong to another message. - // - because netlink is a datagram protocol, we receive entire messages, so we know - // that if those extra bytes do not belong to another message, they belong to - // this one. - if len != buf.length() as usize { - warn!( - "setting packet length to {} instead of {}", - len, - buf.length() - ); - buf.set_length(len as u32); - } - } - bytes - }; - #[cfg(not(feature = "workaround-audit-bug"))] let bytes = src.split_to(len); let parsed = NetlinkMessage::::deserialize(&bytes); @@ -137,15 +98,11 @@ where } } } -} - -impl Encoder> for NetlinkCodec> -where - T: Debug + Eq + PartialEq + Clone + NetlinkSerializable, -{ - type Error = io::Error; - fn encode(&mut self, msg: NetlinkMessage, buf: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(msg: NetlinkMessage, buf: &mut BytesMut) -> io::Result<()> + where + T: Debug + NetlinkSerializable, + { let msg_len = msg.buffer_len(); if buf.remaining_mut() < msg_len { // BytesMut can expand till usize::MAX... unlikely to hit this one. diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index 7406c198..8e022ecd 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -20,23 +20,28 @@ use netlink_packet_core::{ }; use crate::{ - codecs::NetlinkCodec, + codecs::{NetlinkCodec, NetlinkMessageCodec}, framed::NetlinkFramed, - sys::{Socket, SocketAddr}, + sys::{AsyncSocket, SocketAddr}, Protocol, Request, Response, }; +#[cfg(feature = "tokio_socket")] +use netlink_sys::TokioSocket as DefaultSocket; +#[cfg(not(feature = "tokio_socket"))] +type DefaultSocket = (); + /// Connection to a Netlink socket, running in the background. /// /// [`ConnectionHandle`](struct.ConnectionHandle.html) are used to pass new requests to the /// `Connection`, that in turn, sends them through the netlink socket. -pub struct Connection +pub struct Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, + T: Debug + NetlinkSerializable + NetlinkDeserializable, { - socket: NetlinkFramed>>, + socket: NetlinkFramed, protocol: Protocol>>, @@ -50,18 +55,20 @@ where socket_closed: bool, } -impl Connection +impl Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, + T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, + S: AsyncSocket, + C: NetlinkMessageCodec, { pub(crate) fn new( requests_rx: UnboundedReceiver>, unsolicited_messages_tx: UnboundedSender<(NetlinkMessage, SocketAddr)>, protocol: isize, ) -> io::Result { - let socket = Socket::new(protocol)?; + let socket = S::new(protocol)?; Ok(Connection { - socket: NetlinkFramed::new(socket, NetlinkCodec::>::new()), + socket: NetlinkFramed::new(socket), protocol: Protocol::new(), requests_rx: Some(requests_rx), unsolicited_messages_tx: Some(unsolicited_messages_tx), @@ -69,7 +76,7 @@ where }) } - pub fn socket_mut(&mut self) -> &mut Socket { + pub fn socket_mut(&mut self) -> &mut S { self.socket.get_mut() } @@ -250,9 +257,11 @@ where } } -impl Future for Connection +impl Future for Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, + T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, + S: AsyncSocket, + C: NetlinkMessageCodec, { type Output = (); diff --git a/netlink-proto/src/errors.rs b/netlink-proto/src/errors.rs index fe350f98..31173ec3 100644 --- a/netlink-proto/src/errors.rs +++ b/netlink-proto/src/errors.rs @@ -9,14 +9,14 @@ use netlink_packet_core::NetlinkMessage; #[derive(Debug)] pub struct Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { kind: ErrorKind, } impl Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { pub fn kind(&self) -> &ErrorKind { &self.kind @@ -30,7 +30,7 @@ where #[derive(Debug)] pub enum ErrorKind where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { /// The netlink connection is closed ConnectionClosed, @@ -44,7 +44,7 @@ where impl From> for Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { fn from(kind: ErrorKind) -> Error { Error { kind } @@ -53,7 +53,7 @@ where impl fmt::Display for Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::ErrorKind::*; @@ -67,7 +67,7 @@ where impl StdError for Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { fn description(&self) -> &str { use crate::ErrorKind::*; diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index 84a85a6d..b46c1d5b 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -1,20 +1,28 @@ -use bytes::{BufMut, BytesMut}; +use bytes::BytesMut; use std::{ + fmt::Debug, io, + marker::PhantomData, pin::Pin, - slice, task::{Context, Poll}, }; use futures::{Sink, Stream}; use log::error; -use tokio_util::codec::{Decoder, Encoder}; -use crate::sys::{Socket, SocketAddr}; - -pub struct NetlinkFramed { - socket: Socket, - codec: C, +use crate::{ + codecs::NetlinkMessageCodec, + sys::{AsyncSocket, SocketAddr}, +}; +use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable}; + +pub struct NetlinkFramed { + socket: S, + // see https://doc.rust-lang.org/nomicon/phantom-data.html + // "invariant" seems like the safe choice; using `fn(T) -> T` + // should make it invariant but still Send+Sync. + msg_type: PhantomData T>, // invariant + codec: PhantomData C>, // invariant reader: BytesMut, writer: BytesMut, in_addr: SocketAddr, @@ -22,16 +30,16 @@ pub struct NetlinkFramed { flushed: bool, } -impl Stream for NetlinkFramed +impl Stream for NetlinkFramed where - C: Decoder + Unpin, - C::Error: std::error::Error, + T: NetlinkDeserializable + Debug, + S: AsyncSocket, + C: NetlinkMessageCodec, { - type Item = (C::Item, SocketAddr); + type Item = (NetlinkMessage, SocketAddr); fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Self { - ref mut codec, ref mut socket, ref mut in_addr, ref mut reader, @@ -39,7 +47,7 @@ where } = Pin::get_mut(self); loop { - match codec.decode(reader) { + match C::decode::(reader) { Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))), Ok(None) => {} Err(e) => { @@ -51,31 +59,24 @@ where reader.clear(); reader.reserve(INITIAL_READER_CAPACITY); - *in_addr = unsafe { - // Read into the buffer without having to initialize the memory. - // - // safety: we know poll_recv_from never reads from the - // memory during a recv so it's fine to turn &mut - // [>] into &mut[u8] - let bytes = reader.chunk_mut(); - let bytes = slice::from_raw_parts_mut(bytes.as_mut_ptr(), bytes.len()); - match ready!(socket.poll_recv_from(cx, bytes)) { - Ok((n, addr)) => { - reader.advance_mut(n); - addr - } - Err(e) => { - error!("failed to read from netlink socket: {:?}", e); - return Poll::Ready(None); - } + *in_addr = match ready!(socket.poll_recv_from(cx, reader)) { + Ok(addr) => addr, + Err(e) => { + error!("failed to read from netlink socket: {:?}", e); + return Poll::Ready(None); } }; } } } -impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed { - type Error = C::Error; +impl Sink<(NetlinkMessage, SocketAddr)> for NetlinkFramed +where + T: NetlinkSerializable + Debug, + S: AsyncSocket, + C: NetlinkMessageCodec, +{ + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.flushed { @@ -88,11 +89,14 @@ impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed< Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, item: (Item, SocketAddr)) -> Result<(), Self::Error> { + fn start_send( + self: Pin<&mut Self>, + item: (NetlinkMessage, SocketAddr), + ) -> Result<(), Self::Error> { trace!("sending frame"); let (frame, out_addr) = item; let pin = self.get_mut(); - pin.codec.encode(frame, &mut pin.writer)?; + C::encode(frame, &mut pin.writer)?; pin.out_addr = out_addr; pin.flushed = false; trace!("frame encoded; length={}", pin.writer.len()); @@ -125,8 +129,7 @@ impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed< Err(io::Error::new( io::ErrorKind::Other, "failed to write entire datagram to socket", - ) - .into()) + )) }; Poll::Ready(res) @@ -144,14 +147,15 @@ impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed< const INITIAL_READER_CAPACITY: usize = 64 * 1024; const INITIAL_WRITER_CAPACITY: usize = 8 * 1024; -impl NetlinkFramed { +impl NetlinkFramed { /// Create a new `NetlinkFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. - pub fn new(socket: Socket, codec: C) -> NetlinkFramed { - NetlinkFramed { + pub fn new(socket: S) -> Self { + Self { socket, - codec, + msg_type: PhantomData, + codec: PhantomData, out_addr: SocketAddr::new(0, 0), in_addr: SocketAddr::new(0, 0), reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY), @@ -167,7 +171,7 @@ impl NetlinkFramed { /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. - pub fn get_ref(&self) -> &Socket { + pub fn get_ref(&self) -> &S { &self.socket } @@ -179,12 +183,12 @@ impl NetlinkFramed { /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. - pub fn get_mut(&mut self) -> &mut Socket { + pub fn get_mut(&mut self) -> &mut S { &mut self.socket } /// Consumes the `Framed`, returning its underlying I/O stream. - pub fn into_inner(self) -> Socket { + pub fn into_inner(self) -> S { self.socket } } diff --git a/netlink-proto/src/handle.rs b/netlink-proto/src/handle.rs index d47eb4e2..db2d7f9f 100644 --- a/netlink-proto/src/handle.rs +++ b/netlink-proto/src/handle.rs @@ -15,14 +15,14 @@ use crate::{ #[derive(Clone, Debug)] pub struct ConnectionHandle where - T: Debug + Clone + Eq + PartialEq, + T: Debug, { requests_tx: UnboundedSender>, } impl ConnectionHandle where - T: Debug + Clone + Eq + PartialEq, + T: Debug, { pub(crate) fn new(requests_tx: UnboundedSender>) -> Self { ConnectionHandle { requests_tx } diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index c9d4e774..3425e7a9 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -192,13 +192,13 @@ use std::{fmt::Debug, io}; pub use netlink_packet_core as packet; pub mod sys { - pub use netlink_sys::{protocols, SocketAddr}; + pub use netlink_sys::{protocols, AsyncSocket, AsyncSocketExt, SocketAddr}; #[cfg(feature = "tokio_socket")] - pub use netlink_sys::TokioSocket as Socket; + pub use netlink_sys::TokioSocket; #[cfg(feature = "smol_socket")] - pub use netlink_sys::SmolSocket as Socket; + pub use netlink_sys::SmolSocket; } /// Create a new Netlink connection for the given Netlink protocol, and returns a handle to that @@ -218,6 +218,7 @@ pub mod sys { /// handle to send messages. /// /// [protos]: crate::sys::protocols +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection( protocol: isize, @@ -227,13 +228,40 @@ pub fn new_connection( UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, )> where - T: Debug - + PartialEq - + Eq - + Clone - + packet::NetlinkSerializable - + packet::NetlinkDeserializable - + Unpin, + T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, +{ + new_connection_with_codec(protocol) +} + +/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket( + protocol: isize, +) -> io::Result<( + Connection, + ConnectionHandle, + UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, +)> +where + T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, + S: sys::AsyncSocket, +{ + new_connection_with_codec(protocol) +} + +/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling and a special codec +#[allow(clippy::type_complexity)] +pub fn new_connection_with_codec( + protocol: isize, +) -> io::Result<( + Connection, + ConnectionHandle, + UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, +)> +where + T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, + S: sys::AsyncSocket, + C: NetlinkMessageCodec, { let (requests_tx, requests_rx) = unbounded::>(); let (messages_tx, messages_rx) = unbounded::<(packet::NetlinkMessage, sys::SocketAddr)>(); diff --git a/netlink-proto/src/protocol/protocol.rs b/netlink-proto/src/protocol/protocol.rs index eec9bc03..56f2f4b3 100644 --- a/netlink-proto/src/protocol/protocol.rs +++ b/netlink-proto/src/protocol/protocol.rs @@ -30,11 +30,7 @@ impl RequestId { } #[derive(Debug, Eq, PartialEq)] -pub(crate) struct Response -where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, - M: Debug, -{ +pub(crate) struct Response { pub done: bool, pub message: NetlinkMessage, pub metadata: M, @@ -47,11 +43,7 @@ struct PendingRequest { } #[derive(Debug, Default)] -pub(crate) struct Protocol -where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, - M: Debug, -{ +pub(crate) struct Protocol { /// Counter that is incremented for each message sent sequence_id: u32, @@ -71,8 +63,8 @@ where impl Protocol where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, - M: Clone + Debug, + T: Debug + NetlinkSerializable + NetlinkDeserializable, + M: Debug + Clone, { pub fn new() -> Self { Self { diff --git a/netlink-proto/src/protocol/request.rs b/netlink-proto/src/protocol/request.rs index 881b4547..f5aa358c 100644 --- a/netlink-proto/src/protocol/request.rs +++ b/netlink-proto/src/protocol/request.rs @@ -5,11 +5,7 @@ use netlink_packet_core::NetlinkMessage; use crate::sys::SocketAddr; #[derive(Debug)] -pub(crate) struct Request -where - T: Debug + Clone + Eq + PartialEq, - M: Debug, -{ +pub(crate) struct Request { pub metadata: M, pub message: NetlinkMessage, pub destination: SocketAddr, @@ -17,7 +13,7 @@ where impl From<(NetlinkMessage, SocketAddr, M)> for Request where - T: Debug + PartialEq + Eq + Clone, + T: Debug, M: Debug, { fn from(parts: (NetlinkMessage, SocketAddr, M)) -> Self { @@ -31,7 +27,7 @@ where impl From> for (NetlinkMessage, SocketAddr, M) where - T: Debug + PartialEq + Eq + Clone, + T: Debug, M: Debug, { fn from(req: Request) -> (NetlinkMessage, SocketAddr, M) { diff --git a/netlink-sys/Cargo.toml b/netlink-sys/Cargo.toml index c7b95662..a8e59b4d 100644 --- a/netlink-sys/Cargo.toml +++ b/netlink-sys/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-sys" -version = "0.7.0" +version = "0.8.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -12,6 +12,7 @@ repository = "https://github.com/little-dude/netlink" description = "netlink sockets, with optional integration with tokio" [dependencies] +bytes = "1.0" libc = "0.2.66" log = "0.4.8" @@ -42,7 +43,7 @@ tokio_socket = ["tokio", "futures"] smol_socket = ["async-io","futures"] [dev-dependencies] -netlink-packet-audit = "0.2" +netlink-packet-audit = "0.3" [dev-dependencies.tokio] version = "1.0.1" diff --git a/netlink-sys/examples/audit_events_async_std.rs b/netlink-sys/examples/audit_events_async_std.rs index fde12fb5..55227764 100644 --- a/netlink-sys/examples/audit_events_async_std.rs +++ b/netlink-sys/examples/audit_events_async_std.rs @@ -22,7 +22,7 @@ use netlink_packet_audit::{ NLM_F_REQUEST, }; -use netlink_sys::{protocols::NETLINK_AUDIT, SmolSocket, SocketAddr}; +use netlink_sys::{protocols::NETLINK_AUDIT, AsyncSocket, AsyncSocketExt, SmolSocket, SocketAddr}; const AUDIT_STATUS_ENABLED: u32 = 1; const AUDIT_STATUS_PID: u32 = 4; @@ -50,20 +50,22 @@ async fn main() { .await .unwrap(); - let mut buf = vec![0; 1024 * 8]; + let mut buf = bytes::BytesMut::with_capacity(1024 * 8); loop { - let (n, _addr) = socket.recv_from(&mut buf).await.unwrap(); + buf.clear(); + let _addr = socket.recv_from(&mut buf).await.unwrap(); // This dance with the NetlinkBuffer should not be // necessary. It is here to work around a netlink bug. See: // https://github.com/mozilla/libaudit-go/issues/24 // https://github.com/linux-audit/audit-userspace/issues/78 { - let mut nl_buf = NetlinkBuffer::new(&mut buf[0..n]); + let n = buf.len(); + let mut nl_buf = NetlinkBuffer::new(&mut buf); if n != nl_buf.length() as usize { nl_buf.set_length(n as u32); } } - let parsed = NetlinkMessage::::deserialize(&buf[0..n]).unwrap(); + let parsed = NetlinkMessage::::deserialize(&buf).unwrap(); println!("<<< {:?}", parsed); } } diff --git a/netlink-sys/examples/audit_events_tokio.rs b/netlink-sys/examples/audit_events_tokio.rs index afdd5f74..44df4085 100644 --- a/netlink-sys/examples/audit_events_tokio.rs +++ b/netlink-sys/examples/audit_events_tokio.rs @@ -22,7 +22,7 @@ use netlink_packet_audit::{ NLM_F_REQUEST, }; -use netlink_sys::{protocols::NETLINK_AUDIT, SocketAddr, TokioSocket}; +use netlink_sys::{protocols::NETLINK_AUDIT, AsyncSocket, AsyncSocketExt, SocketAddr, TokioSocket}; const AUDIT_STATUS_ENABLED: u32 = 1; const AUDIT_STATUS_PID: u32 = 4; @@ -50,20 +50,22 @@ async fn main() -> Result<(), Box> { .await .unwrap(); - let mut buf = vec![0; 1024 * 8]; + let mut buf = bytes::BytesMut::with_capacity(1024 * 8); loop { - let (n, _addr) = socket.recv_from(&mut buf).await.unwrap(); + buf.clear(); + let _addr = socket.recv_from(&mut buf).await.unwrap(); // This dance with the NetlinkBuffer should not be // necessary. It is here to work around a netlink bug. See: // https://github.com/mozilla/libaudit-go/issues/24 // https://github.com/linux-audit/audit-userspace/issues/78 { - let mut nl_buf = NetlinkBuffer::new(&mut buf[0..n]); + let n = buf.len(); + let mut nl_buf = NetlinkBuffer::new(&mut buf); if n != nl_buf.length() as usize { nl_buf.set_length(n as u32); } } - let parsed = NetlinkMessage::::deserialize(&buf[0..n]).unwrap(); + let parsed = NetlinkMessage::::deserialize(&buf).unwrap(); println!("<<< {:?}", parsed); } } diff --git a/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs b/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs index 3d7e04a6..f0a0b87b 100644 --- a/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs +++ b/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs @@ -3,7 +3,7 @@ * to use netlink. */ -use netlink_sys::{protocols::NETLINK_AUDIT, TokioSocket}; +use netlink_sys::{protocols::NETLINK_AUDIT, AsyncSocket, TokioSocket}; fn main() -> Result<(), String> { let rt = tokio::runtime::Builder::new_multi_thread() diff --git a/netlink-sys/src/async_socket.rs b/netlink-sys/src/async_socket.rs new file mode 100644 index 00000000..3e15a0d6 --- /dev/null +++ b/netlink-sys/src/async_socket.rs @@ -0,0 +1,55 @@ +use std::{ + io, + task::{Context, Poll}, +}; + +use crate::{Socket, SocketAddr}; + +/// Trait to support different async backends +pub trait AsyncSocket: Sized + Unpin { + /// Access underyling [`Socket`] + fn socket_ref(&self) -> &Socket; + + /// Mutable access to underyling [`Socket`] + fn socket_mut(&mut self) -> &mut Socket; + + /// Wrapper for [`Socket::new`] + fn new(protocol: isize) -> io::Result; + + /// Polling wrapper for [`Socket::send`] + fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll>; + + /// Polling wrapper for [`Socket::send_to`] + fn poll_send_to( + &mut self, + cx: &mut Context<'_>, + buf: &[u8], + addr: &SocketAddr, + ) -> Poll>; + + /// Polling wrapper for [`Socket::recv`] + /// + /// Passes 0 for flags, and ignores the returned length (the buffer will have advanced by the amount read). + fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + B: bytes::BufMut; + + /// Polling wrapper for [`Socket::recv_from`] + /// + /// Passes 0 for flags, and ignores the returned length - just returns the address (the buffer will have advanced by the amount read). + fn poll_recv_from( + &mut self, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut; + + /// Polling wrapper for [`Socket::recv_from_full`] + /// + /// Passes 0 for flags, and ignores the returned length - just returns the address (the buffer will have advanced by the amount read). + fn poll_recv_from_full( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, SocketAddr)>>; +} diff --git a/netlink-sys/src/async_socket_ext.rs b/netlink-sys/src/async_socket_ext.rs new file mode 100644 index 00000000..8ed86502 --- /dev/null +++ b/netlink-sys/src/async_socket_ext.rs @@ -0,0 +1,141 @@ +use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{AsyncSocket, SocketAddr}; + +/// Support trait for [`AsyncSocket`] +/// +/// Provides awaitable variants of the poll functions from [`AsyncSocket`]. +pub trait AsyncSocketExt: AsyncSocket { + /// `async fn send(&mut self, buf: &[u8]) -> io::Result` + fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> PollSend<'a, 'b, Self> { + PollSend { socket: self, buf } + } + + /// `async fn send(&mut self, buf: &[u8]) -> io::Result` + fn send_to<'a, 'b>( + &'a mut self, + buf: &'b [u8], + addr: &'b SocketAddr, + ) -> PollSendTo<'a, 'b, Self> { + PollSendTo { + socket: self, + buf, + addr, + } + } + + /// `async fn recv(&mut self, buf: &mut [u8]) -> io::Result<()>` + fn recv<'a, 'b, B>(&'a mut self, buf: &'b mut B) -> PollRecv<'a, 'b, Self, B> + where + B: bytes::BufMut, + { + PollRecv { socket: self, buf } + } + + /// `async fn recv(&mut self, buf: &mut [u8]) -> io::Result` + fn recv_from<'a, 'b, B>(&'a mut self, buf: &'b mut B) -> PollRecvFrom<'a, 'b, Self, B> + where + B: bytes::BufMut, + { + PollRecvFrom { socket: self, buf } + } + + /// `async fn recrecv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)>` + fn recv_from_full(&mut self) -> PollRecvFromFull<'_, Self> { + PollRecvFromFull { socket: self } + } +} + +impl AsyncSocketExt for S {} + +pub struct PollSend<'a, 'b, S> { + socket: &'a mut S, + buf: &'b [u8], +} + +impl Future for PollSend<'_, '_, S> +where + S: AsyncSocket, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_send(cx, this.buf) + } +} + +pub struct PollSendTo<'a, 'b, S> { + socket: &'a mut S, + buf: &'b [u8], + addr: &'b SocketAddr, +} + +impl Future for PollSendTo<'_, '_, S> +where + S: AsyncSocket, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_send_to(cx, this.buf, this.addr) + } +} + +pub struct PollRecv<'a, 'b, S, B> { + socket: &'a mut S, + buf: &'b mut B, +} + +impl Future for PollRecv<'_, '_, S, B> +where + S: AsyncSocket, + B: bytes::BufMut, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_recv(cx, this.buf) + } +} + +pub struct PollRecvFrom<'a, 'b, S, B> { + socket: &'a mut S, + buf: &'b mut B, +} + +impl Future for PollRecvFrom<'_, '_, S, B> +where + S: AsyncSocket, + B: bytes::BufMut, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_recv_from(cx, this.buf) + } +} + +pub struct PollRecvFromFull<'a, S> { + socket: &'a mut S, +} + +impl Future for PollRecvFromFull<'_, S> +where + S: AsyncSocket, +{ + type Output = io::Result<(Vec, SocketAddr)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_recv_from_full(cx) + } +} diff --git a/netlink-sys/src/lib.rs b/netlink-sys/src/lib.rs index 69abdeec..6789c969 100644 --- a/netlink-sys/src/lib.rs +++ b/netlink-sys/src/lib.rs @@ -31,6 +31,12 @@ pub use self::socket::Socket; mod addr; pub use self::addr::SocketAddr; +mod async_socket; +pub use self::async_socket::AsyncSocket; + +pub mod async_socket_ext; +pub use self::async_socket_ext::AsyncSocketExt; + #[cfg(feature = "tokio_socket")] mod tokio; #[cfg(feature = "tokio_socket")] diff --git a/netlink-sys/src/smol.rs b/netlink-sys/src/smol.rs index 77bcaf95..e3b67d94 100644 --- a/netlink-sys/src/smol.rs +++ b/netlink-sys/src/smol.rs @@ -10,165 +10,112 @@ use futures::ready; use log::trace; -use crate::{Socket, SocketAddr}; +use crate::{AsyncSocket, Socket, SocketAddr}; /// An I/O object representing a Netlink socket. pub struct SmolSocket(Async); -impl SmolSocket { - pub fn new(protocol: isize) -> io::Result { - let socket = Socket::new(protocol)?; - Ok(SmolSocket(Async::new(socket)?)) - } - - pub fn bind(&mut self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_mut().bind(addr) - } - - pub fn bind_auto(&mut self) -> io::Result { - self.0.get_mut().bind_auto() - } - - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_ref().connect(addr) - } - - pub async fn send(&mut self, buf: &[u8]) -> io::Result { - self.0.write_with_mut(|sock| sock.send(buf, 0)).await - } - - pub async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> io::Result { - self.0 - .write_with_mut(|sock| sock.send_to(buf, addr, 0)) - .await - } - - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read_with_mut(|sock| sock.recv(buf, 0)).await - } - - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.0.read_with_mut(|sock| sock.recv_from(buf, 0)).await +impl FromRawFd for SmolSocket { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + let socket = Socket::from_raw_fd(fd); + socket.set_non_blocking(true).unwrap(); + SmolSocket(Async::new(socket).unwrap()) } +} - pub async fn recv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)> { - self.0.read_with_mut(|sock| sock.recv_from_full()).await +impl AsRawFd for SmolSocket { + fn as_raw_fd(&self) -> RawFd { + self.0.get_ref().as_raw_fd() } +} - pub fn poll_recv_from( - &mut self, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { +// async_io::Async<..>::{read,write}_with[_mut] functions try IO first, +// and only register context if it would block. +// replicate this in these poll functions: +impl SmolSocket { + fn poll_write_with(&mut self, cx: &mut Context<'_>, mut op: F) -> Poll> + where + F: FnMut(&mut Self) -> io::Result, + { loop { - trace!("poll_recv_from called"); - let _guard = ready!(self.0.poll_readable(cx))?; - trace!("poll_recv_from socket is ready for reading"); - - match self.0.get_ref().recv_from(buf, 0) { - Ok(x) => { - trace!("poll_recv_from {:?} bytes read", x); - return Poll::Ready(Ok(x)); - } - Err(_would_block) => { - trace!("poll_recv_from socket would block"); - continue; - } + match op(self) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), } + // try again if writable now, otherwise come back later: + ready!(self.0.poll_writable(cx))?; } } - pub fn poll_send_to( - &mut self, - cx: &mut Context, - buf: &[u8], - addr: &SocketAddr, - ) -> Poll> { + fn poll_read_with(&mut self, cx: &mut Context<'_>, mut op: F) -> Poll> + where + F: FnMut(&mut Self) -> io::Result, + { loop { - let _guard = ready!(self.0.poll_writable(cx))?; - - match self.0.get_ref().send_to(buf, addr, 0) { - Ok(x) => return Poll::Ready(Ok(x)), - Err(_would_block) => continue, + match op(self) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), } + // try again if readable now, otherwise come back later: + ready!(self.0.poll_readable(cx))?; } } +} - pub fn set_pktinfo(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_pktinfo(value) - } - - pub fn get_pktinfo(&self) -> io::Result { - self.0.get_ref().get_pktinfo() - } - - pub fn add_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().add_membership(group) - } - - pub fn drop_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().drop_membership(group) - } - - // pub fn list_membership(&self) -> Vec { - // self.0.get_ref().list_membership() - // } - - /// `NETLINK_BROADCAST_ERROR` (since Linux 2.6.30). When not set, `netlink_broadcast()` only - /// reports `ESRCH` errors and silently ignore `NOBUFS` errors. - pub fn set_broadcast_error(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_broadcast_error(value) - } - - pub fn get_broadcast_error(&self) -> io::Result { - self.0.get_ref().get_broadcast_error() - } - - /// `NETLINK_NO_ENOBUFS` (since Linux 2.6.30). This flag can be used by unicast and broadcast - /// listeners to avoid receiving `ENOBUFS` errors. - pub fn set_no_enobufs(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_no_enobufs(value) - } - - pub fn get_no_enobufs(&self) -> io::Result { - self.0.get_ref().get_no_enobufs() +impl AsyncSocket for SmolSocket { + fn socket_ref(&self) -> &Socket { + self.0.get_ref() } - /// `NETLINK_LISTEN_ALL_NSID` (since Linux 4.2). When set, this socket will receive netlink - /// notifications from all network namespaces that have an nsid assigned into the network - /// namespace where the socket has been opened. The nsid is sent to user space via an ancillary - /// data. - pub fn set_listen_all_namespaces(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_listen_all_namespaces(value) + /// Mutable access to underyling [`Socket`] + fn socket_mut(&mut self) -> &mut Socket { + self.0.get_mut() } - pub fn get_listen_all_namespaces(&self) -> io::Result { - self.0.get_ref().get_listen_all_namespaces() + fn new(protocol: isize) -> io::Result { + let socket = Socket::new(protocol)?; + Ok(Self(Async::new(socket)?)) } - /// `NETLINK_CAP_ACK` (since Linux 4.2). The kernel may fail to allocate the necessary room - /// for the acknowledgment message back to user space. This option trims off the payload of - /// the original netlink message. The netlink message header is still included, so the user can - /// guess from the sequence number which message triggered the acknowledgment. - pub fn set_cap_ack(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_cap_ack(value) + fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.poll_write_with(cx, |this| this.0.get_mut().send(buf, 0)) } - pub fn get_cap_ack(&self) -> io::Result { - self.0.get_ref().get_cap_ack() + fn poll_send_to( + &mut self, + cx: &mut Context<'_>, + buf: &[u8], + addr: &SocketAddr, + ) -> Poll> { + self.poll_write_with(cx, |this| this.0.get_mut().send_to(buf, addr, 0)) } -} -impl FromRawFd for SmolSocket { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - let socket = Socket::from_raw_fd(fd); - socket.set_non_blocking(true).unwrap(); - SmolSocket(Async::new(socket).unwrap()) + fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + B: bytes::BufMut, + { + self.poll_read_with(cx, |this| this.0.get_mut().recv(buf, 0).map(|_len| ())) } -} -impl AsRawFd for SmolSocket { - fn as_raw_fd(&self) -> RawFd { - self.0.get_ref().as_raw_fd() + fn poll_recv_from( + &mut self, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut, + { + self.poll_read_with(cx, |this| { + let x = this.0.get_mut().recv_from(buf, 0); + trace!("poll_recv_from: {:?}", x); + x.map(|(_len, addr)| addr) + }) + } + + fn poll_recv_from_full( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, SocketAddr)>> { + self.poll_read_with(cx, |this| this.0.get_mut().recv_from_full()) } } diff --git a/netlink-sys/src/socket.rs b/netlink-sys/src/socket.rs index bdc53cbe..2c9b3c86 100644 --- a/netlink-sys/src/socket.rs +++ b/netlink-sys/src/socket.rs @@ -36,7 +36,7 @@ use crate::SocketAddr; /// let mut buf = vec![0; 4096]; /// loop { /// // receive a datagram -/// let (n_received, sender_addr) = socket.recv_from(&mut buf[..], 0).unwrap(); +/// let (n_received, sender_addr) = socket.recv_from(&mut &mut buf[..], 0).unwrap(); /// assert_eq!(sender_addr, kernel_addr); /// println!("received datagram {:?}", &buf[..n_received]); /// if buf[4] == 2 && buf[5] == 0 { @@ -165,7 +165,7 @@ impl Socket { /// // buffer for receiving the response /// let mut buf = vec![0; 4096]; /// loop { - /// let mut n_received = socket.recv(&mut buf[..], 0).unwrap(); + /// let mut n_received = socket.recv(&mut &mut buf[..], 0).unwrap(); /// println!("received {:?}", &buf[..n_received]); /// if buf[4] == 2 && buf[5] == 0 { /// println!("the kernel responded with an error"); @@ -224,7 +224,10 @@ impl Socket { /// In datagram oriented protocols, `recv` and `recvfrom` receive normally only ONE datagram, but this seems not to /// be always true for netlink sockets: with some protocols like `NETLINK_AUDIT`, multiple netlink packets can be /// read with a single call. - pub fn recv_from(&self, buf: &mut [u8], flags: libc::c_int) -> Result<(usize, SocketAddr)> { + pub fn recv_from(&self, buf: &mut B, flags: libc::c_int) -> Result<(usize, SocketAddr)> + where + B: bytes::BufMut, + { // Create an empty storage for the address. Note that Rust standard library create a // sockaddr_storage so that it works for any address family, but here, we already know that // we'll have a Netlink address, so we can create the appropriate storage. @@ -252,34 +255,51 @@ impl Socket { // a pointer to it. let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t; - // Cast the *mut u8 into *mut void. - // This is equivalent to casting a *char into *void - // See [thread] - // ^ - // Create a *mut u8 | - // ^ | - // | | - // +-----+-----+ +--------+-------+ - // / \ / \ - let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void; - let buf_len = buf.len() as libc::size_t; + let chunk = buf.chunk_mut(); + // Cast the *mut u8 into *mut void. + // This is equivalent to casting a *char into *void + // See [thread] + // ^ + // Create a *mut u8 | + // ^ | + // | | + // +------+-------+ +--------+-------+ + // / \ / \ + let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void; + let buf_len = chunk.len() as libc::size_t; let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) }; if res < 0 { return Err(Error::last_os_error()); + } else { + // with `MSG_TRUNC` `res` might exceed `buf_len` + let written = std::cmp::min(buf_len, res as usize); + unsafe { + buf.advance_mut(written); + } } Ok((res as usize, SocketAddr(addr))) } /// For a connected socket, `recv` reads a datagram from the socket. The sender is the remote peer the socket is /// connected to (see [`Socket::connect`]). See also [`Socket::recv_from`] - pub fn recv(&self, buf: &mut [u8], flags: libc::c_int) -> Result { - let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void; - let buf_len = buf.len() as libc::size_t; + pub fn recv(&self, buf: &mut B, flags: libc::c_int) -> Result + where + B: bytes::BufMut, + { + let chunk = buf.chunk_mut(); + let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void; + let buf_len = chunk.len() as libc::size_t; let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) }; if res < 0 { return Err(Error::last_os_error()); + } else { + // with `MSG_TRUNC` `res` might exceed `buf_len` + let written = std::cmp::min(buf_len, res as usize); + unsafe { + buf.advance_mut(written); + } } Ok(res as usize) } @@ -288,13 +308,14 @@ impl Socket { /// buffer passed as argument, this method always reads a whole message, no matter its size. pub fn recv_from_full(&self) -> Result<(Vec, SocketAddr)> { // Peek - let mut buf = Vec::::new(); - let (rlen, _) = self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?; + let mut buf: Vec = Vec::new(); + let (peek_len, _) = self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?; // Receive - let mut buf = vec![0; rlen as usize]; - let (_, addr) = self.recv_from(&mut buf, 0)?; - + buf.clear(); + buf.reserve(peek_len); + let (rlen, addr) = self.recv_from(&mut buf, 0)?; + assert_eq!(rlen, peek_len); Ok((buf, addr)) } diff --git a/netlink-sys/src/tokio.rs b/netlink-sys/src/tokio.rs index a83fe93d..9794fc9b 100644 --- a/netlink-sys/src/tokio.rs +++ b/netlink-sys/src/tokio.rs @@ -4,38 +4,47 @@ use std::{ task::{Context, Poll}, }; -use futures::{future::poll_fn, ready}; +use futures::ready; use log::trace; use tokio::io::unix::AsyncFd; -use crate::{Socket, SocketAddr}; +use crate::{AsyncSocket, Socket, SocketAddr}; /// An I/O object representing a Netlink socket. pub struct TokioSocket(AsyncFd); -impl TokioSocket { - /// This function will create a new Netlink socket and attempt to bind it to - /// the `addr` provided. - pub fn bind(&mut self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_mut().bind(addr) +impl FromRawFd for TokioSocket { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + let socket = Socket::from_raw_fd(fd); + socket.set_non_blocking(true).unwrap(); + TokioSocket(AsyncFd::new(socket).unwrap()) } +} - pub fn bind_auto(&mut self) -> io::Result { - self.0.get_mut().bind_auto() +impl AsRawFd for TokioSocket { + fn as_raw_fd(&self) -> RawFd { + self.0.get_ref().as_raw_fd() } +} - pub fn new(protocol: isize) -> io::Result { - let socket = Socket::new(protocol)?; - socket.set_non_blocking(true)?; - Ok(TokioSocket(AsyncFd::new(socket)?)) +impl AsyncSocket for TokioSocket { + fn socket_ref(&self) -> &Socket { + self.0.get_ref() + } + + /// Mutable access to underyling [`Socket`] + fn socket_mut(&mut self) -> &mut Socket { + self.0.get_mut() } - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_ref().connect(addr) + fn new(protocol: isize) -> io::Result { + let socket = Socket::new(protocol)?; + socket.set_non_blocking(true)?; + Ok(Self(AsyncFd::new(socket)?)) } - pub async fn send(&mut self, buf: &[u8]) -> io::Result { - poll_fn(|cx| loop { + fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + loop { // Check if the socket it writable. If // AsyncFd::poll_write_ready returns NotReady, it will // already have arranged for the current task to be @@ -47,17 +56,12 @@ impl TokioSocket { Ok(x) => return Poll::Ready(x), Err(_would_block) => continue, } - }) - .await - } - - pub async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> io::Result { - poll_fn(|cx| self.poll_send_to(cx, buf, addr)).await + } } - pub fn poll_send_to( + fn poll_send_to( &mut self, - cx: &mut Context, + cx: &mut Context<'_>, buf: &[u8], addr: &SocketAddr, ) -> Poll> { @@ -71,8 +75,11 @@ impl TokioSocket { } } - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { - poll_fn(|cx| loop { + fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + B: bytes::BufMut, + { + loop { // Check if the socket is readable. If not, // AsyncFd::poll_read_ready would have arranged for the // current task to be polled again when the socket becomes @@ -80,26 +87,20 @@ impl TokioSocket { let mut guard = ready!(self.0.poll_read_ready(cx))?; match guard.try_io(|inner| inner.get_ref().recv(buf, 0)) { - Ok(x) => return Poll::Ready(x), + Ok(x) => return Poll::Ready(x.map(|_len| ())), Err(_would_block) => continue, } - }) - .await - } - - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - poll_fn(|cx| self.poll_recv_from(cx, buf)).await - } - - pub async fn recv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)> { - poll_fn(|cx| self.poll_recv_from_full(cx)).await + } } - pub fn poll_recv_from( + fn poll_recv_from( &mut self, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut, + { loop { trace!("poll_recv_from called"); let mut guard = ready!(self.0.poll_read_ready(cx))?; @@ -108,7 +109,7 @@ impl TokioSocket { match guard.try_io(|inner| inner.get_ref().recv_from(buf, 0)) { Ok(x) => { trace!("poll_recv_from {:?} bytes read", x); - return Poll::Ready(x); + return Poll::Ready(x.map(|(_len, addr)| addr)); } Err(_would_block) => { trace!("poll_recv_from socket would block"); @@ -118,9 +119,9 @@ impl TokioSocket { } } - pub fn poll_recv_from_full( + fn poll_recv_from_full( &mut self, - cx: &mut Context, + cx: &mut Context<'_>, ) -> Poll, SocketAddr)>> { loop { trace!("poll_recv_from_full called"); @@ -139,82 +140,4 @@ impl TokioSocket { } } } - - pub fn set_pktinfo(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_pktinfo(value) - } - - pub fn get_pktinfo(&self) -> io::Result { - self.0.get_ref().get_pktinfo() - } - - pub fn add_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().add_membership(group) - } - - pub fn drop_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().drop_membership(group) - } - - // pub fn list_membership(&self) -> Vec { - // self.0.get_ref().list_membership() - // } - - /// `NETLINK_BROADCAST_ERROR` (since Linux 2.6.30). When not set, `netlink_broadcast()` only - /// reports `ESRCH` errors and silently ignore `NOBUFS` errors. - pub fn set_broadcast_error(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_broadcast_error(value) - } - - pub fn get_broadcast_error(&self) -> io::Result { - self.0.get_ref().get_broadcast_error() - } - - /// `NETLINK_NO_ENOBUFS` (since Linux 2.6.30). This flag can be used by unicast and broadcast - /// listeners to avoid receiving `ENOBUFS` errors. - pub fn set_no_enobufs(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_no_enobufs(value) - } - - pub fn get_no_enobufs(&self) -> io::Result { - self.0.get_ref().get_no_enobufs() - } - - /// `NETLINK_LISTEN_ALL_NSID` (since Linux 4.2). When set, this socket will receive netlink - /// notifications from all network namespaces that have an nsid assigned into the network - /// namespace where the socket has been opened. The nsid is sent to user space via an ancillary - /// data. - pub fn set_listen_all_namespaces(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_listen_all_namespaces(value) - } - - pub fn get_listen_all_namespaces(&self) -> io::Result { - self.0.get_ref().get_listen_all_namespaces() - } - - /// `NETLINK_CAP_ACK` (since Linux 4.2). The kernel may fail to allocate the necessary room - /// for the acknowledgment message back to user space. This option trims off the payload of - /// the original netlink message. The netlink message header is still included, so the user can - /// guess from the sequence number which message triggered the acknowledgment. - pub fn set_cap_ack(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_cap_ack(value) - } - - pub fn get_cap_ack(&self) -> io::Result { - self.0.get_ref().get_cap_ack() - } -} - -impl FromRawFd for TokioSocket { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - let socket = Socket::from_raw_fd(fd); - socket.set_non_blocking(true).unwrap(); - TokioSocket(AsyncFd::new(socket).unwrap()) - } -} - -impl AsRawFd for TokioSocket { - fn as_raw_fd(&self) -> RawFd { - self.0.get_ref().as_raw_fd() - } } diff --git a/rtnetlink/Cargo.toml b/rtnetlink/Cargo.toml index 0e6d3fa1..36ab9bf6 100644 --- a/rtnetlink/Cargo.toml +++ b/rtnetlink/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rtnetlink" -version = "0.8.1" +version = "0.9.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Corentin Henry "] edition = "2018" @@ -21,8 +21,8 @@ smol_socket = ["netlink-proto/smol_socket", "async-std"] futures = "0.3.11" log = "0.4.8" thiserror = "1" -netlink-packet-route = "0.8" -netlink-proto = { default-features = false, version = "0.7" } +netlink-packet-route = "0.9" +netlink-proto = { default-features = false, version = "0.8" } nix = "0.22.0" tokio = { version = "1.0.1", features = ["rt"], optional = true} async-std = { version = "1.9.0", features = ["unstable"], optional = true} diff --git a/rtnetlink/examples/ip_monitor.rs b/rtnetlink/examples/ip_monitor.rs index 7dab1ff5..842abc08 100644 --- a/rtnetlink/examples/ip_monitor.rs +++ b/rtnetlink/examples/ip_monitor.rs @@ -1,7 +1,10 @@ use futures::stream::StreamExt; use netlink_packet_route::constants::*; -use rtnetlink::{new_connection, sys::SocketAddr}; +use rtnetlink::{ + new_connection, + sys::{AsyncSocket, SocketAddr}, +}; #[tokio::main] async fn main() -> Result<(), String> { @@ -31,7 +34,10 @@ async fn main() -> Result<(), String> { | RTNLGRP_MPLS_NETCONF; let addr = SocketAddr::new(0, groups); - conn.socket_mut().bind(&addr).expect("Failed to bind"); + conn.socket_mut() + .socket_mut() + .bind(&addr) + .expect("Failed to bind"); // Spawn `Connection` to start polling netlink socket. tokio::spawn(conn); diff --git a/rtnetlink/examples/listen.rs b/rtnetlink/examples/listen.rs index 06c7a6ee..0066fbee 100644 --- a/rtnetlink/examples/listen.rs +++ b/rtnetlink/examples/listen.rs @@ -6,7 +6,7 @@ use futures::stream::StreamExt; use rtnetlink::{ constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE}, new_connection, - sys::SocketAddr, + sys::{AsyncSocket, SocketAddr}, }; #[tokio::main] @@ -20,7 +20,11 @@ async fn main() -> Result<(), String> { // A netlink socket address is created with said flags. let addr = SocketAddr::new(0, mgroup_flags); // Said address is bound so new conenctions and thus new message broadcasts can be received. - connection.socket_mut().bind(&addr).expect("failed to bind"); + connection + .socket_mut() + .socket_mut() + .bind(&addr) + .expect("failed to bind"); tokio::spawn(connection); while let Some((message, _)) = messages.next().await { diff --git a/rtnetlink/src/connection.rs b/rtnetlink/src/connection.rs index e8cdf731..8f295b46 100644 --- a/rtnetlink/src/connection.rs +++ b/rtnetlink/src/connection.rs @@ -5,16 +5,29 @@ use futures::channel::mpsc::UnboundedReceiver; use crate::{ packet::{NetlinkMessage, RtnlMessage}, proto::Connection, - sys::{protocols::NETLINK_ROUTE, SocketAddr}, + sys::{protocols::NETLINK_ROUTE, AsyncSocket, SocketAddr}, Handle, }; +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( Connection, Handle, UnboundedReceiver<(NetlinkMessage, SocketAddr)>, )> { - let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_ROUTE)?; + new_connection_with_socket() +} + +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + Connection, + Handle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> +where + S: AsyncSocket, +{ + let (conn, handle, messages) = netlink_proto::new_connection_with_socket(NETLINK_ROUTE)?; Ok((conn, Handle::new(handle), messages)) } diff --git a/rtnetlink/src/ns.rs b/rtnetlink/src/ns.rs index 3ca6cc40..52bed9c4 100644 --- a/rtnetlink/src/ns.rs +++ b/rtnetlink/src/ns.rs @@ -1,9 +1,3 @@ -#[cfg(feature = "tokio_socket")] -use tokio::task; - -#[cfg(feature = "smol_socket")] -use async_std::task; - use crate::Error; use nix::{ fcntl::OFlag, @@ -16,6 +10,43 @@ use nix::{ }; use std::{option::Option, path::Path, process::exit}; +// if "only" smol or smol+tokio were enabled, we use smol because +// it doesn't require an active tokio runtime - just to be sure. +#[cfg(feature = "smol_socket")] +async fn try_spawn_blocking(fut: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + async_std::task::spawn_blocking(fut).await +} + +// only tokio enabled, so use tokio +#[cfg(all(not(feature = "smol_socket"), feature = "tokio_socket"))] +async fn try_spawn_blocking(fut: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + match tokio::task::spawn_blocking(fut).await { + Ok(v) => v, + Err(err) => { + std::panic::resume_unwind(err.into_panic()); + } + } +} + +// neither smol nor tokio - just run blocking op directly. +// hopefully not too blocking... +#[cfg(all(not(feature = "smol_socket"), not(feature = "tokio_socket")))] +async fn try_spawn_blocking(fut: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + fut() +} + pub const NETNS_PATH: &str = "/run/netns/"; pub const SELF_NS_PATH: &str = "/proc/self/ns/net"; pub const NONE_FS: &str = "none"; @@ -46,7 +77,7 @@ impl NetworkNamespace { /// Remove a network namespace /// This is equivalent to `ip netns del NS_NAME`. pub async fn del(ns_name: String) -> Result<(), Error> { - let res = task::spawn_blocking(move || { + try_spawn_blocking(move || { let mut netns_path = String::new(); netns_path.push_str(NETNS_PATH); netns_path.push_str(&ns_name); @@ -62,20 +93,10 @@ impl NetworkNamespace { String::from("Namespace file remove failed (are you running as root?)"); return Err(Error::NamespaceError(err_msg)); } - #[cfg(feature = "tokio_socket")] - return Ok(()); - - #[cfg(feature = "smol_socket")] - return Ok(Ok(())); - }); - match res.await { - Ok(r) => r, - Err(e) => { - let err_msg = format!("Namespace removal failed: {}", e); - Err(Error::NamespaceError(err_msg)) - } - } + Ok(()) + }) + .await } pub fn prep_for_fork() -> Result<(), Error> {