diff --git a/src/forward.rs b/src/forward.rs index 1783b63..fabed09 100644 --- a/src/forward.rs +++ b/src/forward.rs @@ -1,6 +1,10 @@ use std::{io::Result, sync::Arc}; -use tokio::net::{TcpListener, TcpStream, UdpSocket, UnixStream}; +use tokio::{ + join, + net::{TcpListener, TcpStream, UdpSocket, UnixStream}, + sync, +}; use tracing::{error, info}; use crate::{crypto, tcp, udp}; @@ -91,8 +95,10 @@ impl Forward { }); loop { - let (stream1, addr1) = listener1.accept().await?; - let (stream2, addr2) = listener2.accept().await?; + let (r1, r2) = join!(listener1.accept(), listener2.accept()); + + let (stream1, addr1) = r1?; + let (stream2, addr2) = r2?; info!("Accept connection from {}", addr1); info!("Accept connection from {}", addr2); @@ -128,26 +134,26 @@ impl Forward { }); loop { - let (stream, addr) = listener.accept().await?; + let (client, client_addr) = listener.accept().await?; let remote = TcpStream::connect(&self.remote_addrs[0]).await?; - let peer_addr = remote.peer_addr()?; + let remote_addr = remote.peer_addr()?; + + info!("Accept connection from {}", client_addr); + info!("Connect to {} success", remote_addr); let acceptor = acceptor.clone(); let connector = connector.clone(); - info!("Accept connection from {}", addr); - info!("Connect to {} success", peer_addr); - tokio::spawn(async move { - let stream = tcp::NetStream::from_acceptor(stream, acceptor).await; + let client = tcp::NetStream::from_acceptor(client, acceptor).await; let remote = tcp::NetStream::from_connector(remote, connector).await; - info!("Open pipe: {} <=> {}", addr, peer_addr); - if let Err(e) = tcp::handle_forward(stream, remote).await { + info!("Open pipe: {} <=> {}", client_addr, remote_addr); + if let Err(e) = tcp::handle_forward(client, remote).await { error!("failed to forward: {}", e) } - info!("Close pipe: {} <=> {}", addr, peer_addr); + info!("Close pipe: {} <=> {}", client_addr, remote_addr); }); } } @@ -163,25 +169,41 @@ impl Forward { false => None, }); + // limit the number of concurrent connections + let semaphore = Arc::new(sync::Semaphore::new(32)); + loop { - let stream1 = TcpStream::connect(&self.remote_addrs[0]).await?; - let stream2 = TcpStream::connect(&self.remote_addrs[1]).await?; + let permit = semaphore.clone().acquire_owned().await.unwrap(); - let peer_addr_1 = stream1.peer_addr()?; - let peer_addr_2 = stream2.peer_addr()?; + let (r1, r2) = join!( + TcpStream::connect(&self.remote_addrs[0]), + TcpStream::connect(&self.remote_addrs[1]) + ); - info!("Connect to {} success", peer_addr_1); - info!("Connect to {} success", peer_addr_2); + let (stream1, stream2) = (r1?, r2?); + + let addr1 = stream1.peer_addr()?; + let addr2 = stream2.peer_addr()?; + + info!("Connect to {} success", addr1); + info!("Connect to {} success", addr2); let connector1 = connector1.clone(); let connector2 = connector2.clone(); - let stream1 = tcp::NetStream::from_connector(stream1, connector1).await; - let stream2 = tcp::NetStream::from_connector(stream2, connector2).await; + tokio::spawn(async move { + let stream1 = tcp::NetStream::from_connector(stream1, connector1).await; + let stream2 = tcp::NetStream::from_connector(stream2, connector2).await; - info!("Open pipe: {} <=> {}", peer_addr_1, peer_addr_2); - tcp::handle_forward(stream1, stream2).await?; - info!("Close pipe: {} <=> {}", peer_addr_1, peer_addr_2) + info!("Open pipe: {} <=> {}", addr1, addr2); + if let Err(e) = tcp::handle_forward(stream1, stream2).await { + error!("Failed to forward: {}", e) + } + info!("Close pipe: {} <=> {}", addr1, addr2); + + // drop the permit to release the semaphore + drop(permit); + }); } } @@ -195,19 +217,13 @@ impl Forward { }); loop { - let (local_stream, addr) = local_listener.accept().await?; - let unix_stream = UnixStream::connect(self.socket.as_ref().unwrap()).await?; + let unix_addr = self.socket.clone().unwrap(); - let unix_addr = unix_stream.peer_addr()?; - let socket_path = unix_addr - .as_pathname() - .unwrap() - .to_str() - .unwrap() - .to_string(); + let (local_stream, client_addr) = local_listener.accept().await?; + let unix_stream = UnixStream::connect(&unix_addr).await?; - info!("Accept connection from {}", addr); - info!("Connect to {} success", socket_path); + info!("Accept connection from {}", client_addr); + info!("Connect to {} success", unix_addr); let acceptor = acceptor.clone(); @@ -215,11 +231,11 @@ impl Forward { let local_stream = tcp::NetStream::from_acceptor(local_stream, acceptor).await; let unix_stream = tcp::NetStream::Unix(unix_stream); - info!("Open pipe: {} <=> {}", socket_path, addr); + info!("Open pipe: {} <=> {}", unix_addr, client_addr); if let Err(e) = tcp::handle_forward(unix_stream, local_stream).await { error!("Failed to forward: {}", e) } - info!("Close pipe: {} <=> {}", socket_path, addr); + info!("Close pipe: {} <=> {}", unix_addr, client_addr); }); } } @@ -231,30 +247,23 @@ impl Forward { }); loop { - let unix_stream = UnixStream::connect(self.socket.as_ref().unwrap()).await?; - let remote_stream = TcpStream::connect(&self.remote_addrs[0]).await?; - - let peer_addr = remote_stream.peer_addr()?; - let unix_addr = unix_stream.peer_addr()?; + let unix_addr = self.socket.clone().unwrap(); + let remote_addr = self.remote_addrs[0].clone(); - let socket_path = unix_addr - .as_pathname() - .unwrap() - .to_str() - .unwrap() - .to_string(); + let unix_stream = UnixStream::connect(&unix_addr).await?; + let remote_stream = TcpStream::connect(&remote_addr).await?; - info!("Connect to {} success", socket_path); - info!("Connect to {} success", peer_addr); + info!("Connect to {} success", unix_addr); + info!("Connect to {} success", remote_addr); let connector = connector.clone(); let unix_stream = tcp::NetStream::Unix(unix_stream); let remote_stream = tcp::NetStream::from_connector(remote_stream, connector).await; - info!("Open pipe: {} <=> {}", socket_path, peer_addr); + info!("Open pipe: {} <=> {}", unix_addr, remote_addr); tcp::handle_forward(unix_stream, remote_stream).await?; - info!("Close pipe: {} <=> {}", socket_path, peer_addr); + info!("Close pipe: {} <=> {}", unix_addr, remote_addr); } }