From 91a1ec8af12aeb4604d3a4cfca3fe5a4c817af99 Mon Sep 17 00:00:00 2001 From: sword-jin Date: Mon, 29 Jul 2024 11:17:57 +0000 Subject: [PATCH 1/5] refactor rust udp --- src/client/client.rs | 45 +++++++++++++++------------------------- src/io.rs | 49 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 29 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 11b4058..e5ea6c7 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,4 +1,4 @@ -use anyhow::{Context, Result}; +use anyhow::{Context as _, Result}; use async_shutdown::{ShutdownManager, ShutdownSignal}; use std::net::SocketAddr; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; @@ -12,6 +12,7 @@ use tokio::{ sync::{mpsc, oneshot}, }; +use crate::io::AsyncUdpSocket; use crate::{ constant, io::{StreamingReader, StreamingWriter, TrafficToServerWrapper}, @@ -331,7 +332,7 @@ async fn handle_work_traffic( // write the data streaming response to transfer_tx, // then forward_traffic_to_local can read the data from transfer_rx - let (transfer_tx, mut transfer_rx) = mpsc::channel::(64); + let (transfer_tx, transfer_rx) = mpsc::channel::(64); let (local_conn_established_tx, local_conn_established_rx) = mpsc::channel::<()>(1); let mut local_conn_established_rx = Some(local_conn_established_rx); @@ -375,7 +376,7 @@ async fn handle_work_traffic( }); let wrapper = TrafficToServerWrapper::new(connection_id.clone()); - let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper); + let writer = StreamingWriter::new(streaming_tx.clone(), wrapper); if is_udp { tokio::spawn(async move { @@ -410,30 +411,16 @@ async fn handle_work_traffic( local_conn_established_tx.send(()).await.unwrap(); - let read_transfer_send_to_local = async { - while let Some(buf) = transfer_rx.recv().await { - socket.send(&buf.data).await.unwrap(); - } - }; - - let read_local_send_to_server = async { - loop { - let mut buf = vec![0u8; 65507]; - let result = socket.recv(&mut buf).await; - match result { - Ok(n) => { - writer.write_all(&buf[..n]).await.unwrap(); - } - Err(err) => { - error!(err = ?err, "failed to read from local endpoint"); - break; - } - } - } - writer.shutdown().await.unwrap(); - }; - - tokio::join!(read_transfer_send_to_local, read_local_send_to_server); + if let Err(err) = forward_traffic_to_local( + AsyncUdpSocket::new(&socket), + AsyncUdpSocket::new(&socket), + StreamingReader::new(transfer_rx), + writer, + ) + .await + { + debug!("failed to forward traffic to local: {:?}", err); + } }); } else { tokio::spawn(async move { @@ -485,8 +472,8 @@ async fn handle_work_traffic( async fn forward_traffic_to_local( local_r: impl AsyncRead + Unpin, mut local_w: impl AsyncWrite + Unpin, - remote_r: StreamingReader, - mut remote_w: StreamingWriter, + remote_r: impl AsyncRead + Unpin, + mut remote_w: impl AsyncWrite + Unpin, ) -> Result<()> { let remote_to_me_to_local = async { // read from remote, write to local diff --git a/src/io.rs b/src/io.rs index 5882562..81338ab 100644 --- a/src/io.rs +++ b/src/io.rs @@ -8,6 +8,7 @@ use std::fmt::Debug; use std::task::{Context, Poll}; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; +use tokio::net::UdpSocket; use tokio::sync::mpsc; use tokio::{io, sync::mpsc::Sender}; use tokio_util::sync::CancellationToken; @@ -263,3 +264,51 @@ macro_rules! generate_async_write_impl { generate_async_write_impl!(TrafficToServer); generate_async_write_impl!(Vec); + +pub(crate) struct AsyncUdpSocket<'a> { + socket: &'a UdpSocket, +} + +impl<'a> AsyncUdpSocket<'a> { + pub(crate) fn new(socket: &'a UdpSocket) -> Self { + Self { socket } + } +} + +impl<'a> AsyncRead for AsyncUdpSocket<'a> { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut().socket.poll_recv_from(cx, buf) { + Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +impl<'a> AsyncWrite for AsyncUdpSocket<'a> { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut().socket.poll_send(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) // No-op for UDP + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) // No-op for UDP + } +} From 5bdf7cc5be088108553b898e3e25fa0359424515 Mon Sep 17 00:00:00 2001 From: sword-jin Date: Tue, 30 Jul 2024 00:47:17 +0000 Subject: [PATCH 2/5] refactor a socket module to dial tcp and udp --- src/client/client.rs | 96 +++++++++---------------------- src/helper.rs | 27 +-------- src/io.rs | 49 ---------------- src/lib.rs | 1 + src/server/tunnel/http.rs | 2 +- src/server/tunnel/tcp.rs | 2 +- src/server/tunnel/udp.rs | 2 +- src/socket.rs | 116 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 148 insertions(+), 147 deletions(-) create mode 100644 src/socket.rs diff --git a/src/client/client.rs b/src/client/client.rs index e5ea6c7..2eb72d7 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; use async_shutdown::{ShutdownManager, ShutdownSignal}; -use std::net::SocketAddr; +use std::{net::SocketAddr, pin::Pin}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{transport::Channel, Response, Status, Streaming}; use tracing::{debug, error, info, instrument, span}; @@ -12,7 +12,6 @@ use tokio::{ sync::{mpsc, oneshot}, }; -use crate::io::AsyncUdpSocket; use crate::{ constant, io::{StreamingReader, StreamingWriter, TrafficToServerWrapper}, @@ -21,7 +20,9 @@ use crate::{ tunnel_service_client::TunnelServiceClient, ControlCommand, RegisterReq, TrafficToClient, TrafficToServer, }, + socket::{dial_tcp, dial_udp}, }; +use crate::{io::AsyncUdpSocket, socket::DialFn}; use super::tunnel::Tunnel; @@ -378,56 +379,29 @@ async fn handle_work_traffic( let wrapper = TrafficToServerWrapper::new(connection_id.clone()); let writer = StreamingWriter::new(streaming_tx.clone(), wrapper); + let dial_fn: DialFn; if is_udp { - tokio::spawn(async move { - let local_addr: SocketAddr = if local_endpoint.is_ipv4() { - "0.0.0.0:0" - } else { - "[::]:0" - } - .parse() - .unwrap(); - let socket = UdpSocket::bind(local_addr).await; - if socket.is_err() { - error!(err = ?socket.err(), "failed to init udp socket, so let's notify the server to close the user connection"); - - streaming_tx - .send(TrafficToServer { - connection_id: connection_id.to_string(), - action: traffic_to_server::Action::Close as i32, - ..Default::default() - }) - .await - .context("terrible, the server may be crashed") - .unwrap(); - return; - } - - let socket = socket.unwrap(); - let result = socket.connect(local_endpoint).await; - if let Err(err) = result { - error!(err = ?err, "failed to connect to local endpoint, so let's notify the server to close the user connection"); - } - - local_conn_established_tx.send(()).await.unwrap(); + dial_fn = |endpoint| Box::pin(dial_udp(endpoint)); + } else { + dial_fn = |endpoint| Box::pin(dial_tcp(endpoint)); + } - if let Err(err) = forward_traffic_to_local( - AsyncUdpSocket::new(&socket), - AsyncUdpSocket::new(&socket), - StreamingReader::new(transfer_rx), - writer, - ) - .await - { - debug!("failed to forward traffic to local: {:?}", err); + tokio::spawn(async move { + match dial_fn(local_endpoint).await { + Ok((local_r, local_w)) => { + local_conn_established_tx.send(()).await.unwrap(); + if let Err(err) = + transfer(local_r, local_w, StreamingReader::new(transfer_rx), writer).await + { + debug!("failed to forward traffic to local: {:?}", err); + } } - }); - } else { - tokio::spawn(async move { - // TODO(sword): use a connection pool to reuse the tcp connection - let local_conn = TcpStream::connect(local_endpoint).await; - if local_conn.is_err() { - error!("failed to connect to local endpoint {}, so let's notify the server to close the user connection", local_endpoint); + Err(err) => { + error!( + ?local_endpoint, + ?err, + "failed to connect to local endpoint, so let's notify the server to close the user connection", + ); streaming_tx .send(TrafficToServer { @@ -438,30 +412,14 @@ async fn handle_work_traffic( .await .context("terrible, the server may be crashed") .unwrap(); - return; } - - let mut local_conn = local_conn.unwrap(); - let (local_r, local_w) = local_conn.split(); - local_conn_established_tx.send(()).await.unwrap(); - - if let Err(err) = forward_traffic_to_local( - local_r, - local_w, - StreamingReader::new(transfer_rx), - writer, - ) - .await - { - debug!("failed to forward traffic to local: {:?}", err); - } - }); - } + } + }); Ok(()) } -/// Forwards the traffic from the server to the local endpoint. +/// transfer the traffic from the server to the local endpoint. /// /// Try to imagine the current client is yourself, /// your mission is to forward the traffic from the server to the local, @@ -469,7 +427,7 @@ async fn handle_work_traffic( /// in this process, there are two underlying connections: /// 1. remote <=> me /// 2. me <=> local -async fn forward_traffic_to_local( +async fn transfer( local_r: impl AsyncRead + Unpin, mut local_w: impl AsyncWrite + Unpin, remote_r: impl AsyncRead + Unpin, diff --git a/src/helper.rs b/src/helper.rs index 21eee56..b1314e9 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -1,31 +1,6 @@ -use crate::pb::RegisterReq; - -use tokio::net::{TcpListener, UdpSocket}; use tonic::Status; -use tracing::error; -pub(crate) async fn create_tcp_listener(port: u16) -> Result { - TcpListener::bind(("0.0.0.0", port)) - .await - .map_err(map_bind_error) -} - -pub(crate) async fn create_udp_socket(port: u16) -> Result { - UdpSocket::bind(("0.0.0.0", port)) - .await - .map_err(map_bind_error) -} - -fn map_bind_error(err: std::io::Error) -> Status { - match err.kind() { - std::io::ErrorKind::AddrInUse => Status::already_exists("port already in use"), - std::io::ErrorKind::PermissionDenied => Status::permission_denied("permission denied"), - _ => { - error!("failed to bind port: {}", err); - Status::internal("failed to bind port") - } - } -} +use crate::pb::RegisterReq; pub fn validate_register_req(req: &RegisterReq) -> Option { if req.tunnel.is_none() { diff --git a/src/io.rs b/src/io.rs index 81338ab..5882562 100644 --- a/src/io.rs +++ b/src/io.rs @@ -8,7 +8,6 @@ use std::fmt::Debug; use std::task::{Context, Poll}; use tokio::io::AsyncRead; use tokio::io::AsyncWrite; -use tokio::net::UdpSocket; use tokio::sync::mpsc; use tokio::{io, sync::mpsc::Sender}; use tokio_util::sync::CancellationToken; @@ -264,51 +263,3 @@ macro_rules! generate_async_write_impl { generate_async_write_impl!(TrafficToServer); generate_async_write_impl!(Vec); - -pub(crate) struct AsyncUdpSocket<'a> { - socket: &'a UdpSocket, -} - -impl<'a> AsyncUdpSocket<'a> { - pub(crate) fn new(socket: &'a UdpSocket) -> Self { - Self { socket } - } -} - -impl<'a> AsyncRead for AsyncUdpSocket<'a> { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut io::ReadBuf<'_>, - ) -> Poll> { - match self.get_mut().socket.poll_recv_from(cx, buf) { - Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } -} - -impl<'a> AsyncWrite for AsyncUdpSocket<'a> { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.get_mut().socket.poll_send(cx, buf) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) // No-op for UDP - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) // No-op for UDP - } -} diff --git a/src/lib.rs b/src/lib.rs index 2bae815..2aa628e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub(crate) mod constant; pub(crate) mod event; pub(crate) mod helper; pub(crate) mod io; +pub(crate) mod socket; pub mod pb { include!("gen/message.rs"); diff --git a/src/server/tunnel/http.rs b/src/server/tunnel/http.rs index 0ac6b12..ea2a94e 100644 --- a/src/server/tunnel/http.rs +++ b/src/server/tunnel/http.rs @@ -1,6 +1,6 @@ use crate::bridge::BridgeData; use crate::event::{self, IncomingEventSender}; -use crate::helper::create_tcp_listener; +use crate::socket::create_tcp_listener; use super::{init_data_sender_bridge, BridgeResult}; use anyhow::{Context as _, Result}; diff --git a/src/server/tunnel/tcp.rs b/src/server/tunnel/tcp.rs index 2ebc762..b1d9491 100644 --- a/src/server/tunnel/tcp.rs +++ b/src/server/tunnel/tcp.rs @@ -1,8 +1,8 @@ use crate::{ event, - helper::create_tcp_listener, io::{StreamingReader, StreamingWriter, VecWrapper}, server::tunnel::BridgeResult, + socket::create_tcp_listener, }; use anyhow::Context as _; use tokio::{ diff --git a/src/server/tunnel/udp.rs b/src/server/tunnel/udp.rs index 6821bea..de417d5 100644 --- a/src/server/tunnel/udp.rs +++ b/src/server/tunnel/udp.rs @@ -1,6 +1,6 @@ use std::{net::SocketAddr, sync::Arc}; -use crate::{bridge::BridgeData, event, helper::create_udp_socket, server::tunnel::BridgeResult}; +use crate::{bridge::BridgeData, event, server::tunnel::BridgeResult, socket::create_udp_socket}; use dashmap::DashMap; use tokio::{net::UdpSocket, select, sync::mpsc}; use tokio_util::sync::CancellationToken; diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..fa75113 --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,116 @@ +use std::{ + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; + +use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + net::{TcpListener, TcpStream, UdpSocket}, +}; +use tonic::Status; +use tracing::error; + +pub(crate) async fn create_tcp_listener(port: u16) -> Result { + TcpListener::bind(("0.0.0.0", port)) + .await + .map_err(map_bind_error) +} + +pub(crate) async fn create_udp_socket(port: u16) -> Result { + UdpSocket::bind(("0.0.0.0", port)) + .await + .map_err(map_bind_error) +} + +fn map_bind_error(err: std::io::Error) -> Status { + match err.kind() { + std::io::ErrorKind::AddrInUse => Status::already_exists("port already in use"), + std::io::ErrorKind::PermissionDenied => Status::permission_denied("permission denied"), + _ => { + error!("failed to bind port: {}", err); + Status::internal("failed to bind port") + } + } +} + +pub(crate) struct AsyncUdpSocket<'a> { + socket: &'a UdpSocket, +} + +impl<'a> AsyncUdpSocket<'a> { + pub(crate) fn new(socket: &'a UdpSocket) -> Self { + Self { socket } + } +} + +impl<'a> AsyncRead for AsyncUdpSocket<'a> { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut().socket.poll_recv_from(cx, buf) { + Poll::Ready(Ok(_addr)) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} + +impl<'a> AsyncWrite for AsyncUdpSocket<'a> { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.get_mut().socket.poll_send(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) // No-op for UDP + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) // No-op for UDP + } +} + +pub(crate) type DialResult = Result< + ( + Box, + Box, + ), + Box, +>; + +pub(crate) type DialFn = + fn(SocketAddr) -> Pin + Send>>; + +pub(crate) async fn dial_tcp(local_endpoint: SocketAddr) -> DialResult { + let local_conn = TcpStream::connect(local_endpoint).await?; + let (r, w) = local_conn.into_split(); + Ok((Box::new(r), Box::new(w))) +} + +pub(crate) async fn dial_udp(local_endpoint: SocketAddr) -> DialResult { + let local_addr: SocketAddr = if local_endpoint.is_ipv4() { + "0.0.0.0:0" + } else { + "[::]:0" + } + .parse()?; + let socket = UdpSocket::bind(local_addr).await?; + socket.connect(local_endpoint).await?; + let socket = Box::leak(Box::new(socket)); + Ok(( + Box::new(AsyncUdpSocket::new(socket)), + Box::new(AsyncUdpSocket::new(socket)), + )) +} From db7db97cc4f2acc2cda548f89257bd3cf4da1a77 Mon Sep 17 00:00:00 2001 From: sword-jin Date: Tue, 30 Jul 2024 01:09:51 +0000 Subject: [PATCH 3/5] refactor(server): add a dialer to tunnel to create local connection --- src/client/client.rs | 58 +++++++++++++++----------------------------- src/client/tunnel.rs | 16 +++++++++--- src/socket.rs | 31 +++++++++++++++++++++++ 3 files changed, 63 insertions(+), 42 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 2eb72d7..54b3c2a 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,28 +1,26 @@ use anyhow::{Context as _, Result}; use async_shutdown::{ShutdownManager, ShutdownSignal}; -use std::{net::SocketAddr, pin::Pin}; +use std::{net::SocketAddr, sync::Arc}; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use tonic::{transport::Channel, Response, Status, Streaming}; use tracing::{debug, error, info, instrument, span}; use tokio::{ io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}, - net::{TcpStream, UdpSocket}, select, sync::{mpsc, oneshot}, }; +use crate::socket::Dialer; use crate::{ constant, io::{StreamingReader, StreamingWriter, TrafficToServerWrapper}, pb::{ - self, control_command::Payload, traffic_to_server, tunnel::Config, + self, control_command::Payload, traffic_to_server, tunnel_service_client::TunnelServiceClient, ControlCommand, RegisterReq, TrafficToClient, TrafficToServer, }, - socket::{dial_tcp, dial_udp}, }; -use crate::{io::AsyncUdpSocket, socket::DialFn}; use super::tunnel::Tunnel; @@ -85,13 +83,13 @@ impl Client { ) -> Result> { let (entrypoint_tx, entrypoint_rx) = oneshot::channel(); let pb_tunnel = tunnel.config.to_pb_tunnel(tunnel.name); - let local_endpoint = tunnel.local_endpoint; + let dialer = tunnel.dialer; tokio::spawn(async move { let run_tunnel = self.handle_tunnel( shutdown.wait_shutdown_triggered(), pb_tunnel, - local_endpoint, + dialer, Some(move |entrypoint| { let _ = entrypoint_tx.send(entrypoint); }), @@ -186,11 +184,10 @@ impl Client { &self, shutdown: ShutdownSignal, tunnel: pb::Tunnel, - local_endpoint: SocketAddr, + dial: Dialer, hook: Option) + Send + 'static>, ) -> Result<()> { let mut rpc_client = self.grpc_client.clone(); - let is_udp = matches!(tunnel.config, Some(Config::Udp(_))); let register = self.register_tunnel(&mut rpc_client, tunnel); tokio::select! { @@ -203,8 +200,7 @@ impl Client { shutdown.clone(), rpc_client, register_resp, - local_endpoint, - is_udp, + dial, hook, ).await } @@ -218,8 +214,7 @@ impl Client { shutdown: ShutdownSignal, rpc_client: TunnelServiceClient, register_resp: tonic::Response>, - local_endpoint: SocketAddr, - is_udp: bool, + dialer: Dialer, mut hook: Option) + Send + 'static>, ) -> Result<()> { let mut control_stream = register_resp.into_inner(); @@ -231,14 +226,8 @@ impl Client { hook(entrypoint); } - self.start_streaming( - shutdown, - &mut control_stream, - rpc_client, - local_endpoint, - is_udp, - ) - .await + self.start_streaming(shutdown, &mut control_stream, rpc_client, dialer) + .await } async fn start_streaming( @@ -246,9 +235,9 @@ impl Client { shutdown: ShutdownSignal, control_stream: &mut Streaming, rpc_client: TunnelServiceClient, - local_endpoint: SocketAddr, - is_udp: bool, + dialer: Dialer, ) -> Result<()> { + let dialer = Arc::new(dialer); loop { tokio::select! { result = control_stream.next() => { @@ -270,8 +259,7 @@ impl Client { if let Err(err) = handle_work_traffic( rpc_client.clone() /* cheap clone operation */, &work.connection_id, - local_endpoint, - is_udp, + dialer.clone(), ).await { error!(err = ?err, "failed to handle work traffic"); } else { @@ -309,8 +297,7 @@ async fn new_rpc_client(control_addr: SocketAddr) -> Result, connection_id: &str, - local_endpoint: SocketAddr, - is_udp: bool, + dialer: Arc, ) -> Result<()> { // write response to the streaming_tx // rpc_client sends the data from reading the streaming_rx @@ -379,15 +366,8 @@ async fn handle_work_traffic( let wrapper = TrafficToServerWrapper::new(connection_id.clone()); let writer = StreamingWriter::new(streaming_tx.clone(), wrapper); - let dial_fn: DialFn; - if is_udp { - dial_fn = |endpoint| Box::pin(dial_udp(endpoint)); - } else { - dial_fn = |endpoint| Box::pin(dial_tcp(endpoint)); - } - tokio::spawn(async move { - match dial_fn(local_endpoint).await { + match dialer.dial().await { Ok((local_r, local_w)) => { local_conn_established_tx.send(()).await.unwrap(); if let Err(err) = @@ -398,10 +378,10 @@ async fn handle_work_traffic( } Err(err) => { error!( - ?local_endpoint, - ?err, - "failed to connect to local endpoint, so let's notify the server to close the user connection", - ); + local_endpoint = ?dialer.addr(), + ?err, + "failed to connect to local endpoint, so let's notify the server to close the user connection", + ); streaming_tx .send(TrafficToServer { diff --git a/src/client/tunnel.rs b/src/client/tunnel.rs index 39350a6..d11c803 100644 --- a/src/client/tunnel.rs +++ b/src/client/tunnel.rs @@ -4,13 +4,16 @@ use std::net::SocketAddr; use bytes::Bytes; -use crate::pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig}; +use crate::{ + pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig}, + socket::{dial_tcp, dial_udp, Dialer}, +}; /// Tunnel configuration for the client. #[derive(Debug)] pub struct Tunnel<'a> { pub(crate) name: &'a str, - pub(crate) local_endpoint: SocketAddr, + pub(crate) dialer: Dialer, pub(crate) config: RemoteConfig<'a>, } @@ -19,7 +22,14 @@ impl<'a> Tunnel<'a> { pub fn new(name: &'a str, local_endpoint: SocketAddr, config: RemoteConfig<'a>) -> Self { Self { name, - local_endpoint, + dialer: Dialer::new( + match config { + RemoteConfig::Tcp(_) => |endpoint| Box::pin(dial_tcp(endpoint)), + RemoteConfig::Udp(_) => |endpoint| Box::pin(dial_udp(endpoint)), + RemoteConfig::Http(_) => |endpoint| Box::pin(dial_tcp(endpoint)), + }, + local_endpoint, + ), config, } } diff --git a/src/socket.rs b/src/socket.rs index fa75113..5825d41 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -82,6 +82,37 @@ impl<'a> AsyncWrite for AsyncUdpSocket<'a> { } } +/// Dialer for connecting to a endpoint to get a async reader and a async writer. +/// +/// # Examples +/// +/// ``` +/// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +/// +/// let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); +/// let dialer = Dialer::new(|addr| Box::pin(dial_tcp(addr)), socket); +/// let (r, w) = dialer.dial().await.unwrap(); +/// ``` +#[derive(Debug)] +pub(crate) struct Dialer { + dial: DialFn, + addr: SocketAddr, +} + +impl Dialer { + pub(crate) fn new(dial: DialFn, addr: SocketAddr) -> Self { + Self { dial, addr } + } + + pub(crate) fn dial(&self) -> Pin + Send>> { + (self.dial)(self.addr) + } + + pub(crate) fn addr(&self) -> SocketAddr { + self.addr + } +} + pub(crate) type DialResult = Result< ( Box, From 5a50486bfe63232a0adcb1579f1f3f43f9cbfcb0 Mon Sep 17 00:00:00 2001 From: sword-jin Date: Tue, 30 Jul 2024 01:13:41 +0000 Subject: [PATCH 4/5] add some docs --- src/socket.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/socket.rs b/src/socket.rs index 5825d41..9ddf50f 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,3 +1,4 @@ +//! Socket utilities for creating listeners, async readers, writers, and dialers. use std::{ net::SocketAddr, pin::Pin, @@ -11,12 +12,14 @@ use tokio::{ use tonic::Status; use tracing::error; +/// create a tcp listener. pub(crate) async fn create_tcp_listener(port: u16) -> Result { TcpListener::bind(("0.0.0.0", port)) .await .map_err(map_bind_error) } +/// create a udp socket. pub(crate) async fn create_udp_socket(port: u16) -> Result { UdpSocket::bind(("0.0.0.0", port)) .await @@ -34,6 +37,7 @@ fn map_bind_error(err: std::io::Error) -> Status { } } +/// Async reader for a udp connection. pub(crate) struct AsyncUdpSocket<'a> { socket: &'a UdpSocket, } @@ -113,6 +117,7 @@ impl Dialer { } } +/// Result of dialing a endpoint. pub(crate) type DialResult = Result< ( Box, @@ -121,15 +126,18 @@ pub(crate) type DialResult = Result< Box, >; +/// Function for dialing a endpoint to get a async reader and a async writer. pub(crate) type DialFn = fn(SocketAddr) -> Pin + Send>>; +/// Dial a tcp endpoint. pub(crate) async fn dial_tcp(local_endpoint: SocketAddr) -> DialResult { let local_conn = TcpStream::connect(local_endpoint).await?; let (r, w) = local_conn.into_split(); Ok((Box::new(r), Box::new(w))) } +/// Dial a udp endpoint. pub(crate) async fn dial_udp(local_endpoint: SocketAddr) -> DialResult { let local_addr: SocketAddr = if local_endpoint.is_ipv4() { "0.0.0.0:0" From 8c289ef3d7b92e65a39c96b235da838cbfa13b51 Mon Sep 17 00:00:00 2001 From: sword-jin Date: Tue, 30 Jul 2024 01:36:42 +0000 Subject: [PATCH 5/5] add test case --- src/helper.rs | 2 +- src/server/control_server.rs | 2 +- src/socket.rs | 54 +++++++++++++++++++++++++++++------- tests/lib.rs | 2 ++ 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/helper.rs b/src/helper.rs index b1314e9..0fcc465 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -2,7 +2,7 @@ use tonic::Status; use crate::pb::RegisterReq; -pub fn validate_register_req(req: &RegisterReq) -> Option { +pub(crate) fn validate_register_req(req: &RegisterReq) -> Option { if req.tunnel.is_none() { return Some(Status::invalid_argument("tunnel is required")); } diff --git a/src/server/control_server.rs b/src/server/control_server.rs index ca3e690..817817a 100644 --- a/src/server/control_server.rs +++ b/src/server/control_server.rs @@ -38,7 +38,7 @@ type DataStream = Pin> + Send> /// /// We treat the control server is grpc server as well, in the concept, /// they are same thing. -/// Although the grpc server provides a [`crate::protocol::pb::tunnel_service_server::TunnelService::data`], +/// Although the grpc server provides a [`crate::pb::tunnel_service_server::TunnelService::data`], /// it's similar to the data server(a little), but in the `data` function body, /// the most of work is to forward the data from client to data server. /// We can understand this is a tunnel between the client and the data server. diff --git a/src/socket.rs b/src/socket.rs index 9ddf50f..7629b5a 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -87,16 +87,6 @@ impl<'a> AsyncWrite for AsyncUdpSocket<'a> { } /// Dialer for connecting to a endpoint to get a async reader and a async writer. -/// -/// # Examples -/// -/// ``` -/// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -/// -/// let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); -/// let dialer = Dialer::new(|addr| Box::pin(dial_tcp(addr)), socket); -/// let (r, w) = dialer.dial().await.unwrap(); -/// ``` #[derive(Debug)] pub(crate) struct Dialer { dial: DialFn, @@ -104,14 +94,21 @@ pub(crate) struct Dialer { } impl Dialer { + /// Create a new dialer. pub(crate) fn new(dial: DialFn, addr: SocketAddr) -> Self { Self { dial, addr } } + /// Dial the endpoint. + /// + /// # Returns + /// + /// A future that resolves to a async reader and a async writer. pub(crate) fn dial(&self) -> Pin + Send>> { (self.dial)(self.addr) } + /// Expose the address of the dialer. pub(crate) fn addr(&self) -> SocketAddr { self.addr } @@ -153,3 +150,40 @@ pub(crate) async fn dial_udp(local_endpoint: SocketAddr) -> DialResult { Box::new(AsyncUdpSocket::new(socket)), )) } + +#[cfg(test)] +mod test { + use super::*; + use std::net::TcpListener as StdTcpListener; + + #[tokio::test] + async fn test_tcp_listener_and_dialer() { + let port = free_port().unwrap(); + let listener = create_tcp_listener(port).await.unwrap(); + + let dialer = Dialer::new( + |addr| Box::pin(dial_tcp(addr)), + listener.local_addr().unwrap(), + ); + dialer.dial().await.unwrap(); + } + + #[tokio::test] + async fn test_udp_socket_and_dialer() { + let port = free_port().unwrap(); + let socket = create_udp_socket(port).await.unwrap(); + + let dialer = Dialer::new( + |addr| Box::pin(dial_udp(addr)), + socket.local_addr().unwrap(), + ); + dialer.dial().await.unwrap(); + } + + /// free_port returns a free port number for testing. + fn free_port() -> std::io::Result { + let listener = StdTcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + Ok(port) + } +} diff --git a/tests/lib.rs b/tests/lib.rs index d8ff12a..a6e93b7 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,3 +1,5 @@ +#![warn(missing_docs)] + mod common; use crate::common::free_port;