Skip to content

Commit

Permalink
feat: add timeout and optional fallback (#23)
Browse files Browse the repository at this point in the history
* feat: add timeout and optional fallback

add timeout and optional fallback

* refactor: timeout handling in port reuse mode

change the logic in timeout handing in port reuse mode

---------

Co-authored-by: X1r0z <i@exp10it.io>
  • Loading branch information
zwxxb and X1r0z authored Dec 25, 2024
1 parent 44699ad commit 484bc6f
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 14 deletions.
9 changes: 7 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ pub enum Commands {

/// Fallback IP address, format: IP:PORT
#[arg(short, long)]
fallback: String,
fallback: Option<String>,

/// External IP address, format: IP
#[arg(short, long)]
external: String,

/// Timeout to stop port reuse
#[arg(short, long)]
timeout: Option<u64>,
},
}

Expand Down Expand Up @@ -151,10 +155,11 @@ pub async fn run(cli: Cli) -> Result<()> {
remote,
fallback,
external,
timeout,
} => {
info!("Starting reuse mode");

let reuse = Reuse::new(local, remote, fallback, external);
let reuse = Reuse::new(local, remote, fallback, external, timeout);
reuse.start().await?;
}
}
Expand Down
82 changes: 70 additions & 12 deletions src/reuse.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
use std::{io::Result, net::SocketAddr};

use socket2::{Domain, Protocol, Socket, Type};
use tokio::net::{TcpListener, TcpStream};
use tokio::{
net::{TcpListener, TcpStream},
sync::mpsc,
time,
};
use tracing::{error, info, warn};

use crate::tcp;

pub struct Reuse {
local_addr: String,
remote_addr: String,
fallback_addr: String,
fallback_addr: Option<String>,
external_ip: String,
timeout: Option<u64>,
}

impl Reuse {
pub fn new(
local_addr: String,
remote_addr: String,
fallback_addr: String,
fallback_addr: Option<String>,
external_ip: String,
timeout: Option<u64>,
) -> Self {
Self {
local_addr,
remote_addr,
fallback_addr,
external_ip,
timeout,
}
}

Expand All @@ -45,25 +52,61 @@ impl Reuse {
socket.bind(&local_addr.into())?;
socket.listen(128)?;

let listener = TcpListener::from_std(socket.into())?;
info!("Bind to {} success", self.local_addr);
let (tx, mut rx) = mpsc::channel(1);

loop {
let (client_stream, client_addr) = listener.accept().await?;
info!("Accepted connection from: {}", client_addr);
let reuse_task = async move {
let listener = TcpListener::from_std(socket.into()).expect("Failed to listen");
info!("Bind to {} success", local_addr);

loop {
let (client_stream, client_addr) = listener
.accept()
.await
.expect("Failed to accept connection");

info!("Accepted connection from: {}", client_addr);
tx.send((client_stream, client_addr)).await.unwrap();
}
};

match self.timeout {
Some(timeout) => {
tokio::spawn(time::timeout(
time::Duration::from_secs(timeout),
reuse_task,
));
}
None => {
tokio::spawn(reuse_task);
}
}

let mut alive_tasks = Vec::new();

while let Some((client_stream, client_addr)) = rx.recv().await {
let server_addr = if client_addr.ip().to_string() == self.external_ip {
info!("Redirecting connection to {}", &self.remote_addr);
&self.remote_addr
} else {
warn!("Invalid external IP, fallback to {}", &self.fallback_addr);
&self.fallback_addr
match &self.fallback_addr {
Some(fallback_addr) => {
warn!("Invalid external IP, fallback to {}", fallback_addr);
fallback_addr
}
None => {
warn!("Invalid external IP, abort the connection");
continue;
}
}
};

let server_stream = TcpStream::connect(&server_addr).await?;
let server_stream = TcpStream::connect(&server_addr)
.await
.expect(&format!("Failed to connect to {}", server_addr));

info!("Connect to {} success", server_addr);

tokio::spawn(async move {
let task = tokio::spawn(async move {
let client_stream = tcp::NetStream::Tcp(client_stream);
let remote_stream = tcp::NetStream::Tcp(server_stream);

Expand All @@ -73,6 +116,21 @@ impl Reuse {
}
info!("Close pipe: {} <=> {}", client_addr, local_addr);
});

alive_tasks.push(task);
}

if let Some(timeout) = self.timeout {
info!(
"Stop accepting new connections after {} elapsed, wait for alive tasks",
timeout
)
};

for task in alive_tasks {
task.await.expect("Failed to join task");
}

Ok(())
}
}

0 comments on commit 484bc6f

Please sign in to comment.