Skip to content

Commit

Permalink
Add Socket::peek_sender()
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Feb 24, 2023
1 parent d616440 commit 1d5e20f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,36 @@ impl Socket {
/// `peek_from` makes the same safety guarantees regarding the `buf`fer as
/// [`recv`].
///
/// # Note: Datagram Sockets
/// For datagram sockets, the behavior of this method when `buf` is smaller than
/// the datagram at the head of the receive queue differs between Windows and
/// Unix-like platforms (Linux, macOS, BSDs, etc: colloquially termed "*nix").
///
/// On *nix platforms, the datagram is truncated to the length of `buf`.
///
/// On Windows, an error corresponding to `WSAEMSGSIZE` will be returned.
///
/// For consistency between platforms, be sure to provide a sufficiently large buffer to avoid
/// truncation; this depends on the underlying protocol.
///
/// If you just want to know the sender of the data, try [`peek_sender`].
///
/// [`recv`]: Socket::recv
/// [`peek_sender`]: Socket::peek_sender
pub fn peek_from(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<(usize, SockAddr)> {
self.recv_from_with_flags(buf, sys::MSG_PEEK)
}

/// Retrieve the sender for the data at the head of the receive queue.
///
/// This is equivalent to calling [`peek_from`] with a zero-sized buffer,
/// but suppresses the `WSAEMSGSIZE` error on Windows.
///
/// [`peek_from`]: Socket::peek_from
pub fn peek_sender(&self) -> io::Result<SockAddr> {
sys::peek_sender(self.as_raw())
}

/// Sends data on the socket to a connected peer.
///
/// This is typically used on TCP sockets or datagram sockets which have
Expand Down
8 changes: 8 additions & 0 deletions src/sys/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,14 @@ pub(crate) fn recv_from(
}
}

pub(crate) fn peek_sender(fd: Socket) -> io::Result<SockAddr> {
// Unix-like platforms simply truncate the returned data, so this implementation is trivial.
// However, for Windows this requires suppressing the `WSAEMSGSIZE` error so that
// requires a different approach.
let (_, sender) = recv_from(fd, &mut [], MSG_PEEK)?;
Ok(sender)
}

#[cfg(not(target_os = "redox"))]
pub(crate) fn recv_vectored(
fd: Socket,
Expand Down
30 changes: 30 additions & 0 deletions src/sys/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,36 @@ pub(crate) fn recv_from(
}
}

pub(crate) fn peek_sender(socket: Socket) -> io::Result<SockAddr> {
// Safety: `recvfrom` initialises the `SockAddr` for us.
let ((), sender) = unsafe {
SockAddr::try_init(|storage, addrlen| {
let res = syscall!(
recvfrom(
socket,
//
ptr::null_mut(),
0,
MSG_PEEK,
storage.cast(),
addrlen,
),
PartialEq::eq,
SOCKET_ERROR
);
match res {
Ok(_n) => Ok(()),
Err(e) => match e.raw_os_error() {
Some(code) if code == (WSAESHUTDOWN as i32) || code == (WSAEMSGSIZE as i32) => Ok(()),
_ => Err(e)
},
}
})
}?;

Ok(sender)
}

pub(crate) fn recv_from_vectored(
socket: Socket,
bufs: &mut [crate::MaybeUninitSlice<'_>],
Expand Down
22 changes: 22 additions & 0 deletions tests/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,28 @@ fn out_of_band() {
assert_eq!(unsafe { assume_init(&buf[..n]) }, DATA);
}

#[test]
fn peek_sender_udp() {
let socket_a = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap();
let socket_b = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap();

let localhost_port_0 = SockAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));

// Allocate local address/port pairs for both sockets.
socket_a.bind(&localhost_port_0).unwrap();
socket_b.bind(&localhost_port_0).unwrap();

let socket_a_addr = socket_a.local_addr().unwrap();
let socket_b_addr = socket_b.local_addr().unwrap();

socket_b.send_to(b"Hello, world!", &socket_a_addr).unwrap();

let sender_addr = socket_a.peek_sender().unwrap();

assert_eq!(sender_addr.as_socket_ipv4(), socket_b_addr.as_socket_ipv4());
}


#[test]
#[cfg(not(target_os = "redox"))]
fn send_recv_vectored() {
Expand Down

0 comments on commit 1d5e20f

Please sign in to comment.