Skip to content

Commit

Permalink
Introduce TcpSocket trait
Browse files Browse the repository at this point in the history
  • Loading branch information
badeend committed Jan 13, 2024
1 parent 6343c9e commit a2ee9a9
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 114 deletions.
43 changes: 22 additions & 21 deletions crates/wasi/src/preview2/host/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ use crate::preview2::bindings::sockets::{
tcp::ShutdownType,
};
use crate::preview2::pipe::{AsyncReadStream, AsyncWriteStream};
use crate::preview2::tcp::{SystemTcpReader, SystemTcpSocket, SystemTcpWriter};
use crate::preview2::tcp::{TcpReader, TcpSocket, TcpWriter};
use crate::preview2::{
InputStream, OutputStream, Pollable, Preview2Future, SocketAddrFamily, SocketResult, Subscribe,
WasiView,
InputStream, OutputStream, Pollable, Preview2Future, SocketResult, Subscribe, WasiView,
};
use std::io;
use std::net::SocketAddr;
Expand Down Expand Up @@ -36,12 +35,12 @@ enum TcpState {

/// The socket is now listening and waiting for an incoming connection.
Listening {
pending_result: Option<io::Result<(SystemTcpSocket, SystemTcpReader, SystemTcpWriter)>>,
pending_result: Option<io::Result<(Box<dyn TcpSocket>, TcpReader, TcpWriter)>>,
},

/// An outgoing connection is started via `start_connect`.
Connecting {
future: Preview2Future<io::Result<(SystemTcpReader, SystemTcpWriter)>>,
future: Preview2Future<io::Result<(TcpReader, TcpWriter)>>,
},

/// An outgoing connection has been established.
Expand All @@ -55,26 +54,18 @@ enum TcpState {
pub struct TcpSocketWrapper {
/// The part of a `TcpSocket` which is reference-counted so that we
/// can pass it to async tasks.
inner: SystemTcpSocket,
inner: Box<dyn TcpSocket>,

/// The current state in the bind/listen/accept/connect progression.
tcp_state: TcpState,
}

impl TcpSocketWrapper {
/// Create a new socket in the given family.
pub fn new(family: SocketAddrFamily) -> io::Result<Self> {
Ok(Self {
inner: SystemTcpSocket::new(family)?,
tcp_state: TcpState::Default,
})
}

fn new_input_stream(reader: SystemTcpReader) -> InputStream {
fn new_input_stream(reader: TcpReader) -> InputStream {
InputStream::Host(Box::new(AsyncReadStream::new(reader)))
}

fn new_output_stream(writer: SystemTcpWriter) -> OutputStream {
fn new_output_stream(writer: TcpWriter) -> OutputStream {
const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;

Box::new(AsyncWriteStream::new(SOCKET_READY_SIZE, writer))
Expand All @@ -86,8 +77,15 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp_create_socket::Host fo
&mut self,
address_family: IpAddressFamily,
) -> SocketResult<Resource<TcpSocketWrapper>> {
let socket = TcpSocketWrapper::new(address_family.into())?;
let socket = self.table_mut().push(socket)?;
let socket = self
.ctx_mut()
.network
.new_tcp_socket(address_family.into())?;
let wrapper = TcpSocketWrapper {
inner: socket,
tcp_state: TcpState::Default,
};
let socket = self.table_mut().push(wrapper)?;
Ok(socket)
}
}
Expand Down Expand Up @@ -115,7 +113,7 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {
_ => return Err(ErrorCode::InvalidState.into()),
}

socket.inner.bind(&local_address)?;
socket.inner.bind(local_address)?;
socket.tcp_state = TcpState::BindStarted;

Ok(())
Expand Down Expand Up @@ -156,7 +154,7 @@ impl<T: WasiView> crate::preview2::bindings::sockets::tcp::HostTcpSocket for T {
_ => return Err(ErrorCode::InvalidState.into()),
}

let mut future = Preview2Future::new(socket.inner.connect(&remote_address));
let mut future = Preview2Future::new(socket.inner.connect(remote_address));

// Attempt to return (validation) errors immediately:
let future = match future.try_resolve() {
Expand Down Expand Up @@ -529,7 +527,10 @@ impl Subscribe for TcpSocketWrapper {
TcpState::Connecting { future } => future.ready().await,
TcpState::Listening { pending_result } => match pending_result {
Some(_) => {}
None => *pending_result = Some(self.inner.accept().await),
None => {
let result = futures::future::poll_fn(|cx| self.inner.poll_accept(cx)).await;
*pending_result = Some(result);
}
},
}
}
Expand Down
12 changes: 12 additions & 0 deletions crates/wasi/src/preview2/network.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::preview2::bindings::sockets::network::{Ipv4Address, Ipv6Address};
use crate::preview2::bindings::wasi::sockets::network::ErrorCode;
use crate::preview2::ip_name_lookup::resolve_addresses;
use crate::preview2::tcp::{SystemTcpSocket, TcpSocket};
use crate::preview2::{BoxSyncFuture, TrappableError};
use std::io;
use std::net::{IpAddr, SocketAddr};
Expand All @@ -10,6 +11,9 @@ use std::sync::Arc;
pub trait Network: Sync + Send {
/// Given a name, resolve to a list of IP addresses
fn resolve_addresses(&mut self, name: String) -> BoxSyncFuture<io::Result<Vec<IpAddr>>>;

/// Create a new TCP socket.
fn new_tcp_socket(&mut self, family: SocketAddrFamily) -> io::Result<Box<dyn TcpSocket>>;
}

/// The default network implementation
Expand Down Expand Up @@ -44,6 +48,10 @@ impl Network for DefaultNetwork {

self.system.resolve_addresses(name)
}

fn new_tcp_socket(&mut self, family: SocketAddrFamily) -> io::Result<Box<dyn TcpSocket>> {
self.system.new_tcp_socket(family)
}
}

/// An implementation of `Networked` that uses the underlying system
Expand All @@ -61,6 +69,10 @@ impl Network for SystemNetwork {
fn resolve_addresses(&mut self, name: String) -> BoxSyncFuture<io::Result<Vec<IpAddr>>> {
Box::pin(async move { resolve_addresses(&name).await })
}

fn new_tcp_socket(&mut self, family: SocketAddrFamily) -> io::Result<Box<dyn TcpSocket>> {
Ok(Box::new(SystemTcpSocket::new(family)?))
}
}

pub struct NetworkHandle {
Expand Down
Loading

0 comments on commit a2ee9a9

Please sign in to comment.