From 208b80b65d9a54bac3172b97af81cfe90dd6412b Mon Sep 17 00:00:00 2001 From: Andy Grover Date: Wed, 22 May 2024 16:04:34 -0700 Subject: [PATCH] recvmsg: Check if CMSG buffer was too small and return an error (#2413) If MSG_CTRUNC is set, it is not safe to iterate the cmsgs, since they could have been truncated. Change RecvMsg::cmsgs() to return a Result, and to check for this flag (an API change). Update tests for API change. Add test for too-small buffer. --- changelog/2413.changed.md | 1 + src/sys/socket/mod.rs | 19 ++++++++---- test/sys/test_socket.rs | 64 ++++++++++++++++++++++++--------------- 3 files changed, 53 insertions(+), 31 deletions(-) create mode 100644 changelog/2413.changed.md diff --git a/changelog/2413.changed.md b/changelog/2413.changed.md new file mode 100644 index 0000000000..7bae72f7d8 --- /dev/null +++ b/changelog/2413.changed.md @@ -0,0 +1 @@ +`RecvMsg::cmsgs()` now returns a `Result`, and checks that cmsgs were not truncated. diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 3d1651bd3f..10afacf02d 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -13,6 +13,7 @@ use libc::{self, c_int, size_t, socklen_t}; #[cfg(all(feature = "uio", not(target_os = "redox")))] use libc::{ c_void, iovec, CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE, + MSG_CTRUNC, }; #[cfg(not(target_os = "redox"))] use std::io::{IoSlice, IoSliceMut}; @@ -599,13 +600,19 @@ pub struct RecvMsg<'a, 's, S> { } impl<'a, S> RecvMsg<'a, '_, S> { - /// Iterate over the valid control messages pointed to by this - /// msghdr. - pub fn cmsgs(&self) -> CmsgIterator { - CmsgIterator { + /// Iterate over the valid control messages pointed to by this msghdr. If + /// allocated space for CMSGs was too small it is not safe to iterate, + /// instead return an `Error::ENOBUFS` error. + pub fn cmsgs(&self) -> Result { + + if self.mhdr.msg_flags & MSG_CTRUNC == MSG_CTRUNC { + return Err(Errno::ENOBUFS); + } + + Ok(CmsgIterator { cmsghdr: self.cmsghdr, mhdr: &self.mhdr - } + }) } } @@ -700,7 +707,7 @@ pub enum ControlMessageOwned { /// let mut iov = [IoSliceMut::new(&mut buffer)]; /// let r = recvmsg::(in_socket.as_raw_fd(), &mut iov, Some(&mut cmsgspace), flags) /// .unwrap(); - /// let rtime = match r.cmsgs().next() { + /// let rtime = match r.cmsgs().unwrap().next() { /// Some(ControlMessageOwned::ScmTimestamp(rtime)) => rtime, /// Some(_) => panic!("Unexpected control message"), /// None => panic!("No control message") diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index ee60e62b45..79c97c8720 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -55,7 +55,7 @@ pub fn test_timestamping() { .unwrap(); let mut ts = None; - for c in recv.cmsgs() { + for c in recv.cmsgs().unwrap() { if let ControlMessageOwned::ScmTimestampsns(timestamps) = c { ts = Some(timestamps.system); } @@ -117,7 +117,7 @@ pub fn test_timestamping_realtime() { .unwrap(); let mut ts = None; - for c in recv.cmsgs() { + for c in recv.cmsgs().unwrap() { if let ControlMessageOwned::ScmRealtime(timeval) = c { ts = Some(timeval); } @@ -179,7 +179,7 @@ pub fn test_timestamping_monotonic() { .unwrap(); let mut ts = None; - for c in recv.cmsgs() { + for c in recv.cmsgs().unwrap() { if let ControlMessageOwned::ScmMonotonic(timeval) = c { ts = Some(timeval); } @@ -889,7 +889,7 @@ pub fn test_scm_rights() { ) .unwrap(); - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { if let ControlMessageOwned::ScmRights(fd) = cmsg { assert_eq!(received_r, None); assert_eq!(fd.len(), 1); @@ -1330,7 +1330,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() { .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - let mut cmsgs = msg.cmsgs(); + let mut cmsgs = msg.cmsgs().unwrap(); match cmsgs.next() { Some(ControlMessageOwned::ScmRights(fds)) => { assert_eq!( @@ -1399,7 +1399,7 @@ pub fn test_sendmsg_empty_cmsgs() { ) .unwrap(); - if msg.cmsgs().next().is_some() { + if msg.cmsgs().unwrap().next().is_some() { panic!("unexpected cmsg"); } assert!(!msg @@ -1466,7 +1466,7 @@ fn test_scm_credentials() { .unwrap(); let mut received_cred = None; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { let cred = match cmsg { #[cfg(linux_android)] ControlMessageOwned::ScmCredentials(cred) => cred, @@ -1497,7 +1497,7 @@ fn test_scm_credentials() { #[test] fn test_scm_credentials_and_rights() { let space = cmsg_space!(libc::ucred, RawFd); - test_impl_scm_credentials_and_rights(space); + test_impl_scm_credentials_and_rights(space).unwrap(); } /// Ensure that passing a an oversized control message buffer to recvmsg @@ -1509,11 +1509,23 @@ fn test_scm_credentials_and_rights() { #[test] fn test_too_large_cmsgspace() { let space = vec![0u8; 1024]; - test_impl_scm_credentials_and_rights(space); + test_impl_scm_credentials_and_rights(space).unwrap(); } #[cfg(linux_android)] -fn test_impl_scm_credentials_and_rights(mut space: Vec) { +#[test] +fn test_too_small_cmsgspace() { + let space = vec![0u8; 4]; + assert_eq!( + test_impl_scm_credentials_and_rights(space), + Err(nix::errno::Errno::ENOBUFS) + ); +} + +#[cfg(linux_android)] +fn test_impl_scm_credentials_and_rights( + mut space: Vec, +) -> Result<(), nix::errno::Errno> { use libc::ucred; use nix::sys::socket::sockopt::PassCred; use nix::sys::socket::{ @@ -1573,9 +1585,9 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec) { .unwrap(); let mut received_cred = None; - assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs"); + assert_eq!(msg.cmsgs()?.count(), 2, "expected 2 cmsgs"); - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs()? { match cmsg { ControlMessageOwned::ScmRights(fds) => { assert_eq!(received_r, None, "already received fd"); @@ -1606,6 +1618,8 @@ fn test_impl_scm_credentials_and_rights(mut space: Vec) { read(received_r.as_raw_fd(), &mut buf).unwrap(); assert_eq!(&buf[..], b"world"); close(received_r).unwrap(); + + Ok(()) } // Test creating and using named unix domain sockets @@ -1837,7 +1851,7 @@ pub fn test_recv_ipv4pktinfo() { .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - let mut cmsgs = msg.cmsgs(); + let mut cmsgs = msg.cmsgs().unwrap(); if let Some(ControlMessageOwned::Ipv4PacketInfo(pktinfo)) = cmsgs.next() { let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex"); @@ -1929,11 +1943,11 @@ pub fn test_recvif() { assert!(!msg .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - assert_eq!(msg.cmsgs().count(), 2, "expected 2 cmsgs"); + assert_eq!(msg.cmsgs().unwrap().count(), 2, "expected 2 cmsgs"); let mut rx_recvif = false; let mut rx_recvdstaddr = false; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { match cmsg { ControlMessageOwned::Ipv4RecvIf(dl) => { rx_recvif = true; @@ -2027,10 +2041,10 @@ pub fn test_recvif_ipv4() { assert!(!msg .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs"); + assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs"); let mut rx_recvorigdstaddr = false; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { match cmsg { ControlMessageOwned::Ipv4OrigDstAddr(addr) => { rx_recvorigdstaddr = true; @@ -2113,10 +2127,10 @@ pub fn test_recvif_ipv6() { assert!(!msg .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - assert_eq!(msg.cmsgs().count(), 1, "expected 1 cmsgs"); + assert_eq!(msg.cmsgs().unwrap().count(), 1, "expected 1 cmsgs"); let mut rx_recvorigdstaddr = false; - for cmsg in msg.cmsgs() { + for cmsg in msg.cmsgs().unwrap() { match cmsg { ControlMessageOwned::Ipv6OrigDstAddr(addr) => { rx_recvorigdstaddr = true; @@ -2214,7 +2228,7 @@ pub fn test_recv_ipv6pktinfo() { .flags .intersects(MsgFlags::MSG_TRUNC | MsgFlags::MSG_CTRUNC)); - let mut cmsgs = msg.cmsgs(); + let mut cmsgs = msg.cmsgs().unwrap(); if let Some(ControlMessageOwned::Ipv6PacketInfo(pktinfo)) = cmsgs.next() { let i = if_nametoindex(lo_name.as_bytes()).expect("if_nametoindex"); @@ -2357,7 +2371,7 @@ fn test_recvmsg_timestampns() { flags, ) .unwrap(); - let rtime = match r.cmsgs().next() { + let rtime = match r.cmsgs().unwrap().next() { Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime, Some(_) => panic!("Unexpected control message"), None => panic!("No control message"), @@ -2418,7 +2432,7 @@ fn test_recvmmsg_timestampns() { ) .unwrap() .collect(); - let rtime = match r[0].cmsgs().next() { + let rtime = match r[0].cmsgs().unwrap().next() { Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime, Some(_) => panic!("Unexpected control message"), None => panic!("No control message"), @@ -2508,7 +2522,7 @@ fn test_recvmsg_rxq_ovfl() { MsgFlags::MSG_DONTWAIT, ) { Ok(r) => { - drop_counter = match r.cmsgs().next() { + drop_counter = match r.cmsgs().unwrap().next() { Some(ControlMessageOwned::RxqOvfl(drop_counter)) => { drop_counter } @@ -2687,7 +2701,7 @@ mod linux_errqueue { assert_eq!(msg.address, Some(sock_addr)); // Check for expected control message. - let ext_err = match msg.cmsgs().next() { + let ext_err = match msg.cmsgs().unwrap().next() { Some(cmsg) => testf(&cmsg), None => panic!("No control message"), }; @@ -2878,7 +2892,7 @@ fn test_recvmm2() -> nix::Result<()> { #[cfg(not(any(qemu, target_arch = "aarch64")))] let mut saw_time = false; let mut recvd = 0; - for cmsg in rmsg.cmsgs() { + for cmsg in rmsg.cmsgs().unwrap() { if let ControlMessageOwned::ScmTimestampsns(timestamps) = cmsg { let ts = timestamps.system;