Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net/uds: add methods to connect/bind with a socket address #1630

Merged
merged 2 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/net/uds/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ pub struct UnixListener {
}

impl UnixListener {
/// Creates a new `UnixListener` bound to the specified socket.
/// Creates a new `UnixListener` bound to the specified socket `path`.
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<UnixListener> {
sys::uds::listener::bind(path.as_ref()).map(UnixListener::from_std)
}

/// Creates a new `UnixListener` bound to the specified socket `address`.
pub fn bind_addr(address: &SocketAddr) -> io::Result<UnixListener> {
sys::uds::listener::bind_addr(address).map(UnixListener::from_std)
}

/// Creates a new `UnixListener` from a standard `net::UnixListener`.
///
/// This function is intended to be used to wrap a Unix listener from the
Expand Down
9 changes: 9 additions & 0 deletions src/net/uds/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::io_source::IoSource;
use crate::net::SocketAddr;
use crate::{event, sys, Interest, Registry, Token};

use std::fmt;
Expand All @@ -22,6 +23,14 @@ impl UnixStream {
sys::uds::stream::connect(path.as_ref()).map(UnixStream::from_std)
}

/// Connects to the socket named by `address`.
///
/// This may return a `WouldBlock` in which case the socket connection
/// cannot be completed immediately. Usually it means the backlog is full.
pub fn connect_addr(address: &SocketAddr) -> io::Result<UnixStream> {
sys::uds::stream::connect_addr(address).map(UnixStream::from_std)
}

/// Creates a new `UnixStream` from a standard `net::UnixStream`.
///
/// This function is intended to be used to wrap a Unix stream from the
Expand Down
8 changes: 8 additions & 0 deletions src/sys/shell/uds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub(crate) mod listener {
os_required!()
}

pub(crate) fn bind_addr(_: &SocketAddr) -> io::Result<net::UnixListener> {
os_required!()
}

pub(crate) fn accept(_: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> {
os_required!()
}
Expand All @@ -61,6 +65,10 @@ pub(crate) mod stream {
os_required!()
}

pub(crate) fn connect_addr(_: &SocketAddr) -> io::Result<net::UnixStream> {
os_required!()
}

pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> {
os_required!()
}
Expand Down
3 changes: 2 additions & 1 deletion src/sys/unix/uds/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ use super::{socket_addr, SocketAddr};
use crate::sys::unix::net::new_socket;

use std::io;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
use std::path::Path;

pub(crate) fn bind(path: &Path) -> io::Result<net::UnixDatagram> {
let (sockaddr, socklen) = socket_addr(path)?;
let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const _;

let socket = unbound()?;
Expand Down
15 changes: 12 additions & 3 deletions src/sys/unix/uds/listener.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
use super::socket_addr;
use crate::net::{SocketAddr, UnixStream};
use crate::sys::unix::net::new_socket;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
use std::path::Path;
use std::{io, mem};

pub(crate) fn bind(path: &Path) -> io::Result<net::UnixListener> {
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr;
let socket_address = {
let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?;
SocketAddr::from_parts(sockaddr, socklen)
};

bind_addr(&socket_address)
}

pub(crate) fn bind_addr(address: &SocketAddr) -> io::Result<net::UnixListener> {
let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let socket = unsafe { net::UnixListener::from_raw_fd(fd) };
syscall!(bind(fd, sockaddr, socklen))?;
let sockaddr = address.raw_sockaddr() as *const libc::sockaddr_un as *const libc::sockaddr;

syscall!(bind(fd, sockaddr, *address.raw_socklen()))?;
syscall!(listen(fd, 1024))?;

Ok(socket)
Expand Down
12 changes: 4 additions & 8 deletions src/sys/unix/uds/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ pub(in crate::sys) fn path_offset(sockaddr: &libc::sockaddr_un) -> usize {

cfg_os_poll! {
use std::cmp::Ordering;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{RawFd, FromRawFd};
use std::path::Path;
use std::{io, mem};

pub(crate) mod datagram;
pub(crate) mod listener;
pub(crate) mod stream;

pub(in crate::sys) fn socket_addr(path: &Path) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
pub(in crate::sys) fn socket_addr(bytes: &[u8]) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
let sockaddr = mem::MaybeUninit::<libc::sockaddr_un>::zeroed();

// This is safe to assume because a `libc::sockaddr_un` filled with `0`
Expand All @@ -39,7 +37,6 @@ cfg_os_poll! {

sockaddr.sun_family = libc::AF_UNIX as libc::sa_family_t;

let bytes = path.as_os_str().as_bytes();
match (bytes.first(), bytes.len().cmp(&sockaddr.sun_path.len())) {
// Abstract paths don't need a null terminator
(Some(&0), Ordering::Greater) => {
Expand Down Expand Up @@ -128,6 +125,7 @@ cfg_os_poll! {
#[cfg(test)]
mod tests {
use super::{path_offset, socket_addr};
use std::os::unix::ffi::OsStrExt;
use std::path::Path;
use std::str;

Expand All @@ -139,7 +137,7 @@ cfg_os_poll! {
// Pathname addresses do have a null terminator, so `socklen` is
// expected to be `PATH_LEN` + `offset` + 1.
let path = Path::new(PATH);
let (sockaddr, actual) = socket_addr(path).unwrap();
let (sockaddr, actual) = socket_addr(path.as_os_str().as_bytes()).unwrap();
let offset = path_offset(&sockaddr);
let expected = PATH_LEN + offset + 1;
assert_eq!(expected as libc::socklen_t, actual)
Expand All @@ -152,9 +150,7 @@ cfg_os_poll! {

// Abstract addresses do not have a null terminator, so `socklen` is
// expected to be `PATH_LEN` + `offset`.
let abstract_path = str::from_utf8(PATH).unwrap();
let path = Path::new(abstract_path);
let (sockaddr, actual) = socket_addr(path).unwrap();
let (sockaddr, actual) = socket_addr(PATH).unwrap();
let offset = path_offset(&sockaddr);
let expected = PATH_LEN + offset;
assert_eq!(expected as libc::socklen_t, actual)
Expand Down
8 changes: 8 additions & 0 deletions src/sys/unix/uds/socketaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ cfg_os_poll! {
SocketAddr { sockaddr, socklen }
}

pub(crate) fn raw_sockaddr(&self) -> &libc::sockaddr_un {
&self.sockaddr
}

pub(crate) fn raw_socklen(&self) -> &libc::socklen_t {
&self.socklen
}

/// Returns `true` if the address is unnamed.
///
/// Documentation reflected in [`SocketAddr`]
Expand Down
15 changes: 12 additions & 3 deletions src/sys/unix/uds/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@ use super::{socket_addr, SocketAddr};
use crate::sys::unix::net::new_socket;

use std::io;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
use std::path::Path;

pub(crate) fn connect(path: &Path) -> io::Result<net::UnixStream> {
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr;
let socket_address = {
let (sockaddr, socklen) = socket_addr(path.as_os_str().as_bytes())?;
SocketAddr::from_parts(sockaddr, socklen)
};

connect_addr(&socket_address)
}

pub(crate) fn connect_addr(address: &SocketAddr) -> io::Result<net::UnixStream> {
let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let socket = unsafe { net::UnixStream::from_raw_fd(fd) };
match syscall!(connect(fd, sockaddr, socklen)) {
let sockaddr = address.raw_sockaddr() as *const libc::sockaddr_un as *const libc::sockaddr;

match syscall!(connect(fd, sockaddr, *address.raw_socklen())) {
Ok(_) => {}
Err(ref err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => return Err(e),
Expand Down
42 changes: 42 additions & 0 deletions tests/unix_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,48 @@ fn unix_stream_connect() {
handle.join().unwrap();
}

#[test]
fn unix_stream_connect_addr() {
let (mut poll, mut events) = init_with_poll();
let barrier = Arc::new(Barrier::new(2));

let path = temp_file("unix_stream_connect_addr");
let listener = net::UnixListener::bind(path.clone()).unwrap();
let mio_listener = mio::net::UnixListener::from_std(listener);

let local_addr = mio_listener.local_addr().unwrap();
let mut stream = UnixStream::connect_addr(&local_addr).unwrap();

let barrier_clone = barrier.clone();
let handle = thread::spawn(move || {
let (stream, _) = mio_listener.accept().unwrap();
barrier_clone.wait();
drop(stream);
});

poll.registry()
.register(
&mut stream,
TOKEN_1,
Interest::READABLE | Interest::WRITABLE,
)
.unwrap();
expect_events(
&mut poll,
&mut events,
vec![ExpectEvent::new(TOKEN_1, Interest::WRITABLE)],
);

barrier.wait();
expect_events(
&mut poll,
&mut events,
vec![ExpectEvent::new(TOKEN_1, Interest::READABLE)],
);

handle.join().unwrap();
}

#[test]
fn unix_stream_from_std() {
smoke_test(
Expand Down