diff --git a/crates/wasi/src/preview2/host/udp.rs b/crates/wasi/src/preview2/host/udp.rs index b6dba8be158e..e9cbd158ab12 100644 --- a/crates/wasi/src/preview2/host/udp.rs +++ b/crates/wasi/src/preview2/host/udp.rs @@ -1,4 +1,4 @@ -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::SocketAddr; use crate::preview2::{ bindings::{ @@ -6,7 +6,6 @@ use crate::preview2::{ sockets::udp, }, udp::UdpState, - Table, }; use crate::preview2::{Pollable, SocketResult, WasiView}; use cap_net_ext::{AddressFamily, PoolExt}; @@ -20,48 +19,6 @@ use wasmtime::component::Resource; /// In practice, datagrams are typically less than 1500 bytes. const MAX_UDP_DATAGRAM_SIZE: usize = 65535; -fn start_bind( - table: &mut Table, - this: Resource, - network: Resource, - local_address: IpSocketAddress, -) -> SocketResult<()> { - let socket = table.get_resource(&this)?; - match socket.udp_state { - UdpState::Default => {} - UdpState::BindStarted | UdpState::Connecting(..) | UdpState::ConnectReady(..) => { - return Err(ErrorCode::ConcurrencyConflict.into()) - } - UdpState::Bound | UdpState::Connected(..) => return Err(ErrorCode::AlreadyBound.into()), - } - - let network = table.get_resource(&network)?; - let binder = network.pool.udp_binder(local_address)?; - - // Perform the OS bind call. - binder.bind_existing_udp_socket( - &*socket - .udp_socket() - .as_socketlike_view::(), - )?; - - let socket = table.get_resource_mut(&this)?; - socket.udp_state = UdpState::BindStarted; - - Ok(()) -} - -fn finish_bind(table: &mut Table, this: Resource) -> SocketResult<()> { - let socket = table.get_resource_mut(&this)?; - match socket.udp_state { - UdpState::BindStarted => { - socket.udp_state = UdpState::Bound; - Ok(()) - } - _ => Err(ErrorCode::NotInProgress.into()), - } -} - impl udp::Host for T {} impl crate::preview2::host::udp::udp::HostUdpSocket for T { @@ -71,11 +28,44 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { network: Resource, local_address: IpSocketAddress, ) -> SocketResult<()> { - start_bind(self.table_mut(), this, network, local_address) + let table = self.table_mut(); + let socket = table.get_resource(&this)?; + + match socket.udp_state { + UdpState::Default => {} + UdpState::BindStarted | UdpState::Connecting(..) => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } + UdpState::Bound | UdpState::Connected(..) => return Err(ErrorCode::AlreadyBound.into()), + } + + let network = table.get_resource(&network)?; + let binder = network.pool.udp_binder(local_address)?; + + // Perform the OS bind call. + binder.bind_existing_udp_socket( + &*socket + .udp_socket() + .as_socketlike_view::(), + )?; + + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::BindStarted; + + Ok(()) } fn finish_bind(&mut self, this: Resource) -> SocketResult<()> { - finish_bind(self.table_mut(), this) + let table = self.table_mut(); + let socket = table.get_resource_mut(&this)?; + + match socket.udp_state { + UdpState::BindStarted => { + socket.udp_state = UdpState::Bound; + Ok(()) + } + _ => Err(ErrorCode::NotInProgress.into()), + } } fn start_connect( @@ -86,87 +76,39 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { ) -> SocketResult<()> { let table = self.table_mut(); let socket = table.get_resource(&this)?; + let network = table.get_resource(&network)?; + match socket.udp_state { - UdpState::Default => { - let addr = match socket.family { - AddressFamily::Ipv4 => Ipv4Addr::UNSPECIFIED.into(), - AddressFamily::Ipv6 => Ipv6Addr::UNSPECIFIED.into(), - }; - start_bind( - table, - Resource::new_borrow(this.rep()), - Resource::new_borrow(network.rep()), - SocketAddr::new(addr, 0).into(), - )?; - finish_bind(table, Resource::new_borrow(this.rep()))?; - } - UdpState::Bound => {} - UdpState::BindStarted | UdpState::Connecting(..) | UdpState::ConnectReady(..) => { + UdpState::Default | UdpState::Bound => {} + UdpState::BindStarted | UdpState::Connecting(..) => { return Err(ErrorCode::ConcurrencyConflict.into()) } UdpState::Connected(..) => return Err(ErrorCode::AlreadyConnected.into()), } - let socket = table.get_resource(&this)?; - let network = table.get_resource(&network)?; let connecter = network.pool.udp_connecter(remote_address)?; - // Do an OS `connect`. Our socket is non-blocking, so it'll either... - let res = connecter.connect_existing_udp_socket( + // Do an OS `connect`. + connecter.connect_existing_udp_socket( &*socket .udp_socket() .as_socketlike_view::(), - ); - match res { - // succeed immediately, - Ok(()) => { - let socket = table.get_resource_mut(&this)?; - socket.udp_state = UdpState::ConnectReady(remote_address); - Ok(()) - } - // continue in progress, - Err(err) if err.raw_os_error() == Some(INPROGRESS.raw_os_error()) => { - let socket = table.get_resource_mut(&this)?; - socket.udp_state = UdpState::Connecting(remote_address); - Ok(()) - } - // or fail immediately. - Err(err) => Err(err.into()), - } + )?; + + let socket = table.get_resource_mut(&this)?; + socket.udp_state = UdpState::Connecting(remote_address); + Ok(()) } fn finish_connect(&mut self, this: Resource) -> SocketResult<()> { let table = self.table_mut(); let socket = table.get_resource_mut(&this)?; + match socket.udp_state { - UdpState::ConnectReady(addr) => { + UdpState::Connecting(addr) => { socket.udp_state = UdpState::Connected(addr); Ok(()) } - UdpState::Connecting(addr) => { - // Do a `poll` to test for completion, using a timeout of zero - // to avoid blocking. - match rustix::event::poll( - &mut [rustix::event::PollFd::new( - socket.udp_socket(), - rustix::event::PollFlags::OUT, - )], - 0, - ) { - Ok(0) => return Err(ErrorCode::WouldBlock.into()), - Ok(_) => {} - Err(err) => return Err(err.into()), - } - - // Check whether the connect succeeded. - match sockopt::get_socket_error(socket.udp_socket()) { - Ok(Ok(())) => { - socket.udp_state = UdpState::Connected(addr); - Ok(()) - } - Err(err) | Ok(Err(err)) => Err(err.into()), - } - } _ => Err(ErrorCode::NotInProgress.into()), } } @@ -188,7 +130,7 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { let mut buf = [0; MAX_UDP_DATAGRAM_SIZE]; match socket.udp_state { UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()), - UdpState::Bound | UdpState::Connecting(..) | UdpState::ConnectReady(..) => { + UdpState::Bound | UdpState::Connecting(..) => { for i in 0..max_results { match udp_socket.try_recv_from(&mut buf) { Ok((size, remote_address)) => datagrams.push(udp::Datagram { @@ -235,7 +177,7 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { let mut count = 0; match socket.udp_state { UdpState::Default | UdpState::BindStarted => return Err(ErrorCode::NotBound.into()), - UdpState::Bound | UdpState::Connecting(..) | UdpState::ConnectReady(..) => { + UdpState::Bound | UdpState::Connecting(..) => { for udp::Datagram { data, remote_address, @@ -412,13 +354,3 @@ impl crate::preview2::host::udp::udp::HostUdpSocket for T { Ok(()) } } - -const INPROGRESS: Errno = if cfg!(windows) { - // On Windows, non-blocking UDP socket `connect` uses `WSAEWOULDBLOCK`. - // - Errno::WOULDBLOCK -} else { - // On POSIX, non-blocking UDP socket `connect` uses `EINPROGRESS`. - // - Errno::INPROGRESS -}; diff --git a/crates/wasi/src/preview2/udp.rs b/crates/wasi/src/preview2/udp.rs index 07eae01c342c..2c879f99a6c9 100644 --- a/crates/wasi/src/preview2/udp.rs +++ b/crates/wasi/src/preview2/udp.rs @@ -23,13 +23,10 @@ pub(crate) enum UdpState { /// is not yet listening for connections. Bound, - /// An outgoing connection is started via `start_connect`. + /// A connect call is in progress. Connecting(IpSocketAddress), - /// An outgoing connection is ready to be established. - ConnectReady(IpSocketAddress), - - /// An outgoing connection has been established. + /// The socket is "connected" to a peer address. Connected(IpSocketAddress), }