diff --git a/src/lib.rs b/src/lib.rs index 56fdd82..4bb7f2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,11 +71,15 @@ pub enum Commands { /// Fallback IP address, format: IP:PORT #[arg(short, long)] - fallback: String, + fallback: Option, /// External IP address, format: IP #[arg(short, long)] external: String, + + /// Timeout to stop port reuse + #[arg(short, long)] + timeout: Option, }, } @@ -151,10 +155,11 @@ pub async fn run(cli: Cli) -> Result<()> { remote, fallback, external, + timeout, } => { info!("Starting reuse mode"); - let reuse = Reuse::new(local, remote, fallback, external); + let reuse = Reuse::new(local, remote, fallback, external, timeout); reuse.start().await?; } } diff --git a/src/reuse.rs b/src/reuse.rs index 698830e..ceb3935 100644 --- a/src/reuse.rs +++ b/src/reuse.rs @@ -1,7 +1,11 @@ use std::{io::Result, net::SocketAddr}; use socket2::{Domain, Protocol, Socket, Type}; -use tokio::net::{TcpListener, TcpStream}; +use tokio::{ + net::{TcpListener, TcpStream}, + sync::mpsc, + time, +}; use tracing::{error, info, warn}; use crate::tcp; @@ -9,22 +13,25 @@ use crate::tcp; pub struct Reuse { local_addr: String, remote_addr: String, - fallback_addr: String, + fallback_addr: Option, external_ip: String, + timeout: Option, } impl Reuse { pub fn new( local_addr: String, remote_addr: String, - fallback_addr: String, + fallback_addr: Option, external_ip: String, + timeout: Option, ) -> Self { Self { local_addr, remote_addr, fallback_addr, external_ip, + timeout, } } @@ -45,25 +52,61 @@ impl Reuse { socket.bind(&local_addr.into())?; socket.listen(128)?; - let listener = TcpListener::from_std(socket.into())?; - info!("Bind to {} success", self.local_addr); + let (tx, mut rx) = mpsc::channel(1); - loop { - let (client_stream, client_addr) = listener.accept().await?; - info!("Accepted connection from: {}", client_addr); + let reuse_task = async move { + let listener = TcpListener::from_std(socket.into()).expect("Failed to listen"); + info!("Bind to {} success", local_addr); + loop { + let (client_stream, client_addr) = listener + .accept() + .await + .expect("Failed to accept connection"); + + info!("Accepted connection from: {}", client_addr); + tx.send((client_stream, client_addr)).await.unwrap(); + } + }; + + match self.timeout { + Some(timeout) => { + tokio::spawn(time::timeout( + time::Duration::from_secs(timeout), + reuse_task, + )); + } + None => { + tokio::spawn(reuse_task); + } + } + + let mut alive_tasks = Vec::new(); + + while let Some((client_stream, client_addr)) = rx.recv().await { let server_addr = if client_addr.ip().to_string() == self.external_ip { info!("Redirecting connection to {}", &self.remote_addr); &self.remote_addr } else { - warn!("Invalid external IP, fallback to {}", &self.fallback_addr); - &self.fallback_addr + match &self.fallback_addr { + Some(fallback_addr) => { + warn!("Invalid external IP, fallback to {}", fallback_addr); + fallback_addr + } + None => { + warn!("Invalid external IP, abort the connection"); + continue; + } + } }; - let server_stream = TcpStream::connect(&server_addr).await?; + let server_stream = TcpStream::connect(&server_addr) + .await + .expect(&format!("Failed to connect to {}", server_addr)); + info!("Connect to {} success", server_addr); - tokio::spawn(async move { + let task = tokio::spawn(async move { let client_stream = tcp::NetStream::Tcp(client_stream); let remote_stream = tcp::NetStream::Tcp(server_stream); @@ -73,6 +116,21 @@ impl Reuse { } info!("Close pipe: {} <=> {}", client_addr, local_addr); }); + + alive_tasks.push(task); + } + + if let Some(timeout) = self.timeout { + info!( + "Stop accepting new connections after {} elapsed, wait for alive tasks", + timeout + ) + }; + + for task in alive_tasks { + task.await.expect("Failed to join task"); } + + Ok(()) } }