From 66c270bb37f4d22215f12d6c0f69ad739fb581dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 14 Nov 2021 15:30:57 +0100 Subject: [PATCH 01/11] Fix clippy lint: useless conversion --- netlink-proto/src/framed.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index 84a85a6d..f7aadf0a 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -125,8 +125,7 @@ impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed< Err(io::Error::new( io::ErrorKind::Other, "failed to write entire datagram to socket", - ) - .into()) + )) }; Poll::Ready(res) From dfb55f52166fd3c485d2ae162a430d60c3685e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 14 Nov 2021 22:03:50 +0100 Subject: [PATCH 02/11] Fix clippy lint: redundant closure --- netlink-packet-audit/src/rules/buffer.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/netlink-packet-audit/src/rules/buffer.rs b/netlink-packet-audit/src/rules/buffer.rs index 16abc5ed..e0a9dbcc 100644 --- a/netlink-packet-audit/src/rules/buffer.rs +++ b/netlink-packet-audit/src/rules/buffer.rs @@ -175,14 +175,8 @@ impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for RuleMessage { let mut offset = 0; - let fields = buf - .fields() - .chunks(4) - .map(|chunk| NativeEndian::read_u32(chunk)); - let values = buf - .values() - .chunks(4) - .map(|chunk| NativeEndian::read_u32(chunk)); + let fields = buf.fields().chunks(4).map(NativeEndian::read_u32); + let values = buf.values().chunks(4).map(NativeEndian::read_u32); let field_flags = buf .field_flags() .chunks(4) From 1cd3e0fbb8d77d6b9c4fe43b8c4aa745fa6ba66c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 14 Nov 2021 15:01:34 +0100 Subject: [PATCH 03/11] Fix SmolSocket::poll_recv_from error handling --- netlink-sys/src/smol.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/netlink-sys/src/smol.rs b/netlink-sys/src/smol.rs index 77bcaf95..6ef2b5e0 100644 --- a/netlink-sys/src/smol.rs +++ b/netlink-sys/src/smol.rs @@ -66,14 +66,14 @@ impl SmolSocket { trace!("poll_recv_from socket is ready for reading"); match self.0.get_ref().recv_from(buf, 0) { - Ok(x) => { - trace!("poll_recv_from {:?} bytes read", x); - return Poll::Ready(Ok(x)); - } - Err(_would_block) => { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { trace!("poll_recv_from socket would block"); continue; } + x => { + trace!("poll_recv_from {:?} bytes read", x); + return Poll::Ready(x); + } } } } From 94c263282d9a34d01422513de6a7f683ac08addc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 7 Nov 2021 18:05:57 +0100 Subject: [PATCH 04/11] Remove `T` generic from NetlinkDeserializable and NetlinkSerializable `T` always was the same as `Self`. Also do a major bump on depending crates. --- audit/Cargo.toml | 6 +++--- ethtool/Cargo.toml | 10 +++++----- genetlink/Cargo.toml | 8 ++++---- genetlink/src/message.rs | 4 ++-- netlink-packet-audit/Cargo.toml | 4 ++-- netlink-packet-audit/fuzz/Cargo.toml | 4 ++-- netlink-packet-audit/src/message.rs | 4 ++-- netlink-packet-core/Cargo.toml | 4 ++-- netlink-packet-core/examples/protocol.rs | 4 ++-- netlink-packet-core/src/lib.rs | 4 ++-- netlink-packet-core/src/message.rs | 8 ++++---- netlink-packet-core/src/payload.rs | 2 +- netlink-packet-core/src/traits.rs | 11 +++++------ netlink-packet-generic/Cargo.toml | 4 ++-- netlink-packet-generic/src/message.rs | 4 ++-- netlink-packet-route/Cargo.toml | 4 ++-- netlink-packet-route/fuzz/Cargo.toml | 2 +- netlink-packet-route/src/rtnl/message.rs | 4 ++-- netlink-packet-sock-diag/Cargo.toml | 4 ++-- netlink-packet-sock-diag/src/message.rs | 4 ++-- netlink-proto/Cargo.toml | 8 ++++---- netlink-proto/src/codecs.rs | 4 ++-- netlink-proto/src/connection.rs | 6 +++--- netlink-proto/src/lib.rs | 4 ++-- netlink-proto/src/protocol/protocol.rs | 6 +++--- netlink-sys/Cargo.toml | 2 +- rtnetlink/Cargo.toml | 6 +++--- 27 files changed, 67 insertions(+), 68 deletions(-) diff --git a/audit/Cargo.toml b/audit/Cargo.toml index 562f337a..039073f2 100644 --- a/audit/Cargo.toml +++ b/audit/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "audit" -version = "0.4.0" +version = "0.5.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Corentin Henry "] edition = "2018" @@ -14,8 +14,8 @@ description = "linux audit via netlink" [dependencies] futures = "0.3.11" thiserror = "1" -netlink-packet-audit = "0.2" -netlink-proto = { default-features = false, version = "0.7" } +netlink-packet-audit = "0.3" +netlink-proto = { default-features = false, version = "0.8" } [features] default = ["tokio_socket"] diff --git a/ethtool/Cargo.toml b/ethtool/Cargo.toml index 675a4c86..b5b9fb00 100644 --- a/ethtool/Cargo.toml +++ b/ethtool/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ethtool" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Gris Ge "] license = "MIT" edition = "2018" @@ -24,12 +24,12 @@ anyhow = "1.0.44" async-std = { version = "1.9.0", optional = true} byteorder = "1.4.3" futures = "0.3.17" -genetlink = { default-features = false, version = "0.1.0"} +genetlink = { default-features = false, version = "0.2.0"} log = "0.4.14" -netlink-packet-core = "0.2.4" -netlink-packet-generic = "0.1.0" +netlink-packet-core = "0.3.0" +netlink-packet-generic = "0.2.0" netlink-packet-utils = "0.4.1" -netlink-proto = { default-features = false, version = "0.7.0" } +netlink-proto = { default-features = false, version = "0.8.0" } netlink-sys = "0.7.0" thiserror = "1.0.29" tokio = { version = "1.0.1", features = ["rt"], optional = true} diff --git a/genetlink/Cargo.toml b/genetlink/Cargo.toml index 2e15b8c2..52d9126e 100644 --- a/genetlink/Cargo.toml +++ b/genetlink/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "genetlink" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Leo "] edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -17,12 +17,12 @@ smol_socket = ["netlink-proto/smol_socket","netlink-proto/workaround-audit-bug", [dependencies] futures = "0.3.16" -netlink-packet-generic = "0.1.0" -netlink-proto = { default-features = false, version = "0.7.0" } +netlink-packet-generic = "0.2.0" +netlink-proto = { default-features = false, version = "0.8.0" } tokio = { version = "1.9.0", features = ["rt"], optional = true } async-std = { version = "1.9.0", optional = true } netlink-packet-utils = "0.4.1" -netlink-packet-core = "0.2.4" +netlink-packet-core = "0.3.0" thiserror = "1.0.26" [dev-dependencies] diff --git a/genetlink/src/message.rs b/genetlink/src/message.rs index 7cc0fcc2..defd5ff8 100644 --- a/genetlink/src/message.rs +++ b/genetlink/src/message.rs @@ -107,7 +107,7 @@ where } } -impl NetlinkSerializable for RawGenlMessage { +impl NetlinkSerializable for RawGenlMessage { fn message_type(&self) -> u16 { self.family_id } @@ -121,7 +121,7 @@ impl NetlinkSerializable for RawGenlMessage { } } -impl NetlinkDeserializable for RawGenlMessage { +impl NetlinkDeserializable for RawGenlMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { let buffer = GenlBuffer::new_checked(payload)?; diff --git a/netlink-packet-audit/Cargo.toml b/netlink-packet-audit/Cargo.toml index 245c50e3..e87487f2 100644 --- a/netlink-packet-audit/Cargo.toml +++ b/netlink-packet-audit/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-packet-audit" -version = "0.2.2" +version = "0.3.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -14,7 +14,7 @@ description = "netlink packet types" [dependencies] anyhow = "1.0.31" byteorder = "1.3.2" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = ">= 0.3, <0.5" [dev-dependencies] diff --git a/netlink-packet-audit/fuzz/Cargo.toml b/netlink-packet-audit/fuzz/Cargo.toml index 04067f92..5d2573d9 100644 --- a/netlink-packet-audit/fuzz/Cargo.toml +++ b/netlink-packet-audit/fuzz/Cargo.toml @@ -9,8 +9,8 @@ edition = "2018" cargo-fuzz = true [dependencies] -netlink-packet-audit = "0.2" -netlink-packet-core = "0.2" +netlink-packet-audit = "0.3" +netlink-packet-core = "0.3" libfuzzer-sys = { git = "https://github.com/rust-fuzz/libfuzzer-sys.git" } [[bin]] diff --git a/netlink-packet-audit/src/message.rs b/netlink-packet-audit/src/message.rs index 73dd345b..4a13b921 100644 --- a/netlink-packet-audit/src/message.rs +++ b/netlink-packet-audit/src/message.rs @@ -106,7 +106,7 @@ impl Emitable for AuditMessage { } } -impl NetlinkSerializable for AuditMessage { +impl NetlinkSerializable for AuditMessage { fn message_type(&self) -> u16 { self.message_type() } @@ -120,7 +120,7 @@ impl NetlinkSerializable for AuditMessage { } } -impl NetlinkDeserializable for AuditMessage { +impl NetlinkDeserializable for AuditMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { match AuditBuffer::new_checked(payload) { diff --git a/netlink-packet-core/Cargo.toml b/netlink-packet-core/Cargo.toml index f46fb4fc..ccca7e14 100644 --- a/netlink-packet-core/Cargo.toml +++ b/netlink-packet-core/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-packet-core" -version = "0.2.4" +version = "0.3.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -18,4 +18,4 @@ libc = "0.2.66" netlink-packet-utils = ">=0.3, <0.5" [dev-dependencies] -netlink-packet-route = "0.8" +netlink-packet-route = "0.9" diff --git a/netlink-packet-core/examples/protocol.rs b/netlink-packet-core/examples/protocol.rs index 7c092e0b..59dc1a6d 100644 --- a/netlink-packet-core/examples/protocol.rs +++ b/netlink-packet-core/examples/protocol.rs @@ -44,7 +44,7 @@ impl fmt::Display for DeserializeError { } // NetlinkDeserializable implementation -impl NetlinkDeserializable for PingPongMessage { +impl NetlinkDeserializable for PingPongMessage { type Error = DeserializeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { @@ -59,7 +59,7 @@ impl NetlinkDeserializable for PingPongMessage { } // NetlinkSerializable implementation -impl NetlinkSerializable for PingPongMessage { +impl NetlinkSerializable for PingPongMessage { fn message_type(&self) -> u16 { match self { PingPongMessage::Ping(_) => PING_MESSAGE, diff --git a/netlink-packet-core/src/lib.rs b/netlink-packet-core/src/lib.rs index 7e724829..5faed295 100644 --- a/netlink-packet-core/src/lib.rs +++ b/netlink-packet-core/src/lib.rs @@ -156,7 +156,7 @@ //! } //! //! // NetlinkDeserializable implementation -//! impl NetlinkDeserializable for PingPongMessage { +//! impl NetlinkDeserializable for PingPongMessage { //! type Error = DeserializeError; //! //! fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { @@ -171,7 +171,7 @@ //! } //! //! // NetlinkSerializable implementation -//! impl NetlinkSerializable for PingPongMessage { +//! impl NetlinkSerializable for PingPongMessage { //! fn message_type(&self) -> u16 { //! match self { //! PingPongMessage::Ping(_) => PING_MESSAGE, diff --git a/netlink-packet-core/src/message.rs b/netlink-packet-core/src/message.rs index 497ca0c2..62076241 100644 --- a/netlink-packet-core/src/message.rs +++ b/netlink-packet-core/src/message.rs @@ -45,7 +45,7 @@ where impl NetlinkMessage where - I: NetlinkDeserializable + Debug + PartialEq + Eq + Clone, + I: NetlinkDeserializable + Debug + PartialEq + Eq + Clone, { /// Parse the given buffer as a netlink message pub fn deserialize(buffer: &[u8]) -> Result { @@ -56,7 +56,7 @@ where impl NetlinkMessage where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, { /// Return the length of this message in bytes pub fn buffer_len(&self) -> usize { @@ -92,7 +92,7 @@ where impl<'buffer, B, I> Parseable> for NetlinkMessage where B: AsRef<[u8]> + 'buffer, - I: Debug + PartialEq + Eq + Clone + NetlinkDeserializable, + I: Debug + PartialEq + Eq + Clone + NetlinkDeserializable, { fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result { use self::NetlinkPayload::*; @@ -129,7 +129,7 @@ where impl Emitable for NetlinkMessage where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, { fn buffer_len(&self) -> usize { use self::NetlinkPayload::*; diff --git a/netlink-packet-core/src/payload.rs b/netlink-packet-core/src/payload.rs index 499488b8..40f6f8fe 100644 --- a/netlink-packet-core/src/payload.rs +++ b/netlink-packet-core/src/payload.rs @@ -28,7 +28,7 @@ where impl NetlinkPayload where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, { pub fn message_type(&self) -> u16 { match self { diff --git a/netlink-packet-core/src/traits.rs b/netlink-packet-core/src/traits.rs index 368ba18a..0b9d0a7e 100644 --- a/netlink-packet-core/src/traits.rs +++ b/netlink-packet-core/src/traits.rs @@ -1,16 +1,15 @@ use crate::NetlinkHeader; use std::error::Error; -/// A `NetlinkDeserializable` type can be used to deserialize a buffer -/// into the target type `T` for which it is implemented. -pub trait NetlinkDeserializable { +/// A `NetlinkDeserializable` type can be deserialized from a buffer +pub trait NetlinkDeserializable: Sized { type Error: Error + Send + Sync + 'static; - /// Deserialize the given buffer into `T`. - fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result; + /// Deserialize the given buffer into `Self`. + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result; } -pub trait NetlinkSerializable { +pub trait NetlinkSerializable { fn message_type(&self) -> u16; /// Return the length of the serialized data. diff --git a/netlink-packet-generic/Cargo.toml b/netlink-packet-generic/Cargo.toml index b92a5d95..39af2717 100644 --- a/netlink-packet-generic/Cargo.toml +++ b/netlink-packet-generic/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "netlink-packet-generic" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Leo "] edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -14,7 +14,7 @@ description = "generic netlink packet types" anyhow = "1.0.39" libc = "0.2.86" byteorder = "1.4.2" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = "0.4" [dev-dependencies] diff --git a/netlink-packet-generic/src/message.rs b/netlink-packet-generic/src/message.rs index 3d0bfde6..6e2bcc69 100644 --- a/netlink-packet-generic/src/message.rs +++ b/netlink-packet-generic/src/message.rs @@ -146,7 +146,7 @@ where } } -impl NetlinkSerializable> for GenlMessage +impl NetlinkSerializable for GenlMessage where F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, { @@ -163,7 +163,7 @@ where } } -impl<'a, F> NetlinkDeserializable> for GenlMessage +impl NetlinkDeserializable for GenlMessage where F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, { diff --git a/netlink-packet-route/Cargo.toml b/netlink-packet-route/Cargo.toml index 4e45b2ac..fa8e8ff7 100644 --- a/netlink-packet-route/Cargo.toml +++ b/netlink-packet-route/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-packet-route" -version = "0.8.0" +version = "0.9.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -15,7 +15,7 @@ description = "netlink packet types" anyhow = "1.0.31" byteorder = "1.3.2" libc = "0.2.66" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = "0.4" bitflags = "1.2.1" diff --git a/netlink-packet-route/fuzz/Cargo.toml b/netlink-packet-route/fuzz/Cargo.toml index 741e1626..99888b57 100644 --- a/netlink-packet-route/fuzz/Cargo.toml +++ b/netlink-packet-route/fuzz/Cargo.toml @@ -9,7 +9,7 @@ edition = "2018" cargo-fuzz = true [dependencies] -netlink-packet-route = "0.7" +netlink-packet-route = "0.9" libfuzzer-sys = { git = "https://github.com/rust-fuzz/libfuzzer-sys.git" } [[bin]] diff --git a/netlink-packet-route/src/rtnl/message.rs b/netlink-packet-route/src/rtnl/message.rs index e1627bc3..46baf2ce 100644 --- a/netlink-packet-route/src/rtnl/message.rs +++ b/netlink-packet-route/src/rtnl/message.rs @@ -356,7 +356,7 @@ impl Emitable for RtnlMessage { } } -impl NetlinkSerializable for RtnlMessage { +impl NetlinkSerializable for RtnlMessage { fn message_type(&self) -> u16 { self.message_type() } @@ -370,7 +370,7 @@ impl NetlinkSerializable for RtnlMessage { } } -impl NetlinkDeserializable for RtnlMessage { +impl NetlinkDeserializable for RtnlMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { let buf = RtnlMessageBuffer::new(payload); diff --git a/netlink-packet-sock-diag/Cargo.toml b/netlink-packet-sock-diag/Cargo.toml index e3d09cec..0e06a448 100644 --- a/netlink-packet-sock-diag/Cargo.toml +++ b/netlink-packet-sock-diag/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Flier Lu ", "Corentin Henry "] name = "netlink-packet-sock-diag" -version = "0.1.0" +version = "0.2.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -14,7 +14,7 @@ description = "netlink packet types for the sock_diag subprotocol" [dependencies] anyhow = "1.0.32" byteorder = "1.3.4" -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-packet-utils = ">= 0.3, <0.5" bitflags = "1.2.1" libc = "0.2.77" diff --git a/netlink-packet-sock-diag/src/message.rs b/netlink-packet-sock-diag/src/message.rs index 5a84b5aa..6ab52e86 100644 --- a/netlink-packet-sock-diag/src/message.rs +++ b/netlink-packet-sock-diag/src/message.rs @@ -64,7 +64,7 @@ impl Emitable for SockDiagMessage { } } -impl NetlinkSerializable for SockDiagMessage { +impl NetlinkSerializable for SockDiagMessage { fn message_type(&self) -> u16 { self.message_type() } @@ -78,7 +78,7 @@ impl NetlinkSerializable for SockDiagMessage { } } -impl NetlinkDeserializable for SockDiagMessage { +impl NetlinkDeserializable for SockDiagMessage { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { let buffer = SockDiagBuffer::new_checked(&payload)?; diff --git a/netlink-proto/Cargo.toml b/netlink-proto/Cargo.toml index 830f6698..c1582df0 100644 --- a/netlink-proto/Cargo.toml +++ b/netlink-proto/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-proto" -version = "0.7.0" +version = "0.8.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -17,7 +17,7 @@ log = "0.4.8" futures = "0.3" tokio = { version = "1.0", default-features = false, features = ["io-util"] } tokio-util = { version = "0.6", default-features = false, features = ["codec"] } -netlink-packet-core = "0.2" +netlink-packet-core = "0.3" netlink-sys = { default-features = false, version = "0.7" } [features] @@ -29,8 +29,8 @@ workaround-audit-bug = [] [dev-dependencies] env_logger = "0.8.2" tokio = { version = "1.0.1", default-features = false, features = ["macros", "rt-multi-thread"] } -netlink-packet-route = "0.8" -netlink-packet-audit = "0.2" +netlink-packet-route = "0.9" +netlink-packet-audit = "0.3" async-std = {version = "1.9.0", features = ["attributes"]} [[example]] diff --git a/netlink-proto/src/codecs.rs b/netlink-proto/src/codecs.rs index d5c1af48..596806d4 100644 --- a/netlink-proto/src/codecs.rs +++ b/netlink-proto/src/codecs.rs @@ -31,7 +31,7 @@ impl NetlinkCodec { // See https://github.com/mozilla/libaudit-go/issues/24 impl Decoder for NetlinkCodec> where - T: NetlinkDeserializable + Debug + Eq + PartialEq + Clone, + T: NetlinkDeserializable + Debug + Eq + PartialEq + Clone, { type Item = NetlinkMessage; type Error = io::Error; @@ -141,7 +141,7 @@ where impl Encoder> for NetlinkCodec> where - T: Debug + Eq + PartialEq + Clone + NetlinkSerializable, + T: Debug + Eq + PartialEq + Clone + NetlinkSerializable, { type Error = io::Error; diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index 7406c198..3517d780 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -34,7 +34,7 @@ use crate::{ /// `Connection`, that in turn, sends them through the netlink socket. pub struct Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, + T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, { socket: NetlinkFramed>>, @@ -52,7 +52,7 @@ where impl Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, + T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, { pub(crate) fn new( requests_rx: UnboundedReceiver>, @@ -252,7 +252,7 @@ where impl Future for Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, + T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, { type Output = (); diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index c9d4e774..d1c09786 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -231,8 +231,8 @@ where + PartialEq + Eq + Clone - + packet::NetlinkSerializable - + packet::NetlinkDeserializable + + packet::NetlinkSerializable + + packet::NetlinkDeserializable + Unpin, { let (requests_tx, requests_rx) = unbounded::>(); diff --git a/netlink-proto/src/protocol/protocol.rs b/netlink-proto/src/protocol/protocol.rs index eec9bc03..d53baa2e 100644 --- a/netlink-proto/src/protocol/protocol.rs +++ b/netlink-proto/src/protocol/protocol.rs @@ -32,7 +32,7 @@ impl RequestId { #[derive(Debug, Eq, PartialEq)] pub(crate) struct Response where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, + T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, M: Debug, { pub done: bool, @@ -49,7 +49,7 @@ struct PendingRequest { #[derive(Debug, Default)] pub(crate) struct Protocol where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, + T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, M: Debug, { /// Counter that is incremented for each message sent @@ -71,7 +71,7 @@ where impl Protocol where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, + T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, M: Clone + Debug, { pub fn new() -> Self { diff --git a/netlink-sys/Cargo.toml b/netlink-sys/Cargo.toml index c7b95662..579044eb 100644 --- a/netlink-sys/Cargo.toml +++ b/netlink-sys/Cargo.toml @@ -42,7 +42,7 @@ tokio_socket = ["tokio", "futures"] smol_socket = ["async-io","futures"] [dev-dependencies] -netlink-packet-audit = "0.2" +netlink-packet-audit = "0.3" [dev-dependencies.tokio] version = "1.0.1" diff --git a/rtnetlink/Cargo.toml b/rtnetlink/Cargo.toml index 0e6d3fa1..36ab9bf6 100644 --- a/rtnetlink/Cargo.toml +++ b/rtnetlink/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rtnetlink" -version = "0.8.1" +version = "0.9.0" # TODO: drop this comment - already bumped version for trait changes authors = ["Corentin Henry "] edition = "2018" @@ -21,8 +21,8 @@ smol_socket = ["netlink-proto/smol_socket", "async-std"] futures = "0.3.11" log = "0.4.8" thiserror = "1" -netlink-packet-route = "0.8" -netlink-proto = { default-features = false, version = "0.7" } +netlink-packet-route = "0.9" +netlink-proto = { default-features = false, version = "0.8" } nix = "0.22.0" tokio = { version = "1.0.1", features = ["rt"], optional = true} async-std = { version = "1.9.0", features = ["unstable"], optional = true} From 71b71ce5688bfba2e66461cbcb5fa9e9482d7c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 7 Nov 2021 18:43:37 +0100 Subject: [PATCH 05/11] Remove constraints from type definitions; remove unneeded constraints in implementations (apart from Debug) - where clauses are almost never needed on type definitions unless you need them in the `Drop` implementation - Debug trait might be useful if (debug) logging gets added - Clone/PartialEq/Eq shouldn't be needed ever in the implementations --- genetlink/src/handle.rs | 34 ++++---------------------- genetlink/src/message.rs | 8 +++--- netlink-packet-core/src/message.rs | 20 ++++++--------- netlink-packet-core/src/payload.rs | 7 ++---- netlink-packet-generic/src/buffer.rs | 4 +-- netlink-packet-generic/src/message.rs | 17 ++++++------- netlink-proto/src/codecs.rs | 4 +-- netlink-proto/src/connection.rs | 6 ++--- netlink-proto/src/errors.rs | 12 ++++----- netlink-proto/src/handle.rs | 4 +-- netlink-proto/src/lib.rs | 8 +----- netlink-proto/src/protocol/protocol.rs | 16 +++--------- netlink-proto/src/protocol/request.rs | 10 +++----- 13 files changed, 48 insertions(+), 102 deletions(-) diff --git a/genetlink/src/handle.rs b/genetlink/src/handle.rs index 1a312913..935b81a8 100644 --- a/genetlink/src/handle.rs +++ b/genetlink/src/handle.rs @@ -78,13 +78,7 @@ impl GenetlinkHandle { GenetlinkError, > where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { self.resolve_message_family_id(&mut message).await?; self.send_request(message) @@ -102,13 +96,7 @@ impl GenetlinkHandle { GenetlinkError, > where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { let raw_msg = map_to_rawgenlmsg(message); @@ -122,13 +110,7 @@ impl GenetlinkHandle { mut message: NetlinkMessage>, ) -> Result<(), GenetlinkError> where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { self.resolve_message_family_id(&mut message).await?; self.send_notify(message) @@ -140,13 +122,7 @@ impl GenetlinkHandle { message: NetlinkMessage>, ) -> Result<(), GenetlinkError> where - F: GenlFamily - + Emitable - + ParseableParametrized<[u8], GenlHeader> - + Clone - + Debug - + PartialEq - + Eq, + F: GenlFamily + Emitable + ParseableParametrized<[u8], GenlHeader> + Debug, { let raw_msg = map_to_rawgenlmsg(message); @@ -159,7 +135,7 @@ impl GenetlinkHandle { message: &mut NetlinkMessage>, ) -> Result<(), GenetlinkError> where - F: GenlFamily + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Debug, { if let NetlinkPayload::InnerMessage(genlmsg) = &mut message.payload { if genlmsg.family_id() == 0 { diff --git a/genetlink/src/message.rs b/genetlink/src/message.rs index defd5ff8..0b5ffabb 100644 --- a/genetlink/src/message.rs +++ b/genetlink/src/message.rs @@ -57,7 +57,7 @@ impl RawGenlMessage { /// Serialize the generic netlink payload into raw bytes pub fn from_genlmsg(genlmsg: GenlMessage) -> Self where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { let mut payload_buf = vec![0u8; genlmsg.payload.buffer_len()]; genlmsg.payload.emit(&mut payload_buf); @@ -72,7 +72,7 @@ impl RawGenlMessage { /// Try to deserialize the generic netlink payload from raw bytes pub fn parse_into_genlmsg(&self) -> Result, DecodeError> where - F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Debug, { let inner = F::parse_with_param(&self.payload, self.header)?; Ok(GenlMessage::new(self.header, inner, self.family_id)) @@ -141,7 +141,7 @@ pub fn map_to_rawgenlmsg( message: NetlinkMessage>, ) -> NetlinkMessage where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { let raw_payload = match message.payload { NetlinkPayload::InnerMessage(genlmsg) => { @@ -162,7 +162,7 @@ pub fn map_from_rawgenlmsg( raw_msg: NetlinkMessage, ) -> Result>, DecodeError> where - F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Debug, { let payload = match raw_msg.payload { NetlinkPayload::InnerMessage(raw_genlmsg) => { diff --git a/netlink-packet-core/src/message.rs b/netlink-packet-core/src/message.rs index 62076241..fb699116 100644 --- a/netlink-packet-core/src/message.rs +++ b/netlink-packet-core/src/message.rs @@ -18,20 +18,14 @@ use crate::{ /// Represent a netlink message. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct NetlinkMessage -where - I: Debug + PartialEq + Eq + Clone, -{ +pub struct NetlinkMessage { /// Message header (this is common to all the netlink protocols) pub header: NetlinkHeader, /// Inner message, which depends on the netlink protocol being used. pub payload: NetlinkPayload, } -impl NetlinkMessage -where - I: Debug + PartialEq + Eq + Clone, -{ +impl NetlinkMessage { /// Create a new netlink message from the given header and payload pub fn new(header: NetlinkHeader, payload: NetlinkPayload) -> Self { NetlinkMessage { header, payload } @@ -45,7 +39,7 @@ where impl NetlinkMessage where - I: NetlinkDeserializable + Debug + PartialEq + Eq + Clone, + I: NetlinkDeserializable + Debug, { /// Parse the given buffer as a netlink message pub fn deserialize(buffer: &[u8]) -> Result { @@ -56,7 +50,7 @@ where impl NetlinkMessage where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug, { /// Return the length of this message in bytes pub fn buffer_len(&self) -> usize { @@ -92,7 +86,7 @@ where impl<'buffer, B, I> Parseable> for NetlinkMessage where B: AsRef<[u8]> + 'buffer, - I: Debug + PartialEq + Eq + Clone + NetlinkDeserializable, + I: NetlinkDeserializable + Debug, { fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result { use self::NetlinkPayload::*; @@ -129,7 +123,7 @@ where impl Emitable for NetlinkMessage where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug, { fn buffer_len(&self) -> usize { use self::NetlinkPayload::*; @@ -163,7 +157,7 @@ where impl From for NetlinkMessage where - T: Into> + Debug + Clone + Eq + PartialEq, + T: Into> + Debug, { fn from(inner_message: T) -> Self { NetlinkMessage { diff --git a/netlink-packet-core/src/payload.rs b/netlink-packet-core/src/payload.rs index 40f6f8fe..d6c598f0 100644 --- a/netlink-packet-core/src/payload.rs +++ b/netlink-packet-core/src/payload.rs @@ -14,10 +14,7 @@ pub const NLMSG_OVERRUN: u16 = 4; pub const NLMSG_ALIGNTO: u16 = 4; #[derive(Debug, PartialEq, Eq, Clone)] -pub enum NetlinkPayload -where - I: Debug + PartialEq + Eq + Clone, -{ +pub enum NetlinkPayload { Done, Error(ErrorMessage), Ack(AckMessage), @@ -28,7 +25,7 @@ where impl NetlinkPayload where - I: NetlinkSerializable + Debug + PartialEq + Eq + Clone, + I: NetlinkSerializable + Debug, { pub fn message_type(&self) -> u16 { match self { diff --git a/netlink-packet-generic/src/buffer.rs b/netlink-packet-generic/src/buffer.rs index 7bd29d24..f7f0f365 100644 --- a/netlink-packet-generic/src/buffer.rs +++ b/netlink-packet-generic/src/buffer.rs @@ -12,7 +12,7 @@ buffer!(GenlBuffer(GENL_HDRLEN) { impl ParseableParametrized<[u8], u16> for GenlMessage where - F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: ParseableParametrized<[u8], GenlHeader> + Debug, { fn parse_with_param(buf: &[u8], message_type: u16) -> Result { let buf = GenlBuffer::new_checked(buf)?; @@ -22,7 +22,7 @@ where impl<'a, F, T> ParseableParametrized, u16> for GenlMessage where - F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: ParseableParametrized<[u8], GenlHeader> + Debug, T: AsRef<[u8]> + ?Sized, { fn parse_with_param(buf: &GenlBuffer<&'a T>, message_type: u16) -> Result { diff --git a/netlink-packet-generic/src/message.rs b/netlink-packet-generic/src/message.rs index 6e2bcc69..5ca5a961 100644 --- a/netlink-packet-generic/src/message.rs +++ b/netlink-packet-generic/src/message.rs @@ -20,10 +20,7 @@ use netlink_packet_core::NetlinkMessage; /// The message can be serialize/deserialize if the type `F` implements [`GenlFamily`], /// [`Emitable`], and [`ParseableParametrized<[u8], GenlHeader>`](ParseableParametrized). #[derive(Clone, Debug, PartialEq, Eq)] -pub struct GenlMessage -where - F: Clone + Debug + PartialEq + Eq, -{ +pub struct GenlMessage { pub header: GenlHeader, pub payload: F, resolved_family_id: u16, @@ -31,7 +28,7 @@ where impl GenlMessage where - F: Clone + Debug + PartialEq + Eq, + F: Debug, { /// Construct the message pub fn new(header: GenlHeader, payload: F, family_id: u16) -> Self { @@ -85,7 +82,7 @@ where impl GenlMessage where - F: GenlFamily + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Debug, { /// Build the message from the payload /// @@ -132,7 +129,7 @@ where impl Emitable for GenlMessage where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { fn buffer_len(&self) -> usize { self.header.buffer_len() + self.payload.buffer_len() @@ -148,7 +145,7 @@ where impl NetlinkSerializable for GenlMessage where - F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + F: GenlFamily + Emitable + Debug, { fn message_type(&self) -> u16 { self.family_id() @@ -165,7 +162,7 @@ where impl NetlinkDeserializable for GenlMessage where - F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + F: ParseableParametrized<[u8], GenlHeader> + Debug, { type Error = DecodeError; fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { @@ -176,7 +173,7 @@ where impl From> for NetlinkPayload> where - F: Clone + Debug + PartialEq + Eq, + F: Debug, { fn from(message: GenlMessage) -> Self { NetlinkPayload::InnerMessage(message) diff --git a/netlink-proto/src/codecs.rs b/netlink-proto/src/codecs.rs index 596806d4..20bb0bc4 100644 --- a/netlink-proto/src/codecs.rs +++ b/netlink-proto/src/codecs.rs @@ -31,7 +31,7 @@ impl NetlinkCodec { // See https://github.com/mozilla/libaudit-go/issues/24 impl Decoder for NetlinkCodec> where - T: NetlinkDeserializable + Debug + Eq + PartialEq + Clone, + T: NetlinkDeserializable + Debug, { type Item = NetlinkMessage; type Error = io::Error; @@ -141,7 +141,7 @@ where impl Encoder> for NetlinkCodec> where - T: Debug + Eq + PartialEq + Clone + NetlinkSerializable, + T: Debug + NetlinkSerializable, { type Error = io::Error; diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index 3517d780..f14df05f 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -34,7 +34,7 @@ use crate::{ /// `Connection`, that in turn, sends them through the netlink socket. pub struct Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, + T: Debug + NetlinkSerializable + NetlinkDeserializable, { socket: NetlinkFramed>>, @@ -52,7 +52,7 @@ where impl Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, + T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, { pub(crate) fn new( requests_rx: UnboundedReceiver>, @@ -252,7 +252,7 @@ where impl Future for Connection where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable + Unpin, + T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, { type Output = (); diff --git a/netlink-proto/src/errors.rs b/netlink-proto/src/errors.rs index fe350f98..31173ec3 100644 --- a/netlink-proto/src/errors.rs +++ b/netlink-proto/src/errors.rs @@ -9,14 +9,14 @@ use netlink_packet_core::NetlinkMessage; #[derive(Debug)] pub struct Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { kind: ErrorKind, } impl Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { pub fn kind(&self) -> &ErrorKind { &self.kind @@ -30,7 +30,7 @@ where #[derive(Debug)] pub enum ErrorKind where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { /// The netlink connection is closed ConnectionClosed, @@ -44,7 +44,7 @@ where impl From> for Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { fn from(kind: ErrorKind) -> Error { Error { kind } @@ -53,7 +53,7 @@ where impl fmt::Display for Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use crate::ErrorKind::*; @@ -67,7 +67,7 @@ where impl StdError for Error where - T: Debug + Eq + PartialEq + Clone, + T: Debug, { fn description(&self) -> &str { use crate::ErrorKind::*; diff --git a/netlink-proto/src/handle.rs b/netlink-proto/src/handle.rs index d47eb4e2..db2d7f9f 100644 --- a/netlink-proto/src/handle.rs +++ b/netlink-proto/src/handle.rs @@ -15,14 +15,14 @@ use crate::{ #[derive(Clone, Debug)] pub struct ConnectionHandle where - T: Debug + Clone + Eq + PartialEq, + T: Debug, { requests_tx: UnboundedSender>, } impl ConnectionHandle where - T: Debug + Clone + Eq + PartialEq, + T: Debug, { pub(crate) fn new(requests_tx: UnboundedSender>) -> Self { ConnectionHandle { requests_tx } diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index d1c09786..1b18abae 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -227,13 +227,7 @@ pub fn new_connection( UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, )> where - T: Debug - + PartialEq - + Eq - + Clone - + packet::NetlinkSerializable - + packet::NetlinkDeserializable - + Unpin, + T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, { let (requests_tx, requests_rx) = unbounded::>(); let (messages_tx, messages_rx) = unbounded::<(packet::NetlinkMessage, sys::SocketAddr)>(); diff --git a/netlink-proto/src/protocol/protocol.rs b/netlink-proto/src/protocol/protocol.rs index d53baa2e..56f2f4b3 100644 --- a/netlink-proto/src/protocol/protocol.rs +++ b/netlink-proto/src/protocol/protocol.rs @@ -30,11 +30,7 @@ impl RequestId { } #[derive(Debug, Eq, PartialEq)] -pub(crate) struct Response -where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, - M: Debug, -{ +pub(crate) struct Response { pub done: bool, pub message: NetlinkMessage, pub metadata: M, @@ -47,11 +43,7 @@ struct PendingRequest { } #[derive(Debug, Default)] -pub(crate) struct Protocol -where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, - M: Debug, -{ +pub(crate) struct Protocol { /// Counter that is incremented for each message sent sequence_id: u32, @@ -71,8 +63,8 @@ where impl Protocol where - T: Debug + Clone + PartialEq + Eq + NetlinkSerializable + NetlinkDeserializable, - M: Clone + Debug, + T: Debug + NetlinkSerializable + NetlinkDeserializable, + M: Debug + Clone, { pub fn new() -> Self { Self { diff --git a/netlink-proto/src/protocol/request.rs b/netlink-proto/src/protocol/request.rs index 881b4547..f5aa358c 100644 --- a/netlink-proto/src/protocol/request.rs +++ b/netlink-proto/src/protocol/request.rs @@ -5,11 +5,7 @@ use netlink_packet_core::NetlinkMessage; use crate::sys::SocketAddr; #[derive(Debug)] -pub(crate) struct Request -where - T: Debug + Clone + Eq + PartialEq, - M: Debug, -{ +pub(crate) struct Request { pub metadata: M, pub message: NetlinkMessage, pub destination: SocketAddr, @@ -17,7 +13,7 @@ where impl From<(NetlinkMessage, SocketAddr, M)> for Request where - T: Debug + PartialEq + Eq + Clone, + T: Debug, M: Debug, { fn from(parts: (NetlinkMessage, SocketAddr, M)) -> Self { @@ -31,7 +27,7 @@ where impl From> for (NetlinkMessage, SocketAddr, M) where - T: Debug + PartialEq + Eq + Clone, + T: Debug, M: Debug, { fn from(req: Request) -> (NetlinkMessage, SocketAddr, M) { From cac815a73fbdfeaa4ad91323a15f92710c86a91c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 7 Nov 2021 19:59:40 +0100 Subject: [PATCH 06/11] Make codec a Connection type parameter, split audit implementation 1. get rid of "workaround-audit-bug" "feature" 2. no longer use tokio_util::codec::{Decoder, Encoder} tokio_util::codec is "designed" for bytestreams (and building "frames" of messages on top), but we need to deal with datagrams (which still can contain multiples messages, just not one message across multiple datagrams). This make it a little less "combinable", but we actually don't want people to reuse these codecs on bytestream (and writing an adapter wouldn't be that hard anyway). Also we can use a fixed error type, making dealing with it a little bit easier. --- .github/workflows/main.yml | 1 - audit/Cargo.toml | 4 +- audit/src/lib.rs | 5 +- genetlink/Cargo.toml | 4 +- netlink-packet-audit/Cargo.toml | 8 ++ netlink-packet-audit/src/codec.rs | 135 +++++++++++++++++++++++++ netlink-packet-audit/src/lib.rs | 6 ++ netlink-proto/Cargo.toml | 3 - netlink-proto/examples/audit_events.rs | 5 +- netlink-proto/src/codecs.rs | 125 ++++++++--------------- netlink-proto/src/connection.rs | 14 +-- netlink-proto/src/framed.rs | 49 +++++---- netlink-proto/src/lib.rs | 16 +++ 13 files changed, 253 insertions(+), 122 deletions(-) create mode 100644 netlink-packet-audit/src/codec.rs diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ddc83934..6e877e35 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -61,7 +61,6 @@ jobs: run: | cd netlink-proto cargo test - cargo test --features workaround-audit-bug - name: test (rtnetlink) env: diff --git a/audit/Cargo.toml b/audit/Cargo.toml index 039073f2..6f6d896b 100644 --- a/audit/Cargo.toml +++ b/audit/Cargo.toml @@ -19,8 +19,8 @@ netlink-proto = { default-features = false, version = "0.8" } [features] default = ["tokio_socket"] -tokio_socket = ["netlink-proto/tokio_socket", "netlink-proto/workaround-audit-bug"] -smol_socket = ["netlink-proto/smol_socket", "netlink-proto/workaround-audit-bug"] +tokio_socket = ["netlink-proto/tokio_socket", "netlink-packet-audit/tokio_socket"] +smol_socket = ["netlink-proto/smol_socket", "netlink-packet-audit/smol_socket"] [dev-dependencies] tokio = { version = "1.0.1", default-features = false, features = ["macros", "rt-multi-thread"] } diff --git a/audit/src/lib.rs b/audit/src/lib.rs index e38cc9f0..91a29b1a 100644 --- a/audit/src/lib.rs +++ b/audit/src/lib.rs @@ -16,13 +16,14 @@ use futures::channel::mpsc::UnboundedReceiver; #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( - proto::Connection, + proto::Connection, Handle, UnboundedReceiver<( packet::NetlinkMessage, sys::SocketAddr, )>, )> { - let (conn, handle, messages) = netlink_proto::new_connection(sys::protocols::NETLINK_AUDIT)?; + let (conn, handle, messages) = + netlink_proto::new_connection_with_codec(sys::protocols::NETLINK_AUDIT)?; Ok((conn, Handle::new(handle), messages)) } diff --git a/genetlink/Cargo.toml b/genetlink/Cargo.toml index 52d9126e..461a108a 100644 --- a/genetlink/Cargo.toml +++ b/genetlink/Cargo.toml @@ -12,8 +12,8 @@ description = "communicate with generic netlink" [features] default = ["tokio_socket"] -tokio_socket = ["netlink-proto/tokio_socket","netlink-proto/workaround-audit-bug", "tokio"] -smol_socket = ["netlink-proto/smol_socket","netlink-proto/workaround-audit-bug","async-std"] +tokio_socket = ["netlink-proto/tokio_socket", "tokio"] +smol_socket = ["netlink-proto/smol_socket","async-std"] [dependencies] futures = "0.3.16" diff --git a/netlink-packet-audit/Cargo.toml b/netlink-packet-audit/Cargo.toml index e87487f2..7a135fab 100644 --- a/netlink-packet-audit/Cargo.toml +++ b/netlink-packet-audit/Cargo.toml @@ -11,11 +11,19 @@ readme = "../README.md" repository = "https://github.com/little-dude/netlink" description = "netlink packet types" +[features] +default = ["tokio_socket"] +tokio_socket = ["netlink-proto/tokio_socket"] +smol_socket = ["netlink-proto/smol_socket"] + [dependencies] anyhow = "1.0.31" +bytes = "1.0" byteorder = "1.3.2" +log = "0.4.8" netlink-packet-core = "0.3" netlink-packet-utils = ">= 0.3, <0.5" +netlink-proto = { default-features = false, version = "0.8" } [dev-dependencies] lazy_static = "1.4.0" diff --git a/netlink-packet-audit/src/codec.rs b/netlink-packet-audit/src/codec.rs new file mode 100644 index 00000000..36f226af --- /dev/null +++ b/netlink-packet-audit/src/codec.rs @@ -0,0 +1,135 @@ +use std::{fmt::Debug, io}; + +use bytes::BytesMut; +use netlink_packet_core::{ + NetlinkBuffer, + NetlinkDeserializable, + NetlinkMessage, + NetlinkSerializable, +}; +pub(crate) use netlink_proto::{NetlinkCodec, NetlinkMessageCodec}; + +/// audit specific implementation of [`NetlinkMessageCodec`] due to the +/// protocol violations in messages generated by kernal audit. +/// +/// Among the known bugs in kernel audit messages: +/// - `nlmsg_len` sometimes contains the padding too (it shouldn't) +/// - `nlmsg_len` sometimes doesn't contain the header (it really should) +/// +/// See also: +/// - https://blog.des.no/2020/08/netlink-auditing-and-counting-bytes/ +/// - https://github.com/torvalds/linux/blob/b5013d084e03e82ceeab4db8ae8ceeaebe76b0eb/kernel/audit.c#L2386 +/// - https://github.com/mozilla/libaudit-go/issues/24 +/// - https://github.com/linux-audit/audit-userspace/issues/78 +pub struct NetlinkAuditCodec { + // we don't need an instance of this, just the type + _private: (), +} + +impl NetlinkMessageCodec for NetlinkAuditCodec { + fn decode(src: &mut BytesMut) -> io::Result>> + where + T: NetlinkDeserializable + Debug, + { + debug!("NetlinkAuditCodec: decoding next message"); + + loop { + // If there's nothing to read, return Ok(None) + if src.as_ref().is_empty() { + trace!("buffer is empty"); + src.clear(); + return Ok(None); + } + + // This is a bit hacky because we don't want to keep `src` + // borrowed, since we need to mutate it later. + let len_res = match NetlinkBuffer::new_checked(src.as_ref()) { + Ok(buf) => { + if (src.as_ref().len() as isize - buf.length() as isize) <= 16 { + // The audit messages are sometimes truncated, + // because the length specified in the header, + // does not take the header itself into + // account. To workaround this, we tweak the + // length. We've noticed two occurences of + // truncated packets: + // + // - the length of the header is not included (see also: + // https://github.com/mozilla/libaudit-go/issues/24) + // - some rule message have some padding for alignment (see + // https://github.com/linux-audit/audit-userspace/issues/78) which is not + // taken into account in the buffer length. + warn!("found what looks like a truncated audit packet"); + Ok(src.as_ref().len()) + } else { + Ok(buf.length() as usize) + } + } + Err(e) => { + // We either received a truncated packet, or the + // packet if malformed (invalid length field). In + // both case, we can't decode the datagram, and we + // cannot find the start of the next one (if + // any). The only solution is to clear the buffer + // and potentially lose some datagrams. + error!("failed to decode datagram: {:?}: {:#x?}.", e, src.as_ref()); + Err(()) + } + }; + + if len_res.is_err() { + error!("clearing the whole socket buffer. Datagrams may have been lost"); + src.clear(); + return Ok(None); + } + + let len = len_res.unwrap(); + + let bytes = { + let mut bytes = src.split_to(len); + { + let mut buf = NetlinkBuffer::new(bytes.as_mut()); + // If the buffer contains more bytes than what the header says the length is, it + // means we ran into a malformed packet (see comment above), and we just set the + // "right" length ourself, so that parsing does not fail. + // + // How do we know that's the right length? Due to an implementation detail and to + // the fact that netlink is a datagram protocol. + // + // - our implementation of Stream always calls the codec with at most 1 message in + // the buffer, so we know the extra bytes do not belong to another message. + // - because netlink is a datagram protocol, we receive entire messages, so we know + // that if those extra bytes do not belong to another message, they belong to + // this one. + if len != buf.length() as usize { + warn!( + "setting packet length to {} instead of {}", + len, + buf.length() + ); + buf.set_length(len as u32); + } + } + bytes + }; + + let parsed = NetlinkMessage::::deserialize(&bytes); + match parsed { + Ok(packet) => { + trace!("<<< {:?}", packet); + return Ok(Some(packet)); + } + Err(e) => { + error!("failed to decode packet {:#x?}: {}", &bytes, e); + // continue looping, there may be more datagrams in the buffer + } + } + } + } + + fn encode(msg: NetlinkMessage, buf: &mut BytesMut) -> io::Result<()> + where + T: Debug + NetlinkSerializable, + { + NetlinkCodec::encode(msg, buf) + } +} diff --git a/netlink-packet-audit/src/lib.rs b/netlink-packet-audit/src/lib.rs index a80d4340..19bb06bb 100644 --- a/netlink-packet-audit/src/lib.rs +++ b/netlink-packet-audit/src/lib.rs @@ -1,3 +1,6 @@ +#[macro_use] +extern crate log; + pub(crate) extern crate netlink_packet_utils as utils; pub use self::utils::{traits, DecodeError}; pub use netlink_packet_core::{ @@ -13,6 +16,9 @@ use core::ops::Range; /// Represent a multi-bytes field with a fixed size in a packet pub(crate) type Field = Range; +mod codec; +pub use codec::NetlinkAuditCodec; + pub mod status; pub use self::status::*; diff --git a/netlink-proto/Cargo.toml b/netlink-proto/Cargo.toml index c1582df0..74da076b 100644 --- a/netlink-proto/Cargo.toml +++ b/netlink-proto/Cargo.toml @@ -16,7 +16,6 @@ bytes = "1.0" log = "0.4.8" futures = "0.3" tokio = { version = "1.0", default-features = false, features = ["io-util"] } -tokio-util = { version = "0.6", default-features = false, features = ["codec"] } netlink-packet-core = "0.3" netlink-sys = { default-features = false, version = "0.7" } @@ -24,7 +23,6 @@ netlink-sys = { default-features = false, version = "0.7" } default = ["tokio_socket"] tokio_socket = ["netlink-sys/tokio_socket"] smol_socket = ["netlink-sys/smol_socket"] -workaround-audit-bug = [] [dev-dependencies] env_logger = "0.8.2" @@ -42,4 +40,3 @@ required-features = ["smol_socket"] [[example]] name = "audit_events" -required-features = ["workaround-audit-bug"] diff --git a/netlink-proto/examples/audit_events.rs b/netlink-proto/examples/audit_events.rs index 1c451f7c..de82ccab 100644 --- a/netlink-proto/examples/audit_events.rs +++ b/netlink-proto/examples/audit_events.rs @@ -7,10 +7,7 @@ // Compilation: // ------------ // -// cargo build --example audit_events --features="workaround-audit-bug" -// -// Note that the audit protocol has a bug that we have to workaround, -// hence the custom --features flag for that protocol +// cargo build --example audit_events // // Usage: // ------ diff --git a/netlink-proto/src/codecs.rs b/netlink-proto/src/codecs.rs index 20bb0bc4..1b1bd737 100644 --- a/netlink-proto/src/codecs.rs +++ b/netlink-proto/src/codecs.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, io, marker::PhantomData}; +use std::{fmt::Debug, io}; use bytes::{BufMut, BytesMut}; use netlink_packet_core::{ @@ -7,36 +7,48 @@ use netlink_packet_core::{ NetlinkMessage, NetlinkSerializable, }; -use tokio_util::codec::{Decoder, Encoder}; -pub struct NetlinkCodec { - phantom: PhantomData, -} +/// Protocol to serialize and deserialize messages to and from datagrams +/// +/// This is separate from `tokio_util::codec::{Decoder, Encoder}` as the implementations +/// rely on the buffer containing full datagrams; they won't work well with simple +/// bytestreams. +/// +/// Officially there should be exactly one implementation of this, but the audit +/// subsystem ignores way too many rules of the protocol, so they need a separate +/// implementation. +/// +/// Although one could make a tighter binding between `NetlinkMessageCodec` and +/// the message types (NetlinkDeserializable+NetlinkSerializable) it can handle, +/// this would put quite some overhead on subsystems that followed the spec - so +/// we simply default to the proper implementation (in `Connection`) and the +/// `audit` code needs to overwrite it. +pub trait NetlinkMessageCodec { + /// Decode message of given type from datagram payload + /// + /// There might be more than one message; this needs to be called until it + /// either returns `Ok(None)` or an error. + fn decode(src: &mut BytesMut) -> io::Result>> + where + T: NetlinkDeserializable + Debug; -impl Default for NetlinkCodec { - fn default() -> Self { - Self::new() - } + /// Encode message to (datagram) buffer + fn encode(msg: NetlinkMessage, buf: &mut BytesMut) -> io::Result<()> + where + T: NetlinkSerializable + Debug; } -impl NetlinkCodec { - pub fn new() -> Self { - NetlinkCodec { - phantom: PhantomData, - } - } +/// Standard implementation of `NetlinkMessageCodec` +pub struct NetlinkCodec { + // we don't need an instance of this, just the type + _private: (), } -// FIXME: it seems that for audit, we're receiving malformed packets. -// See https://github.com/mozilla/libaudit-go/issues/24 -impl Decoder for NetlinkCodec> -where - T: NetlinkDeserializable + Debug, -{ - type Item = NetlinkMessage; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { +impl NetlinkMessageCodec for NetlinkCodec { + fn decode(src: &mut BytesMut) -> io::Result>> + where + T: NetlinkDeserializable + Debug, + { debug!("NetlinkCodec: decoding next message"); loop { @@ -50,29 +62,7 @@ where // This is a bit hacky because we don't want to keep `src` // borrowed, since we need to mutate it later. let len_res = match NetlinkBuffer::new_checked(src.as_ref()) { - #[cfg(not(feature = "workaround-audit-bug"))] Ok(buf) => Ok(buf.length() as usize), - #[cfg(feature = "workaround-audit-bug")] - Ok(buf) => { - if (src.as_ref().len() as isize - buf.length() as isize) <= 16 { - // The audit messages are sometimes truncated, - // because the length specified in the header, - // does not take the header itself into - // account. To workaround this, we tweak the - // length. We've noticed two occurences of - // truncated packets: - // - // - the length of the header is not included (see also: - // https://github.com/mozilla/libaudit-go/issues/24) - // - some rule message have some padding for alignment (see - // https://github.com/linux-audit/audit-userspace/issues/78) which is not - // taken into account in the buffer length. - warn!("found what looks like a truncated audit packet"); - Ok(src.as_ref().len()) - } else { - Ok(buf.length() as usize) - } - } Err(e) => { // We either received a truncated packet, or the // packet if malformed (invalid length field). In @@ -93,35 +83,6 @@ where let len = len_res.unwrap(); - #[cfg(feature = "workaround-audit-bug")] - let bytes = { - let mut bytes = src.split_to(len); - { - let mut buf = NetlinkBuffer::new(bytes.as_mut()); - // If the buffer contains more bytes than what the header says the length is, it - // means we ran into a malformed packet (see comment above), and we just set the - // "right" length ourself, so that parsing does not fail. - // - // How do we know that's the right length? Due to an implementation detail and to - // the fact that netlink is a datagram protocol. - // - // - our implementation of Stream always calls the codec with at most 1 message in - // the buffer, so we know the extra bytes do not belong to another message. - // - because netlink is a datagram protocol, we receive entire messages, so we know - // that if those extra bytes do not belong to another message, they belong to - // this one. - if len != buf.length() as usize { - warn!( - "setting packet length to {} instead of {}", - len, - buf.length() - ); - buf.set_length(len as u32); - } - } - bytes - }; - #[cfg(not(feature = "workaround-audit-bug"))] let bytes = src.split_to(len); let parsed = NetlinkMessage::::deserialize(&bytes); @@ -137,15 +98,11 @@ where } } } -} - -impl Encoder> for NetlinkCodec> -where - T: Debug + NetlinkSerializable, -{ - type Error = io::Error; - fn encode(&mut self, msg: NetlinkMessage, buf: &mut BytesMut) -> Result<(), Self::Error> { + fn encode(msg: NetlinkMessage, buf: &mut BytesMut) -> io::Result<()> + where + T: Debug + NetlinkSerializable, + { let msg_len = msg.buffer_len(); if buf.remaining_mut() < msg_len { // BytesMut can expand till usize::MAX... unlikely to hit this one. diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index f14df05f..a3418ef0 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -20,7 +20,7 @@ use netlink_packet_core::{ }; use crate::{ - codecs::NetlinkCodec, + codecs::{NetlinkCodec, NetlinkMessageCodec}, framed::NetlinkFramed, sys::{Socket, SocketAddr}, Protocol, @@ -32,11 +32,11 @@ use crate::{ /// /// [`ConnectionHandle`](struct.ConnectionHandle.html) are used to pass new requests to the /// `Connection`, that in turn, sends them through the netlink socket. -pub struct Connection +pub struct Connection where T: Debug + NetlinkSerializable + NetlinkDeserializable, { - socket: NetlinkFramed>>, + socket: NetlinkFramed, protocol: Protocol>>, @@ -50,9 +50,10 @@ where socket_closed: bool, } -impl Connection +impl Connection where T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, + C: NetlinkMessageCodec, { pub(crate) fn new( requests_rx: UnboundedReceiver>, @@ -61,7 +62,7 @@ where ) -> io::Result { let socket = Socket::new(protocol)?; Ok(Connection { - socket: NetlinkFramed::new(socket, NetlinkCodec::>::new()), + socket: NetlinkFramed::new(socket), protocol: Protocol::new(), requests_rx: Some(requests_rx), unsolicited_messages_tx: Some(unsolicited_messages_tx), @@ -250,9 +251,10 @@ where } } -impl Future for Connection +impl Future for Connection where T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, + C: NetlinkMessageCodec, { type Output = (); diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index f7aadf0a..af1cd596 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -1,6 +1,8 @@ use bytes::{BufMut, BytesMut}; use std::{ + fmt::Debug, io, + marker::PhantomData, pin::Pin, slice, task::{Context, Poll}, @@ -8,13 +10,17 @@ use std::{ use futures::{Sink, Stream}; use log::error; -use tokio_util::codec::{Decoder, Encoder}; -use crate::sys::{Socket, SocketAddr}; +use crate::{ + codecs::NetlinkMessageCodec, + sys::{Socket, SocketAddr}, +}; +use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable}; -pub struct NetlinkFramed { +pub struct NetlinkFramed { socket: Socket, - codec: C, + msg_type: PhantomData T>, // invariant + codec: PhantomData C>, // invariant reader: BytesMut, writer: BytesMut, in_addr: SocketAddr, @@ -22,16 +28,15 @@ pub struct NetlinkFramed { flushed: bool, } -impl Stream for NetlinkFramed +impl Stream for NetlinkFramed where - C: Decoder + Unpin, - C::Error: std::error::Error, + T: NetlinkDeserializable + Debug, + C: NetlinkMessageCodec, { - type Item = (C::Item, SocketAddr); + type Item = (NetlinkMessage, SocketAddr); fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Self { - ref mut codec, ref mut socket, ref mut in_addr, ref mut reader, @@ -39,7 +44,7 @@ where } = Pin::get_mut(self); loop { - match codec.decode(reader) { + match C::decode::(reader) { Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))), Ok(None) => {} Err(e) => { @@ -74,8 +79,12 @@ where } } -impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed { - type Error = C::Error; +impl Sink<(NetlinkMessage, SocketAddr)> for NetlinkFramed +where + T: NetlinkSerializable + Debug, + C: NetlinkMessageCodec, +{ + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.flushed { @@ -88,11 +97,14 @@ impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed< Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, item: (Item, SocketAddr)) -> Result<(), Self::Error> { + fn start_send( + self: Pin<&mut Self>, + item: (NetlinkMessage, SocketAddr), + ) -> Result<(), Self::Error> { trace!("sending frame"); let (frame, out_addr) = item; let pin = self.get_mut(); - pin.codec.encode(frame, &mut pin.writer)?; + C::encode(frame, &mut pin.writer)?; pin.out_addr = out_addr; pin.flushed = false; trace!("frame encoded; length={}", pin.writer.len()); @@ -143,14 +155,15 @@ impl + Unpin, Item> Sink<(Item, SocketAddr)> for NetlinkFramed< const INITIAL_READER_CAPACITY: usize = 64 * 1024; const INITIAL_WRITER_CAPACITY: usize = 8 * 1024; -impl NetlinkFramed { +impl NetlinkFramed { /// Create a new `NetlinkFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. - pub fn new(socket: Socket, codec: C) -> NetlinkFramed { - NetlinkFramed { + pub fn new(socket: Socket) -> Self { + Self { socket, - codec, + msg_type: PhantomData, + codec: PhantomData, out_addr: SocketAddr::new(0, 0), in_addr: SocketAddr::new(0, 0), reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY), diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index 1b18abae..e2be8b21 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -228,6 +228,22 @@ pub fn new_connection( )> where T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, +{ + new_connection_with_codec(protocol) +} + +/// Variant of [`new_connection`] that allows specifying a separate codec +#[allow(clippy::type_complexity)] +pub fn new_connection_with_codec( + protocol: isize, +) -> io::Result<( + Connection, + ConnectionHandle, + UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, +)> +where + T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, + C: NetlinkMessageCodec, { let (requests_tx, requests_rx) = unbounded::>(); let (messages_tx, messages_rx) = unbounded::<(packet::NetlinkMessage, sys::SocketAddr)>(); From eda45967026d63cb6102f50aa1c9ed020f8cb45c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Tue, 16 Nov 2021 00:06:57 +0100 Subject: [PATCH 07/11] fixup codec --- netlink-proto/src/framed.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index af1cd596..a72b2ac0 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -19,6 +19,9 @@ use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializ pub struct NetlinkFramed { socket: Socket, + // see https://doc.rust-lang.org/nomicon/phantom-data.html + // "invariant" seems like the safe choice; using `fn(T) -> T` + // should make it invariant but still Send+Sync. msg_type: PhantomData T>, // invariant codec: PhantomData C>, // invariant reader: BytesMut, From 7e6cfd743bf822e917e260eb24fbf5b2c541922e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Wed, 10 Nov 2021 20:55:16 +0100 Subject: [PATCH 08/11] Fix UB from unsafe uninitialized buffer slices Creating &mut [u8] (and &[u8]) for unitialized memory is undefined behaviour even when not actually reading the data. Use bytes::BufMut instead, and advance buffer in recv functions. High level functions (without flags param) don't need to return length of read, as it was used to advance the buffer - only low-level read might return larger length than the buffer (PEEK + TRUNC flags). This also bumps netlink-sys; the other crates already got bumped. --- ethtool/Cargo.toml | 2 +- netlink-packet-generic/Cargo.toml | 2 +- .../tests/query_family_id.rs | 3 +- netlink-packet-route/Cargo.toml | 2 +- netlink-packet-route/examples/dump_links.rs | 2 +- .../examples/dump_neighbours.rs | 2 +- netlink-packet-route/examples/dump_rules.rs | 2 +- netlink-packet-sock-diag/Cargo.toml | 2 +- .../examples/dump_ipv4.rs | 2 +- netlink-proto/Cargo.toml | 2 +- netlink-proto/src/framed.rs | 25 ++----- netlink-sys/Cargo.toml | 3 +- .../examples/audit_events_async_std.rs | 10 +-- netlink-sys/examples/audit_events_tokio.rs | 10 +-- netlink-sys/src/smol.rs | 29 ++++++--- netlink-sys/src/socket.rs | 65 ++++++++++++------- netlink-sys/src/tokio.rs | 23 +++++-- 17 files changed, 110 insertions(+), 76 deletions(-) diff --git a/ethtool/Cargo.toml b/ethtool/Cargo.toml index b5b9fb00..f6a92ad5 100644 --- a/ethtool/Cargo.toml +++ b/ethtool/Cargo.toml @@ -30,7 +30,7 @@ netlink-packet-core = "0.3.0" netlink-packet-generic = "0.2.0" netlink-packet-utils = "0.4.1" netlink-proto = { default-features = false, version = "0.8.0" } -netlink-sys = "0.7.0" +netlink-sys = "0.8.0" thiserror = "1.0.29" tokio = { version = "1.0.1", features = ["rt"], optional = true} diff --git a/netlink-packet-generic/Cargo.toml b/netlink-packet-generic/Cargo.toml index 39af2717..bfbbce79 100644 --- a/netlink-packet-generic/Cargo.toml +++ b/netlink-packet-generic/Cargo.toml @@ -18,4 +18,4 @@ netlink-packet-core = "0.3" netlink-packet-utils = "0.4" [dev-dependencies] -netlink-sys = { path = "../netlink-sys", version = "0.7" } +netlink-sys = { path = "../netlink-sys", version = "0.8" } diff --git a/netlink-packet-generic/tests/query_family_id.rs b/netlink-packet-generic/tests/query_family_id.rs index a74eadf4..da639eae 100644 --- a/netlink-packet-generic/tests/query_family_id.rs +++ b/netlink-packet-generic/tests/query_family_id.rs @@ -26,8 +26,7 @@ fn query_family_id() { socket.send(&txbuf, 0).unwrap(); - let mut rxbuf = vec![0u8; 2048]; - socket.recv(&mut rxbuf, 0).unwrap(); + let (rxbuf, _addr) = socket.recv_from_full().unwrap(); let rx_packet = >>::deserialize(&rxbuf).unwrap(); if let NetlinkPayload::InnerMessage(genlmsg) = rx_packet.payload { diff --git a/netlink-packet-route/Cargo.toml b/netlink-packet-route/Cargo.toml index fa8e8ff7..d47c3020 100644 --- a/netlink-packet-route/Cargo.toml +++ b/netlink-packet-route/Cargo.toml @@ -26,7 +26,7 @@ name = "dump_links" criterion = "0.3.0" pcap-file = "1.1.1" lazy_static = "1.4.0" -netlink-sys = "0.7" +netlink-sys = "0.8" [[bench]] name = "link_message" diff --git a/netlink-packet-route/examples/dump_links.rs b/netlink-packet-route/examples/dump_links.rs index cfcaa1d0..3781445f 100644 --- a/netlink-packet-route/examples/dump_links.rs +++ b/netlink-packet-route/examples/dump_links.rs @@ -37,7 +37,7 @@ fn main() { // we set the NLM_F_DUMP flag so we expect a multipart rx_packet in response. loop { - let size = socket.recv(&mut receive_buffer[..], 0).unwrap(); + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); loop { let bytes = &receive_buffer[offset..]; diff --git a/netlink-packet-route/examples/dump_neighbours.rs b/netlink-packet-route/examples/dump_neighbours.rs index 89c06b20..3816824a 100644 --- a/netlink-packet-route/examples/dump_neighbours.rs +++ b/netlink-packet-route/examples/dump_neighbours.rs @@ -38,7 +38,7 @@ fn main() { let mut offset = 0; 'outer: loop { - let size = socket.recv(&mut receive_buffer[..], 0).unwrap(); + let size = socket.recv(&mut &mut receive_buffer[..], 0).unwrap(); loop { let bytes = &receive_buffer[offset..]; diff --git a/netlink-packet-route/examples/dump_rules.rs b/netlink-packet-route/examples/dump_rules.rs index 2b3566e5..a239542d 100644 --- a/netlink-packet-route/examples/dump_rules.rs +++ b/netlink-packet-route/examples/dump_rules.rs @@ -41,7 +41,7 @@ fn main() { let mut offset = 0; // we set the NLM_F_DUMP flag so we expect a multipart rx_packet in response. - while let Ok(size) = socket.recv(&mut receive_buffer[..], 0) { + while let Ok(size) = socket.recv(&mut &mut receive_buffer[..], 0) { loop { let bytes = &receive_buffer[offset..]; let rx_packet = >::deserialize(bytes).unwrap(); diff --git a/netlink-packet-sock-diag/Cargo.toml b/netlink-packet-sock-diag/Cargo.toml index 0e06a448..dfd52347 100644 --- a/netlink-packet-sock-diag/Cargo.toml +++ b/netlink-packet-sock-diag/Cargo.toml @@ -22,4 +22,4 @@ smallvec = "1.4.2" [dev-dependencies] lazy_static = "1.4.0" -netlink-sys = "0.7" +netlink-sys = "0.8" diff --git a/netlink-packet-sock-diag/examples/dump_ipv4.rs b/netlink-packet-sock-diag/examples/dump_ipv4.rs index fc9342e8..d4f8287e 100644 --- a/netlink-packet-sock-diag/examples/dump_ipv4.rs +++ b/netlink-packet-sock-diag/examples/dump_ipv4.rs @@ -46,7 +46,7 @@ fn main() { let mut receive_buffer = vec![0; 4096]; let mut offset = 0; - while let Ok(size) = socket.recv(&mut receive_buffer[..], 0) { + while let Ok(size) = socket.recv(&mut &mut receive_buffer[..], 0) { loop { let bytes = &receive_buffer[offset..]; let rx_packet = >::deserialize(bytes).unwrap(); diff --git a/netlink-proto/Cargo.toml b/netlink-proto/Cargo.toml index 74da076b..f8d4f3aa 100644 --- a/netlink-proto/Cargo.toml +++ b/netlink-proto/Cargo.toml @@ -17,7 +17,7 @@ log = "0.4.8" futures = "0.3" tokio = { version = "1.0", default-features = false, features = ["io-util"] } netlink-packet-core = "0.3" -netlink-sys = { default-features = false, version = "0.7" } +netlink-sys = { default-features = false, version = "0.8" } [features] default = ["tokio_socket"] diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index a72b2ac0..872e1491 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -1,10 +1,9 @@ -use bytes::{BufMut, BytesMut}; +use bytes::BytesMut; use std::{ fmt::Debug, io, marker::PhantomData, pin::Pin, - slice, task::{Context, Poll}, }; @@ -59,23 +58,11 @@ where reader.clear(); reader.reserve(INITIAL_READER_CAPACITY); - *in_addr = unsafe { - // Read into the buffer without having to initialize the memory. - // - // safety: we know poll_recv_from never reads from the - // memory during a recv so it's fine to turn &mut - // [>] into &mut[u8] - let bytes = reader.chunk_mut(); - let bytes = slice::from_raw_parts_mut(bytes.as_mut_ptr(), bytes.len()); - match ready!(socket.poll_recv_from(cx, bytes)) { - Ok((n, addr)) => { - reader.advance_mut(n); - addr - } - Err(e) => { - error!("failed to read from netlink socket: {:?}", e); - return Poll::Ready(None); - } + *in_addr = match ready!(socket.poll_recv_from(cx, reader)) { + Ok(addr) => addr, + Err(e) => { + error!("failed to read from netlink socket: {:?}", e); + return Poll::Ready(None); } }; } diff --git a/netlink-sys/Cargo.toml b/netlink-sys/Cargo.toml index 579044eb..a8e59b4d 100644 --- a/netlink-sys/Cargo.toml +++ b/netlink-sys/Cargo.toml @@ -1,7 +1,7 @@ [package] authors = ["Corentin Henry "] name = "netlink-sys" -version = "0.7.0" +version = "0.8.0" # TODO: drop this comment - already bumped version for trait changes edition = "2018" homepage = "https://github.com/little-dude/netlink" @@ -12,6 +12,7 @@ repository = "https://github.com/little-dude/netlink" description = "netlink sockets, with optional integration with tokio" [dependencies] +bytes = "1.0" libc = "0.2.66" log = "0.4.8" diff --git a/netlink-sys/examples/audit_events_async_std.rs b/netlink-sys/examples/audit_events_async_std.rs index fde12fb5..db318c32 100644 --- a/netlink-sys/examples/audit_events_async_std.rs +++ b/netlink-sys/examples/audit_events_async_std.rs @@ -50,20 +50,22 @@ async fn main() { .await .unwrap(); - let mut buf = vec![0; 1024 * 8]; + let mut buf = bytes::BytesMut::with_capacity(1024 * 8); loop { - let (n, _addr) = socket.recv_from(&mut buf).await.unwrap(); + buf.clear(); + let _addr = socket.recv_from(&mut buf).await.unwrap(); // This dance with the NetlinkBuffer should not be // necessary. It is here to work around a netlink bug. See: // https://github.com/mozilla/libaudit-go/issues/24 // https://github.com/linux-audit/audit-userspace/issues/78 { - let mut nl_buf = NetlinkBuffer::new(&mut buf[0..n]); + let n = buf.len(); + let mut nl_buf = NetlinkBuffer::new(&mut buf); if n != nl_buf.length() as usize { nl_buf.set_length(n as u32); } } - let parsed = NetlinkMessage::::deserialize(&buf[0..n]).unwrap(); + let parsed = NetlinkMessage::::deserialize(&buf).unwrap(); println!("<<< {:?}", parsed); } } diff --git a/netlink-sys/examples/audit_events_tokio.rs b/netlink-sys/examples/audit_events_tokio.rs index afdd5f74..5f121c45 100644 --- a/netlink-sys/examples/audit_events_tokio.rs +++ b/netlink-sys/examples/audit_events_tokio.rs @@ -50,20 +50,22 @@ async fn main() -> Result<(), Box> { .await .unwrap(); - let mut buf = vec![0; 1024 * 8]; + let mut buf = bytes::BytesMut::with_capacity(1024 * 8); loop { - let (n, _addr) = socket.recv_from(&mut buf).await.unwrap(); + buf.clear(); + let _addr = socket.recv_from(&mut buf).await.unwrap(); // This dance with the NetlinkBuffer should not be // necessary. It is here to work around a netlink bug. See: // https://github.com/mozilla/libaudit-go/issues/24 // https://github.com/linux-audit/audit-userspace/issues/78 { - let mut nl_buf = NetlinkBuffer::new(&mut buf[0..n]); + let n = buf.len(); + let mut nl_buf = NetlinkBuffer::new(&mut buf); if n != nl_buf.length() as usize { nl_buf.set_length(n as u32); } } - let parsed = NetlinkMessage::::deserialize(&buf[0..n]).unwrap(); + let parsed = NetlinkMessage::::deserialize(&buf).unwrap(); println!("<<< {:?}", parsed); } } diff --git a/netlink-sys/src/smol.rs b/netlink-sys/src/smol.rs index 6ef2b5e0..1647b38a 100644 --- a/netlink-sys/src/smol.rs +++ b/netlink-sys/src/smol.rs @@ -43,23 +43,36 @@ impl SmolSocket { .await } - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read_with_mut(|sock| sock.recv(buf, 0)).await + pub async fn recv(&mut self, buf: &mut B) -> io::Result<()> + where + B: bytes::BufMut, + { + self.0 + .read_with_mut(|sock| sock.recv(buf, 0).map(|_len| ())) + .await } - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { - self.0.read_with_mut(|sock| sock.recv_from(buf, 0)).await + pub async fn recv_from(&mut self, buf: &mut B) -> io::Result + where + B: bytes::BufMut, + { + self.0 + .read_with_mut(|sock| sock.recv_from(buf, 0).map(|(_len, addr)| addr)) + .await } pub async fn recv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)> { self.0.read_with_mut(|sock| sock.recv_from_full()).await } - pub fn poll_recv_from( + pub fn poll_recv_from( &mut self, cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut, + { loop { trace!("poll_recv_from called"); let _guard = ready!(self.0.poll_readable(cx))?; @@ -72,7 +85,7 @@ impl SmolSocket { } x => { trace!("poll_recv_from {:?} bytes read", x); - return Poll::Ready(x); + return Poll::Ready(x.map(|(_len, addr)| addr)); } } } diff --git a/netlink-sys/src/socket.rs b/netlink-sys/src/socket.rs index bdc53cbe..2c9b3c86 100644 --- a/netlink-sys/src/socket.rs +++ b/netlink-sys/src/socket.rs @@ -36,7 +36,7 @@ use crate::SocketAddr; /// let mut buf = vec![0; 4096]; /// loop { /// // receive a datagram -/// let (n_received, sender_addr) = socket.recv_from(&mut buf[..], 0).unwrap(); +/// let (n_received, sender_addr) = socket.recv_from(&mut &mut buf[..], 0).unwrap(); /// assert_eq!(sender_addr, kernel_addr); /// println!("received datagram {:?}", &buf[..n_received]); /// if buf[4] == 2 && buf[5] == 0 { @@ -165,7 +165,7 @@ impl Socket { /// // buffer for receiving the response /// let mut buf = vec![0; 4096]; /// loop { - /// let mut n_received = socket.recv(&mut buf[..], 0).unwrap(); + /// let mut n_received = socket.recv(&mut &mut buf[..], 0).unwrap(); /// println!("received {:?}", &buf[..n_received]); /// if buf[4] == 2 && buf[5] == 0 { /// println!("the kernel responded with an error"); @@ -224,7 +224,10 @@ impl Socket { /// In datagram oriented protocols, `recv` and `recvfrom` receive normally only ONE datagram, but this seems not to /// be always true for netlink sockets: with some protocols like `NETLINK_AUDIT`, multiple netlink packets can be /// read with a single call. - pub fn recv_from(&self, buf: &mut [u8], flags: libc::c_int) -> Result<(usize, SocketAddr)> { + pub fn recv_from(&self, buf: &mut B, flags: libc::c_int) -> Result<(usize, SocketAddr)> + where + B: bytes::BufMut, + { // Create an empty storage for the address. Note that Rust standard library create a // sockaddr_storage so that it works for any address family, but here, we already know that // we'll have a Netlink address, so we can create the appropriate storage. @@ -252,34 +255,51 @@ impl Socket { // a pointer to it. let addrlen_ptr = &mut addrlen as *mut usize as *mut libc::socklen_t; - // Cast the *mut u8 into *mut void. - // This is equivalent to casting a *char into *void - // See [thread] - // ^ - // Create a *mut u8 | - // ^ | - // | | - // +-----+-----+ +--------+-------+ - // / \ / \ - let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void; - let buf_len = buf.len() as libc::size_t; + let chunk = buf.chunk_mut(); + // Cast the *mut u8 into *mut void. + // This is equivalent to casting a *char into *void + // See [thread] + // ^ + // Create a *mut u8 | + // ^ | + // | | + // +------+-------+ +--------+-------+ + // / \ / \ + let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void; + let buf_len = chunk.len() as libc::size_t; let res = unsafe { libc::recvfrom(self.0, buf_ptr, buf_len, flags, addr_ptr, addrlen_ptr) }; if res < 0 { return Err(Error::last_os_error()); + } else { + // with `MSG_TRUNC` `res` might exceed `buf_len` + let written = std::cmp::min(buf_len, res as usize); + unsafe { + buf.advance_mut(written); + } } Ok((res as usize, SocketAddr(addr))) } /// For a connected socket, `recv` reads a datagram from the socket. The sender is the remote peer the socket is /// connected to (see [`Socket::connect`]). See also [`Socket::recv_from`] - pub fn recv(&self, buf: &mut [u8], flags: libc::c_int) -> Result { - let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void; - let buf_len = buf.len() as libc::size_t; + pub fn recv(&self, buf: &mut B, flags: libc::c_int) -> Result + where + B: bytes::BufMut, + { + let chunk = buf.chunk_mut(); + let buf_ptr = chunk.as_mut_ptr() as *mut libc::c_void; + let buf_len = chunk.len() as libc::size_t; let res = unsafe { libc::recv(self.0, buf_ptr, buf_len, flags) }; if res < 0 { return Err(Error::last_os_error()); + } else { + // with `MSG_TRUNC` `res` might exceed `buf_len` + let written = std::cmp::min(buf_len, res as usize); + unsafe { + buf.advance_mut(written); + } } Ok(res as usize) } @@ -288,13 +308,14 @@ impl Socket { /// buffer passed as argument, this method always reads a whole message, no matter its size. pub fn recv_from_full(&self) -> Result<(Vec, SocketAddr)> { // Peek - let mut buf = Vec::::new(); - let (rlen, _) = self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?; + let mut buf: Vec = Vec::new(); + let (peek_len, _) = self.recv_from(&mut buf, libc::MSG_PEEK | libc::MSG_TRUNC)?; // Receive - let mut buf = vec![0; rlen as usize]; - let (_, addr) = self.recv_from(&mut buf, 0)?; - + buf.clear(); + buf.reserve(peek_len); + let (rlen, addr) = self.recv_from(&mut buf, 0)?; + assert_eq!(rlen, peek_len); Ok((buf, addr)) } diff --git a/netlink-sys/src/tokio.rs b/netlink-sys/src/tokio.rs index a83fe93d..c4eebae3 100644 --- a/netlink-sys/src/tokio.rs +++ b/netlink-sys/src/tokio.rs @@ -71,7 +71,10 @@ impl TokioSocket { } } - pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { + pub async fn recv(&mut self, buf: &mut B) -> io::Result<()> + where + B: bytes::BufMut, + { poll_fn(|cx| loop { // Check if the socket is readable. If not, // AsyncFd::poll_read_ready would have arranged for the @@ -80,14 +83,17 @@ impl TokioSocket { let mut guard = ready!(self.0.poll_read_ready(cx))?; match guard.try_io(|inner| inner.get_ref().recv(buf, 0)) { - Ok(x) => return Poll::Ready(x), + Ok(x) => return Poll::Ready(x.map(|_len| ())), Err(_would_block) => continue, } }) .await } - pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { + pub async fn recv_from(&mut self, buf: &mut B) -> io::Result + where + B: bytes::BufMut, + { poll_fn(|cx| self.poll_recv_from(cx, buf)).await } @@ -95,11 +101,14 @@ impl TokioSocket { poll_fn(|cx| self.poll_recv_from_full(cx)).await } - pub fn poll_recv_from( + pub fn poll_recv_from( &mut self, cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut, + { loop { trace!("poll_recv_from called"); let mut guard = ready!(self.0.poll_read_ready(cx))?; @@ -108,7 +117,7 @@ impl TokioSocket { match guard.try_io(|inner| inner.get_ref().recv_from(buf, 0)) { Ok(x) => { trace!("poll_recv_from {:?} bytes read", x); - return Poll::Ready(x); + return Poll::Ready(x.map(|(_len, addr)| addr)); } Err(_would_block) => { trace!("poll_recv_from socket would block"); From 1ea0da5ee50587352ce845e88b184729d55c26c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 14 Nov 2021 17:09:42 +0100 Subject: [PATCH 09/11] Add AsyncSocket trait in netlink-sys Preparation to make tokio and smol feature non-conflicting in netlink-proto. --- netlink-proto/src/connection.rs | 2 +- netlink-proto/src/framed.rs | 2 +- netlink-proto/src/lib.rs | 2 +- .../examples/audit_events_async_std.rs | 2 +- netlink-sys/examples/audit_events_tokio.rs | 2 +- ...udit_events_tokio_manual_thread_builder.rs | 2 +- netlink-sys/src/async_socket.rs | 55 +++++ netlink-sys/src/async_socket_ext.rs | 141 ++++++++++++ netlink-sys/src/lib.rs | 6 + netlink-sys/src/smol.rs | 212 ++++++------------ netlink-sys/src/tokio.rs | 158 +++---------- rtnetlink/examples/ip_monitor.rs | 10 +- rtnetlink/examples/listen.rs | 8 +- 13 files changed, 331 insertions(+), 271 deletions(-) create mode 100644 netlink-sys/src/async_socket.rs create mode 100644 netlink-sys/src/async_socket_ext.rs diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index a3418ef0..24311b25 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -22,7 +22,7 @@ use netlink_packet_core::{ use crate::{ codecs::{NetlinkCodec, NetlinkMessageCodec}, framed::NetlinkFramed, - sys::{Socket, SocketAddr}, + sys::{AsyncSocket, Socket, SocketAddr}, Protocol, Request, Response, diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index 872e1491..df6a5f7f 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -12,7 +12,7 @@ use log::error; use crate::{ codecs::NetlinkMessageCodec, - sys::{Socket, SocketAddr}, + sys::{AsyncSocket, Socket, SocketAddr}, }; use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable}; diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index e2be8b21..5777dd92 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -192,7 +192,7 @@ use std::{fmt::Debug, io}; pub use netlink_packet_core as packet; pub mod sys { - pub use netlink_sys::{protocols, SocketAddr}; + pub use netlink_sys::{protocols, AsyncSocket, AsyncSocketExt, SocketAddr}; #[cfg(feature = "tokio_socket")] pub use netlink_sys::TokioSocket as Socket; diff --git a/netlink-sys/examples/audit_events_async_std.rs b/netlink-sys/examples/audit_events_async_std.rs index db318c32..55227764 100644 --- a/netlink-sys/examples/audit_events_async_std.rs +++ b/netlink-sys/examples/audit_events_async_std.rs @@ -22,7 +22,7 @@ use netlink_packet_audit::{ NLM_F_REQUEST, }; -use netlink_sys::{protocols::NETLINK_AUDIT, SmolSocket, SocketAddr}; +use netlink_sys::{protocols::NETLINK_AUDIT, AsyncSocket, AsyncSocketExt, SmolSocket, SocketAddr}; const AUDIT_STATUS_ENABLED: u32 = 1; const AUDIT_STATUS_PID: u32 = 4; diff --git a/netlink-sys/examples/audit_events_tokio.rs b/netlink-sys/examples/audit_events_tokio.rs index 5f121c45..44df4085 100644 --- a/netlink-sys/examples/audit_events_tokio.rs +++ b/netlink-sys/examples/audit_events_tokio.rs @@ -22,7 +22,7 @@ use netlink_packet_audit::{ NLM_F_REQUEST, }; -use netlink_sys::{protocols::NETLINK_AUDIT, SocketAddr, TokioSocket}; +use netlink_sys::{protocols::NETLINK_AUDIT, AsyncSocket, AsyncSocketExt, SocketAddr, TokioSocket}; const AUDIT_STATUS_ENABLED: u32 = 1; const AUDIT_STATUS_PID: u32 = 4; diff --git a/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs b/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs index 3d7e04a6..f0a0b87b 100644 --- a/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs +++ b/netlink-sys/examples/audit_events_tokio_manual_thread_builder.rs @@ -3,7 +3,7 @@ * to use netlink. */ -use netlink_sys::{protocols::NETLINK_AUDIT, TokioSocket}; +use netlink_sys::{protocols::NETLINK_AUDIT, AsyncSocket, TokioSocket}; fn main() -> Result<(), String> { let rt = tokio::runtime::Builder::new_multi_thread() diff --git a/netlink-sys/src/async_socket.rs b/netlink-sys/src/async_socket.rs new file mode 100644 index 00000000..3e15a0d6 --- /dev/null +++ b/netlink-sys/src/async_socket.rs @@ -0,0 +1,55 @@ +use std::{ + io, + task::{Context, Poll}, +}; + +use crate::{Socket, SocketAddr}; + +/// Trait to support different async backends +pub trait AsyncSocket: Sized + Unpin { + /// Access underyling [`Socket`] + fn socket_ref(&self) -> &Socket; + + /// Mutable access to underyling [`Socket`] + fn socket_mut(&mut self) -> &mut Socket; + + /// Wrapper for [`Socket::new`] + fn new(protocol: isize) -> io::Result; + + /// Polling wrapper for [`Socket::send`] + fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll>; + + /// Polling wrapper for [`Socket::send_to`] + fn poll_send_to( + &mut self, + cx: &mut Context<'_>, + buf: &[u8], + addr: &SocketAddr, + ) -> Poll>; + + /// Polling wrapper for [`Socket::recv`] + /// + /// Passes 0 for flags, and ignores the returned length (the buffer will have advanced by the amount read). + fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + B: bytes::BufMut; + + /// Polling wrapper for [`Socket::recv_from`] + /// + /// Passes 0 for flags, and ignores the returned length - just returns the address (the buffer will have advanced by the amount read). + fn poll_recv_from( + &mut self, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut; + + /// Polling wrapper for [`Socket::recv_from_full`] + /// + /// Passes 0 for flags, and ignores the returned length - just returns the address (the buffer will have advanced by the amount read). + fn poll_recv_from_full( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, SocketAddr)>>; +} diff --git a/netlink-sys/src/async_socket_ext.rs b/netlink-sys/src/async_socket_ext.rs new file mode 100644 index 00000000..8ed86502 --- /dev/null +++ b/netlink-sys/src/async_socket_ext.rs @@ -0,0 +1,141 @@ +use std::{ + future::Future, + io, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{AsyncSocket, SocketAddr}; + +/// Support trait for [`AsyncSocket`] +/// +/// Provides awaitable variants of the poll functions from [`AsyncSocket`]. +pub trait AsyncSocketExt: AsyncSocket { + /// `async fn send(&mut self, buf: &[u8]) -> io::Result` + fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> PollSend<'a, 'b, Self> { + PollSend { socket: self, buf } + } + + /// `async fn send(&mut self, buf: &[u8]) -> io::Result` + fn send_to<'a, 'b>( + &'a mut self, + buf: &'b [u8], + addr: &'b SocketAddr, + ) -> PollSendTo<'a, 'b, Self> { + PollSendTo { + socket: self, + buf, + addr, + } + } + + /// `async fn recv(&mut self, buf: &mut [u8]) -> io::Result<()>` + fn recv<'a, 'b, B>(&'a mut self, buf: &'b mut B) -> PollRecv<'a, 'b, Self, B> + where + B: bytes::BufMut, + { + PollRecv { socket: self, buf } + } + + /// `async fn recv(&mut self, buf: &mut [u8]) -> io::Result` + fn recv_from<'a, 'b, B>(&'a mut self, buf: &'b mut B) -> PollRecvFrom<'a, 'b, Self, B> + where + B: bytes::BufMut, + { + PollRecvFrom { socket: self, buf } + } + + /// `async fn recrecv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)>` + fn recv_from_full(&mut self) -> PollRecvFromFull<'_, Self> { + PollRecvFromFull { socket: self } + } +} + +impl AsyncSocketExt for S {} + +pub struct PollSend<'a, 'b, S> { + socket: &'a mut S, + buf: &'b [u8], +} + +impl Future for PollSend<'_, '_, S> +where + S: AsyncSocket, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_send(cx, this.buf) + } +} + +pub struct PollSendTo<'a, 'b, S> { + socket: &'a mut S, + buf: &'b [u8], + addr: &'b SocketAddr, +} + +impl Future for PollSendTo<'_, '_, S> +where + S: AsyncSocket, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_send_to(cx, this.buf, this.addr) + } +} + +pub struct PollRecv<'a, 'b, S, B> { + socket: &'a mut S, + buf: &'b mut B, +} + +impl Future for PollRecv<'_, '_, S, B> +where + S: AsyncSocket, + B: bytes::BufMut, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_recv(cx, this.buf) + } +} + +pub struct PollRecvFrom<'a, 'b, S, B> { + socket: &'a mut S, + buf: &'b mut B, +} + +impl Future for PollRecvFrom<'_, '_, S, B> +where + S: AsyncSocket, + B: bytes::BufMut, +{ + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_recv_from(cx, this.buf) + } +} + +pub struct PollRecvFromFull<'a, S> { + socket: &'a mut S, +} + +impl Future for PollRecvFromFull<'_, S> +where + S: AsyncSocket, +{ + type Output = io::Result<(Vec, SocketAddr)>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this: &mut Self = Pin::into_inner(self); + this.socket.poll_recv_from_full(cx) + } +} diff --git a/netlink-sys/src/lib.rs b/netlink-sys/src/lib.rs index 69abdeec..6789c969 100644 --- a/netlink-sys/src/lib.rs +++ b/netlink-sys/src/lib.rs @@ -31,6 +31,12 @@ pub use self::socket::Socket; mod addr; pub use self::addr::SocketAddr; +mod async_socket; +pub use self::async_socket::AsyncSocket; + +pub mod async_socket_ext; +pub use self::async_socket_ext::AsyncSocketExt; + #[cfg(feature = "tokio_socket")] mod tokio; #[cfg(feature = "tokio_socket")] diff --git a/netlink-sys/src/smol.rs b/netlink-sys/src/smol.rs index 1647b38a..e3b67d94 100644 --- a/netlink-sys/src/smol.rs +++ b/netlink-sys/src/smol.rs @@ -10,178 +10,112 @@ use futures::ready; use log::trace; -use crate::{Socket, SocketAddr}; +use crate::{AsyncSocket, Socket, SocketAddr}; /// An I/O object representing a Netlink socket. pub struct SmolSocket(Async); -impl SmolSocket { - pub fn new(protocol: isize) -> io::Result { - let socket = Socket::new(protocol)?; - Ok(SmolSocket(Async::new(socket)?)) - } - - pub fn bind(&mut self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_mut().bind(addr) - } - - pub fn bind_auto(&mut self) -> io::Result { - self.0.get_mut().bind_auto() - } - - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_ref().connect(addr) - } - - pub async fn send(&mut self, buf: &[u8]) -> io::Result { - self.0.write_with_mut(|sock| sock.send(buf, 0)).await - } - - pub async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> io::Result { - self.0 - .write_with_mut(|sock| sock.send_to(buf, addr, 0)) - .await - } - - pub async fn recv(&mut self, buf: &mut B) -> io::Result<()> - where - B: bytes::BufMut, - { - self.0 - .read_with_mut(|sock| sock.recv(buf, 0).map(|_len| ())) - .await - } - - pub async fn recv_from(&mut self, buf: &mut B) -> io::Result - where - B: bytes::BufMut, - { - self.0 - .read_with_mut(|sock| sock.recv_from(buf, 0).map(|(_len, addr)| addr)) - .await +impl FromRawFd for SmolSocket { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + let socket = Socket::from_raw_fd(fd); + socket.set_non_blocking(true).unwrap(); + SmolSocket(Async::new(socket).unwrap()) } +} - pub async fn recv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)> { - self.0.read_with_mut(|sock| sock.recv_from_full()).await +impl AsRawFd for SmolSocket { + fn as_raw_fd(&self) -> RawFd { + self.0.get_ref().as_raw_fd() } +} - pub fn poll_recv_from( - &mut self, - cx: &mut Context, - buf: &mut B, - ) -> Poll> +// async_io::Async<..>::{read,write}_with[_mut] functions try IO first, +// and only register context if it would block. +// replicate this in these poll functions: +impl SmolSocket { + fn poll_write_with(&mut self, cx: &mut Context<'_>, mut op: F) -> Poll> where - B: bytes::BufMut, + F: FnMut(&mut Self) -> io::Result, { loop { - trace!("poll_recv_from called"); - let _guard = ready!(self.0.poll_readable(cx))?; - trace!("poll_recv_from socket is ready for reading"); - - match self.0.get_ref().recv_from(buf, 0) { - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - trace!("poll_recv_from socket would block"); - continue; - } - x => { - trace!("poll_recv_from {:?} bytes read", x); - return Poll::Ready(x.map(|(_len, addr)| addr)); - } + match op(self) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), } + // try again if writable now, otherwise come back later: + ready!(self.0.poll_writable(cx))?; } } - pub fn poll_send_to( - &mut self, - cx: &mut Context, - buf: &[u8], - addr: &SocketAddr, - ) -> Poll> { + fn poll_read_with(&mut self, cx: &mut Context<'_>, mut op: F) -> Poll> + where + F: FnMut(&mut Self) -> io::Result, + { loop { - let _guard = ready!(self.0.poll_writable(cx))?; - - match self.0.get_ref().send_to(buf, addr, 0) { - Ok(x) => return Poll::Ready(Ok(x)), - Err(_would_block) => continue, + match op(self) { + Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} + res => return Poll::Ready(res), } + // try again if readable now, otherwise come back later: + ready!(self.0.poll_readable(cx))?; } } +} - pub fn set_pktinfo(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_pktinfo(value) - } - - pub fn get_pktinfo(&self) -> io::Result { - self.0.get_ref().get_pktinfo() - } - - pub fn add_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().add_membership(group) - } - - pub fn drop_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().drop_membership(group) - } - - // pub fn list_membership(&self) -> Vec { - // self.0.get_ref().list_membership() - // } - - /// `NETLINK_BROADCAST_ERROR` (since Linux 2.6.30). When not set, `netlink_broadcast()` only - /// reports `ESRCH` errors and silently ignore `NOBUFS` errors. - pub fn set_broadcast_error(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_broadcast_error(value) - } - - pub fn get_broadcast_error(&self) -> io::Result { - self.0.get_ref().get_broadcast_error() - } - - /// `NETLINK_NO_ENOBUFS` (since Linux 2.6.30). This flag can be used by unicast and broadcast - /// listeners to avoid receiving `ENOBUFS` errors. - pub fn set_no_enobufs(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_no_enobufs(value) +impl AsyncSocket for SmolSocket { + fn socket_ref(&self) -> &Socket { + self.0.get_ref() } - pub fn get_no_enobufs(&self) -> io::Result { - self.0.get_ref().get_no_enobufs() + /// Mutable access to underyling [`Socket`] + fn socket_mut(&mut self) -> &mut Socket { + self.0.get_mut() } - /// `NETLINK_LISTEN_ALL_NSID` (since Linux 4.2). When set, this socket will receive netlink - /// notifications from all network namespaces that have an nsid assigned into the network - /// namespace where the socket has been opened. The nsid is sent to user space via an ancillary - /// data. - pub fn set_listen_all_namespaces(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_listen_all_namespaces(value) + fn new(protocol: isize) -> io::Result { + let socket = Socket::new(protocol)?; + Ok(Self(Async::new(socket)?)) } - pub fn get_listen_all_namespaces(&self) -> io::Result { - self.0.get_ref().get_listen_all_namespaces() + fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.poll_write_with(cx, |this| this.0.get_mut().send(buf, 0)) } - /// `NETLINK_CAP_ACK` (since Linux 4.2). The kernel may fail to allocate the necessary room - /// for the acknowledgment message back to user space. This option trims off the payload of - /// the original netlink message. The netlink message header is still included, so the user can - /// guess from the sequence number which message triggered the acknowledgment. - pub fn set_cap_ack(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_cap_ack(value) + fn poll_send_to( + &mut self, + cx: &mut Context<'_>, + buf: &[u8], + addr: &SocketAddr, + ) -> Poll> { + self.poll_write_with(cx, |this| this.0.get_mut().send_to(buf, addr, 0)) } - pub fn get_cap_ack(&self) -> io::Result { - self.0.get_ref().get_cap_ack() + fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut B) -> Poll> + where + B: bytes::BufMut, + { + self.poll_read_with(cx, |this| this.0.get_mut().recv(buf, 0).map(|_len| ())) } -} -impl FromRawFd for SmolSocket { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - let socket = Socket::from_raw_fd(fd); - socket.set_non_blocking(true).unwrap(); - SmolSocket(Async::new(socket).unwrap()) + fn poll_recv_from( + &mut self, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut, + { + self.poll_read_with(cx, |this| { + let x = this.0.get_mut().recv_from(buf, 0); + trace!("poll_recv_from: {:?}", x); + x.map(|(_len, addr)| addr) + }) } -} -impl AsRawFd for SmolSocket { - fn as_raw_fd(&self) -> RawFd { - self.0.get_ref().as_raw_fd() + fn poll_recv_from_full( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, SocketAddr)>> { + self.poll_read_with(cx, |this| this.0.get_mut().recv_from_full()) } } diff --git a/netlink-sys/src/tokio.rs b/netlink-sys/src/tokio.rs index c4eebae3..9794fc9b 100644 --- a/netlink-sys/src/tokio.rs +++ b/netlink-sys/src/tokio.rs @@ -4,38 +4,47 @@ use std::{ task::{Context, Poll}, }; -use futures::{future::poll_fn, ready}; +use futures::ready; use log::trace; use tokio::io::unix::AsyncFd; -use crate::{Socket, SocketAddr}; +use crate::{AsyncSocket, Socket, SocketAddr}; /// An I/O object representing a Netlink socket. pub struct TokioSocket(AsyncFd); -impl TokioSocket { - /// This function will create a new Netlink socket and attempt to bind it to - /// the `addr` provided. - pub fn bind(&mut self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_mut().bind(addr) +impl FromRawFd for TokioSocket { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + let socket = Socket::from_raw_fd(fd); + socket.set_non_blocking(true).unwrap(); + TokioSocket(AsyncFd::new(socket).unwrap()) } +} - pub fn bind_auto(&mut self) -> io::Result { - self.0.get_mut().bind_auto() +impl AsRawFd for TokioSocket { + fn as_raw_fd(&self) -> RawFd { + self.0.get_ref().as_raw_fd() } +} - pub fn new(protocol: isize) -> io::Result { - let socket = Socket::new(protocol)?; - socket.set_non_blocking(true)?; - Ok(TokioSocket(AsyncFd::new(socket)?)) +impl AsyncSocket for TokioSocket { + fn socket_ref(&self) -> &Socket { + self.0.get_ref() + } + + /// Mutable access to underyling [`Socket`] + fn socket_mut(&mut self) -> &mut Socket { + self.0.get_mut() } - pub fn connect(&self, addr: &SocketAddr) -> io::Result<()> { - self.0.get_ref().connect(addr) + fn new(protocol: isize) -> io::Result { + let socket = Socket::new(protocol)?; + socket.set_non_blocking(true)?; + Ok(Self(AsyncFd::new(socket)?)) } - pub async fn send(&mut self, buf: &[u8]) -> io::Result { - poll_fn(|cx| loop { + fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + loop { // Check if the socket it writable. If // AsyncFd::poll_write_ready returns NotReady, it will // already have arranged for the current task to be @@ -47,17 +56,12 @@ impl TokioSocket { Ok(x) => return Poll::Ready(x), Err(_would_block) => continue, } - }) - .await - } - - pub async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> io::Result { - poll_fn(|cx| self.poll_send_to(cx, buf, addr)).await + } } - pub fn poll_send_to( + fn poll_send_to( &mut self, - cx: &mut Context, + cx: &mut Context<'_>, buf: &[u8], addr: &SocketAddr, ) -> Poll> { @@ -71,11 +75,11 @@ impl TokioSocket { } } - pub async fn recv(&mut self, buf: &mut B) -> io::Result<()> + fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut B) -> Poll> where B: bytes::BufMut, { - poll_fn(|cx| loop { + loop { // Check if the socket is readable. If not, // AsyncFd::poll_read_ready would have arranged for the // current task to be polled again when the socket becomes @@ -86,24 +90,12 @@ impl TokioSocket { Ok(x) => return Poll::Ready(x.map(|_len| ())), Err(_would_block) => continue, } - }) - .await - } - - pub async fn recv_from(&mut self, buf: &mut B) -> io::Result - where - B: bytes::BufMut, - { - poll_fn(|cx| self.poll_recv_from(cx, buf)).await - } - - pub async fn recv_from_full(&mut self) -> io::Result<(Vec, SocketAddr)> { - poll_fn(|cx| self.poll_recv_from_full(cx)).await + } } - pub fn poll_recv_from( + fn poll_recv_from( &mut self, - cx: &mut Context, + cx: &mut Context<'_>, buf: &mut B, ) -> Poll> where @@ -127,9 +119,9 @@ impl TokioSocket { } } - pub fn poll_recv_from_full( + fn poll_recv_from_full( &mut self, - cx: &mut Context, + cx: &mut Context<'_>, ) -> Poll, SocketAddr)>> { loop { trace!("poll_recv_from_full called"); @@ -148,82 +140,4 @@ impl TokioSocket { } } } - - pub fn set_pktinfo(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_pktinfo(value) - } - - pub fn get_pktinfo(&self) -> io::Result { - self.0.get_ref().get_pktinfo() - } - - pub fn add_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().add_membership(group) - } - - pub fn drop_membership(&mut self, group: u32) -> io::Result<()> { - self.0.get_mut().drop_membership(group) - } - - // pub fn list_membership(&self) -> Vec { - // self.0.get_ref().list_membership() - // } - - /// `NETLINK_BROADCAST_ERROR` (since Linux 2.6.30). When not set, `netlink_broadcast()` only - /// reports `ESRCH` errors and silently ignore `NOBUFS` errors. - pub fn set_broadcast_error(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_broadcast_error(value) - } - - pub fn get_broadcast_error(&self) -> io::Result { - self.0.get_ref().get_broadcast_error() - } - - /// `NETLINK_NO_ENOBUFS` (since Linux 2.6.30). This flag can be used by unicast and broadcast - /// listeners to avoid receiving `ENOBUFS` errors. - pub fn set_no_enobufs(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_no_enobufs(value) - } - - pub fn get_no_enobufs(&self) -> io::Result { - self.0.get_ref().get_no_enobufs() - } - - /// `NETLINK_LISTEN_ALL_NSID` (since Linux 4.2). When set, this socket will receive netlink - /// notifications from all network namespaces that have an nsid assigned into the network - /// namespace where the socket has been opened. The nsid is sent to user space via an ancillary - /// data. - pub fn set_listen_all_namespaces(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_listen_all_namespaces(value) - } - - pub fn get_listen_all_namespaces(&self) -> io::Result { - self.0.get_ref().get_listen_all_namespaces() - } - - /// `NETLINK_CAP_ACK` (since Linux 4.2). The kernel may fail to allocate the necessary room - /// for the acknowledgment message back to user space. This option trims off the payload of - /// the original netlink message. The netlink message header is still included, so the user can - /// guess from the sequence number which message triggered the acknowledgment. - pub fn set_cap_ack(&mut self, value: bool) -> io::Result<()> { - self.0.get_mut().set_cap_ack(value) - } - - pub fn get_cap_ack(&self) -> io::Result { - self.0.get_ref().get_cap_ack() - } -} - -impl FromRawFd for TokioSocket { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - let socket = Socket::from_raw_fd(fd); - socket.set_non_blocking(true).unwrap(); - TokioSocket(AsyncFd::new(socket).unwrap()) - } -} - -impl AsRawFd for TokioSocket { - fn as_raw_fd(&self) -> RawFd { - self.0.get_ref().as_raw_fd() - } } diff --git a/rtnetlink/examples/ip_monitor.rs b/rtnetlink/examples/ip_monitor.rs index 7dab1ff5..842abc08 100644 --- a/rtnetlink/examples/ip_monitor.rs +++ b/rtnetlink/examples/ip_monitor.rs @@ -1,7 +1,10 @@ use futures::stream::StreamExt; use netlink_packet_route::constants::*; -use rtnetlink::{new_connection, sys::SocketAddr}; +use rtnetlink::{ + new_connection, + sys::{AsyncSocket, SocketAddr}, +}; #[tokio::main] async fn main() -> Result<(), String> { @@ -31,7 +34,10 @@ async fn main() -> Result<(), String> { | RTNLGRP_MPLS_NETCONF; let addr = SocketAddr::new(0, groups); - conn.socket_mut().bind(&addr).expect("Failed to bind"); + conn.socket_mut() + .socket_mut() + .bind(&addr) + .expect("Failed to bind"); // Spawn `Connection` to start polling netlink socket. tokio::spawn(conn); diff --git a/rtnetlink/examples/listen.rs b/rtnetlink/examples/listen.rs index 06c7a6ee..0066fbee 100644 --- a/rtnetlink/examples/listen.rs +++ b/rtnetlink/examples/listen.rs @@ -6,7 +6,7 @@ use futures::stream::StreamExt; use rtnetlink::{ constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE}, new_connection, - sys::SocketAddr, + sys::{AsyncSocket, SocketAddr}, }; #[tokio::main] @@ -20,7 +20,11 @@ async fn main() -> Result<(), String> { // A netlink socket address is created with said flags. let addr = SocketAddr::new(0, mgroup_flags); // Said address is bound so new conenctions and thus new message broadcasts can be received. - connection.socket_mut().bind(&addr).expect("failed to bind"); + connection + .socket_mut() + .socket_mut() + .bind(&addr) + .expect("failed to bind"); tokio::spawn(connection); while let Some((message, _)) = messages.next().await { From 944307ce292682283891f41db8a0ec4706419664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 14 Nov 2021 19:06:43 +0100 Subject: [PATCH 10/11] Make tokio and smol not conflicting in netlink-proto - This adds a new type paramater to connection for the socket being used. - can drop tokio/smol selection from netlink-packet-audit --- audit/Cargo.toml | 4 ++-- audit/src/lib.rs | 18 +++++++++++++++++- ethtool/src/connection.rs | 17 +++++++++++++++-- ethtool/src/lib.rs | 2 ++ genetlink/src/connection.rs | 18 ++++++++++++++++-- genetlink/src/lib.rs | 2 ++ netlink-packet-audit/Cargo.toml | 5 ----- netlink-proto/src/connection.rs | 21 ++++++++++++++------- netlink-proto/src/framed.rs | 22 ++++++++++++---------- netlink-proto/src/lib.rs | 28 +++++++++++++++++++++++----- rtnetlink/src/connection.rs | 17 +++++++++++++++-- 11 files changed, 118 insertions(+), 36 deletions(-) diff --git a/audit/Cargo.toml b/audit/Cargo.toml index 6f6d896b..d84969c1 100644 --- a/audit/Cargo.toml +++ b/audit/Cargo.toml @@ -19,8 +19,8 @@ netlink-proto = { default-features = false, version = "0.8" } [features] default = ["tokio_socket"] -tokio_socket = ["netlink-proto/tokio_socket", "netlink-packet-audit/tokio_socket"] -smol_socket = ["netlink-proto/smol_socket", "netlink-packet-audit/smol_socket"] +tokio_socket = ["netlink-proto/tokio_socket"] +smol_socket = ["netlink-proto/smol_socket"] [dev-dependencies] tokio = { version = "1.0.1", default-features = false, features = ["macros", "rt-multi-thread"] } diff --git a/audit/src/lib.rs b/audit/src/lib.rs index 91a29b1a..eadb16d7 100644 --- a/audit/src/lib.rs +++ b/audit/src/lib.rs @@ -15,14 +15,30 @@ use std::io; use futures::channel::mpsc::UnboundedReceiver; #[allow(clippy::type_complexity)] +#[cfg(feature = "tokio_socket")] pub fn new_connection() -> io::Result<( - proto::Connection, + proto::Connection, Handle, UnboundedReceiver<( packet::NetlinkMessage, sys::SocketAddr, )>, )> { + new_connection_with_socket() +} + +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + proto::Connection, + Handle, + UnboundedReceiver<( + packet::NetlinkMessage, + sys::SocketAddr, + )>, +)> +where + S: sys::AsyncSocket, +{ let (conn, handle, messages) = netlink_proto::new_connection_with_codec(sys::protocols::NETLINK_AUDIT)?; Ok((conn, Handle::new(handle), messages)) diff --git a/ethtool/src/connection.rs b/ethtool/src/connection.rs index c9d9ef2c..165f31ea 100644 --- a/ethtool/src/connection.rs +++ b/ethtool/src/connection.rs @@ -4,16 +4,29 @@ use futures::channel::mpsc::UnboundedReceiver; use genetlink::message::RawGenlMessage; use netlink_packet_core::NetlinkMessage; use netlink_proto::Connection; -use netlink_sys::SocketAddr; +use netlink_sys::{AsyncSocket, SocketAddr}; use crate::EthtoolHandle; +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( Connection, EthtoolHandle, UnboundedReceiver<(NetlinkMessage, SocketAddr)>, )> { - let (conn, handle, messages) = genetlink::new_connection()?; + new_connection_with_socket() +} + +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + Connection, + EthtoolHandle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> +where + S: AsyncSocket, +{ + let (conn, handle, messages) = genetlink::new_connection_with_socket()?; Ok((conn, EthtoolHandle::new(handle), messages)) } diff --git a/ethtool/src/lib.rs b/ethtool/src/lib.rs index cf986aab..f1a43c2f 100644 --- a/ethtool/src/lib.rs +++ b/ethtool/src/lib.rs @@ -11,7 +11,9 @@ mod pause; mod ring; pub use coalesce::{EthtoolCoalesceAttr, EthtoolCoalesceGetRequest, EthtoolCoalesceHandle}; +#[cfg(feature = "tokio_socket")] pub use connection::new_connection; +pub use connection::new_connection_with_socket; pub use error::EthtoolError; pub use feature::{ EthtoolFeatureAttr, diff --git a/genetlink/src/connection.rs b/genetlink/src/connection.rs index d82af5e7..34eb2623 100644 --- a/genetlink/src/connection.rs +++ b/genetlink/src/connection.rs @@ -3,7 +3,7 @@ use futures::channel::mpsc::UnboundedReceiver; use netlink_packet_core::NetlinkMessage; use netlink_proto::{ self, - sys::{protocols::NETLINK_GENERIC, SocketAddr}, + sys::{protocols::NETLINK_GENERIC, AsyncSocket, SocketAddr}, Connection, }; use std::io; @@ -21,12 +21,26 @@ use std::io; /// /// The [`GenetlinkHandle`] can send and receive any type of generic netlink message. /// And it can automatic resolve the generic family id before sending. +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( Connection, GenetlinkHandle, UnboundedReceiver<(NetlinkMessage, SocketAddr)>, )> { - let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_GENERIC)?; + new_connection_with_socket() +} + +/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + Connection, + GenetlinkHandle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> +where + S: AsyncSocket, +{ + let (conn, handle, messages) = netlink_proto::new_connection_with_socket(NETLINK_GENERIC)?; Ok((conn, GenetlinkHandle::new(handle), messages)) } diff --git a/genetlink/src/lib.rs b/genetlink/src/lib.rs index e145b43f..a3b1e9e1 100644 --- a/genetlink/src/lib.rs +++ b/genetlink/src/lib.rs @@ -7,6 +7,8 @@ mod handle; pub mod message; mod resolver; +#[cfg(feature = "tokio_socket")] pub use connection::new_connection; +pub use connection::new_connection_with_socket; pub use error::GenetlinkError; pub use handle::GenetlinkHandle; diff --git a/netlink-packet-audit/Cargo.toml b/netlink-packet-audit/Cargo.toml index 7a135fab..fabd09d8 100644 --- a/netlink-packet-audit/Cargo.toml +++ b/netlink-packet-audit/Cargo.toml @@ -11,11 +11,6 @@ readme = "../README.md" repository = "https://github.com/little-dude/netlink" description = "netlink packet types" -[features] -default = ["tokio_socket"] -tokio_socket = ["netlink-proto/tokio_socket"] -smol_socket = ["netlink-proto/smol_socket"] - [dependencies] anyhow = "1.0.31" bytes = "1.0" diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index 24311b25..8e022ecd 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -22,21 +22,26 @@ use netlink_packet_core::{ use crate::{ codecs::{NetlinkCodec, NetlinkMessageCodec}, framed::NetlinkFramed, - sys::{AsyncSocket, Socket, SocketAddr}, + sys::{AsyncSocket, SocketAddr}, Protocol, Request, Response, }; +#[cfg(feature = "tokio_socket")] +use netlink_sys::TokioSocket as DefaultSocket; +#[cfg(not(feature = "tokio_socket"))] +type DefaultSocket = (); + /// Connection to a Netlink socket, running in the background. /// /// [`ConnectionHandle`](struct.ConnectionHandle.html) are used to pass new requests to the /// `Connection`, that in turn, sends them through the netlink socket. -pub struct Connection +pub struct Connection where T: Debug + NetlinkSerializable + NetlinkDeserializable, { - socket: NetlinkFramed, + socket: NetlinkFramed, protocol: Protocol>>, @@ -50,9 +55,10 @@ where socket_closed: bool, } -impl Connection +impl Connection where T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, + S: AsyncSocket, C: NetlinkMessageCodec, { pub(crate) fn new( @@ -60,7 +66,7 @@ where unsolicited_messages_tx: UnboundedSender<(NetlinkMessage, SocketAddr)>, protocol: isize, ) -> io::Result { - let socket = Socket::new(protocol)?; + let socket = S::new(protocol)?; Ok(Connection { socket: NetlinkFramed::new(socket), protocol: Protocol::new(), @@ -70,7 +76,7 @@ where }) } - pub fn socket_mut(&mut self) -> &mut Socket { + pub fn socket_mut(&mut self) -> &mut S { self.socket.get_mut() } @@ -251,9 +257,10 @@ where } } -impl Future for Connection +impl Future for Connection where T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin, + S: AsyncSocket, C: NetlinkMessageCodec, { type Output = (); diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index df6a5f7f..b46c1d5b 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -12,12 +12,12 @@ use log::error; use crate::{ codecs::NetlinkMessageCodec, - sys::{AsyncSocket, Socket, SocketAddr}, + sys::{AsyncSocket, SocketAddr}, }; use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable}; -pub struct NetlinkFramed { - socket: Socket, +pub struct NetlinkFramed { + socket: S, // see https://doc.rust-lang.org/nomicon/phantom-data.html // "invariant" seems like the safe choice; using `fn(T) -> T` // should make it invariant but still Send+Sync. @@ -30,9 +30,10 @@ pub struct NetlinkFramed { flushed: bool, } -impl Stream for NetlinkFramed +impl Stream for NetlinkFramed where T: NetlinkDeserializable + Debug, + S: AsyncSocket, C: NetlinkMessageCodec, { type Item = (NetlinkMessage, SocketAddr); @@ -69,9 +70,10 @@ where } } -impl Sink<(NetlinkMessage, SocketAddr)> for NetlinkFramed +impl Sink<(NetlinkMessage, SocketAddr)> for NetlinkFramed where T: NetlinkSerializable + Debug, + S: AsyncSocket, C: NetlinkMessageCodec, { type Error = io::Error; @@ -145,11 +147,11 @@ where const INITIAL_READER_CAPACITY: usize = 64 * 1024; const INITIAL_WRITER_CAPACITY: usize = 8 * 1024; -impl NetlinkFramed { +impl NetlinkFramed { /// Create a new `NetlinkFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. - pub fn new(socket: Socket) -> Self { + pub fn new(socket: S) -> Self { Self { socket, msg_type: PhantomData, @@ -169,7 +171,7 @@ impl NetlinkFramed { /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. - pub fn get_ref(&self) -> &Socket { + pub fn get_ref(&self) -> &S { &self.socket } @@ -181,12 +183,12 @@ impl NetlinkFramed { /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. - pub fn get_mut(&mut self) -> &mut Socket { + pub fn get_mut(&mut self) -> &mut S { &mut self.socket } /// Consumes the `Framed`, returning its underlying I/O stream. - pub fn into_inner(self) -> Socket { + pub fn into_inner(self) -> S { self.socket } } diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index 5777dd92..3425e7a9 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -195,10 +195,10 @@ pub mod sys { pub use netlink_sys::{protocols, AsyncSocket, AsyncSocketExt, SocketAddr}; #[cfg(feature = "tokio_socket")] - pub use netlink_sys::TokioSocket as Socket; + pub use netlink_sys::TokioSocket; #[cfg(feature = "smol_socket")] - pub use netlink_sys::SmolSocket as Socket; + pub use netlink_sys::SmolSocket; } /// Create a new Netlink connection for the given Netlink protocol, and returns a handle to that @@ -218,6 +218,7 @@ pub mod sys { /// handle to send messages. /// /// [protos]: crate::sys::protocols +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection( protocol: isize, @@ -232,17 +233,34 @@ where new_connection_with_codec(protocol) } -/// Variant of [`new_connection`] that allows specifying a separate codec +/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling #[allow(clippy::type_complexity)] -pub fn new_connection_with_codec( +pub fn new_connection_with_socket( protocol: isize, ) -> io::Result<( - Connection, + Connection, ConnectionHandle, UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, )> where T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, + S: sys::AsyncSocket, +{ + new_connection_with_codec(protocol) +} + +/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling and a special codec +#[allow(clippy::type_complexity)] +pub fn new_connection_with_codec( + protocol: isize, +) -> io::Result<( + Connection, + ConnectionHandle, + UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, +)> +where + T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, + S: sys::AsyncSocket, C: NetlinkMessageCodec, { let (requests_tx, requests_rx) = unbounded::>(); diff --git a/rtnetlink/src/connection.rs b/rtnetlink/src/connection.rs index e8cdf731..8f295b46 100644 --- a/rtnetlink/src/connection.rs +++ b/rtnetlink/src/connection.rs @@ -5,16 +5,29 @@ use futures::channel::mpsc::UnboundedReceiver; use crate::{ packet::{NetlinkMessage, RtnlMessage}, proto::Connection, - sys::{protocols::NETLINK_ROUTE, SocketAddr}, + sys::{protocols::NETLINK_ROUTE, AsyncSocket, SocketAddr}, Handle, }; +#[cfg(feature = "tokio_socket")] #[allow(clippy::type_complexity)] pub fn new_connection() -> io::Result<( Connection, Handle, UnboundedReceiver<(NetlinkMessage, SocketAddr)>, )> { - let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_ROUTE)?; + new_connection_with_socket() +} + +#[allow(clippy::type_complexity)] +pub fn new_connection_with_socket() -> io::Result<( + Connection, + Handle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> +where + S: AsyncSocket, +{ + let (conn, handle, messages) = netlink_proto::new_connection_with_socket(NETLINK_ROUTE)?; Ok((conn, Handle::new(handle), messages)) } From aa89fe65545b57f74e693078c028d8e90af55fc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20B=C3=BChler?= Date: Sun, 14 Nov 2021 20:04:18 +0100 Subject: [PATCH 11/11] Get rid of tokio/smol feature conflict in rtnetlink --- rtnetlink/src/ns.rs | 61 ++++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/rtnetlink/src/ns.rs b/rtnetlink/src/ns.rs index 3ca6cc40..52bed9c4 100644 --- a/rtnetlink/src/ns.rs +++ b/rtnetlink/src/ns.rs @@ -1,9 +1,3 @@ -#[cfg(feature = "tokio_socket")] -use tokio::task; - -#[cfg(feature = "smol_socket")] -use async_std::task; - use crate::Error; use nix::{ fcntl::OFlag, @@ -16,6 +10,43 @@ use nix::{ }; use std::{option::Option, path::Path, process::exit}; +// if "only" smol or smol+tokio were enabled, we use smol because +// it doesn't require an active tokio runtime - just to be sure. +#[cfg(feature = "smol_socket")] +async fn try_spawn_blocking(fut: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + async_std::task::spawn_blocking(fut).await +} + +// only tokio enabled, so use tokio +#[cfg(all(not(feature = "smol_socket"), feature = "tokio_socket"))] +async fn try_spawn_blocking(fut: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + match tokio::task::spawn_blocking(fut).await { + Ok(v) => v, + Err(err) => { + std::panic::resume_unwind(err.into_panic()); + } + } +} + +// neither smol nor tokio - just run blocking op directly. +// hopefully not too blocking... +#[cfg(all(not(feature = "smol_socket"), not(feature = "tokio_socket")))] +async fn try_spawn_blocking(fut: F) -> R +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + fut() +} + pub const NETNS_PATH: &str = "/run/netns/"; pub const SELF_NS_PATH: &str = "/proc/self/ns/net"; pub const NONE_FS: &str = "none"; @@ -46,7 +77,7 @@ impl NetworkNamespace { /// Remove a network namespace /// This is equivalent to `ip netns del NS_NAME`. pub async fn del(ns_name: String) -> Result<(), Error> { - let res = task::spawn_blocking(move || { + try_spawn_blocking(move || { let mut netns_path = String::new(); netns_path.push_str(NETNS_PATH); netns_path.push_str(&ns_name); @@ -62,20 +93,10 @@ impl NetworkNamespace { String::from("Namespace file remove failed (are you running as root?)"); return Err(Error::NamespaceError(err_msg)); } - #[cfg(feature = "tokio_socket")] - return Ok(()); - - #[cfg(feature = "smol_socket")] - return Ok(Ok(())); - }); - match res.await { - Ok(r) => r, - Err(e) => { - let err_msg = format!("Namespace removal failed: {}", e); - Err(Error::NamespaceError(err_msg)) - } - } + Ok(()) + }) + .await } pub fn prep_for_fork() -> Result<(), Error> {