diff --git a/src/client/client.rs b/src/client/client.rs index 11b4058..54b3c2a 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,22 +1,22 @@ -use anyhow::{Context, Result}; +use anyhow::{Context as _, Result}; use async_shutdown::{ShutdownManager, ShutdownSignal}; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{transport::Channel, Response, Status, Streaming}; use tracing::{debug, error, info, instrument, span}; use tokio::{ io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, - net::{TcpStream, UdpSocket}, select, sync::{mpsc, oneshot}, }; +use crate::socket::Dialer; use crate::{ constant, io::{StreamingReader, StreamingWriter, TrafficToServerWrapper}, pb::{ - self, control_command::Payload, traffic_to_server, tunnel::Config, + self, control_command::Payload, traffic_to_server, tunnel_service_client::TunnelServiceClient, ControlCommand, RegisterReq, TrafficToClient, TrafficToServer, }, @@ -83,13 +83,13 @@ impl Client { ) -> Result> { let (entrypoint_tx, entrypoint_rx) = oneshot::channel(); let pb_tunnel = tunnel.config.to_pb_tunnel(tunnel.name); - let local_endpoint = tunnel.local_endpoint; + let dialer = tunnel.dialer; tokio::spawn(async move { let run_tunnel = self.handle_tunnel( shutdown.wait_shutdown_triggered(), pb_tunnel, - local_endpoint, + dialer, Some(move |entrypoint| { let _ = entrypoint_tx.send(entrypoint); }), @@ -184,11 +184,10 @@ impl Client { &self, shutdown: ShutdownSignal, tunnel: pb::Tunnel, - local_endpoint: SocketAddr, + dial: Dialer, hook: Option) + Send + 'static>, ) -> Result<()> { let mut rpc_client = self.grpc_client.clone(); - let is_udp = matches!(tunnel.config, Some(Config::Udp(_))); let register = self.register_tunnel(&mut rpc_client, tunnel); tokio::select! { @@ -201,8 +200,7 @@ impl Client { shutdown.clone(), rpc_client, register_resp, - local_endpoint, - is_udp, + dial, hook, ).await } @@ -216,8 +214,7 @@ impl Client { shutdown: ShutdownSignal, rpc_client: TunnelServiceClient, register_resp: tonic::Response>, - local_endpoint: SocketAddr, - is_udp: bool, + dialer: Dialer, mut hook: Option) + Send + 'static>, ) -> Result<()> { let mut control_stream = register_resp.into_inner(); @@ -229,14 +226,8 @@ impl Client { hook(entrypoint); } - self.start_streaming( - shutdown, - &mut control_stream, - rpc_client, - local_endpoint, - is_udp, - ) - .await + self.start_streaming(shutdown, &mut control_stream, rpc_client, dialer) + .await } async fn start_streaming( @@ -244,9 +235,9 @@ impl Client { shutdown: ShutdownSignal, control_stream: &mut Streaming, rpc_client: TunnelServiceClient, - local_endpoint: SocketAddr, - is_udp: bool, + dialer: Dialer, ) -> Result<()> { + let dialer = Arc::new(dialer); loop { tokio::select! { result = control_stream.next() => { @@ -268,8 +259,7 @@ impl Client { if let Err(err) = handle_work_traffic( rpc_client.clone() /* cheap clone operation */, &work.connection_id, - local_endpoint, - is_udp, + dialer.clone(), ).await { error!(err = ?err, "failed to handle work traffic"); } else { @@ -307,8 +297,7 @@ async fn new_rpc_client(control_addr: SocketAddr) -> Result, connection_id: &str, - local_endpoint: SocketAddr, - is_udp: bool, + dialer: Arc, ) -> Result<()> { // write response to the streaming_tx // rpc_client sends the data from reading the streaming_rx @@ -331,7 +320,7 @@ async fn handle_work_traffic( // write the data streaming response to transfer_tx, // then forward_traffic_to_local can read the data from transfer_rx - let (transfer_tx, mut transfer_rx) = mpsc::channel::(64); + let (transfer_tx, transfer_rx) = mpsc::channel::(64); let (local_conn_established_tx, local_conn_established_rx) = mpsc::channel::<()>(1); let mut local_conn_established_rx = Some(local_conn_established_rx); @@ -375,72 +364,24 @@ async fn handle_work_traffic( }); let wrapper = TrafficToServerWrapper::new(connection_id.clone()); - let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper); - - if is_udp { - tokio::spawn(async move { - let local_addr: SocketAddr = if local_endpoint.is_ipv4() { - "0.0.0.0:0" - } else { - "[::]:0" - } - .parse() - .unwrap(); - let socket = UdpSocket::bind(local_addr).await; - if socket.is_err() { - error!(err = ?socket.err(), "failed to init udp socket, so let's notify the server to close the user connection"); - - streaming_tx - .send(TrafficToServer { - connection_id: connection_id.to_string(), - action: traffic_to_server::Action::Close as i32, - ..Default::default() - }) - .await - .context("terrible, the server may be crashed") - .unwrap(); - return; - } - - let socket = socket.unwrap(); - let result = socket.connect(local_endpoint).await; - if let Err(err) = result { - error!(err = ?err, "failed to connect to local endpoint, so let's notify the server to close the user connection"); - } + let writer = StreamingWriter::new(streaming_tx.clone(), wrapper); - local_conn_established_tx.send(()).await.unwrap(); - - let read_transfer_send_to_local = async { - while let Some(buf) = transfer_rx.recv().await { - socket.send(&buf.data).await.unwrap(); - } - }; - - let read_local_send_to_server = async { - loop { - let mut buf = vec![0u8; 65507]; - let result = socket.recv(&mut buf).await; - match result { - Ok(n) => { - writer.write_all(&buf[..n]).await.unwrap(); - } - Err(err) => { - error!(err = ?err, "failed to read from local endpoint"); - break; - } - } + tokio::spawn(async move { + match dialer.dial().await { + Ok((local_r, local_w)) => { + local_conn_established_tx.send(()).await.unwrap(); + if let Err(err) = + transfer(local_r, local_w, StreamingReader::new(transfer_rx), writer).await + { + debug!("failed to forward traffic to local: {:?}", err); } - writer.shutdown().await.unwrap(); - }; - - tokio::join!(read_transfer_send_to_local, read_local_send_to_server); - }); - } else { - tokio::spawn(async move { - // TODO(sword): use a connection pool to reuse the tcp connection - let local_conn = TcpStream::connect(local_endpoint).await; - if local_conn.is_err() { - error!("failed to connect to local endpoint {}, so let's notify the server to close the user connection", local_endpoint); + } + Err(err) => { + error!( + local_endpoint = ?dialer.addr(), + ?err, + "failed to connect to local endpoint, so let's notify the server to close the user connection", + ); streaming_tx .send(TrafficToServer { @@ -451,30 +392,14 @@ async fn handle_work_traffic( .await .context("terrible, the server may be crashed") .unwrap(); - return; - } - - let mut local_conn = local_conn.unwrap(); - let (local_r, local_w) = local_conn.split(); - local_conn_established_tx.send(()).await.unwrap(); - - if let Err(err) = forward_traffic_to_local( - local_r, - local_w, - StreamingReader::new(transfer_rx), - writer, - ) - .await - { - debug!("failed to forward traffic to local: {:?}", err); } - }); - } + } + }); Ok(()) } -/// Forwards the traffic from the server to the local endpoint. +/// transfer the traffic from the server to the local endpoint. /// /// Try to imagine the current client is yourself, /// your mission is to forward the traffic from the server to the local, @@ -482,11 +407,11 @@ async fn handle_work_traffic( /// in this process, there are two underlying connections: /// 1. remote <=> me /// 2. me <=> local -async fn forward_traffic_to_local( +async fn transfer( local_r: impl AsyncRead + Unpin, mut local_w: impl AsyncWrite + Unpin, - remote_r: StreamingReader, - mut remote_w: StreamingWriter, + remote_r: impl AsyncRead + Unpin, + mut remote_w: impl AsyncWrite + Unpin, ) -> Result<()> { let remote_to_me_to_local = async { // read from remote, write to local diff --git a/src/client/tunnel.rs b/src/client/tunnel.rs index 39350a6..d11c803 100644 --- a/src/client/tunnel.rs +++ b/src/client/tunnel.rs @@ -4,13 +4,16 @@ use std::net::SocketAddr; use bytes::Bytes; -use crate::pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig}; +use crate::{ + pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig}, + socket::{dial_tcp, dial_udp, Dialer}, +}; /// Tunnel configuration for the client. #[derive(Debug)] pub struct Tunnel<'a> { pub(crate) name: &'a str, - pub(crate) local_endpoint: SocketAddr, + pub(crate) dialer: Dialer, pub(crate) config: RemoteConfig<'a>, } @@ -19,7 +22,14 @@ impl<'a> Tunnel<'a> { pub fn new(name: &'a str, local_endpoint: SocketAddr, config: RemoteConfig<'a>) -> Self { Self { name, - local_endpoint, + dialer: Dialer::new( + match config { + RemoteConfig::Tcp(_) => |endpoint| Box::pin(dial_tcp(endpoint)), + RemoteConfig::Udp(_) => |endpoint| Box::pin(dial_udp(endpoint)), + RemoteConfig::Http(_) => |endpoint| Box::pin(dial_tcp(endpoint)), + }, + local_endpoint, + ), config, } } diff --git a/src/helper.rs b/src/helper.rs index 21eee56..0fcc465 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,33 +1,8 @@ -use crate::pb::RegisterReq; - -use tokio::net::{TcpListener, UdpSocket}; use tonic::Status; -use tracing::error; -pub(crate) async fn create_tcp_listener(port: u16) -> Result { - TcpListener::bind(("0.0.0.0", port)) - .await - .map_err(map_bind_error) -} - -pub(crate) async fn create_udp_socket(port: u16) -> Result { - UdpSocket::bind(("0.0.0.0", port)) - .await - .map_err(map_bind_error) -} - -fn map_bind_error(err: std::io::Error) -> Status { - match err.kind() { - std::io::ErrorKind::AddrInUse => Status::already_exists("port already in use"), - std::io::ErrorKind::PermissionDenied => Status::permission_denied("permission denied"), - _ => { - error!("failed to bind port: {}", err); - Status::internal("failed to bind port") - } - } -} +use crate::pb::RegisterReq; -pub fn validate_register_req(req: &RegisterReq) -> Option { +pub(crate) fn validate_register_req(req: &RegisterReq) -> Option { if req.tunnel.is_none() { return Some(Status::invalid_argument("tunnel is required")); } diff --git a/src/lib.rs b/src/lib.rs index 2bae815..2aa628e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub(crate) mod constant; pub(crate) mod event; pub(crate) mod helper; pub(crate) mod io; +pub(crate) mod socket; pub mod pb { include!("gen/message.rs"); diff --git a/src/server/control_server.rs b/src/server/control_server.rs index ca3e690..817817a 100644 --- a/src/server/control_server.rs +++ b/src/server/control_server.rs @@ -38,7 +38,7 @@ type DataStream = Pin> + Send> /// /// We treat the control server is grpc server as well, in the concept, /// they are same thing. -/// Although the grpc server provides a [`crate::protocol::pb::tunnel_service_server::TunnelService::data`], +/// Although the grpc server provides a [`crate::pb::tunnel_service_server::TunnelService::data`], /// it's similar to the data server(a little), but in the `data` function body, /// the most of work is to forward the data from client to data server. /// We can understand this is a tunnel between the client and the data server. diff --git a/src/server/tunnel/http.rs b/src/server/tunnel/http.rs index 0ac6b12..ea2a94e 100644 --- a/src/server/tunnel/http.rs +++ b/src/server/tunnel/http.rs @@ -1,6 +1,6 @@ use crate::bridge::BridgeData; use crate::event::{self, IncomingEventSender}; -use crate::helper::create_tcp_listener; +use crate::socket::create_tcp_listener; use super::{init_data_sender_bridge, BridgeResult}; use anyhow::{Context as _, Result}; diff --git a/src/server/tunnel/tcp.rs b/src/server/tunnel/tcp.rs index 2ebc762..b1d9491 100644 --- a/src/server/tunnel/tcp.rs +++ b/src/server/tunnel/tcp.rs @@ -1,8 +1,8 @@ use crate::{ event, - helper::create_tcp_listener, io::{StreamingReader, StreamingWriter, VecWrapper}, server::tunnel::BridgeResult, + socket::create_tcp_listener, }; use anyhow::Context as _; use tokio::{ diff --git a/src/server/tunnel/udp.rs b/src/server/tunnel/udp.rs index 6821bea..de417d5 100644 --- a/src/server/tunnel/udp.rs +++ b/src/server/tunnel/udp.rs @@ -1,6 +1,6 @@ use std::{net::SocketAddr, sync::Arc}; -use crate::{bridge::BridgeData, event, helper::create_udp_socket, server::tunnel::BridgeResult}; +use crate::{bridge::BridgeData, event, server::tunnel::BridgeResult, socket::create_udp_socket}; use dashmap::DashMap; use tokio::{net::UdpSocket, select, sync::mpsc}; use tokio_util::sync::CancellationToken; diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..7629b5a --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,189 @@ +//! Socket utilities for creating listeners, async readers, writers, and dialers. +use std::{ + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + net::{TcpListener, TcpStream, UdpSocket}, +}; +use tonic::Status; +use tracing::error; + +/// create a tcp listener. +pub(crate) async fn create_tcp_listener(port: u16) -> Result { + TcpListener::bind(("0.0.0.0", port)) + .await + .map_err(map_bind_error) +} + +/// create a udp socket. +pub(crate) async fn create_udp_socket(port: u16) -> Result { + UdpSocket::bind(("0.0.0.0", port)) + .await + .map_err(map_bind_error) +} + +fn map_bind_error(err: std::io::Error) -> Status { + match err.kind() { + std::io::ErrorKind::AddrInUse => Status::already_exists("port already in use"), + std::io::ErrorKind::PermissionDenied => Status::permission_denied("permission denied"), + _ => { + error!("failed to bind port: {}", err); + Status::internal("failed to bind port") + } + } +} + +/// Async reader for a udp connection. +pub(crate) struct AsyncUdpSocket<'a> { + socket: &'a UdpSocket, +} + +impl<'a> AsyncUdpSocket<'a> { + pub(crate) fn new(socket: &'a UdpSocket) -> Self { + Self { socket } + } +} + +impl<'a> AsyncRead for AsyncUdpSocket<'a> { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut().socket.poll_recv_from(cx, buf) { + Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +impl<'a> AsyncWrite for AsyncUdpSocket<'a> { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut().socket.poll_send(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) // No-op for UDP + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) // No-op for UDP + } +} + +/// Dialer for connecting to a endpoint to get a async reader and a async writer. +#[derive(Debug)] +pub(crate) struct Dialer { + dial: DialFn, + addr: SocketAddr, +} + +impl Dialer { + /// Create a new dialer. + pub(crate) fn new(dial: DialFn, addr: SocketAddr) -> Self { + Self { dial, addr } + } + + /// Dial the endpoint. + /// + /// # Returns + /// + /// A future that resolves to a async reader and a async writer. + pub(crate) fn dial(&self) -> Pin + Send>> { + (self.dial)(self.addr) + } + + /// Expose the address of the dialer. + pub(crate) fn addr(&self) -> SocketAddr { + self.addr + } +} + +/// Result of dialing a endpoint. +pub(crate) type DialResult = Result< + ( + Box, + Box, + ), + Box, +>; + +/// Function for dialing a endpoint to get a async reader and a async writer. +pub(crate) type DialFn = + fn(SocketAddr) -> Pin + Send>>; + +/// Dial a tcp endpoint. +pub(crate) async fn dial_tcp(local_endpoint: SocketAddr) -> DialResult { + let local_conn = TcpStream::connect(local_endpoint).await?; + let (r, w) = local_conn.into_split(); + Ok((Box::new(r), Box::new(w))) +} + +/// Dial a udp endpoint. +pub(crate) async fn dial_udp(local_endpoint: SocketAddr) -> DialResult { + let local_addr: SocketAddr = if local_endpoint.is_ipv4() { + "0.0.0.0:0" + } else { + "[::]:0" + } + .parse()?; + let socket = UdpSocket::bind(local_addr).await?; + socket.connect(local_endpoint).await?; + let socket = Box::leak(Box::new(socket)); + Ok(( + Box::new(AsyncUdpSocket::new(socket)), + Box::new(AsyncUdpSocket::new(socket)), + )) +} + +#[cfg(test)] +mod test { + use super::*; + use std::net::TcpListener as StdTcpListener; + + #[tokio::test] + async fn test_tcp_listener_and_dialer() { + let port = free_port().unwrap(); + let listener = create_tcp_listener(port).await.unwrap(); + + let dialer = Dialer::new( + |addr| Box::pin(dial_tcp(addr)), + listener.local_addr().unwrap(), + ); + dialer.dial().await.unwrap(); + } + + #[tokio::test] + async fn test_udp_socket_and_dialer() { + let port = free_port().unwrap(); + let socket = create_udp_socket(port).await.unwrap(); + + let dialer = Dialer::new( + |addr| Box::pin(dial_udp(addr)), + socket.local_addr().unwrap(), + ); + dialer.dial().await.unwrap(); + } + + /// free_port returns a free port number for testing. + fn free_port() -> std::io::Result { + let listener = StdTcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + Ok(port) + } +} diff --git a/tests/lib.rs b/tests/lib.rs index d8ff12a..a6e93b7 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,3 +1,5 @@ +#![warn(missing_docs)] + mod common; use crate::common::free_port;