Skip to content
This repository has been archived by the owner on Oct 26, 2022. It is now read-only.

Commit

Permalink
Make tokio and smol not conflicting in netlink-proto
Browse files Browse the repository at this point in the history
- This adds a new type paramater to connection for the socket being used.
- can drop tokio/smol selection from netlink-packet-audit
  • Loading branch information
stbuehler committed Nov 14, 2021
1 parent 26a1d7a commit a3241c4
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 36 deletions.
4 changes: 2 additions & 2 deletions audit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ netlink-proto = { default-features = false, version = "0.8" }

[features]
default = ["tokio_socket"]
tokio_socket = ["netlink-proto/tokio_socket", "netlink-packet-audit/tokio_socket"]
smol_socket = ["netlink-proto/smol_socket", "netlink-packet-audit/smol_socket"]
tokio_socket = ["netlink-proto/tokio_socket"]
smol_socket = ["netlink-proto/smol_socket"]

[dev-dependencies]
tokio = { version = "1.0.1", default-features = false, features = ["macros", "rt-multi-thread"] }
Expand Down
18 changes: 17 additions & 1 deletion audit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,30 @@ use std::io;
use futures::channel::mpsc::UnboundedReceiver;

#[allow(clippy::type_complexity)]
#[cfg(feature = "tokio_socket")]
pub fn new_connection() -> io::Result<(
proto::Connection<packet::AuditMessage, packet::NetlinkAuditCodec>,
proto::Connection<packet::AuditMessage, sys::TokioSocket, packet::NetlinkAuditCodec>,
Handle,
UnboundedReceiver<(
packet::NetlinkMessage<packet::AuditMessage>,
sys::SocketAddr,
)>,
)> {
new_connection_with_socket()
}

#[allow(clippy::type_complexity)]
pub fn new_connection_with_socket<S>() -> io::Result<(
proto::Connection<packet::AuditMessage, S, packet::NetlinkAuditCodec>,
Handle,
UnboundedReceiver<(
packet::NetlinkMessage<packet::AuditMessage>,
sys::SocketAddr,
)>,
)>
where
S: sys::AsyncSocket,
{
let (conn, handle, messages) =
netlink_proto::new_connection_with_codec(sys::protocols::NETLINK_AUDIT)?;
Ok((conn, Handle::new(handle), messages))
Expand Down
17 changes: 15 additions & 2 deletions ethtool/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,29 @@ use futures::channel::mpsc::UnboundedReceiver;
use genetlink::message::RawGenlMessage;
use netlink_packet_core::NetlinkMessage;
use netlink_proto::Connection;
use netlink_sys::SocketAddr;
use netlink_sys::{AsyncSocket, SocketAddr};

use crate::EthtoolHandle;

#[cfg(feature = "tokio_socket")]
#[allow(clippy::type_complexity)]
pub fn new_connection() -> io::Result<(
Connection<RawGenlMessage>,
EthtoolHandle,
UnboundedReceiver<(NetlinkMessage<RawGenlMessage>, SocketAddr)>,
)> {
let (conn, handle, messages) = genetlink::new_connection()?;
new_connection_with_socket()
}

#[allow(clippy::type_complexity)]
pub fn new_connection_with_socket<S>() -> io::Result<(
Connection<RawGenlMessage, S>,
EthtoolHandle,
UnboundedReceiver<(NetlinkMessage<RawGenlMessage>, SocketAddr)>,
)>
where
S: AsyncSocket,
{
let (conn, handle, messages) = genetlink::new_connection_with_socket()?;
Ok((conn, EthtoolHandle::new(handle), messages))
}
2 changes: 2 additions & 0 deletions ethtool/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ mod pause;
mod ring;

pub use coalesce::{EthtoolCoalesceAttr, EthtoolCoalesceGetRequest, EthtoolCoalesceHandle};
#[cfg(feature = "tokio_socket")]
pub use connection::new_connection;
pub use connection::new_connection_with_socket;
pub use error::EthtoolError;
pub use feature::{
EthtoolFeatureAttr,
Expand Down
18 changes: 16 additions & 2 deletions genetlink/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use futures::channel::mpsc::UnboundedReceiver;
use netlink_packet_core::NetlinkMessage;
use netlink_proto::{
self,
sys::{protocols::NETLINK_GENERIC, SocketAddr},
sys::{protocols::NETLINK_GENERIC, AsyncSocket, SocketAddr},
Connection,
};
use std::io;
Expand All @@ -21,12 +21,26 @@ use std::io;
///
/// The [`GenetlinkHandle`] can send and receive any type of generic netlink message.
/// And it can automatic resolve the generic family id before sending.
#[cfg(feature = "tokio_socket")]
#[allow(clippy::type_complexity)]
pub fn new_connection() -> io::Result<(
Connection<RawGenlMessage>,
GenetlinkHandle,
UnboundedReceiver<(NetlinkMessage<RawGenlMessage>, SocketAddr)>,
)> {
let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_GENERIC)?;
new_connection_with_socket()
}

/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling
#[allow(clippy::type_complexity)]
pub fn new_connection_with_socket<S>() -> io::Result<(
Connection<RawGenlMessage, S>,
GenetlinkHandle,
UnboundedReceiver<(NetlinkMessage<RawGenlMessage>, SocketAddr)>,
)>
where
S: AsyncSocket,
{
let (conn, handle, messages) = netlink_proto::new_connection_with_socket(NETLINK_GENERIC)?;
Ok((conn, GenetlinkHandle::new(handle), messages))
}
2 changes: 2 additions & 0 deletions genetlink/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ mod handle;
pub mod message;
mod resolver;

#[cfg(feature = "tokio_socket")]
pub use connection::new_connection;
pub use connection::new_connection_with_socket;
pub use error::GenetlinkError;
pub use handle::GenetlinkHandle;
5 changes: 0 additions & 5 deletions netlink-packet-audit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ readme = "../README.md"
repository = "https://github.com/little-dude/netlink"
description = "netlink packet types"

[features]
default = ["tokio_socket"]
tokio_socket = ["netlink-proto/tokio_socket"]
smol_socket = ["netlink-proto/smol_socket"]

[dependencies]
anyhow = "1.0.31"
bytes = "1.0"
Expand Down
21 changes: 14 additions & 7 deletions netlink-proto/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,26 @@ use netlink_packet_core::{
use crate::{
codecs::{NetlinkCodec, NetlinkMessageCodec},
framed::NetlinkFramed,
sys::{AsyncSocket, Socket, SocketAddr},
sys::{AsyncSocket, SocketAddr},
Protocol,
Request,
Response,
};

#[cfg(feature = "tokio_socket")]
use netlink_sys::TokioSocket as DefaultSocket;
#[cfg(not(feature = "tokio_socket"))]
type DefaultSocket = ();

/// Connection to a Netlink socket, running in the background.
///
/// [`ConnectionHandle`](struct.ConnectionHandle.html) are used to pass new requests to the
/// `Connection`, that in turn, sends them through the netlink socket.
pub struct Connection<T, C = NetlinkCodec>
pub struct Connection<T, S = DefaultSocket, C = NetlinkCodec>
where
T: Debug + NetlinkSerializable + NetlinkDeserializable,
{
socket: NetlinkFramed<T, C>,
socket: NetlinkFramed<T, S, C>,

protocol: Protocol<T, UnboundedSender<NetlinkMessage<T>>>,

Expand All @@ -50,17 +55,18 @@ where
socket_closed: bool,
}

impl<T, C> Connection<T, C>
impl<T, S, C> Connection<T, S, C>
where
T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
pub(crate) fn new(
requests_rx: UnboundedReceiver<Request<T>>,
unsolicited_messages_tx: UnboundedSender<(NetlinkMessage<T>, SocketAddr)>,
protocol: isize,
) -> io::Result<Self> {
let socket = Socket::new(protocol)?;
let socket = S::new(protocol)?;
Ok(Connection {
socket: NetlinkFramed::new(socket),
protocol: Protocol::new(),
Expand All @@ -70,7 +76,7 @@ where
})
}

pub fn socket_mut(&mut self) -> &mut Socket {
pub fn socket_mut(&mut self) -> &mut S {
self.socket.get_mut()
}

Expand Down Expand Up @@ -251,9 +257,10 @@ where
}
}

impl<T, C> Future for Connection<T, C>
impl<T, S, C> Future for Connection<T, S, C>
where
T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Output = ();
Expand Down
22 changes: 12 additions & 10 deletions netlink-proto/src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use log::error;

use crate::{
codecs::NetlinkMessageCodec,
sys::{AsyncSocket, Socket, SocketAddr},
sys::{AsyncSocket, SocketAddr},
};
use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable};

pub struct NetlinkFramed<T, C> {
socket: Socket,
pub struct NetlinkFramed<T, S, C> {
socket: S,
msg_type: PhantomData<fn(T) -> T>, // invariant
codec: PhantomData<fn(C) -> C>, // invariant
reader: BytesMut,
Expand All @@ -27,9 +27,10 @@ pub struct NetlinkFramed<T, C> {
flushed: bool,
}

impl<T, C> Stream for NetlinkFramed<T, C>
impl<T, S, C> Stream for NetlinkFramed<T, S, C>
where
T: NetlinkDeserializable + Debug,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Item = (NetlinkMessage<T>, SocketAddr);
Expand Down Expand Up @@ -66,9 +67,10 @@ where
}
}

impl<T, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, C>
impl<T, S, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, S, C>
where
T: NetlinkSerializable + Debug,
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Error = io::Error;
Expand Down Expand Up @@ -142,11 +144,11 @@ where
const INITIAL_READER_CAPACITY: usize = 64 * 1024;
const INITIAL_WRITER_CAPACITY: usize = 8 * 1024;

impl<T, C> NetlinkFramed<T, C> {
impl<T, S, C> NetlinkFramed<T, S, C> {
/// Create a new `NetlinkFramed` backed by the given socket and codec.
///
/// See struct level documentation for more details.
pub fn new(socket: Socket) -> Self {
pub fn new(socket: S) -> Self {
Self {
socket,
msg_type: PhantomData,
Expand All @@ -166,7 +168,7 @@ impl<T, C> NetlinkFramed<T, C> {
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_ref(&self) -> &Socket {
pub fn get_ref(&self) -> &S {
&self.socket
}

Expand All @@ -178,12 +180,12 @@ impl<T, C> NetlinkFramed<T, C> {
/// Care should be taken to not tamper with the underlying stream of data
/// coming in as it may corrupt the stream of frames otherwise being worked
/// with.
pub fn get_mut(&mut self) -> &mut Socket {
pub fn get_mut(&mut self) -> &mut S {
&mut self.socket
}

/// Consumes the `Framed`, returning its underlying I/O stream.
pub fn into_inner(self) -> Socket {
pub fn into_inner(self) -> S {
self.socket
}
}
28 changes: 23 additions & 5 deletions netlink-proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ pub mod sys {
pub use netlink_sys::{protocols, AsyncSocket, AsyncSocketExt, SocketAddr};

#[cfg(feature = "tokio_socket")]
pub use netlink_sys::TokioSocket as Socket;
pub use netlink_sys::TokioSocket;

#[cfg(feature = "smol_socket")]
pub use netlink_sys::SmolSocket as Socket;
pub use netlink_sys::SmolSocket;
}

/// Create a new Netlink connection for the given Netlink protocol, and returns a handle to that
Expand All @@ -218,6 +218,7 @@ pub mod sys {
/// handle to send messages.
///
/// [protos]: crate::sys::protocols
#[cfg(feature = "tokio_socket")]
#[allow(clippy::type_complexity)]
pub fn new_connection<T>(
protocol: isize,
Expand All @@ -232,17 +233,34 @@ where
new_connection_with_codec(protocol)
}

/// Variant of [`new_connection`] that allows specifying a separate codec
/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling
#[allow(clippy::type_complexity)]
pub fn new_connection_with_codec<T, C>(
pub fn new_connection_with_socket<T, S>(
protocol: isize,
) -> io::Result<(
Connection<T, C>,
Connection<T, S>,
ConnectionHandle<T>,
UnboundedReceiver<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
)>
where
T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin,
S: sys::AsyncSocket,
{
new_connection_with_codec(protocol)
}

/// Variant of [`new_connection`] that allows specifying a socket type to use for async handling and a special codec
#[allow(clippy::type_complexity)]
pub fn new_connection_with_codec<T, S, C>(
protocol: isize,
) -> io::Result<(
Connection<T, S, C>,
ConnectionHandle<T>,
UnboundedReceiver<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
)>
where
T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin,
S: sys::AsyncSocket,
C: NetlinkMessageCodec,
{
let (requests_tx, requests_rx) = unbounded::<Request<T>>();
Expand Down
17 changes: 15 additions & 2 deletions rtnetlink/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,29 @@ use futures::channel::mpsc::UnboundedReceiver;
use crate::{
packet::{NetlinkMessage, RtnlMessage},
proto::Connection,
sys::{protocols::NETLINK_ROUTE, SocketAddr},
sys::{protocols::NETLINK_ROUTE, AsyncSocket, SocketAddr},
Handle,
};

#[cfg(feature = "tokio_socket")]
#[allow(clippy::type_complexity)]
pub fn new_connection() -> io::Result<(
Connection<RtnlMessage>,
Handle,
UnboundedReceiver<(NetlinkMessage<RtnlMessage>, SocketAddr)>,
)> {
let (conn, handle, messages) = netlink_proto::new_connection(NETLINK_ROUTE)?;
new_connection_with_socket()
}

#[allow(clippy::type_complexity)]
pub fn new_connection_with_socket<S>() -> io::Result<(
Connection<RtnlMessage, S>,
Handle,
UnboundedReceiver<(NetlinkMessage<RtnlMessage>, SocketAddr)>,
)>
where
S: AsyncSocket,
{
let (conn, handle, messages) = netlink_proto::new_connection_with_socket(NETLINK_ROUTE)?;
Ok((conn, Handle::new(handle), messages))
}

0 comments on commit a3241c4

Please sign in to comment.