diff --git a/audit/examples/events.rs b/audit/examples/events.rs index fb84fb5f..131db6e6 100644 --- a/audit/examples/events.rs +++ b/audit/examples/events.rs @@ -8,14 +8,14 @@ use futures::stream::StreamExt; #[tokio::main] async fn main() -> Result<(), String> { - let (connection, mut handle, mut messages) = new_connection().map_err(|e| format!("{}", e))?; + let (connection, mut handle, mut events) = new_connection().map_err(|e| format!("{}", e))?; tokio::spawn(connection); handle.enable_events().await.map_err(|e| format!("{}", e))?; env_logger::init(); - while let Some((msg, _)) = messages.next().await { - println!("{:?}", msg); + while let Some(event) = events.next().await { + println!("{event:?}"); } Ok(()) } diff --git a/audit/src/lib.rs b/audit/src/lib.rs index fd34c3a5..d1b1d6a8 100644 --- a/audit/src/lib.rs +++ b/audit/src/lib.rs @@ -21,10 +21,12 @@ use futures::channel::mpsc::UnboundedReceiver; pub fn new_connection() -> io::Result<( proto::Connection, Handle, - UnboundedReceiver<( - packet::NetlinkMessage, - sys::SocketAddr, - )>, + UnboundedReceiver< + packet::NetlinkEvent<( + packet::NetlinkMessage, + sys::SocketAddr, + )>, + >, )> { new_connection_with_socket() } @@ -33,10 +35,12 @@ pub fn new_connection() -> io::Result<( pub fn new_connection_with_socket() -> io::Result<( proto::Connection, Handle, - UnboundedReceiver<( - packet::NetlinkMessage, - sys::SocketAddr, - )>, + UnboundedReceiver< + packet::NetlinkEvent<( + packet::NetlinkMessage, + sys::SocketAddr, + )>, + >, )> where S: sys::AsyncSocket, diff --git a/ethtool/src/connection.rs b/ethtool/src/connection.rs index 8987d991..ee624e2a 100644 --- a/ethtool/src/connection.rs +++ b/ethtool/src/connection.rs @@ -4,7 +4,7 @@ use std::io; use futures::channel::mpsc::UnboundedReceiver; use genetlink::message::RawGenlMessage; -use netlink_packet_core::NetlinkMessage; +use netlink_packet_core::{NetlinkEvent, NetlinkMessage}; use netlink_proto::Connection; use netlink_sys::{AsyncSocket, SocketAddr}; @@ -15,7 +15,7 @@ use crate::EthtoolHandle; pub fn new_connection() -> io::Result<( Connection, EthtoolHandle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> { new_connection_with_socket() } @@ -24,7 +24,7 @@ pub fn new_connection() -> io::Result<( pub fn new_connection_with_socket() -> io::Result<( Connection, EthtoolHandle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> where S: AsyncSocket, diff --git a/genetlink/src/connection.rs b/genetlink/src/connection.rs index 72ce3942..c78b9c59 100644 --- a/genetlink/src/connection.rs +++ b/genetlink/src/connection.rs @@ -2,7 +2,7 @@ use crate::{message::RawGenlMessage, GenetlinkHandle}; use futures::channel::mpsc::UnboundedReceiver; -use netlink_packet_core::NetlinkMessage; +use netlink_packet_core::{NetlinkEvent, NetlinkMessage}; use netlink_proto::{ self, sys::{protocols::NETLINK_GENERIC, AsyncSocket, SocketAddr}, @@ -28,7 +28,7 @@ use std::io; pub fn new_connection() -> io::Result<( Connection, GenetlinkHandle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> { new_connection_with_socket() } @@ -38,7 +38,7 @@ pub fn new_connection() -> io::Result<( pub fn new_connection_with_socket() -> io::Result<( Connection, GenetlinkHandle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> where S: AsyncSocket, diff --git a/mptcp-pm/src/connection.rs b/mptcp-pm/src/connection.rs index 774663ba..ca7a10dc 100644 --- a/mptcp-pm/src/connection.rs +++ b/mptcp-pm/src/connection.rs @@ -4,7 +4,7 @@ use std::io; use futures::channel::mpsc::UnboundedReceiver; use genetlink::message::RawGenlMessage; -use netlink_packet_core::NetlinkMessage; +use netlink_packet_core::{NetlinkEvent, NetlinkMessage}; use netlink_proto::Connection; use netlink_sys::{AsyncSocket, SocketAddr}; @@ -15,7 +15,7 @@ use crate::MptcpPathManagerHandle; pub fn new_connection() -> io::Result<( Connection, MptcpPathManagerHandle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> { new_connection_with_socket() } @@ -24,7 +24,7 @@ pub fn new_connection() -> io::Result<( pub fn new_connection_with_socket() -> io::Result<( Connection, MptcpPathManagerHandle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> where S: AsyncSocket, diff --git a/netlink-packet-audit/src/lib.rs b/netlink-packet-audit/src/lib.rs index 189e26cc..6542c240 100644 --- a/netlink-packet-audit/src/lib.rs +++ b/netlink-packet-audit/src/lib.rs @@ -8,6 +8,7 @@ pub use self::utils::{traits, DecodeError}; pub use netlink_packet_core::{ ErrorMessage, NetlinkBuffer, + NetlinkEvent, NetlinkHeader, NetlinkMessage, NetlinkPayload, diff --git a/netlink-packet-core/src/message.rs b/netlink-packet-core/src/message.rs index b9fcd2d0..537b60d0 100644 --- a/netlink-packet-core/src/message.rs +++ b/netlink-packet-core/src/message.rs @@ -18,6 +18,15 @@ use crate::{ Parseable, }; +/// Represent a Netlink event +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum NetlinkEvent { + /// An actual message was received from Netlink + Message(M), + /// The socket receive buffer filled up + Overrun, +} + /// Represent a netlink message. #[derive(Debug, PartialEq, Eq, Clone)] pub struct NetlinkMessage { diff --git a/netlink-packet-route/src/lib.rs b/netlink-packet-route/src/lib.rs index 43fc7e58..866b8060 100644 --- a/netlink-packet-route/src/lib.rs +++ b/netlink-packet-route/src/lib.rs @@ -10,6 +10,7 @@ pub use self::utils::{traits, DecodeError}; pub use netlink_packet_core::{ ErrorMessage, NetlinkBuffer, + NetlinkEvent, NetlinkHeader, NetlinkMessage, NetlinkPayload, diff --git a/netlink-proto/examples/audit_netlink_events.rs b/netlink-proto/examples/audit_netlink_events.rs index 7c87bd1a..c12328ed 100644 --- a/netlink-proto/examples/audit_netlink_events.rs +++ b/netlink-proto/examples/audit_netlink_events.rs @@ -33,6 +33,7 @@ use std::process; use netlink_proto::{ new_connection, + packet::NetlinkEvent, sys::{protocols::NETLINK_AUDIT, SocketAddr}, }; @@ -50,11 +51,11 @@ async fn main() -> Result<(), String> { // - `handle` is a `Handle` to the `Connection`. We use it to send // netlink messages and receive responses to these messages. // - // - `messages` is a channel receiver through which we receive + // - `events` is a channel receiver through which we receive // messages that we have not sollicated, ie that are not // response to a request we made. In this example, we'll receive // the audit event through that channel. - let (conn, mut handle, mut messages) = new_connection(NETLINK_AUDIT) + let (conn, mut handle, mut events) = new_connection(NETLINK_AUDIT) .map_err(|e| format!("Failed to create a new netlink connection: {}", e))?; // Spawn the `Connection` so that it starts polling the netlink @@ -91,13 +92,18 @@ async fn main() -> Result<(), String> { } }); - // Finally, start receiving event through the `messages` channel. + // Finally, start receiving event through the `events` channel. println!("Starting to print audit events... press ^C to interrupt"); - while let Some((message, _addr)) = messages.next().await { - if let NetlinkPayload::Error(err_message) = message.payload { - eprintln!("received an error message: {:?}", err_message); - } else { - println!("{:?}", message); + while let Some(event) = events.next().await { + match event { + NetlinkEvent::Message((message, _addr)) => { + if let NetlinkPayload::Error(err_message) = message.payload { + eprintln!("received an error message: {:?}", err_message); + } else { + println!("{:?}", message); + } + } + NetlinkEvent::Overrun => println!("Netlink socket overrun. Some messages were lost"), } } diff --git a/netlink-proto/src/connection.rs b/netlink-proto/src/connection.rs index df6d9dff..a9521d39 100644 --- a/netlink-proto/src/connection.rs +++ b/netlink-proto/src/connection.rs @@ -16,6 +16,7 @@ use futures::{ use log::{error, warn}; use netlink_packet_core::{ NetlinkDeserializable, + NetlinkEvent, NetlinkMessage, NetlinkPayload, NetlinkSerializable, @@ -52,7 +53,7 @@ where /// Channel used to transmit to the ConnectionHandle the unsolicited messages received from the /// socket (multicast messages for instance). - unsolicited_messages_tx: Option, SocketAddr)>>, + unsolicited_messages_tx: Option, SocketAddr)>>>, socket_closed: bool, } @@ -65,7 +66,7 @@ where { pub(crate) fn new( requests_rx: UnboundedReceiver>, - unsolicited_messages_tx: UnboundedSender<(NetlinkMessage, SocketAddr)>, + unsolicited_messages_tx: UnboundedSender, SocketAddr)>>, protocol: isize, ) -> io::Result { let socket = S::new(protocol)?; @@ -131,10 +132,14 @@ where loop { trace!("polling socket"); match socket.as_mut().poll_next(cx) { - Poll::Ready(Some((message, addr))) => { + Poll::Ready(Some(NetlinkEvent::Message((message, addr)))) => { trace!("read datagram from socket"); self.protocol.handle_message(message, addr); } + Poll::Ready(Some(NetlinkEvent::Overrun)) => { + warn!("netlink socket buffer full"); + self.protocol.handle_buffer_full(); + } Poll::Ready(None) => { warn!("netlink socket stream shut down"); self.socket_closed = true; @@ -165,11 +170,13 @@ where pub fn forward_unsolicited_messages(&mut self) { if self.unsolicited_messages_tx.is_none() { - while let Some((message, source)) = self.protocol.incoming_requests.pop_front() { - warn!( - "ignoring unsolicited message {:?} from {:?}", - message, source - ); + while let Some(event) = self.protocol.incoming_requests.pop_front() { + match event { + NetlinkEvent::Message((message, source)) => { + warn!("ignoring unsolicited message {message:?} from {source:?}") + } + NetlinkEvent::Overrun => warn!("ignoring unsolicited socket overrun"), + } } return; } @@ -183,11 +190,11 @@ where .. } = self; - while let Some((message, source)) = protocol.incoming_requests.pop_front() { + while let Some(event) = protocol.incoming_requests.pop_front() { if unsolicited_messages_tx .as_mut() .unwrap() - .unbounded_send((message, source)) + .unbounded_send(event) .is_err() { // The channel is unbounded so the only error that can diff --git a/netlink-proto/src/framed.rs b/netlink-proto/src/framed.rs index 5d9d3011..c4c1d8d3 100644 --- a/netlink-proto/src/framed.rs +++ b/netlink-proto/src/framed.rs @@ -16,7 +16,12 @@ use crate::{ codecs::NetlinkMessageCodec, sys::{AsyncSocket, SocketAddr}, }; -use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable}; +use netlink_packet_core::{ + NetlinkDeserializable, + NetlinkEvent, + NetlinkMessage, + NetlinkSerializable, +}; pub struct NetlinkFramed { socket: S, @@ -38,7 +43,7 @@ where S: AsyncSocket, C: NetlinkMessageCodec, { - type Item = (NetlinkMessage, SocketAddr); + type Item = NetlinkEvent<(NetlinkMessage, SocketAddr)>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Self { @@ -50,7 +55,9 @@ where loop { match C::decode::(reader) { - Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))), + Ok(Some(item)) => { + return Poll::Ready(Some(NetlinkEvent::Message((item, *in_addr)))) + } Ok(None) => {} Err(e) => { error!("unrecoverable error in decoder: {:?}", e); @@ -63,6 +70,23 @@ where *in_addr = match ready!(socket.poll_recv_from(cx, reader)) { Ok(addr) => addr, + // When receiving messages in multicast mode (i.e. we subscribed to + // notifications), the kernel will not wait for us to read datagrams before + // sending more. The receive buffer has a finite size, so once it is full (no + // more message can fit in), new messages will be dropped and recv calls will + // return `ENOBUFS`. + // This needs to be handled for applications to resynchronize with the contents + // of the kernel if necessary. + // We don't need to do anything special: + // - contents of the reader is still valid because we won't have partial messages + // in there anyways (large enough buffer) + // - contents of the socket's internal buffer is still valid because the kernel + // won't put partial data in it + Err(e) if e.raw_os_error() == Some(105) => { + // ENOBUFS + warn!("netlink socket buffer full"); + return Poll::Ready(Some(NetlinkEvent::Overrun)); + } Err(e) => { error!("failed to read from netlink socket: {:?}", e); return Poll::Ready(None); diff --git a/netlink-proto/src/lib.rs b/netlink-proto/src/lib.rs index ea3e8fe3..a040acd2 100644 --- a/netlink-proto/src/lib.rs +++ b/netlink-proto/src/lib.rs @@ -17,6 +17,7 @@ //! use futures::stream::StreamExt; //! use netlink_packet_audit::{ //! AuditMessage, +//! NetlinkEvent, //! NetlinkMessage, //! NetlinkPayload, //! StatusMessage, @@ -44,11 +45,11 @@ //! // - `handle` is a `Handle` to the `Connection`. We use it to send //! // netlink messages and receive responses to these messages. //! // -//! // - `messages` is a channel receiver through which we receive +//! // - `events` is a channel receiver through which we receive //! // messages that we have not solicited, ie that are not //! // response to a request we made. In this example, we'll receive //! // the audit event through that channel. -//! let (conn, mut handle, mut messages) = new_connection(NETLINK_AUDIT) +//! let (conn, mut handle, mut events) = new_connection(NETLINK_AUDIT) //! .map_err(|e| format!("Failed to create a new netlink connection: {}", e))?; //! //! // Spawn the `Connection` so that it starts polling the netlink @@ -85,13 +86,21 @@ //! } //! }); //! -//! // Finally, start receiving event through the `messages` channel. +//! // Finally, start receiving event through the `events` channel. //! println!("Starting to print audit events... press ^C to interrupt"); -//! while let Some((message, _addr)) = messages.next().await { -//! if let NetlinkPayload::Error(err_message) = message.payload { -//! eprintln!("received an error message: {:?}", err_message); -//! } else { -//! println!("{:?}", message); +//! while let Some(event) = events.next().await { +//! match event { +//! NetlinkEvent::Message((message, _addr)) => { +//! if let NetlinkPayload::Error(err_message) = message.payload { +//! eprintln!("received an error message: {:?}", err_message); +//! } else { +//! println!("{:?}", message); +//! } +//! } +//! // Netlink sockets have a finite receive buffer that can fill up if there are more +//! // messages sent by the kernel than we can read. +//! // In this case at least one message has been lost. +//! NetlinkEvent::Overrun => println!("Netlink socket overrun. Some messages were lost"), //! } //! } //! @@ -227,7 +236,7 @@ pub fn new_connection( ) -> io::Result<( Connection, ConnectionHandle, - UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, + UnboundedReceiver, sys::SocketAddr)>>, )> where T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, @@ -242,7 +251,7 @@ pub fn new_connection_with_socket( ) -> io::Result<( Connection, ConnectionHandle, - UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, + UnboundedReceiver, sys::SocketAddr)>>, )> where T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, @@ -258,7 +267,7 @@ pub fn new_connection_with_codec( ) -> io::Result<( Connection, ConnectionHandle, - UnboundedReceiver<(packet::NetlinkMessage, sys::SocketAddr)>, + UnboundedReceiver, sys::SocketAddr)>>, )> where T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin, @@ -266,7 +275,8 @@ where C: NetlinkMessageCodec, { let (requests_tx, requests_rx) = unbounded::>(); - let (messages_tx, messages_rx) = unbounded::<(packet::NetlinkMessage, sys::SocketAddr)>(); + let (messages_tx, messages_rx) = + unbounded::, sys::SocketAddr)>>(); Ok(( Connection::new(requests_rx, messages_tx, protocol)?, ConnectionHandle::new(requests_tx), diff --git a/netlink-proto/src/protocol/protocol.rs b/netlink-proto/src/protocol/protocol.rs index 43b7891f..3a0d20f9 100644 --- a/netlink-proto/src/protocol/protocol.rs +++ b/netlink-proto/src/protocol/protocol.rs @@ -8,6 +8,7 @@ use std::{ use netlink_packet_core::{ constants::*, NetlinkDeserializable, + NetlinkEvent, NetlinkMessage, NetlinkPayload, NetlinkSerializable, @@ -57,7 +58,7 @@ pub(crate) struct Protocol { pub incoming_responses: VecDeque>, /// Requests from remote peers - pub incoming_requests: VecDeque<(NetlinkMessage, SocketAddr)>, + pub incoming_requests: VecDeque, SocketAddr)>>, /// The messages to be sent out pub outgoing_messages: VecDeque<(NetlinkMessage, SocketAddr)>, @@ -84,10 +85,15 @@ where if let hash_map::Entry::Occupied(entry) = self.pending_requests.entry(request_id) { Self::handle_response(&mut self.incoming_responses, entry, message); } else { - self.incoming_requests.push_back((message, source)); + self.incoming_requests + .push_back(NetlinkEvent::Message((message, source))); } } + pub fn handle_buffer_full(&mut self) { + self.incoming_requests.push_back(NetlinkEvent::Overrun); + } + fn handle_response( incoming_responses: &mut VecDeque>, entry: hash_map::OccupiedEntry>, diff --git a/netlink-sys/src/socket.rs b/netlink-sys/src/socket.rs index f3c04a92..ad306898 100644 --- a/netlink-sys/src/socket.rs +++ b/netlink-sys/src/socket.rs @@ -81,7 +81,7 @@ impl Socket { let res = unsafe { libc::socket( libc::PF_NETLINK, - libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, + libc::SOCK_RAW | libc::SOCK_CLOEXEC, protocol as libc::c_int, ) }; @@ -446,6 +446,18 @@ impl Socket { let res = getsockopt::(self.0, libc::SOL_NETLINK, libc::NETLINK_CAP_ACK)?; Ok(res == 1) } + + pub fn set_buf_size(&mut self, tx: i32, rx: i32) -> Result<()> { + setsockopt(self.0, libc::SOL_SOCKET, libc::SO_SNDBUF, tx as libc::c_int)?; + setsockopt(self.0, libc::SOL_SOCKET, libc::SO_RCVBUF, rx as libc::c_int)?; + Ok(()) + } + + pub fn get_buf_size(&self) -> Result<(i32, i32)> { + let tx = getsockopt::(self.0, libc::SOL_SOCKET, libc::SO_SNDBUF)?; + let rx = getsockopt::(self.0, libc::SOL_SOCKET, libc::SO_RCVBUF)?; + Ok((tx, rx)) + } } /// Wrapper around `getsockopt`: diff --git a/rtnetlink/examples/ip_monitor.rs b/rtnetlink/examples/ip_monitor.rs index 740a961d..7e797ad5 100644 --- a/rtnetlink/examples/ip_monitor.rs +++ b/rtnetlink/examples/ip_monitor.rs @@ -3,6 +3,7 @@ use futures::stream::StreamExt; use netlink_packet_route::constants::*; +use netlink_proto::packet::NetlinkEvent; use rtnetlink::{ new_connection, sys::{AsyncSocket, SocketAddr}, @@ -25,8 +26,8 @@ async fn main() -> Result<(), String> { // // handle - `Handle` to the `Connection`. Used to send/recv netlink messages. // - // messages - A channel receiver. - let (mut conn, mut _handle, mut messages) = new_connection().map_err(|e| format!("{}", e))?; + // events - A channel receiver. + let (mut conn, mut _handle, mut events) = new_connection().map_err(|e| format!("{}", e))?; // These flags specify what kinds of broadcast messages we want to listen for. let groups = nl_mgrp(RTNLGRP_LINK) @@ -59,10 +60,15 @@ async fn main() -> Result<(), String> { // Create message to enable }); - // Start receiving events through `messages` channel. - while let Some((message, _)) = messages.next().await { - let payload = message.payload; - println!("{:?}", payload); + // Start receiving events through `events` channel. + while let Some(event) = events.next().await { + match event { + NetlinkEvent::Message((message, _)) => { + let payload = message.payload; + println!("Route change message - {:?}", payload); + } + NetlinkEvent::Overrun => println!("Netlink socket overrun. Some messages were lost"), + } } Ok(()) } diff --git a/rtnetlink/examples/listen.rs b/rtnetlink/examples/listen.rs index 9a99dca1..1f1fd2ea 100644 --- a/rtnetlink/examples/listen.rs +++ b/rtnetlink/examples/listen.rs @@ -5,6 +5,7 @@ use futures::stream::StreamExt; +use netlink_proto::packet::NetlinkEvent; use rtnetlink::{ constants::{RTMGRP_IPV4_ROUTE, RTMGRP_IPV6_ROUTE}, new_connection, @@ -14,7 +15,7 @@ use rtnetlink::{ #[tokio::main] async fn main() -> Result<(), String> { // Open the netlink socket - let (mut connection, _, mut messages) = new_connection().map_err(|e| format!("{}", e))?; + let (mut connection, _, mut events) = new_connection().map_err(|e| format!("{}", e))?; // These flags specify what kinds of broadcast messages we want to listen for. let mgroup_flags = RTMGRP_IPV4_ROUTE | RTMGRP_IPV6_ROUTE; @@ -29,9 +30,14 @@ async fn main() -> Result<(), String> { .expect("failed to bind"); tokio::spawn(connection); - while let Some((message, _)) = messages.next().await { - let payload = message.payload; - println!("Route change message - {:?}", payload); + while let Some(event) = events.next().await { + match event { + NetlinkEvent::Message((message, _)) => { + let payload = message.payload; + println!("Route change message - {:?}", payload); + } + NetlinkEvent::Overrun => println!("Netlink socket overrun. Some messages were lost"), + } } Ok(()) } diff --git a/rtnetlink/src/connection.rs b/rtnetlink/src/connection.rs index 20c9e7e4..7714d7bb 100644 --- a/rtnetlink/src/connection.rs +++ b/rtnetlink/src/connection.rs @@ -5,7 +5,7 @@ use std::io; use futures::channel::mpsc::UnboundedReceiver; use crate::{ - packet::{NetlinkMessage, RtnlMessage}, + packet::{NetlinkEvent, NetlinkMessage, RtnlMessage}, proto::Connection, sys::{protocols::NETLINK_ROUTE, AsyncSocket, SocketAddr}, Handle, @@ -16,7 +16,7 @@ use crate::{ pub fn new_connection() -> io::Result<( Connection, Handle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> { new_connection_with_socket() } @@ -25,7 +25,7 @@ pub fn new_connection() -> io::Result<( pub fn new_connection_with_socket() -> io::Result<( Connection, Handle, - UnboundedReceiver<(NetlinkMessage, SocketAddr)>, + UnboundedReceiver, SocketAddr)>>, )> where S: AsyncSocket,