From 06dbe83c60cbfa603c9437aef9dfd119f068b1df Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Thu, 19 Oct 2023 05:56:34 -0700 Subject: [PATCH] Fix `sendmsg_unix`'s address encoding. (#885) (#886) When encoding the address for `sendmsg_unix`, use the `unix` field of `SocketAddrUnix`, since the `unix` field is the `sockaddr_un` that the OS will read. Fixes #884. --- src/backend/libc/net/msghdr.rs | 2 +- src/backend/linux_raw/net/msghdr.rs | 2 +- tests/net/unix.rs | 133 ++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 2 deletions(-) diff --git a/src/backend/libc/net/msghdr.rs b/src/backend/libc/net/msghdr.rs index cb5053982..08b16fc5a 100644 --- a/src/backend/libc/net/msghdr.rs +++ b/src/backend/libc/net/msghdr.rs @@ -115,7 +115,7 @@ pub(crate) fn with_unix_msghdr( ) -> R { f({ let mut h: c::msghdr = unsafe { zeroed() }; - h.msg_name = as_ptr(addr) as _; + h.msg_name = as_ptr(&addr.unix) as _; h.msg_namelen = addr.addr_len(); h.msg_iov = iov.as_ptr() as _; h.msg_iovlen = msg_iov_len(iov.len()); diff --git a/src/backend/linux_raw/net/msghdr.rs b/src/backend/linux_raw/net/msghdr.rs index 06bcd594d..af5cad4c4 100644 --- a/src/backend/linux_raw/net/msghdr.rs +++ b/src/backend/linux_raw/net/msghdr.rs @@ -131,7 +131,7 @@ pub(crate) fn with_unix_msghdr( f: impl FnOnce(c::msghdr) -> R, ) -> R { f(c::msghdr { - msg_name: as_ptr(addr) as _, + msg_name: as_ptr(&addr.unix) as _, msg_namelen: addr.addr_len() as _, msg_iov: iov.as_ptr() as _, msg_iovlen: msg_iov_len(iov.len()), diff --git a/tests/net/unix.rs b/tests/net/unix.rs index 0c571618b..028bf2e01 100644 --- a/tests/net/unix.rs +++ b/tests/net/unix.rs @@ -283,6 +283,113 @@ fn do_test_unix_msg(addr: SocketAddrUnix) { server.join().unwrap(); } +/// Similar to `do_test_unix_msg` but uses an unconnected socket and +/// `sendmsg_unix` instead of `sendmsg`. +#[cfg(not(any(target_os = "redox", target_os = "wasi")))] +fn do_test_unix_msg_unconnected(addr: SocketAddrUnix) { + use rustix::io::{IoSlice, IoSliceMut}; + use rustix::net::{recvmsg, sendmsg_unix, RecvFlags, SendFlags}; + + let server = { + let runs: &[i32] = &[3, 184, 187, 0]; + let data_socket = + socket(AddressFamily::UNIX, SocketType::DGRAM, Default::default()).unwrap(); + bind_unix(&data_socket, &addr).unwrap(); + + move || { + let mut buffer = vec![0; BUFFER_SIZE]; + for expected_sum in runs { + let mut sum = 0; + loop { + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap() + .bytes; + + assert_ne!(&buffer[..nread], b"exit"); + if &buffer[..nread] == b"sum" { + break; + } + + sum += i32::from_str(&String::from_utf8_lossy(&buffer[..nread])).unwrap(); + } + + assert_eq!(sum, *expected_sum); + } + let nread = recvmsg( + &data_socket, + &mut [IoSliceMut::new(&mut buffer)], + &mut Default::default(), + RecvFlags::empty(), + ) + .unwrap() + .bytes; + + assert_eq!(&buffer[..nread], b"exit"); + } + }; + + let client = move || { + let runs: &[&[&str]] = &[&["1", "2"], &["4", "77", "103"], &["5", "78", "104"], &[]]; + + for args in runs { + let data_socket = + socket(AddressFamily::UNIX, SocketType::DGRAM, Default::default()).unwrap(); + + for arg in *args { + sendmsg_unix( + &data_socket, + &addr, + &[IoSlice::new(arg.as_bytes())], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + sendmsg_unix( + &data_socket, + &addr, + &[IoSlice::new(b"sum")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + } + + let data_socket = + socket(AddressFamily::UNIX, SocketType::DGRAM, Default::default()).unwrap(); + sendmsg_unix( + &data_socket, + &addr, + &[IoSlice::new(b"exit")], + &mut Default::default(), + SendFlags::empty(), + ) + .unwrap(); + }; + + let server = thread::Builder::new() + .name("server".to_string()) + .spawn(move || { + server(); + }) + .unwrap(); + + let client = thread::Builder::new() + .name("client".to_string()) + .spawn(move || { + client(); + }) + .unwrap(); + + client.join().unwrap(); + server.join().unwrap(); +} + #[cfg(not(any(target_os = "redox", target_os = "wasi")))] #[test] fn test_unix_msg() { @@ -295,6 +402,19 @@ fn test_unix_msg() { unlinkat(cwd(), path, AtFlags::empty()).unwrap(); } +/// Like `test_unix_msg` but tests `do_test_unix_msg_unconnected`. +#[cfg(not(any(target_os = "espidf", target_os = "redox", target_os = "wasi")))] +#[test] +fn test_unix_msg_unconnected() { + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new(&path).unwrap(); + do_test_unix_msg_unconnected(name); + + unlinkat(cwd(), path, AtFlags::empty()).unwrap(); +} + #[cfg(linux_kernel)] #[test] fn test_abstract_unix_msg() { @@ -307,6 +427,19 @@ fn test_abstract_unix_msg() { do_test_unix_msg(name); } +/// Like `test_abstract_unix_msg` but tests `do_test_unix_msg_unconnected`. +#[cfg(linux_kernel)] +#[test] +fn test_abstract_unix_msg_unconnected() { + use std::os::unix::ffi::OsStrExt; + + let tmpdir = tempfile::tempdir().unwrap(); + let path = tmpdir.path().join("scp_4804"); + + let name = SocketAddrUnix::new_abstract_name(path.as_os_str().as_bytes()).unwrap(); + do_test_unix_msg_unconnected(name); +} + #[cfg(not(any(target_os = "redox", target_os = "wasi")))] #[test] fn test_unix_msg_with_scm_rights() {