From 89ee697cd8e4d8a174c2395441a79fcb914e3dfd Mon Sep 17 00:00:00 2001 From: Leo Date: Wed, 15 Sep 2021 17:21:36 +0800 Subject: [PATCH] Add support for generic netlink Squashed 41 commits from branch 'generic-netlink' Implemented 'netlink-packet-generic' and 'genetlink' to provide generic netlink packet definition and asynchronous connection. --- .github/workflows/main.yml | 10 + Cargo.toml | 6 + genetlink/Cargo.toml | 38 +++ genetlink/examples/dump_family_policy.rs | 60 ++++ genetlink/examples/list_generic_family.rs | 108 +++++++ genetlink/src/connection.rs | 32 ++ genetlink/src/error.rs | 24 ++ genetlink/src/handle.rs | 175 +++++++++++ genetlink/src/lib.rs | 12 + genetlink/src/message.rs | 178 +++++++++++ genetlink/src/resolver.rs | 151 ++++++++++ netlink-packet-generic/Cargo.toml | 21 ++ netlink-packet-generic/LICENSE-MIT | 1 + .../examples/list_generic_family.rs | 111 +++++++ netlink-packet-generic/src/buffer.rs | 37 +++ netlink-packet-generic/src/constants.rs | 70 +++++ netlink-packet-generic/src/ctrl/mod.rs | 137 +++++++++ netlink-packet-generic/src/ctrl/nlas/mcast.rs | 60 ++++ netlink-packet-generic/src/ctrl/nlas/mod.rs | 160 ++++++++++ .../src/ctrl/nlas/oppolicy.rs | 96 ++++++ netlink-packet-generic/src/ctrl/nlas/ops.rs | 57 ++++ .../src/ctrl/nlas/policy.rs | 279 ++++++++++++++++++ netlink-packet-generic/src/header.rs | 32 ++ netlink-packet-generic/src/lib.rs | 81 +++++ netlink-packet-generic/src/message.rs | 184 ++++++++++++ netlink-packet-generic/src/traits.rs | 33 +++ .../tests/query_family_id.rs | 55 ++++ 27 files changed, 2208 insertions(+) create mode 100644 genetlink/Cargo.toml create mode 100644 genetlink/examples/dump_family_policy.rs create mode 100644 genetlink/examples/list_generic_family.rs create mode 100644 genetlink/src/connection.rs create mode 100644 genetlink/src/error.rs create mode 100644 genetlink/src/handle.rs create mode 100644 genetlink/src/lib.rs create mode 100644 genetlink/src/message.rs create mode 100644 genetlink/src/resolver.rs create mode 100644 netlink-packet-generic/Cargo.toml create mode 120000 netlink-packet-generic/LICENSE-MIT create mode 100644 netlink-packet-generic/examples/list_generic_family.rs create mode 100644 netlink-packet-generic/src/buffer.rs create mode 100644 netlink-packet-generic/src/constants.rs create mode 100644 netlink-packet-generic/src/ctrl/mod.rs create mode 100644 netlink-packet-generic/src/ctrl/nlas/mcast.rs create mode 100644 netlink-packet-generic/src/ctrl/nlas/mod.rs create mode 100644 netlink-packet-generic/src/ctrl/nlas/oppolicy.rs create mode 100644 netlink-packet-generic/src/ctrl/nlas/ops.rs create mode 100644 netlink-packet-generic/src/ctrl/nlas/policy.rs create mode 100644 netlink-packet-generic/src/header.rs create mode 100644 netlink-packet-generic/src/lib.rs create mode 100644 netlink-packet-generic/src/message.rs create mode 100644 netlink-packet-generic/src/traits.rs create mode 100644 netlink-packet-generic/tests/query_family_id.rs diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index bd7d93cc..3182f267 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,6 +38,11 @@ jobs: cd netlink-packet-core cargo test + - name: test (netlink-packet-generic) + run: | + cd netlink-packet-generic + cargo test + - name: test (netlink-packet-route) run: | cd netlink-packet-route @@ -71,3 +76,8 @@ jobs: run: | cd audit cargo test + + - name: test (genetlink) + run: | + cd genetlink + cargo test --features tokio_socket diff --git a/Cargo.toml b/Cargo.toml index a83e52b7..02d52891 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,12 +4,14 @@ members = [ "netlink-sys", "netlink-packet-core", "netlink-packet-utils", + "netlink-packet-generic", "netlink-packet-route", "netlink-packet-route/fuzz", "netlink-packet-audit", "netlink-packet-audit/fuzz", "netlink-packet-sock-diag", "netlink-proto", + "genetlink", "rtnetlink", "audit", ] @@ -19,10 +21,12 @@ default-members = [ "netlink-sys", "netlink-packet-core", "netlink-packet-utils", + "netlink-packet-generic", "netlink-packet-route", "netlink-packet-audit", "netlink-packet-sock-diag", "netlink-proto", + "genetlink", "rtnetlink", "audit", ] @@ -31,9 +35,11 @@ default-members = [ netlink-sys = { path = "netlink-sys" } netlink-packet-core = { path = "netlink-packet-core" } netlink-packet-utils = { path = "netlink-packet-utils" } +netlink-packet-generic = { path = "netlink-packet-generic" } netlink-packet-route = { path = "netlink-packet-route" } netlink-packet-audit = { path = "netlink-packet-audit" } netlink-packet-sock-diag = { path = "netlink-packet-sock-diag" } netlink-proto = { path = "netlink-proto" } +genetlink = { path = "genetlink" } rtnetlink = { path = "rtnetlink" } audit = { path = "audit" } diff --git a/genetlink/Cargo.toml b/genetlink/Cargo.toml new file mode 100644 index 00000000..2e15b8c2 --- /dev/null +++ b/genetlink/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "genetlink" +version = "0.1.0" +authors = ["Leo "] +edition = "2018" +homepage = "https://github.com/little-dude/netlink" +repository = "https://github.com/little-dude/netlink" +keywords = ["netlink", "linux"] +license = "MIT" +readme = "../README.md" +description = "communicate with generic netlink" + +[features] +default = ["tokio_socket"] +tokio_socket = ["netlink-proto/tokio_socket","netlink-proto/workaround-audit-bug", "tokio"] +smol_socket = ["netlink-proto/smol_socket","netlink-proto/workaround-audit-bug","async-std"] + +[dependencies] +futures = "0.3.16" +netlink-packet-generic = "0.1.0" +netlink-proto = { default-features = false, version = "0.7.0" } +tokio = { version = "1.9.0", features = ["rt"], optional = true } +async-std = { version = "1.9.0", optional = true } +netlink-packet-utils = "0.4.1" +netlink-packet-core = "0.2.4" +thiserror = "1.0.26" + +[dev-dependencies] +anyhow = "1.0.42" +tokio = { version = "1.9.0", features = ["rt", "rt-multi-thread", "macros"] } + +[[example]] +name = "list_generic_family" +required-features = ["tokio_socket"] + +[[example]] +name = "dump_family_policy" +required-features = ["tokio_socket"] diff --git a/genetlink/examples/dump_family_policy.rs b/genetlink/examples/dump_family_policy.rs new file mode 100644 index 00000000..64fb89a7 --- /dev/null +++ b/genetlink/examples/dump_family_policy.rs @@ -0,0 +1,60 @@ +use std::env::args; + +use anyhow::{bail, Error}; +use futures::StreamExt; +use genetlink::new_connection; +use netlink_packet_core::{ + NetlinkHeader, + NetlinkMessage, + NetlinkPayload, + NLM_F_DUMP, + NLM_F_REQUEST, +}; +use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlMessage, +}; + +#[tokio::main] +async fn main() -> Result<(), Error> { + let argv: Vec<_> = args().collect(); + + if argv.len() < 2 { + eprintln!("Usage: dump_family_policy "); + bail!("Required arguments not given"); + } + + let nlmsg = NetlinkMessage { + header: NetlinkHeader { + flags: NLM_F_REQUEST | NLM_F_DUMP, + ..Default::default() + }, + payload: GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetPolicy, + nlas: vec![GenlCtrlAttrs::FamilyName(argv[1].to_owned())], + }) + .into(), + }; + let (conn, mut handle, _) = new_connection()?; + tokio::spawn(conn); + + let mut responses = handle.request(nlmsg).await?; + + while let Some(result) = responses.next().await { + let resp = result?; + match resp.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + if genlmsg.payload.cmd == GenlCtrlCmd::GetPolicy { + println!("<<< {:?}", genlmsg); + } + } + NetlinkPayload::Error(err) => { + eprintln!("Received a netlink error message: {:?}", err); + bail!(err); + } + _ => {} + } + } + + Ok(()) +} diff --git a/genetlink/examples/list_generic_family.rs b/genetlink/examples/list_generic_family.rs new file mode 100644 index 00000000..5f12998c --- /dev/null +++ b/genetlink/examples/list_generic_family.rs @@ -0,0 +1,108 @@ +//! Example of listing generic families based on `netlink_proto` +//! +//! This example's functionality is same as the identical name example in `netlink_packet_generic`. +//! But this example shows you the usage of this crate to run generic netlink protocol asynchronously. + +use anyhow::{bail, Error}; +use futures::StreamExt; +use genetlink::new_connection; +use netlink_packet_core::{ + NetlinkHeader, + NetlinkMessage, + NetlinkPayload, + NLM_F_DUMP, + NLM_F_REQUEST, +}; +use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlMessage, +}; + +#[tokio::main] +async fn main() -> Result<(), Error> { + let nlmsg = NetlinkMessage { + header: NetlinkHeader { + flags: NLM_F_REQUEST | NLM_F_DUMP, + ..Default::default() + }, + payload: GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![], + }) + .into(), + }; + let (conn, mut handle, _) = new_connection()?; + tokio::spawn(conn); + + let mut responses = handle.request(nlmsg).await?; + + while let Some(result) = responses.next().await { + let resp = result?; + match resp.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + if genlmsg.payload.cmd == GenlCtrlCmd::NewFamily { + print_entry(genlmsg.payload.nlas); + } + } + NetlinkPayload::Error(err) => { + eprintln!("Received a netlink error message: {:?}", err); + bail!(err); + } + _ => {} + } + } + + Ok(()) +} + +fn print_entry(entry: Vec) { + let family_id = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::FamilyId(id) = nla { + Some(*id) + } else { + None + } + }) + .expect("Cannot find FamilyId attribute"); + let family_name = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::FamilyName(name) = nla { + Some(name.as_str()) + } else { + None + } + }) + .expect("Cannot find FamilyName attribute"); + let version = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::Version(ver) = nla { + Some(*ver) + } else { + None + } + }) + .expect("Cannot find Version attribute"); + let hdrsize = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::HdrSize(hdr) = nla { + Some(*hdr) + } else { + None + } + }) + .expect("Cannot find HdrSize attribute"); + + if hdrsize == 0 { + println!("0x{:04x} {} [Version {}]", family_id, family_name, version); + } else { + println!( + "0x{:04x} {} [Version {}] [Header {} bytes]", + family_id, family_name, version, hdrsize + ); + } +} diff --git a/genetlink/src/connection.rs b/genetlink/src/connection.rs new file mode 100644 index 00000000..d82af5e7 --- /dev/null +++ b/genetlink/src/connection.rs @@ -0,0 +1,32 @@ +use crate::{message::RawGenlMessage, GenetlinkHandle}; +use futures::channel::mpsc::UnboundedReceiver; +use netlink_packet_core::NetlinkMessage; +use netlink_proto::{ + self, + sys::{protocols::NETLINK_GENERIC, SocketAddr}, + Connection, +}; +use std::io; + +/// Construct a generic netlink connection +/// +/// The function would return a tuple containing three objects. +/// - an async netlink connection +/// - a connection handle to interact with the connection +/// - a receiver of the unsolicited messages +/// +/// The connection object is also a event loop which implements [`std::future::Future`]. +/// In most cases, users spawn it on an async runtime and use the handle to send +/// messages. For detailed documentation, please refer to [`netlink_proto::new_connection`]. +/// +/// The [`GenetlinkHandle`] can send and receive any type of generic netlink message. +/// And it can automatic resolve the generic family id before sending. +#[allow(clippy::type_complexity)] +pub fn new_connection() -> io::Result<( + Connection, + GenetlinkHandle, + UnboundedReceiver<(NetlinkMessage, SocketAddr)>, +)> { + let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_GENERIC)?; + Ok((conn, GenetlinkHandle::new(handle), messages)) +} diff --git a/genetlink/src/error.rs b/genetlink/src/error.rs new file mode 100644 index 00000000..2e2f54f4 --- /dev/null +++ b/genetlink/src/error.rs @@ -0,0 +1,24 @@ +use crate::message::RawGenlMessage; + +/// Error type of genetlink +#[derive(Debug, Error)] +pub enum GenetlinkError { + #[error("Failed to send netlink request")] + ProtocolError(#[from] netlink_proto::Error), + #[error("Failed to decode generic packet")] + DecodeError(#[from] netlink_packet_utils::DecodeError), + #[error("Netlink error message: {0}")] + NetlinkError(std::io::Error), + #[error("Cannot find specified netlink attribute: {0}")] + AttributeNotFound(String), + #[error("Desire netlink message type not received")] + NoMessageReceived, +} + +// Since `netlink_packet_core::error::ErrorMessage` doesn't impl `Error` trait, +// it need to convert to `std::io::Error` here +impl From for GenetlinkError { + fn from(err_msg: netlink_packet_core::error::ErrorMessage) -> Self { + Self::NetlinkError(err_msg.to_io()) + } +} diff --git a/genetlink/src/handle.rs b/genetlink/src/handle.rs new file mode 100644 index 00000000..1a312913 --- /dev/null +++ b/genetlink/src/handle.rs @@ -0,0 +1,175 @@ +use crate::{ + error::GenetlinkError, + message::{map_from_rawgenlmsg, map_to_rawgenlmsg, RawGenlMessage}, + resolver::Resolver, +}; +use futures::{lock::Mutex, Stream, StreamExt}; +use netlink_packet_core::{DecodeError, NetlinkMessage, NetlinkPayload}; +use netlink_packet_generic::{GenlFamily, GenlHeader, GenlMessage}; +use netlink_packet_utils::{Emitable, ParseableParametrized}; +use netlink_proto::{sys::SocketAddr, ConnectionHandle}; +use std::{fmt::Debug, sync::Arc}; + +/// The generic netlink connection handle +/// +/// The handle is used to send messages to the connection. It also resolves +/// the family id automatically before sending messages. +/// +/// # Family id resolving +/// There is a resolver with cache inside each connection. When you send generic +/// netlink message, the handle resolves and fills the family id into the message. +/// +/// Since the resolver is created in [`new_connection()`](crate::new_connection), +/// the cache state wouldn't share between different connections. +/// +/// P.s. The cloned handles use the same connection with the original handle. So, +/// they share the same cache state. +/// +/// # Detailed process of sending generic messages +/// 1. Check if the message's family id is resolved. If yes, jump to step 6. +/// 2. Query the family id using the builtin resolver. +/// 3. If the id is in the cache, returning the id in the cache and skip step 4. +/// 4. The resolver sends `CTRL_CMD_GETFAMILY` request to get the id and records it in the cache. +/// 5. fill the family id using [`GenlMessage::set_resolved_family_id()`]. +/// 6. Serialize the payload to [`RawGenlMessage`]. +/// 7. Send it through the connection. +/// - The family id filled into `message_type` field in [`NetlinkMessage::finalize()`]. +/// 8. In the response stream, deserialize the payload back to [`GenlMessage`]. +#[derive(Clone, Debug)] +pub struct GenetlinkHandle { + handle: ConnectionHandle, + resolver: Arc>, +} + +impl GenetlinkHandle { + pub(crate) fn new(handle: ConnectionHandle) -> Self { + Self { + handle, + resolver: Arc::new(Mutex::new(Resolver::new())), + } + } + + /// Resolve the family id of the given [`GenlFamily`]. + pub async fn resolve_family_id(&self) -> Result + where + F: GenlFamily, + { + self.resolver + .lock() + .await + .query_family_id(self, F::family_name()) + .await + } + + /// Clear the resolver's fanily id cache + pub async fn clear_family_id_cache(&self) { + self.resolver.lock().await.clear_cache(); + } + + /// Send the generic netlink message and get the response stream + /// + /// The function resolves the family id before sending the request. If the + /// resolving process is failed, the function would return an error. + pub async fn request( + &mut self, + mut message: NetlinkMessage>, + ) -> Result< + impl Stream>, DecodeError>>, + GenetlinkError, + > + where + F: GenlFamily + + Emitable + + ParseableParametrized<[u8], GenlHeader> + + Clone + + Debug + + PartialEq + + Eq, + { + self.resolve_message_family_id(&mut message).await?; + self.send_request(message) + } + + /// Send the request without resolving family id + /// + /// This function is identical to [`request()`](Self::request) but it doesn't + /// resolve the family id for you. + pub fn send_request( + &mut self, + message: NetlinkMessage>, + ) -> Result< + impl Stream>, DecodeError>>, + GenetlinkError, + > + where + F: GenlFamily + + Emitable + + ParseableParametrized<[u8], GenlHeader> + + Clone + + Debug + + PartialEq + + Eq, + { + let raw_msg = map_to_rawgenlmsg(message); + + let stream = self.handle.request(raw_msg, SocketAddr::new(0, 0))?; + Ok(stream.map(map_from_rawgenlmsg)) + } + + /// Send the generic netlink message without returning the response stream + pub async fn notify( + &mut self, + mut message: NetlinkMessage>, + ) -> Result<(), GenetlinkError> + where + F: GenlFamily + + Emitable + + ParseableParametrized<[u8], GenlHeader> + + Clone + + Debug + + PartialEq + + Eq, + { + self.resolve_message_family_id(&mut message).await?; + self.send_notify(message) + } + + /// Send the notify without resolving family id + pub fn send_notify( + &mut self, + message: NetlinkMessage>, + ) -> Result<(), GenetlinkError> + where + F: GenlFamily + + Emitable + + ParseableParametrized<[u8], GenlHeader> + + Clone + + Debug + + PartialEq + + Eq, + { + let raw_msg = map_to_rawgenlmsg(message); + + self.handle.notify(raw_msg, SocketAddr::new(0, 0))?; + Ok(()) + } + + async fn resolve_message_family_id( + &mut self, + message: &mut NetlinkMessage>, + ) -> Result<(), GenetlinkError> + where + F: GenlFamily + Clone + Debug + PartialEq + Eq, + { + if let NetlinkPayload::InnerMessage(genlmsg) = &mut message.payload { + if genlmsg.family_id() == 0 { + // The family id is not resolved + // Resolve it before send it + let id = self.resolve_family_id::().await?; + genlmsg.set_resolved_family_id(id); + } + } + + Ok(()) + } +} diff --git a/genetlink/src/lib.rs b/genetlink/src/lib.rs new file mode 100644 index 00000000..e145b43f --- /dev/null +++ b/genetlink/src/lib.rs @@ -0,0 +1,12 @@ +#[macro_use] +extern crate thiserror; + +mod connection; +mod error; +mod handle; +pub mod message; +mod resolver; + +pub use connection::new_connection; +pub use error::GenetlinkError; +pub use handle::GenetlinkHandle; diff --git a/genetlink/src/message.rs b/genetlink/src/message.rs new file mode 100644 index 00000000..7cc0fcc2 --- /dev/null +++ b/genetlink/src/message.rs @@ -0,0 +1,178 @@ +//! Raw generic netlink payload message +//! +//! # Design +//! Since we use generic type to represent different generic family's message type, +//! and it is not easy to create a underlying [`netlink_proto::new_connection()`] +//! with trait object to multiplex different generic netlink family's message. +//! +//! Therefore, I decided to serialize the generic type payload into bytes before +//! sending to the underlying connection. The [`RawGenlMessage`] is meant for this. +//! +//! This special message doesn't use generic type and its payload is `Vec`. +//! Therefore, its type is easier to use. +//! +//! Another advantage is that it can let users know when the generic netlink payload +//! fails to decode instead of just dropping the messages. +//! (`netlink_proto` would drop messages if they fails to decode.) +//! I think this can help developers debug their deserializing implementation. +use netlink_packet_core::{ + DecodeError, + NetlinkDeserializable, + NetlinkHeader, + NetlinkMessage, + NetlinkPayload, + NetlinkSerializable, +}; +use netlink_packet_generic::{GenlBuffer, GenlFamily, GenlHeader, GenlMessage}; +use netlink_packet_utils::{Emitable, Parseable, ParseableParametrized}; +use std::fmt::Debug; + +/// Message type to hold serialized generic netlink payload +/// +/// **Note** This message type is not intend to be used by normal users, unless +/// you need to use the `UnboundedReceiver<(NetlinkMessage, SocketAddr)>` +/// return by [`new_connection()`](crate::new_connection) +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct RawGenlMessage { + pub header: GenlHeader, + pub payload: Vec, + pub family_id: u16, +} + +impl RawGenlMessage { + /// Construct the message + pub fn new(header: GenlHeader, payload: Vec, family_id: u16) -> Self { + Self { + header, + payload, + family_id, + } + } + + /// Consume this message and return its header and payload + pub fn into_parts(self) -> (GenlHeader, Vec) { + (self.header, self.payload) + } + + /// Serialize the generic netlink payload into raw bytes + pub fn from_genlmsg(genlmsg: GenlMessage) -> Self + where + F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, + { + let mut payload_buf = vec![0u8; genlmsg.payload.buffer_len()]; + genlmsg.payload.emit(&mut payload_buf); + + Self { + header: genlmsg.header, + payload: payload_buf, + family_id: genlmsg.family_id(), + } + } + + /// Try to deserialize the generic netlink payload from raw bytes + pub fn parse_into_genlmsg(&self) -> Result, DecodeError> + where + F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + { + let inner = F::parse_with_param(&self.payload, self.header)?; + Ok(GenlMessage::new(self.header, inner, self.family_id)) + } +} + +impl Emitable for RawGenlMessage { + fn buffer_len(&self) -> usize { + self.header.buffer_len() + self.payload.len() + } + + fn emit(&self, buffer: &mut [u8]) { + self.header.emit(buffer); + + let buffer = &mut buffer[self.header.buffer_len()..]; + buffer.copy_from_slice(&self.payload); + } +} + +impl<'a, T> ParseableParametrized, u16> for RawGenlMessage +where + T: AsRef<[u8]> + ?Sized, +{ + fn parse_with_param(buf: &GenlBuffer<&'a T>, message_type: u16) -> Result { + let header = GenlHeader::parse(buf)?; + let payload_buf = buf.payload(); + Ok(RawGenlMessage::new( + header, + payload_buf.to_vec(), + message_type, + )) + } +} + +impl NetlinkSerializable for RawGenlMessage { + fn message_type(&self) -> u16 { + self.family_id + } + + fn buffer_len(&self) -> usize { + ::buffer_len(self) + } + + fn serialize(&self, buffer: &mut [u8]) { + self.emit(buffer) + } +} + +impl NetlinkDeserializable for RawGenlMessage { + type Error = DecodeError; + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { + let buffer = GenlBuffer::new_checked(payload)?; + RawGenlMessage::parse_with_param(&buffer, header.message_type) + } +} + +impl From for NetlinkPayload { + fn from(message: RawGenlMessage) -> Self { + NetlinkPayload::InnerMessage(message) + } +} + +/// Helper function to map the [`NetlinkPayload`] types in [`NetlinkMessage`] +/// and serialize the generic netlink payload into raw bytes. +pub fn map_to_rawgenlmsg( + message: NetlinkMessage>, +) -> NetlinkMessage +where + F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, +{ + let raw_payload = match message.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + NetlinkPayload::InnerMessage(RawGenlMessage::from_genlmsg(genlmsg)) + } + NetlinkPayload::Done => NetlinkPayload::Done, + NetlinkPayload::Error(i) => NetlinkPayload::Error(i), + NetlinkPayload::Ack(i) => NetlinkPayload::Ack(i), + NetlinkPayload::Noop => NetlinkPayload::Noop, + NetlinkPayload::Overrun(i) => NetlinkPayload::Overrun(i), + }; + NetlinkMessage::new(message.header, raw_payload) +} + +/// Helper function to map the [`NetlinkPayload`] types in [`NetlinkMessage`] +/// and try to deserialize the generic netlink payload from raw bytes. +pub fn map_from_rawgenlmsg( + raw_msg: NetlinkMessage, +) -> Result>, DecodeError> +where + F: GenlFamily + ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, +{ + let payload = match raw_msg.payload { + NetlinkPayload::InnerMessage(raw_genlmsg) => { + NetlinkPayload::InnerMessage(raw_genlmsg.parse_into_genlmsg()?) + } + NetlinkPayload::Done => NetlinkPayload::Done, + NetlinkPayload::Error(i) => NetlinkPayload::Error(i), + NetlinkPayload::Ack(i) => NetlinkPayload::Ack(i), + NetlinkPayload::Noop => NetlinkPayload::Noop, + NetlinkPayload::Overrun(i) => NetlinkPayload::Overrun(i), + }; + Ok(NetlinkMessage::new(raw_msg.header, payload)) +} diff --git a/genetlink/src/resolver.rs b/genetlink/src/resolver.rs new file mode 100644 index 00000000..7c151492 --- /dev/null +++ b/genetlink/src/resolver.rs @@ -0,0 +1,151 @@ +use crate::{error::GenetlinkError, GenetlinkHandle}; +use futures::{future::Either, StreamExt}; +use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST}; +use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlMessage, +}; +use std::{collections::HashMap, future::Future}; + +#[derive(Clone, Debug, Default)] +pub struct Resolver { + cache: HashMap<&'static str, u16>, +} + +impl Resolver { + pub fn new() -> Self { + Self { + cache: HashMap::new(), + } + } + + pub fn get_cache_by_name(&self, family_name: &str) -> Option { + self.cache.get(family_name).copied() + } + + pub fn query_family_id( + &mut self, + handle: &GenetlinkHandle, + family_name: &'static str, + ) -> impl Future> + '_ { + if let Some(id) = self.get_cache_by_name(family_name) { + Either::Left(futures::future::ready(Ok(id))) + } else { + let mut handle = handle.clone(); + Either::Right(async move { + let mut genlmsg: GenlMessage = GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyName(family_name.to_owned())], + }); + genlmsg.finalize(); + // We don't have to set family id here, since nlctrl has static family id (0x10) + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST; + nlmsg.finalize(); + + let mut res = handle.send_request(nlmsg)?; + + while let Some(result) = res.next().await { + let rx_packet = result?; + match rx_packet.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + let family_id = genlmsg + .payload + .nlas + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::FamilyId(id) = nla { + Some(*id) + } else { + None + } + }) + .ok_or_else(|| { + GenetlinkError::AttributeNotFound( + "CTRL_ATTR_FAMILY_ID".to_owned(), + ) + })?; + + self.cache.insert(family_name, family_id); + return Ok(family_id); + } + NetlinkPayload::Error(e) => return Err(e.into()), + _ => (), + } + } + + Err(GenetlinkError::NoMessageReceived) + }) + } + } + + pub fn clear_cache(&mut self) { + self.cache.clear(); + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::new_connection; + use std::io::ErrorKind; + + #[tokio::test] + async fn test_resolver_nlctrl() { + let (conn, handle, _) = new_connection().unwrap(); + tokio::spawn(conn); + + let mut resolver = Resolver::new(); + // nlctrl should always be 0x10 + let nlctrl_fid = resolver.query_family_id(&handle, "nlctrl").await.unwrap(); + assert_eq!(nlctrl_fid, 0x10); + } + + const TEST_FAMILIES: &[&str] = &[ + "devlink", + "ethtool", + "acpi_event", + "tcp_metrics", + "TASKSTATS", + "nl80211", + ]; + + #[tokio::test] + async fn test_resolver_cache() { + let (conn, handle, _) = new_connection().unwrap(); + tokio::spawn(conn); + + let mut resolver = Resolver::new(); + + // Test if family id cached + for name in TEST_FAMILIES.iter().copied() { + let id = resolver + .query_family_id(&handle, name) + .await + .or_else(|e| { + if let GenetlinkError::NetlinkError(io_err) = &e { + if io_err.kind() == ErrorKind::NotFound { + // Ignore non exist entries + Ok(0) + } else { + Err(e) + } + } else { + Err(e) + } + }) + .unwrap(); + if id == 0 { + eprintln!( + "Generic family \"{}\" not exist or not loaded in this environment. Ignored.", + name + ); + continue; + } + + let cache = resolver.get_cache_by_name(name).unwrap(); + assert_eq!(id, cache); + eprintln!("{:?}", (name, cache)); + } + } +} diff --git a/netlink-packet-generic/Cargo.toml b/netlink-packet-generic/Cargo.toml new file mode 100644 index 00000000..b92a5d95 --- /dev/null +++ b/netlink-packet-generic/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "netlink-packet-generic" +version = "0.1.0" +authors = ["Leo "] +edition = "2018" +homepage = "https://github.com/little-dude/netlink" +repository = "https://github.com/little-dude/netlink" +keywords = ["netlink", "linux"] +license = "MIT" +readme = "../README.md" +description = "generic netlink packet types" + +[dependencies] +anyhow = "1.0.39" +libc = "0.2.86" +byteorder = "1.4.2" +netlink-packet-core = "0.2" +netlink-packet-utils = "0.4" + +[dev-dependencies] +netlink-sys = { path = "../netlink-sys", version = "0.7" } diff --git a/netlink-packet-generic/LICENSE-MIT b/netlink-packet-generic/LICENSE-MIT new file mode 120000 index 00000000..76219eb7 --- /dev/null +++ b/netlink-packet-generic/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/netlink-packet-generic/examples/list_generic_family.rs b/netlink-packet-generic/examples/list_generic_family.rs new file mode 100644 index 00000000..00cb55a4 --- /dev/null +++ b/netlink-packet-generic/examples/list_generic_family.rs @@ -0,0 +1,111 @@ +use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_DUMP, NLM_F_REQUEST}; +use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlMessage, +}; +use netlink_sys::{protocols::NETLINK_GENERIC, Socket, SocketAddr}; + +fn main() { + let mut socket = Socket::new(NETLINK_GENERIC).unwrap(); + socket.bind_auto().unwrap(); + socket.connect(&SocketAddr::new(0, 0)).unwrap(); + + let mut genlmsg = GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![], + }); + genlmsg.finalize(); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST | NLM_F_DUMP; + nlmsg.finalize(); + + let mut txbuf = vec![0u8; nlmsg.buffer_len()]; + nlmsg.serialize(&mut txbuf); + + socket.send(&txbuf, 0).unwrap(); + + let mut rxbuf = vec![0u8; 4096]; + let mut offset = 0; + + 'outer: loop { + let size = socket.recv(&mut rxbuf, 0).unwrap(); + + loop { + let buf = &rxbuf[offset..]; + // Parse the message + let msg = >>::deserialize(buf).unwrap(); + + match msg.payload { + NetlinkPayload::Done => break 'outer, + NetlinkPayload::InnerMessage(genlmsg) => { + if GenlCtrlCmd::NewFamily == genlmsg.payload.cmd { + print_entry(genlmsg.payload.nlas); + } + } + NetlinkPayload::Error(err) => { + eprintln!("Received a netlink error message: {:?}", err); + return; + } + _ => {} + } + + offset += msg.header.length as usize; + if offset == size || msg.header.length == 0 { + offset = 0; + break; + } + } + } +} + +fn print_entry(entry: Vec) { + let family_id = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::FamilyId(id) = nla { + Some(*id) + } else { + None + } + }) + .expect("Cannot find FamilyId attribute"); + let family_name = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::FamilyName(name) = nla { + Some(name.as_str()) + } else { + None + } + }) + .expect("Cannot find FamilyName attribute"); + let version = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::Version(ver) = nla { + Some(*ver) + } else { + None + } + }) + .expect("Cannot find Version attribute"); + let hdrsize = entry + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::HdrSize(hdr) = nla { + Some(*hdr) + } else { + None + } + }) + .expect("Cannot find HdrSize attribute"); + + if hdrsize == 0 { + println!("0x{:04x} {} [Version {}]", family_id, family_name, version); + } else { + println!( + "0x{:04x} {} [Version {}] [Header {} bytes]", + family_id, family_name, version, hdrsize + ); + } +} diff --git a/netlink-packet-generic/src/buffer.rs b/netlink-packet-generic/src/buffer.rs new file mode 100644 index 00000000..7bd29d24 --- /dev/null +++ b/netlink-packet-generic/src/buffer.rs @@ -0,0 +1,37 @@ +//! Buffer definition of generic netlink packet +use crate::{constants::GENL_HDRLEN, header::GenlHeader, message::GenlMessage}; +use netlink_packet_core::DecodeError; +use netlink_packet_utils::{Parseable, ParseableParametrized}; +use std::fmt::Debug; + +buffer!(GenlBuffer(GENL_HDRLEN) { + cmd: (u8, 0), + version: (u8, 1), + payload: (slice, GENL_HDRLEN..), +}); + +impl ParseableParametrized<[u8], u16> for GenlMessage +where + F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, +{ + fn parse_with_param(buf: &[u8], message_type: u16) -> Result { + let buf = GenlBuffer::new_checked(buf)?; + Self::parse_with_param(&buf, message_type) + } +} + +impl<'a, F, T> ParseableParametrized, u16> for GenlMessage +where + F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, + T: AsRef<[u8]> + ?Sized, +{ + fn parse_with_param(buf: &GenlBuffer<&'a T>, message_type: u16) -> Result { + let header = GenlHeader::parse(buf)?; + let payload_buf = buf.payload(); + Ok(GenlMessage::new( + header, + F::parse_with_param(payload_buf, header)?, + message_type, + )) + } +} diff --git a/netlink-packet-generic/src/constants.rs b/netlink-packet-generic/src/constants.rs new file mode 100644 index 00000000..03e85748 --- /dev/null +++ b/netlink-packet-generic/src/constants.rs @@ -0,0 +1,70 @@ +//! Define constants related to generic netlink +pub const GENL_ID_CTRL: u16 = libc::GENL_ID_CTRL as u16; +pub const GENL_HDRLEN: usize = 4; + +pub const CTRL_CMD_UNSPEC: u8 = libc::CTRL_CMD_UNSPEC as u8; +pub const CTRL_CMD_NEWFAMILY: u8 = libc::CTRL_CMD_NEWFAMILY as u8; +pub const CTRL_CMD_DELFAMILY: u8 = libc::CTRL_CMD_DELFAMILY as u8; +pub const CTRL_CMD_GETFAMILY: u8 = libc::CTRL_CMD_GETFAMILY as u8; +pub const CTRL_CMD_NEWOPS: u8 = libc::CTRL_CMD_NEWOPS as u8; +pub const CTRL_CMD_DELOPS: u8 = libc::CTRL_CMD_DELOPS as u8; +pub const CTRL_CMD_GETOPS: u8 = libc::CTRL_CMD_GETOPS as u8; +pub const CTRL_CMD_NEWMCAST_GRP: u8 = libc::CTRL_CMD_NEWMCAST_GRP as u8; +pub const CTRL_CMD_DELMCAST_GRP: u8 = libc::CTRL_CMD_DELMCAST_GRP as u8; +pub const CTRL_CMD_GETMCAST_GRP: u8 = libc::CTRL_CMD_GETMCAST_GRP as u8; +pub const CTRL_CMD_GETPOLICY: u8 = 10; + +pub const CTRL_ATTR_UNSPEC: u16 = libc::CTRL_ATTR_UNSPEC as u16; +pub const CTRL_ATTR_FAMILY_ID: u16 = libc::CTRL_ATTR_FAMILY_ID as u16; +pub const CTRL_ATTR_FAMILY_NAME: u16 = libc::CTRL_ATTR_FAMILY_NAME as u16; +pub const CTRL_ATTR_VERSION: u16 = libc::CTRL_ATTR_VERSION as u16; +pub const CTRL_ATTR_HDRSIZE: u16 = libc::CTRL_ATTR_HDRSIZE as u16; +pub const CTRL_ATTR_MAXATTR: u16 = libc::CTRL_ATTR_MAXATTR as u16; +pub const CTRL_ATTR_OPS: u16 = libc::CTRL_ATTR_OPS as u16; +pub const CTRL_ATTR_MCAST_GROUPS: u16 = libc::CTRL_ATTR_MCAST_GROUPS as u16; +pub const CTRL_ATTR_POLICY: u16 = 8; +pub const CTRL_ATTR_OP_POLICY: u16 = 9; +pub const CTRL_ATTR_OP: u16 = 10; + +pub const CTRL_ATTR_OP_UNSPEC: u16 = libc::CTRL_ATTR_OP_UNSPEC as u16; +pub const CTRL_ATTR_OP_ID: u16 = libc::CTRL_ATTR_OP_ID as u16; +pub const CTRL_ATTR_OP_FLAGS: u16 = libc::CTRL_ATTR_OP_FLAGS as u16; + +pub const CTRL_ATTR_MCAST_GRP_UNSPEC: u16 = libc::CTRL_ATTR_MCAST_GRP_UNSPEC as u16; +pub const CTRL_ATTR_MCAST_GRP_NAME: u16 = libc::CTRL_ATTR_MCAST_GRP_NAME as u16; +pub const CTRL_ATTR_MCAST_GRP_ID: u16 = libc::CTRL_ATTR_MCAST_GRP_ID as u16; + +pub const CTRL_ATTR_POLICY_UNSPEC: u16 = 0; +pub const CTRL_ATTR_POLICY_DO: u16 = 1; +pub const CTRL_ATTR_POLICY_DUMP: u16 = 2; + +pub const NL_ATTR_TYPE_INVALID: u32 = 0; +pub const NL_ATTR_TYPE_FLAG: u32 = 1; +pub const NL_ATTR_TYPE_U8: u32 = 2; +pub const NL_ATTR_TYPE_U16: u32 = 3; +pub const NL_ATTR_TYPE_U32: u32 = 4; +pub const NL_ATTR_TYPE_U64: u32 = 5; +pub const NL_ATTR_TYPE_S8: u32 = 6; +pub const NL_ATTR_TYPE_S16: u32 = 7; +pub const NL_ATTR_TYPE_S32: u32 = 8; +pub const NL_ATTR_TYPE_S64: u32 = 9; +pub const NL_ATTR_TYPE_BINARY: u32 = 10; +pub const NL_ATTR_TYPE_STRING: u32 = 11; +pub const NL_ATTR_TYPE_NUL_STRING: u32 = 12; +pub const NL_ATTR_TYPE_NESTED: u32 = 13; +pub const NL_ATTR_TYPE_NESTED_ARRAY: u32 = 14; +pub const NL_ATTR_TYPE_BITFIELD32: u32 = 15; + +pub const NL_POLICY_TYPE_ATTR_UNSPEC: u16 = 0; +pub const NL_POLICY_TYPE_ATTR_TYPE: u16 = 1; +pub const NL_POLICY_TYPE_ATTR_MIN_VALUE_S: u16 = 2; +pub const NL_POLICY_TYPE_ATTR_MAX_VALUE_S: u16 = 3; +pub const NL_POLICY_TYPE_ATTR_MIN_VALUE_U: u16 = 4; +pub const NL_POLICY_TYPE_ATTR_MAX_VALUE_U: u16 = 5; +pub const NL_POLICY_TYPE_ATTR_MIN_LENGTH: u16 = 6; +pub const NL_POLICY_TYPE_ATTR_MAX_LENGTH: u16 = 7; +pub const NL_POLICY_TYPE_ATTR_POLICY_IDX: u16 = 8; +pub const NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE: u16 = 9; +pub const NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: u16 = 10; +pub const NL_POLICY_TYPE_ATTR_PAD: u16 = 11; +pub const NL_POLICY_TYPE_ATTR_MASK: u16 = 12; diff --git a/netlink-packet-generic/src/ctrl/mod.rs b/netlink-packet-generic/src/ctrl/mod.rs new file mode 100644 index 00000000..71b0132a --- /dev/null +++ b/netlink-packet-generic/src/ctrl/mod.rs @@ -0,0 +1,137 @@ +//! Generic netlink controller implementation +//! +//! This module provides the definition of the controller packet. +//! It also serves as an example for creating a generic family. + +use self::nlas::*; +use crate::{constants::*, traits::*, GenlHeader}; +use anyhow::Context; +use netlink_packet_utils::{nla::NlasIterator, traits::*, DecodeError}; +use std::convert::{TryFrom, TryInto}; + +/// Netlink attributes for this family +pub mod nlas; + +/// Command code definition of Netlink controller (nlctrl) family +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum GenlCtrlCmd { + /// Notify from event + NewFamily, + /// Notify from event + DelFamily, + /// Request to get family info + GetFamily, + /// Currently unused + NewOps, + /// Currently unused + DelOps, + /// Currently unused + GetOps, + /// Notify from event + NewMcastGrp, + /// Notify from event + DelMcastGrp, + /// Currently unused + GetMcastGrp, + /// Request to get family policy + GetPolicy, +} + +impl From for u8 { + fn from(cmd: GenlCtrlCmd) -> u8 { + use GenlCtrlCmd::*; + match cmd { + NewFamily => CTRL_CMD_NEWFAMILY, + DelFamily => CTRL_CMD_DELFAMILY, + GetFamily => CTRL_CMD_GETFAMILY, + NewOps => CTRL_CMD_NEWOPS, + DelOps => CTRL_CMD_DELOPS, + GetOps => CTRL_CMD_GETOPS, + NewMcastGrp => CTRL_CMD_NEWMCAST_GRP, + DelMcastGrp => CTRL_CMD_DELMCAST_GRP, + GetMcastGrp => CTRL_CMD_GETMCAST_GRP, + GetPolicy => CTRL_CMD_GETPOLICY, + } + } +} + +impl TryFrom for GenlCtrlCmd { + type Error = DecodeError; + + fn try_from(value: u8) -> Result { + use GenlCtrlCmd::*; + Ok(match value { + CTRL_CMD_NEWFAMILY => NewFamily, + CTRL_CMD_DELFAMILY => DelFamily, + CTRL_CMD_GETFAMILY => GetFamily, + CTRL_CMD_NEWOPS => NewOps, + CTRL_CMD_DELOPS => DelOps, + CTRL_CMD_GETOPS => GetOps, + CTRL_CMD_NEWMCAST_GRP => NewMcastGrp, + CTRL_CMD_DELMCAST_GRP => DelMcastGrp, + CTRL_CMD_GETMCAST_GRP => GetMcastGrp, + CTRL_CMD_GETPOLICY => GetPolicy, + cmd => { + return Err(DecodeError::from(format!( + "Unknown control command: {}", + cmd + ))) + } + }) + } +} + +/// Payload of generic netlink controller +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GenlCtrl { + /// Command code of this message + pub cmd: GenlCtrlCmd, + /// Netlink attributes in this message + pub nlas: Vec, +} + +impl GenlFamily for GenlCtrl { + fn family_name() -> &'static str { + "nlctrl" + } + + fn family_id(&self) -> u16 { + GENL_ID_CTRL + } + + fn command(&self) -> u8 { + self.cmd.into() + } + + fn version(&self) -> u8 { + 2 + } +} + +impl Emitable for GenlCtrl { + fn emit(&self, buffer: &mut [u8]) { + self.nlas.as_slice().emit(buffer) + } + + fn buffer_len(&self) -> usize { + self.nlas.as_slice().buffer_len() + } +} + +impl ParseableParametrized<[u8], GenlHeader> for GenlCtrl { + fn parse_with_param(buf: &[u8], header: GenlHeader) -> Result { + Ok(Self { + cmd: header.cmd.try_into()?, + nlas: parse_ctrlnlas(buf)?, + }) + } +} + +fn parse_ctrlnlas(buf: &[u8]) -> Result, DecodeError> { + let nlas = NlasIterator::new(buf) + .map(|nla| nla.and_then(|nla| GenlCtrlAttrs::parse(&nla))) + .collect::, _>>() + .context("failed to parse control message attributes")?; + + Ok(nlas) +} diff --git a/netlink-packet-generic/src/ctrl/nlas/mcast.rs b/netlink-packet-generic/src/ctrl/nlas/mcast.rs new file mode 100644 index 00000000..0ddee4f2 --- /dev/null +++ b/netlink-packet-generic/src/ctrl/nlas/mcast.rs @@ -0,0 +1,60 @@ +use crate::constants::*; +use anyhow::Context; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer}, + parsers::*, + traits::*, + DecodeError, +}; +use std::mem::size_of_val; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum McastGrpAttrs { + Name(String), + Id(u32), +} + +impl Nla for McastGrpAttrs { + fn value_len(&self) -> usize { + use McastGrpAttrs::*; + match self { + Name(s) => s.as_bytes().len() + 1, + Id(v) => size_of_val(v), + } + } + + fn kind(&self) -> u16 { + use McastGrpAttrs::*; + match self { + Name(_) => CTRL_ATTR_MCAST_GRP_NAME, + Id(_) => CTRL_ATTR_MCAST_GRP_ID, + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + use McastGrpAttrs::*; + match self { + Name(s) => { + buffer[..s.len()].copy_from_slice(s.as_bytes()); + buffer[s.len()] = 0; + } + Id(v) => NativeEndian::write_u32(buffer, *v), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for McastGrpAttrs { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + Ok(match buf.kind() { + CTRL_ATTR_MCAST_GRP_NAME => { + Self::Name(parse_string(payload).context("invalid CTRL_ATTR_MCAST_GRP_NAME value")?) + } + CTRL_ATTR_MCAST_GRP_ID => { + Self::Id(parse_u32(payload).context("invalid CTRL_ATTR_MCAST_GRP_ID value")?) + } + kind => return Err(DecodeError::from(format!("Unknown NLA type: {}", kind))), + }) + } +} diff --git a/netlink-packet-generic/src/ctrl/nlas/mod.rs b/netlink-packet-generic/src/ctrl/nlas/mod.rs new file mode 100644 index 00000000..cf5c6d44 --- /dev/null +++ b/netlink-packet-generic/src/ctrl/nlas/mod.rs @@ -0,0 +1,160 @@ +use crate::constants::*; +use anyhow::Context; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator}, + parsers::*, + traits::*, + DecodeError, +}; +use std::mem::size_of_val; + +mod mcast; +mod oppolicy; +mod ops; +mod policy; + +pub use mcast::*; +pub use oppolicy::*; +pub use ops::*; +pub use policy::*; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum GenlCtrlAttrs { + FamilyId(u16), + FamilyName(String), + Version(u32), + HdrSize(u32), + MaxAttr(u32), + Ops(Vec>), + McastGroups(Vec>), + Policy(PolicyAttr), + OpPolicy(OppolicyAttr), + Op(u32), +} + +impl Nla for GenlCtrlAttrs { + fn value_len(&self) -> usize { + use GenlCtrlAttrs::*; + match self { + FamilyId(v) => size_of_val(v), + FamilyName(s) => s.len() + 1, + Version(v) => size_of_val(v), + HdrSize(v) => size_of_val(v), + MaxAttr(v) => size_of_val(v), + Ops(nlas) => nlas.iter().map(|op| op.as_slice().buffer_len()).sum(), + McastGroups(nlas) => nlas.iter().map(|op| op.as_slice().buffer_len()).sum(), + Policy(nla) => nla.buffer_len(), + OpPolicy(nla) => nla.buffer_len(), + Op(v) => size_of_val(v), + } + } + + fn kind(&self) -> u16 { + use GenlCtrlAttrs::*; + match self { + FamilyId(_) => CTRL_ATTR_FAMILY_ID, + FamilyName(_) => CTRL_ATTR_FAMILY_NAME, + Version(_) => CTRL_ATTR_VERSION, + HdrSize(_) => CTRL_ATTR_HDRSIZE, + MaxAttr(_) => CTRL_ATTR_MAXATTR, + Ops(_) => CTRL_ATTR_OPS, + McastGroups(_) => CTRL_ATTR_MCAST_GROUPS, + Policy(_) => CTRL_ATTR_POLICY, + OpPolicy(_) => CTRL_ATTR_OP_POLICY, + Op(_) => CTRL_ATTR_OP, + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + use GenlCtrlAttrs::*; + match self { + FamilyId(v) => NativeEndian::write_u16(buffer, *v), + FamilyName(s) => { + buffer[..s.len()].copy_from_slice(s.as_bytes()); + buffer[s.len()] = 0; + } + Version(v) => NativeEndian::write_u32(buffer, *v), + HdrSize(v) => NativeEndian::write_u32(buffer, *v), + MaxAttr(v) => NativeEndian::write_u32(buffer, *v), + Ops(nlas) => { + let mut len = 0; + for op in nlas { + op.as_slice().emit(&mut buffer[len..]); + len += op.as_slice().buffer_len(); + } + } + McastGroups(nlas) => { + let mut len = 0; + for op in nlas { + op.as_slice().emit(&mut buffer[len..]); + len += op.as_slice().buffer_len(); + } + } + Policy(nla) => nla.emit_value(buffer), + OpPolicy(nla) => nla.emit_value(buffer), + Op(v) => NativeEndian::write_u32(buffer, *v), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for GenlCtrlAttrs { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + Ok(match buf.kind() { + CTRL_ATTR_FAMILY_ID => { + Self::FamilyId(parse_u16(payload).context("invalid CTRL_ATTR_FAMILY_ID value")?) + } + CTRL_ATTR_FAMILY_NAME => Self::FamilyName( + parse_string(payload).context("invalid CTRL_ATTR_FAMILY_NAME value")?, + ), + CTRL_ATTR_VERSION => { + Self::Version(parse_u32(payload).context("invalid CTRL_ATTR_VERSION value")?) + } + CTRL_ATTR_HDRSIZE => { + Self::HdrSize(parse_u32(payload).context("invalid CTRL_ATTR_HDRSIZE value")?) + } + CTRL_ATTR_MAXATTR => { + Self::MaxAttr(parse_u32(payload).context("invalid CTRL_ATTR_MAXATTR value")?) + } + CTRL_ATTR_OPS => { + let ops = NlasIterator::new(payload) + .map(|nlas| { + nlas.and_then(|nlas| { + NlasIterator::new(nlas.value()) + .map(|nla| nla.and_then(|nla| OpAttrs::parse(&nla))) + .collect::, _>>() + }) + }) + .collect::>, _>>() + .context("failed to parse CTRL_ATTR_OPS")?; + + Self::Ops(ops) + } + CTRL_ATTR_MCAST_GROUPS => { + let groups = NlasIterator::new(payload) + .map(|nlas| { + nlas.and_then(|nlas| { + NlasIterator::new(nlas.value()) + .map(|nla| nla.and_then(|nla| McastGrpAttrs::parse(&nla))) + .collect::, _>>() + }) + }) + .collect::>, _>>() + .context("failed to parse CTRL_ATTR_MCAST_GROUPS")?; + + Self::McastGroups(groups) + } + CTRL_ATTR_POLICY => Self::Policy( + PolicyAttr::parse(&NlaBuffer::new(payload)) + .context("failed to parse CTRL_ATTR_POLICY")?, + ), + CTRL_ATTR_OP_POLICY => Self::OpPolicy( + OppolicyAttr::parse(&NlaBuffer::new(payload)) + .context("failed to parse CTRL_ATTR_OP_POLICY")?, + ), + CTRL_ATTR_OP => Self::Op(parse_u32(payload)?), + kind => return Err(DecodeError::from(format!("Unknown NLA type: {}", kind))), + }) + } +} diff --git a/netlink-packet-generic/src/ctrl/nlas/oppolicy.rs b/netlink-packet-generic/src/ctrl/nlas/oppolicy.rs new file mode 100644 index 00000000..a19c8d1c --- /dev/null +++ b/netlink-packet-generic/src/ctrl/nlas/oppolicy.rs @@ -0,0 +1,96 @@ +use crate::constants::*; +use anyhow::Context; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator}, + parsers::*, + traits::*, + DecodeError, +}; +use std::mem::size_of_val; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct OppolicyAttr { + pub cmd: u8, + pub policy_idx: Vec, +} + +impl Nla for OppolicyAttr { + fn value_len(&self) -> usize { + self.policy_idx.as_slice().buffer_len() + } + + fn kind(&self) -> u16 { + self.cmd as u16 + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.policy_idx.as_slice().emit(buffer); + } + + fn is_nested(&self) -> bool { + true + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for OppolicyAttr { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + let policy_idx = NlasIterator::new(payload) + .map(|nla| nla.and_then(|nla| OppolicyIndexAttr::parse(&nla))) + .collect::, _>>() + .context("failed to parse OppolicyAttr")?; + + Ok(Self { + cmd: buf.kind() as u8, + policy_idx, + }) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OppolicyIndexAttr { + Do(u32), + Dump(u32), +} + +impl Nla for OppolicyIndexAttr { + fn value_len(&self) -> usize { + use OppolicyIndexAttr::*; + match self { + Do(v) => size_of_val(v), + Dump(v) => size_of_val(v), + } + } + + fn kind(&self) -> u16 { + use OppolicyIndexAttr::*; + match self { + Do(_) => CTRL_ATTR_POLICY_DO, + Dump(_) => CTRL_ATTR_POLICY_DUMP, + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + use OppolicyIndexAttr::*; + match self { + Do(v) => NativeEndian::write_u32(buffer, *v), + Dump(v) => NativeEndian::write_u32(buffer, *v), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for OppolicyIndexAttr { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + Ok(match buf.kind() { + CTRL_ATTR_POLICY_DO => { + Self::Do(parse_u32(payload).context("invalid CTRL_ATTR_POLICY_DO value")?) + } + CTRL_ATTR_POLICY_DUMP => { + Self::Dump(parse_u32(payload).context("invalid CTRL_ATTR_POLICY_DUMP value")?) + } + kind => return Err(DecodeError::from(format!("Unknown NLA type: {}", kind))), + }) + } +} diff --git a/netlink-packet-generic/src/ctrl/nlas/ops.rs b/netlink-packet-generic/src/ctrl/nlas/ops.rs new file mode 100644 index 00000000..03191f2a --- /dev/null +++ b/netlink-packet-generic/src/ctrl/nlas/ops.rs @@ -0,0 +1,57 @@ +use crate::constants::*; +use anyhow::Context; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer}, + parsers::*, + traits::*, + DecodeError, +}; +use std::mem::size_of_val; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum OpAttrs { + Id(u32), + Flags(u32), +} + +impl Nla for OpAttrs { + fn value_len(&self) -> usize { + use OpAttrs::*; + match self { + Id(v) => size_of_val(v), + Flags(v) => size_of_val(v), + } + } + + fn kind(&self) -> u16 { + use OpAttrs::*; + match self { + Id(_) => CTRL_ATTR_OP_ID, + Flags(_) => CTRL_ATTR_OP_FLAGS, + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + use OpAttrs::*; + match self { + Id(v) => NativeEndian::write_u32(buffer, *v), + Flags(v) => NativeEndian::write_u32(buffer, *v), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for OpAttrs { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + Ok(match buf.kind() { + CTRL_ATTR_OP_ID => { + Self::Id(parse_u32(payload).context("invalid CTRL_ATTR_OP_ID value")?) + } + CTRL_ATTR_OP_FLAGS => { + Self::Flags(parse_u32(payload).context("invalid CTRL_ATTR_OP_FLAGS value")?) + } + kind => return Err(DecodeError::from(format!("Unknown NLA type: {}", kind))), + }) + } +} diff --git a/netlink-packet-generic/src/ctrl/nlas/policy.rs b/netlink-packet-generic/src/ctrl/nlas/policy.rs new file mode 100644 index 00000000..4fafd85f --- /dev/null +++ b/netlink-packet-generic/src/ctrl/nlas/policy.rs @@ -0,0 +1,279 @@ +use crate::constants::*; +use anyhow::Context; +use byteorder::{ByteOrder, NativeEndian}; +use netlink_packet_utils::{ + nla::{Nla, NlaBuffer, NlasIterator}, + parsers::*, + traits::*, + DecodeError, +}; +use std::{ + convert::TryFrom, + mem::{size_of, size_of_val}, +}; + +// PolicyAttr + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PolicyAttr { + pub index: u16, + pub attr_policy: AttributePolicyAttr, +} + +impl Nla for PolicyAttr { + fn value_len(&self) -> usize { + self.attr_policy.buffer_len() + } + + fn kind(&self) -> u16 { + self.index + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.attr_policy.emit(buffer); + } + + fn is_nested(&self) -> bool { + true + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for PolicyAttr { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + + Ok(Self { + index: buf.kind(), + attr_policy: AttributePolicyAttr::parse(&NlaBuffer::new(payload)) + .context("failed to parse PolicyAttr")?, + }) + } +} + +// AttributePolicyAttr + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct AttributePolicyAttr { + pub index: u16, + pub policies: Vec, +} + +impl Nla for AttributePolicyAttr { + fn value_len(&self) -> usize { + self.policies.as_slice().buffer_len() + } + + fn kind(&self) -> u16 { + self.index + } + + fn emit_value(&self, buffer: &mut [u8]) { + self.policies.as_slice().emit(buffer); + } + + fn is_nested(&self) -> bool { + true + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for AttributePolicyAttr { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + let policies = NlasIterator::new(payload) + .map(|nla| nla.and_then(|nla| NlPolicyTypeAttrs::parse(&nla))) + .collect::, _>>() + .context("failed to parse AttributePolicyAttr")?; + + Ok(Self { + index: buf.kind(), + policies, + }) + } +} + +// PolicyTypeAttrs + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum NlPolicyTypeAttrs { + Type(NlaType), + MinValueSigned(i64), + MaxValueSigned(i64), + MaxValueUnsigned(u64), + MinValueUnsigned(u64), + MinLength(u32), + MaxLength(u32), + PolicyIdx(u32), + PolicyMaxType(u32), + Bitfield32Mask(u32), + Mask(u64), +} + +impl Nla for NlPolicyTypeAttrs { + fn value_len(&self) -> usize { + use NlPolicyTypeAttrs::*; + match self { + Type(v) => size_of_val(v), + MinValueSigned(v) => size_of_val(v), + MaxValueSigned(v) => size_of_val(v), + MaxValueUnsigned(v) => size_of_val(v), + MinValueUnsigned(v) => size_of_val(v), + MinLength(v) => size_of_val(v), + MaxLength(v) => size_of_val(v), + PolicyIdx(v) => size_of_val(v), + PolicyMaxType(v) => size_of_val(v), + Bitfield32Mask(v) => size_of_val(v), + Mask(v) => size_of_val(v), + } + } + + fn kind(&self) -> u16 { + use NlPolicyTypeAttrs::*; + match self { + Type(_) => NL_POLICY_TYPE_ATTR_TYPE, + MinValueSigned(_) => NL_POLICY_TYPE_ATTR_MIN_VALUE_S, + MaxValueSigned(_) => NL_POLICY_TYPE_ATTR_MAX_VALUE_S, + MaxValueUnsigned(_) => NL_POLICY_TYPE_ATTR_MIN_VALUE_U, + MinValueUnsigned(_) => NL_POLICY_TYPE_ATTR_MAX_VALUE_U, + MinLength(_) => NL_POLICY_TYPE_ATTR_MIN_LENGTH, + MaxLength(_) => NL_POLICY_TYPE_ATTR_MAX_LENGTH, + PolicyIdx(_) => NL_POLICY_TYPE_ATTR_POLICY_IDX, + PolicyMaxType(_) => NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE, + Bitfield32Mask(_) => NL_POLICY_TYPE_ATTR_BITFIELD32_MASK, + Mask(_) => NL_POLICY_TYPE_ATTR_MASK, + } + } + + fn emit_value(&self, buffer: &mut [u8]) { + use NlPolicyTypeAttrs::*; + match self { + Type(v) => NativeEndian::write_u32(buffer, u32::from(*v)), + MinValueSigned(v) => NativeEndian::write_i64(buffer, *v), + MaxValueSigned(v) => NativeEndian::write_i64(buffer, *v), + MaxValueUnsigned(v) => NativeEndian::write_u64(buffer, *v), + MinValueUnsigned(v) => NativeEndian::write_u64(buffer, *v), + MinLength(v) => NativeEndian::write_u32(buffer, *v), + MaxLength(v) => NativeEndian::write_u32(buffer, *v), + PolicyIdx(v) => NativeEndian::write_u32(buffer, *v), + PolicyMaxType(v) => NativeEndian::write_u32(buffer, *v), + Bitfield32Mask(v) => NativeEndian::write_u32(buffer, *v), + Mask(v) => NativeEndian::write_u64(buffer, *v), + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Parseable> for NlPolicyTypeAttrs { + fn parse(buf: &NlaBuffer<&'a T>) -> Result { + let payload = buf.value(); + Ok(match buf.kind() { + NL_POLICY_TYPE_ATTR_TYPE => { + let value = parse_u32(payload).context("invalid NL_POLICY_TYPE_ATTR_TYPE value")?; + Self::Type(NlaType::try_from(value)?) + } + NL_POLICY_TYPE_ATTR_MIN_VALUE_S => Self::MinValueSigned( + parse_i64(payload).context("invalid NL_POLICY_TYPE_ATTR_MIN_VALUE_S value")?, + ), + NL_POLICY_TYPE_ATTR_MAX_VALUE_S => Self::MaxValueSigned( + parse_i64(payload).context("invalid NL_POLICY_TYPE_ATTR_MAX_VALUE_S value")?, + ), + NL_POLICY_TYPE_ATTR_MIN_VALUE_U => Self::MinValueUnsigned( + parse_u64(payload).context("invalid NL_POLICY_TYPE_ATTR_MIN_VALUE_U value")?, + ), + NL_POLICY_TYPE_ATTR_MAX_VALUE_U => Self::MaxValueUnsigned( + parse_u64(payload).context("invalid NL_POLICY_TYPE_ATTR_MAX_VALUE_U value")?, + ), + NL_POLICY_TYPE_ATTR_MIN_LENGTH => Self::MinLength( + parse_u32(payload).context("invalid NL_POLICY_TYPE_ATTR_MIN_LENGTH value")?, + ), + NL_POLICY_TYPE_ATTR_MAX_LENGTH => Self::MaxLength( + parse_u32(payload).context("invalid NL_POLICY_TYPE_ATTR_MAX_LENGTH value")?, + ), + NL_POLICY_TYPE_ATTR_POLICY_IDX => Self::PolicyIdx( + parse_u32(payload).context("invalid NL_POLICY_TYPE_ATTR_POLICY_IDX value")?, + ), + NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE => Self::PolicyMaxType( + parse_u32(payload).context("invalid NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE value")?, + ), + NL_POLICY_TYPE_ATTR_BITFIELD32_MASK => Self::Bitfield32Mask( + parse_u32(payload).context("invalid NL_POLICY_TYPE_ATTR_BITFIELD32_MASK value")?, + ), + NL_POLICY_TYPE_ATTR_MASK => { + Self::Mask(parse_u64(payload).context("invalid NL_POLICY_TYPE_ATTR_MASK value")?) + } + kind => return Err(DecodeError::from(format!("Unknown NLA type: {}", kind))), + }) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum NlaType { + Flag, + U8, + U16, + U32, + U64, + S8, + S16, + S32, + S64, + Binary, + String, + NulString, + Nested, + NestedArray, + Bitfield32, +} + +impl From for u32 { + fn from(nlatype: NlaType) -> u32 { + match nlatype { + NlaType::Flag => NL_ATTR_TYPE_FLAG, + NlaType::U8 => NL_ATTR_TYPE_U8, + NlaType::U16 => NL_ATTR_TYPE_U16, + NlaType::U32 => NL_ATTR_TYPE_U32, + NlaType::U64 => NL_ATTR_TYPE_U64, + NlaType::S8 => NL_ATTR_TYPE_S8, + NlaType::S16 => NL_ATTR_TYPE_S16, + NlaType::S32 => NL_ATTR_TYPE_S32, + NlaType::S64 => NL_ATTR_TYPE_S64, + NlaType::Binary => NL_ATTR_TYPE_BINARY, + NlaType::String => NL_ATTR_TYPE_STRING, + NlaType::NulString => NL_ATTR_TYPE_NUL_STRING, + NlaType::Nested => NL_ATTR_TYPE_NESTED, + NlaType::NestedArray => NL_ATTR_TYPE_NESTED_ARRAY, + NlaType::Bitfield32 => NL_ATTR_TYPE_BITFIELD32, + } + } +} + +impl TryFrom for NlaType { + type Error = DecodeError; + + fn try_from(value: u32) -> Result { + Ok(match value { + NL_ATTR_TYPE_FLAG => NlaType::Flag, + NL_ATTR_TYPE_U8 => NlaType::U8, + NL_ATTR_TYPE_U16 => NlaType::U16, + NL_ATTR_TYPE_U32 => NlaType::U32, + NL_ATTR_TYPE_U64 => NlaType::U64, + NL_ATTR_TYPE_S8 => NlaType::S8, + NL_ATTR_TYPE_S16 => NlaType::S16, + NL_ATTR_TYPE_S32 => NlaType::S32, + NL_ATTR_TYPE_S64 => NlaType::S64, + NL_ATTR_TYPE_BINARY => NlaType::Binary, + NL_ATTR_TYPE_STRING => NlaType::String, + NL_ATTR_TYPE_NUL_STRING => NlaType::NulString, + NL_ATTR_TYPE_NESTED => NlaType::Nested, + NL_ATTR_TYPE_NESTED_ARRAY => NlaType::NestedArray, + NL_ATTR_TYPE_BITFIELD32 => NlaType::Bitfield32, + _ => return Err(DecodeError::from(format!("invalid NLA type: {}", value))), + }) + } +} + +// FIXME: Add this into netlink_packet_utils::parser +fn parse_i64(payload: &[u8]) -> Result { + if payload.len() != size_of::() { + return Err(format!("invalid i64: {:?}", payload).into()); + } + Ok(NativeEndian::read_i64(payload)) +} diff --git a/netlink-packet-generic/src/header.rs b/netlink-packet-generic/src/header.rs new file mode 100644 index 00000000..b031f4a8 --- /dev/null +++ b/netlink-packet-generic/src/header.rs @@ -0,0 +1,32 @@ +//! header definition of generic netlink packet +use crate::{buffer::GenlBuffer, constants::GENL_HDRLEN}; +use netlink_packet_core::DecodeError; +use netlink_packet_utils::{Emitable, Parseable}; + +/// Generic Netlink header +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct GenlHeader { + pub cmd: u8, + pub version: u8, +} + +impl Emitable for GenlHeader { + fn buffer_len(&self) -> usize { + GENL_HDRLEN + } + + fn emit(&self, buffer: &mut [u8]) { + let mut packet = GenlBuffer::new(buffer); + packet.set_cmd(self.cmd); + packet.set_version(self.version); + } +} + +impl> Parseable> for GenlHeader { + fn parse(buf: &GenlBuffer) -> Result { + Ok(Self { + cmd: buf.cmd(), + version: buf.version(), + }) + } +} diff --git a/netlink-packet-generic/src/lib.rs b/netlink-packet-generic/src/lib.rs new file mode 100644 index 00000000..19ec8fc8 --- /dev/null +++ b/netlink-packet-generic/src/lib.rs @@ -0,0 +1,81 @@ +//! This crate provides the packet of generic netlink family and its controller. +//! +//! The `[GenlMessage]` provides a generic netlink family message which is +//! sub-protocol independant. +//! You can wrap your message into the type, then it can be used in `netlink-proto` crate. +//! +//! # Implementing a generic netlink family +//! A generic netlink family contains several commands, and a version number in +//! the header. +//! +//! The payload usually consists of netlink attributes, carrying the messages to +//! the peer. In order to help you to make your payload into a valid netlink +//! packet, this crate requires the informations about the family id, +//! and the informations in the generic header. So, you need to implement some +//! traits on your types. +//! +//! All the things in the payload including all netlink attributes used +//! and the optional header should be handled by your implementation. +//! +//! ## Serializaion / Deserialization +//! To implement your generic netlink family, you should handle the payload +//! serialization process including its specific header (if any) and the netlink +//! attributes. +//! +//! To achieve this, you should implement [`netlink_packet_utils::Emitable`] +//! trait for the payload type. +//! +//! For deserialization, [`netlink_packet_utils::ParseableParametrized<[u8], GenlHeader>`](netlink_packet_utils::ParseableParametrized) +//! trait should be implemented. As mention above, to provide more scalability, +//! we use the simplest buffer type: `[u8]` here. You can turn it into other +//! buffer type easily during deserializing. +//! +//! ## `GenlFamily` trait +//! The trait is aim to provide some necessary informations in order to build +//! the packet headers of netlink (nlmsghdr) and generic netlink (genlmsghdr). +//! +//! ### `family_name()` +//! The method let the resolver to obtain the name registered in the kernel. +//! +//! ### `family_id()` +//! Few netlink family has static family ID (e.g. controller). The method is +//! mainly used to let those family to return their familt ID. +//! +//! If you don't know what is this, please **DO NOT** implement this method. +//! Since the default implementation return `GENL_ID_GENERATE`, which means +//! the family ID is allocated by the kernel dynamically. +//! +//! ### `command()` +//! This method tells the generic netlink command id of the packet +//! The return value is used to fill the `cmd` field in the generic netlink header. +//! +//! ### `version()` +//! This method return the family version of the payload. +//! The return value is used to fill the `version` field in the generic netlink header. +//! +//! ## Family Header +//! Few family would use a family specific message header. For simplification +//! and scalability, this crate treats it as a part of the payload, and make +//! implementations to handle the header by themselves. +//! +//! If you are implementing such a generic family, note that you should define +//! the header data structure in your payload type and handle the serialization. + +#[macro_use] +extern crate netlink_packet_utils; + +pub mod buffer; +pub use self::buffer::GenlBuffer; + +pub mod constants; + +pub mod ctrl; + +pub mod header; +pub use self::header::GenlHeader; + +pub mod message; +pub use self::message::GenlMessage; + +pub mod traits; +pub use self::traits::GenlFamily; diff --git a/netlink-packet-generic/src/message.rs b/netlink-packet-generic/src/message.rs new file mode 100644 index 00000000..3d0bfde6 --- /dev/null +++ b/netlink-packet-generic/src/message.rs @@ -0,0 +1,184 @@ +//! Message definition and method implementations + +use crate::{buffer::GenlBuffer, header::GenlHeader, traits::*}; +use netlink_packet_core::{ + DecodeError, + NetlinkDeserializable, + NetlinkHeader, + NetlinkPayload, + NetlinkSerializable, +}; +use netlink_packet_utils::{Emitable, ParseableParametrized}; +use std::fmt::Debug; + +#[cfg(doc)] +use netlink_packet_core::NetlinkMessage; + +/// Represent the generic netlink messages +/// +/// This type can wrap data types `F` which represents a generic family payload. +/// The message can be serialize/deserialize if the type `F` implements [`GenlFamily`], +/// [`Emitable`], and [`ParseableParametrized<[u8], GenlHeader>`](ParseableParametrized). +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct GenlMessage +where + F: Clone + Debug + PartialEq + Eq, +{ + pub header: GenlHeader, + pub payload: F, + resolved_family_id: u16, +} + +impl GenlMessage +where + F: Clone + Debug + PartialEq + Eq, +{ + /// Construct the message + pub fn new(header: GenlHeader, payload: F, family_id: u16) -> Self { + Self { + header, + payload, + resolved_family_id: family_id, + } + } + + /// Construct the message by the given header and payload + pub fn from_parts(header: GenlHeader, payload: F) -> Self { + Self { + header, + payload, + resolved_family_id: 0, + } + } + + /// Consume this message and return its header and payload + pub fn into_parts(self) -> (GenlHeader, F) { + (self.header, self.payload) + } + + /// Return the previously set resolved family ID in this message. + /// + /// This value would be used to serialize the message only if + /// the ([`GenlFamily::family_id()`]) return 0 in the underlying type. + pub fn resolved_family_id(&self) -> u16 { + self.resolved_family_id + } + + /// Set the resolved dynamic family ID of the message, if the generic family + /// uses dynamic generated ID by kernel. + /// + /// This method is a interface to provide other high level library to + /// set the resolved family ID before the message is serialized. + /// + /// # Usage + /// Normally, you don't have to call this function directly if you are + /// using library which helps you handle the dynamic family id. + /// + /// If you are the developer of some high level generic netlink library, + /// you can call this method to set the family id resolved by your resolver. + /// Without having to modify the `message_type` field of the serialized + /// netlink packet header before sending it. + pub fn set_resolved_family_id(&mut self, family_id: u16) { + self.resolved_family_id = family_id; + } +} + +impl GenlMessage +where + F: GenlFamily + Clone + Debug + PartialEq + Eq, +{ + /// Build the message from the payload + /// + /// This function would automatically fill the header for you. You can directly emit + /// the message without having to call [`finalize()`](Self::finalize). + pub fn from_payload(payload: F) -> Self { + Self { + header: GenlHeader { + cmd: payload.command(), + version: payload.version(), + }, + payload, + resolved_family_id: 0, + } + } + + /// Ensure the header ([`GenlHeader`]) is consistent with the payload (`F: GenlFamily`): + /// + /// - Fill the command and version number into the header + /// + /// If you are not 100% sure the header is correct, this method should be called before calling + /// [`Emitable::emit()`], as it could get error result if the header is inconsistent with the message. + pub fn finalize(&mut self) { + self.header.cmd = self.payload.command(); + self.header.version = self.payload.version(); + } + + /// Return the resolved family ID which should be filled into the `message_type` + /// field in [`NetlinkHeader`]. + /// + /// The implementation of [`NetlinkSerializable::message_type()`] would use + /// this function's result as its the return value. Thus, the family id can + /// be automatically filled into the `message_type` during the call to + /// [`NetlinkMessage::finalize()`]. + pub fn family_id(&self) -> u16 { + let static_id = self.payload.family_id(); + if static_id == 0 { + self.resolved_family_id + } else { + static_id + } + } +} + +impl Emitable for GenlMessage +where + F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, +{ + fn buffer_len(&self) -> usize { + self.header.buffer_len() + self.payload.buffer_len() + } + + fn emit(&self, buffer: &mut [u8]) { + self.header.emit(buffer); + + let buffer = &mut buffer[self.header.buffer_len()..]; + self.payload.emit(buffer); + } +} + +impl NetlinkSerializable> for GenlMessage +where + F: GenlFamily + Emitable + Clone + Debug + PartialEq + Eq, +{ + fn message_type(&self) -> u16 { + self.family_id() + } + + fn buffer_len(&self) -> usize { + ::buffer_len(self) + } + + fn serialize(&self, buffer: &mut [u8]) { + self.emit(buffer) + } +} + +impl<'a, F> NetlinkDeserializable> for GenlMessage +where + F: ParseableParametrized<[u8], GenlHeader> + Clone + Debug + PartialEq + Eq, +{ + type Error = DecodeError; + fn deserialize(header: &NetlinkHeader, payload: &[u8]) -> Result { + let buffer = GenlBuffer::new_checked(payload)?; + GenlMessage::parse_with_param(&buffer, header.message_type) + } +} + +impl From> for NetlinkPayload> +where + F: Clone + Debug + PartialEq + Eq, +{ + fn from(message: GenlMessage) -> Self { + NetlinkPayload::InnerMessage(message) + } +} diff --git a/netlink-packet-generic/src/traits.rs b/netlink-packet-generic/src/traits.rs new file mode 100644 index 00000000..1110fa3d --- /dev/null +++ b/netlink-packet-generic/src/traits.rs @@ -0,0 +1,33 @@ +//! Traits for implementing generic netlink family + +/// Provide the definition for generic netlink family +/// +/// Family payload type should implement this trait to provide necessary +/// informations in order to build the packet headers (`nlmsghdr` and `genlmsghdr`). +/// +/// If you are looking for an example implementation, you can refer to the +/// [`crate::ctrl`] module. +pub trait GenlFamily { + /// Return the unique family name registered in the kernel + /// + /// Let the resolver lookup the dynamically assigned ID + fn family_name() -> &'static str; + + /// Return the assigned family ID + /// + /// # Note + /// The implementation of generic family should assign the ID to `GENL_ID_GENERATE` (0x0). + /// So the controller can dynamically assign the family ID. + /// + /// Regarding to the reason above, you should not have to implement the function + /// unless the family uses static ID. + fn family_id(&self) -> u16 { + 0 + } + + /// Return the command type of the current message + fn command(&self) -> u8; + + /// Indicate the protocol version + fn version(&self) -> u8; +} diff --git a/netlink-packet-generic/tests/query_family_id.rs b/netlink-packet-generic/tests/query_family_id.rs new file mode 100644 index 00000000..a74eadf4 --- /dev/null +++ b/netlink-packet-generic/tests/query_family_id.rs @@ -0,0 +1,55 @@ +use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST}; +use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlMessage, +}; +use netlink_sys::{protocols::NETLINK_GENERIC, Socket, SocketAddr}; + +#[test] +fn query_family_id() { + let mut socket = Socket::new(NETLINK_GENERIC).unwrap(); + socket.bind_auto().unwrap(); + socket.connect(&SocketAddr::new(0, 0)).unwrap(); + + let mut genlmsg = GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyName("nlctrl".to_owned())], + }); + genlmsg.finalize(); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST; + nlmsg.finalize(); + + println!("Buffer length: {}", nlmsg.buffer_len()); + let mut txbuf = vec![0u8; nlmsg.buffer_len()]; + nlmsg.serialize(&mut txbuf); + + socket.send(&txbuf, 0).unwrap(); + + let mut rxbuf = vec![0u8; 2048]; + socket.recv(&mut rxbuf, 0).unwrap(); + let rx_packet = >>::deserialize(&rxbuf).unwrap(); + + if let NetlinkPayload::InnerMessage(genlmsg) = rx_packet.payload { + if GenlCtrlCmd::NewFamily == genlmsg.payload.cmd { + let family_id = genlmsg + .payload + .nlas + .iter() + .find_map(|nla| { + if let GenlCtrlAttrs::FamilyId(id) = nla { + Some(*id) + } else { + None + } + }) + .expect("Cannot find FamilyId attribute"); + // nlctrl's family must be 0x10 + assert_eq!(0x10, family_id); + } else { + panic!("Invalid payload type: {:?}", genlmsg.payload.cmd); + } + } else { + panic!("Failed to get family ID"); + } +}