Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the length of a socket address when calling recvmsg on Linux #2041

Merged
merged 11 commits into from
Jul 17, 2023
2 changes: 1 addition & 1 deletion src/sys/socket/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1679,7 +1679,7 @@ impl PartialEq for SockaddrStorage {
}
}

mod private {
pub(super) mod private {
pub trait SockaddrLikePriv {
/// Returns a mutable raw pointer to the inner structure.
///
Expand Down
34 changes: 18 additions & 16 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,7 @@ impl<S> MultiHeaders<S> {
{
// we will be storing pointers to addresses inside mhdr - convert it into boxed
// slice so it can'be changed later by pushing anything into self.addresses
let mut addresses = vec![std::mem::MaybeUninit::uninit(); num_slices].into_boxed_slice();
let mut addresses = vec![std::mem::MaybeUninit::<S>::uninit(); num_slices].into_boxed_slice();

let msg_controllen = cmsg_buffer.as_ref().map_or(0, |v| v.capacity());

Expand All @@ -1626,7 +1626,9 @@ impl<S> MultiHeaders<S> {
Some(v) => ((&v[ix * msg_controllen] as *const u8), msg_controllen),
None => (std::ptr::null(), 0),
};
let msg_hdr = unsafe { pack_mhdr_to_receive(std::ptr::null(), 0, ptr, cap, address.as_mut_ptr()) };
let msg_hdr = unsafe {
pack_mhdr_to_receive(std::ptr::null(), 0, ptr, cap, <S as addr::private::SockaddrLikePriv>::as_mut_ptr(address.assume_init_mut()).cast())
};
libc::mmsghdr {
msg_hdr,
msg_len: 0,
Expand Down Expand Up @@ -1761,7 +1763,7 @@ where
mmsghdr.msg_hdr,
mmsghdr.msg_len as isize,
self.rmm.msg_controllen,
address,
Some(address),
)
})
}
Expand Down Expand Up @@ -1914,7 +1916,7 @@ unsafe fn read_mhdr<'a, 'i, S>(
mhdr: msghdr,
r: isize,
msg_controllen: usize,
address: S,
address: Option<S>,
) -> RecvMsg<'a, 'i, S>
where S: SockaddrLike
{
Expand All @@ -1933,7 +1935,7 @@ unsafe fn read_mhdr<'a, 'i, S>(
RecvMsg {
bytes: r as usize,
cmsghdr,
address: Some(address),
address,
flags: MsgFlags::from_bits_truncate(mhdr.msg_flags),
mhdr,
iobufs: std::marker::PhantomData,
Expand All @@ -1951,22 +1953,19 @@ unsafe fn read_mhdr<'a, 'i, S>(
/// headers are not used
///
/// Buffers must remain valid for the whole lifetime of msghdr
unsafe fn pack_mhdr_to_receive<S>(
unsafe fn pack_mhdr_to_receive(
iov_buffer: *const IoSliceMut,
iov_buffer_len: usize,
cmsg_buffer: *const u8,
cmsg_capacity: usize,
address: *mut S,
) -> msghdr
where
S: SockaddrLike
{
address: *mut libc::sockaddr_storage,
) -> msghdr {
// Musl's msghdr has private fields, so this is the only way to
// initialize it.
let mut mhdr = mem::MaybeUninit::<msghdr>::zeroed();
let p = mhdr.as_mut_ptr();
(*p).msg_name = (*address).as_mut_ptr() as *mut c_void;
(*p).msg_namelen = S::size();
(*p).msg_name = address as *mut c_void;
(*p).msg_namelen = mem::size_of::<libc::sockaddr_storage>() as u32;
(*p).msg_iov = iov_buffer as *mut iovec;
(*p).msg_iovlen = iov_buffer_len as _;
(*p).msg_control = cmsg_buffer as *mut c_void;
Expand Down Expand Up @@ -2048,20 +2047,23 @@ pub fn recvmsg<'a, 'outer, 'inner, S>(fd: RawFd, iov: &'outer mut [IoSliceMut<'i
where S: SockaddrLike + 'a,
'inner: 'outer
{
let mut address = mem::MaybeUninit::uninit();
let mut address: libc::sockaddr_storage = unsafe { mem::MaybeUninit::zeroed().assume_init() };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use sockaddr_storage instead of S?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't pass tests when I used S, I think because () is shorter than the actual socket address type so it ran out of room (though I didn't check for the cause, so I might be wrong about that). Though the point is moot as I put it back to how it was before for implementing your suggestion about the different approach.

let address_ptr: *mut libc::sockaddr_storage = &mut address as *mut libc::sockaddr_storage;

let (msg_control, msg_controllen) = cmsg_buffer.as_mut()
.map(|v| (v.as_mut_ptr(), v.capacity()))
.unwrap_or((ptr::null_mut(), 0));
let mut mhdr = unsafe {
pack_mhdr_to_receive(iov.as_ref().as_ptr(), iov.len(), msg_control, msg_controllen, address.as_mut_ptr())
pack_mhdr_to_receive(iov.as_ref().as_ptr(), iov.len(), msg_control, msg_controllen, address_ptr)
};

let ret = unsafe { libc::recvmsg(fd, &mut mhdr, flags.bits()) };

let r = Errno::result(ret)?;

Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address.assume_init()) })
let address = unsafe { S::from_raw(address_ptr.cast::<libc::sockaddr>(), Some(mhdr.msg_namelen)) };

Ok(unsafe { read_mhdr(mhdr, r, msg_controllen, address) })
}
}

Expand Down
44 changes: 44 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,50 @@ pub fn test_socketpair() {
assert_eq!(&buf[..], b"hello");
}

#[test]
pub fn test_recvmsg_sockaddr_un() {
use nix::sys::socket::{
self, bind, socket, AddressFamily, MsgFlags, SockFlag, SockType,
};

let tempdir = tempfile::tempdir().unwrap();
let sockname = tempdir.path().join("sock");
let sock = socket(
AddressFamily::Unix,
SockType::Datagram,
SockFlag::empty(),
None,
)
.expect("socket failed");
let sockaddr = UnixAddr::new(&sockname).unwrap();
bind(sock, &sockaddr).expect("bind failed");

// Send a message
let send_buffer = "hello".as_bytes();
if let Err(e) = socket::sendmsg(
sock,
&[std::io::IoSlice::new(send_buffer)],
&[],
MsgFlags::empty(),
Some(&sockaddr),
) {
print!("Couldn't send ({e:?}), so skipping test");
return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print!("Couldn't send ({e:?}), so skipping test");
return;
skip!("Couldn't send ({e:?}), so skipping test");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 72a2f56

}

// Receive the message
let mut recv_buffer = [0u8; 32];
let received = socket::recvmsg(
sock,
&mut [std::io::IoSliceMut::new(&mut recv_buffer)],
None,
MsgFlags::empty(),
)
.unwrap();
// Check the address in the received message
assert_eq!(sockaddr, received.address.unwrap());
}

#[test]
pub fn test_std_conversions() {
use nix::sys::socket::*;
Expand Down