Skip to content

Commit

Permalink
Merge pull request #13 from X1r0z/bugfix/blocked-connections
Browse files Browse the repository at this point in the history
fix: use semaphore to resolve blocked connections
  • Loading branch information
X1r0z authored Dec 22, 2024
2 parents 64ef11b + 106cd82 commit fda96b1
Showing 1 changed file with 84 additions and 59 deletions.
143 changes: 84 additions & 59 deletions src/forward.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
use std::{io::Result, sync::Arc};

use tokio::net::{TcpListener, TcpStream, UdpSocket, UnixStream};
use tokio::{
join,
net::{TcpListener, TcpStream, UdpSocket, UnixStream},
sync,
};
use tracing::{error, info};

use crate::{crypto, tcp, udp};

pub struct Forward {
local_addrs: Vec<String>,
remote_addrs: Vec<String>,
socket: Option<String>,
udp: bool,
local_opts: Vec<bool>,
remote_opts: Vec<bool>,
socket: Option<String>,
udp: bool,
}

impl Forward {
Expand Down Expand Up @@ -91,8 +95,10 @@ impl Forward {
});

loop {
let (stream1, addr1) = listener1.accept().await?;
let (stream2, addr2) = listener2.accept().await?;
let (r1, r2) = join!(listener1.accept(), listener2.accept());

let (stream1, addr1) = r1?;
let (stream2, addr2) = r2?;

info!("Accept connection from {}", addr1);
info!("Accept connection from {}", addr2);
Expand Down Expand Up @@ -128,26 +134,26 @@ impl Forward {
});

loop {
let (stream, addr) = listener.accept().await?;
let remote = TcpStream::connect(&self.remote_addrs[0]).await?;
let (client_stream, client_addr) = listener.accept().await?;
let remote_stream = TcpStream::connect(&self.remote_addrs[0]).await?;

let remote_addr = remote_stream.peer_addr()?;

let peer_addr = remote.peer_addr()?;
info!("Accept connection from {}", client_addr);
info!("Connect to {} success", remote_addr);

let acceptor = acceptor.clone();
let connector = connector.clone();

info!("Accept connection from {}", addr);
info!("Connect to {} success", peer_addr);

tokio::spawn(async move {
let stream = tcp::NetStream::from_acceptor(stream, acceptor).await;
let remote = tcp::NetStream::from_connector(remote, connector).await;
let client_stream = tcp::NetStream::from_acceptor(client_stream, acceptor).await;
let remote_stream = tcp::NetStream::from_connector(remote_stream, connector).await;

info!("Open pipe: {} <=> {}", addr, peer_addr);
if let Err(e) = tcp::handle_forward(stream, remote).await {
info!("Open pipe: {} <=> {}", client_addr, remote_addr);
if let Err(e) = tcp::handle_forward(client_stream, remote_stream).await {
error!("failed to forward: {}", e)
}
info!("Close pipe: {} <=> {}", addr, peer_addr);
info!("Close pipe: {} <=> {}", client_addr, remote_addr);
});
}
}
Expand All @@ -163,25 +169,41 @@ impl Forward {
false => None,
});

// limit the number of concurrent connections
let semaphore = Arc::new(sync::Semaphore::new(32));

loop {
let stream1 = TcpStream::connect(&self.remote_addrs[0]).await?;
let stream2 = TcpStream::connect(&self.remote_addrs[1]).await?;
let permit = semaphore.clone().acquire_owned().await.unwrap();

let (r1, r2) = join!(
TcpStream::connect(&self.remote_addrs[0]),
TcpStream::connect(&self.remote_addrs[1])
);

let (stream1, stream2) = (r1?, r2?);

let peer_addr_1 = stream1.peer_addr()?;
let peer_addr_2 = stream2.peer_addr()?;
let addr1 = stream1.peer_addr()?;
let addr2 = stream2.peer_addr()?;

info!("Connect to {} success", peer_addr_1);
info!("Connect to {} success", peer_addr_2);
info!("Connect to {} success", addr1);
info!("Connect to {} success", addr2);

let connector1 = connector1.clone();
let connector2 = connector2.clone();

let stream1 = tcp::NetStream::from_connector(stream1, connector1).await;
let stream2 = tcp::NetStream::from_connector(stream2, connector2).await;
tokio::spawn(async move {
let stream1 = tcp::NetStream::from_connector(stream1, connector1).await;
let stream2 = tcp::NetStream::from_connector(stream2, connector2).await;

info!("Open pipe: {} <=> {}", addr1, addr2);
if let Err(e) = tcp::handle_forward(stream1, stream2).await {
error!("Failed to forward: {}", e)
}
info!("Close pipe: {} <=> {}", addr1, addr2);

info!("Open pipe: {} <=> {}", peer_addr_1, peer_addr_2);
tcp::handle_forward(stream1, stream2).await?;
info!("Close pipe: {} <=> {}", peer_addr_1, peer_addr_2)
// drop the permit to release the semaphore
drop(permit);
});
}
}

Expand All @@ -195,31 +217,25 @@ impl Forward {
});

loop {
let (local_stream, addr) = local_listener.accept().await?;
let unix_stream = UnixStream::connect(self.socket.as_ref().unwrap()).await?;
let unix_addr = self.socket.clone().unwrap();

let unix_addr = unix_stream.peer_addr()?;
let socket_path = unix_addr
.as_pathname()
.unwrap()
.to_str()
.unwrap()
.to_string();
let (client_stream, client_addr) = local_listener.accept().await?;
let unix_stream = UnixStream::connect(&unix_addr).await?;

info!("Accept connection from {}", addr);
info!("Connect to {} success", socket_path);
info!("Accept connection from {}", client_addr);
info!("Connect to {} success", unix_addr);

let acceptor = acceptor.clone();

tokio::spawn(async move {
let local_stream = tcp::NetStream::from_acceptor(local_stream, acceptor).await;
let client_stream = tcp::NetStream::from_acceptor(client_stream, acceptor).await;
let unix_stream = tcp::NetStream::Unix(unix_stream);

info!("Open pipe: {} <=> {}", socket_path, addr);
if let Err(e) = tcp::handle_forward(unix_stream, local_stream).await {
info!("Open pipe: {} <=> {}", unix_addr, client_addr);
if let Err(e) = tcp::handle_forward(client_stream, unix_stream).await {
error!("Failed to forward: {}", e)
}
info!("Close pipe: {} <=> {}", socket_path, addr);
info!("Close pipe: {} <=> {}", unix_addr, client_addr);
});
}
}
Expand All @@ -230,31 +246,40 @@ impl Forward {
false => None,
});

// limit the number of concurrent connections
let semaphore = Arc::new(sync::Semaphore::new(32));

loop {
let unix_stream = UnixStream::connect(self.socket.as_ref().unwrap()).await?;
let remote_stream = TcpStream::connect(&self.remote_addrs[0]).await?;
let permit = semaphore.clone().acquire_owned().await.unwrap();

let peer_addr = remote_stream.peer_addr()?;
let unix_addr = unix_stream.peer_addr()?;
let unix_addr = self.socket.clone().unwrap();
let remote_addr = self.remote_addrs[0].clone();

let socket_path = unix_addr
.as_pathname()
.unwrap()
.to_str()
.unwrap()
.to_string();
let (r1, r2) = join!(
UnixStream::connect(&unix_addr),
TcpStream::connect(&remote_addr)
);

info!("Connect to {} success", socket_path);
info!("Connect to {} success", peer_addr);
let (unix_stream, remote_stream) = (r1?, r2?);

info!("Connect to {} success", unix_addr);
info!("Connect to {} success", remote_addr);

let connector = connector.clone();

let unix_stream = tcp::NetStream::Unix(unix_stream);
let remote_stream = tcp::NetStream::from_connector(remote_stream, connector).await;
tokio::spawn(async move {
let unix_stream = tcp::NetStream::Unix(unix_stream);
let remote_stream = tcp::NetStream::from_connector(remote_stream, connector).await;

info!("Open pipe: {} <=> {}", socket_path, peer_addr);
tcp::handle_forward(unix_stream, remote_stream).await?;
info!("Close pipe: {} <=> {}", socket_path, peer_addr);
info!("Open pipe: {} <=> {}", unix_addr, remote_addr);
if let Err(e) = tcp::handle_forward(unix_stream, remote_stream).await {
error!("Failed to forward: {}", e)
}
info!("Close pipe: {} <=> {}", unix_addr, remote_addr);

// drop the permit to release the semaphore
drop(permit);
});
}
}

Expand Down

0 comments on commit fda96b1

Please sign in to comment.