Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the length of a socket address when calling recvmsg on Linux #2041

Merged
merged 11 commits into from
Jul 17, 2023
103 changes: 99 additions & 4 deletions src/sys/socket/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,22 @@ impl SockaddrLike for UnixAddr {
{
mem::size_of::<libc::sockaddr_un>() as libc::socklen_t
}

unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
self.sun_len = new_length as u8;
} else {
self.sun.sun_len = new_length as u8;
JarredAllen marked this conversation as resolved.
Show resolved Hide resolved
}
};
Ok(())
}
}

impl AsRef<libc::sockaddr_un> for UnixAddr {
Expand Down Expand Up @@ -912,8 +928,30 @@ pub trait SockaddrLike: private::SockaddrLikePriv {
{
mem::size_of::<Self>() as libc::socklen_t
}

/// Set the length of this socket address
///
/// This method may only be called on socket addresses whose lenghts are dynamic, and it
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// This method may only be called on socket addresses whose lenghts are dynamic, and it
/// This method may only be called on socket addresses whose lengths are dynamic, and it

/// returns an error if called on a type whose length is static.
///
/// # Safety
///
/// `new_length` must be a valid length for this type of address. Specifically, reads of that
/// length from `self` must be valid.
unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add a default implementation here, then you can avoid a lot of boilerplate elsewhere.

Suggested change
unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic>;
unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}

}

/// The error returned by [`SockaddrLike::set_length`] on an address whose length is statically
/// fixed.
#[derive(Copy, Clone, Debug)]
pub struct SocketAddressLengthNotDynamic;
impl fmt::Display for SocketAddressLengthNotDynamic {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("Attempted to set length on socket whose length is statically fixed")
}
}
impl std::error::Error for SocketAddressLengthNotDynamic {}

impl private::SockaddrLikePriv for () {
fn as_mut_ptr(&mut self) -> *mut libc::sockaddr {
ptr::null_mut()
Expand Down Expand Up @@ -946,6 +984,10 @@ impl SockaddrLike for () {
fn len(&self) -> libc::socklen_t {
0
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

/// An IPv4 socket address
Expand Down Expand Up @@ -1015,6 +1057,10 @@ impl SockaddrLike for SockaddrIn {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

#[cfg(feature = "net")]
Expand Down Expand Up @@ -1134,6 +1180,10 @@ impl SockaddrLike for SockaddrIn6 {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

#[cfg(feature = "net")]
Expand Down Expand Up @@ -1361,6 +1411,27 @@ impl SockaddrLike for SockaddrStorage {
None => mem::size_of_val(self) as libc::socklen_t,
}
}

unsafe fn set_length(&mut self, new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
match self.as_unix_addr_mut() {
Some(addr) => {
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
addr.sun_len = new_length as u8;
} else {
addr.sun.sun_len = new_length as u8;
}
}
Ok(())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cfg_if! {
if #[cfg(any(target_os = "android",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "redox",
))] {
addr.sun_len = new_length as u8;
} else {
addr.sun.sun_len = new_length as u8;
}
}
Ok(())
addr.set_length(new_length)

},
None => Err(SocketAddressLengthNotDynamic),
}
}
}

macro_rules! accessors {
Expand Down Expand Up @@ -1679,7 +1750,7 @@ impl PartialEq for SockaddrStorage {
}
}

mod private {
pub(super) mod private {
pub trait SockaddrLikePriv {
/// Returns a mutable raw pointer to the inner structure.
///
Expand Down Expand Up @@ -1754,6 +1825,10 @@ pub mod netlink {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_nl> for NetlinkAddr {
Expand Down Expand Up @@ -1803,6 +1878,10 @@ pub mod alg {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_alg> for AlgAddr {
Expand Down Expand Up @@ -1902,7 +1981,7 @@ pub mod sys_control {
use std::{fmt, mem, ptr};
use std::os::unix::io::RawFd;
use crate::{Errno, Result};
use super::{private, SockaddrLike};
use super::{private, SockaddrLike, SocketAddressLengthNotDynamic};

// FIXME: Move type into `libc`
#[repr(C)]
Expand Down Expand Up @@ -1943,6 +2022,10 @@ pub mod sys_control {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_ctl> for SysControlAddr {
Expand Down Expand Up @@ -2007,7 +2090,7 @@ pub mod sys_control {
mod datalink {
feature! {
#![feature = "net"]
use super::{fmt, mem, private, ptr, SockaddrLike};
use super::{fmt, mem, private, ptr, SockaddrLike, SocketAddressLengthNotDynamic};

/// Hardware Address
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down Expand Up @@ -2085,6 +2168,10 @@ mod datalink {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_ll> for LinkAddr {
Expand All @@ -2110,7 +2197,7 @@ mod datalink {
mod datalink {
feature! {
#![feature = "net"]
use super::{fmt, mem, private, ptr, SockaddrLike};
use super::{fmt, mem, private, ptr, SockaddrLike, SocketAddressLengthNotDynamic};

/// Hardware Address
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down Expand Up @@ -2209,6 +2296,10 @@ mod datalink {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_dl> for LinkAddr {
Expand Down Expand Up @@ -2257,6 +2348,10 @@ pub mod vsock {
}
Some(Self(ptr::read_unaligned(addr as *const _)))
}

unsafe fn set_length(&mut self, _new_length: usize) -> std::result::Result<(), SocketAddressLengthNotDynamic> {
Err(SocketAddressLengthNotDynamic)
}
}

impl AsRef<libc::sockaddr_vm> for VsockAddr {
Expand Down
13 changes: 9 additions & 4 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ impl<S> MultiHeaders<S> {
{
// we will be storing pointers to addresses inside mhdr - convert it into boxed
// slice so it can'be changed later by pushing anything into self.addresses
let mut addresses = vec![std::mem::MaybeUninit::uninit(); num_slices].into_boxed_slice();
let mut addresses = vec![std::mem::MaybeUninit::<S>::uninit(); num_slices].into_boxed_slice();

let msg_controllen = cmsg_buffer.as_ref().map_or(0, |v| v.capacity());

Expand Down Expand Up @@ -1914,7 +1914,7 @@ unsafe fn read_mhdr<'a, 'i, S>(
mhdr: msghdr,
r: isize,
msg_controllen: usize,
address: S,
mut address: S,
) -> RecvMsg<'a, 'i, S>
where S: SockaddrLike
{
Expand All @@ -1930,6 +1930,11 @@ unsafe fn read_mhdr<'a, 'i, S>(
}.as_ref()
};

// Ignore errors if this socket address has statically-known length
//
// This is to ensure that unix socket addresses have their length set appropriately.
let _ = address.set_length(mhdr.msg_namelen as usize);

RecvMsg {
bytes: r as usize,
cmsghdr,
Expand Down Expand Up @@ -1965,7 +1970,7 @@ unsafe fn pack_mhdr_to_receive<S>(
// initialize it.
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
let p = mhdr.as_mut_ptr();
(*p).msg_name = (*address).as_mut_ptr() as *mut c_void;
(*p).msg_name = address as *mut c_void;
(*p).msg_namelen = S::size();
(*p).msg_iov = iov_buffer as *mut iovec;
(*p).msg_iovlen = iov_buffer_len as _;
Expand Down Expand Up @@ -2048,7 +2053,7 @@ pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'i
where S: SockaddrLike + 'a,
'inner: 'outer
{
let mut address = mem::MaybeUninit::uninit();
let mut address = mem::MaybeUninit::zeroed();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think it's necessary to zero the address here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally thought it might be necessary when I first started working on this (since the length was being left uninitialized initially), but looking back at it now I think everything is initialized now so I put it back to uninit()


let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
.map(|v| (v.as_mut_ptr(), v.capacity()))
Expand Down
43 changes: 43 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,49 @@ pub fn test_socketpair() {
assert_eq!(&buf[..], b"hello");
}

#[test]
pub fn test_recvmsg_sockaddr_un() {
use nix::sys::socket::{
self, bind, socket, AddressFamily, MsgFlags, SockFlag, SockType,
};

let tempdir = tempfile::tempdir().unwrap();
let sockname = tempdir.path().join("sock");
let sock = socket(
AddressFamily::Unix,
SockType::Datagram,
SockFlag::empty(),
None,
)
.expect("socket failed");
let sockaddr = UnixAddr::new(&sockname).unwrap();
bind(sock, &sockaddr).expect("bind failed");

// Send a message
let send_buffer = "hello".as_bytes();
if let Err(e) = socket::sendmsg(
sock,
&[std::io::IoSlice::new(send_buffer)],
&[],
MsgFlags::empty(),
Some(&sockaddr),
) {
crate::skip!("Couldn't send ({e:?}), so skipping test");
}

// Receive the message
let mut recv_buffer = [0u8; 32];
let received = socket::recvmsg(
sock,
&mut [std::io::IoSliceMut::new(&mut recv_buffer)],
None,
MsgFlags::empty(),
)
.unwrap();
// Check the address in the received message
assert_eq!(sockaddr, received.address.unwrap());
}

#[test]
pub fn test_std_conversions() {
use nix::sys::socket::*;
Expand Down