Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor rust udp #101

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 38 additions & 113 deletions src/client/client.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
use anyhow::{Context, Result};
use anyhow::{Context as _, Result};
use async_shutdown::{ShutdownManager, ShutdownSignal};
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tonic::{transport::Channel, Response, Status, Streaming};
use tracing::{debug, error, info, instrument, span};

use tokio::{
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
net::{TcpStream, UdpSocket},
select,
sync::{mpsc, oneshot},
};

use crate::socket::Dialer;
use crate::{
constant,
io::{StreamingReader, StreamingWriter, TrafficToServerWrapper},
pb::{
self, control_command::Payload, traffic_to_server, tunnel::Config,
self, control_command::Payload, traffic_to_server,
tunnel_service_client::TunnelServiceClient, ControlCommand, RegisterReq, TrafficToClient,
TrafficToServer,
},
Expand Down Expand Up @@ -83,13 +83,13 @@ impl Client {
) -> Result<Vec<String>> {
let (entrypoint_tx, entrypoint_rx) = oneshot::channel();
let pb_tunnel = tunnel.config.to_pb_tunnel(tunnel.name);
let local_endpoint = tunnel.local_endpoint;
let dialer = tunnel.dialer;

tokio::spawn(async move {
let run_tunnel = self.handle_tunnel(
shutdown.wait_shutdown_triggered(),
pb_tunnel,
local_endpoint,
dialer,
Some(move |entrypoint| {
let _ = entrypoint_tx.send(entrypoint);
}),
Expand Down Expand Up @@ -184,11 +184,10 @@ impl Client {
&self,
shutdown: ShutdownSignal<i8>,
tunnel: pb::Tunnel,
local_endpoint: SocketAddr,
dial: Dialer,
hook: Option<impl FnOnce(Vec<String>) + Send + 'static>,
) -> Result<()> {
let mut rpc_client = self.grpc_client.clone();
let is_udp = matches!(tunnel.config, Some(Config::Udp(_)));
let register = self.register_tunnel(&mut rpc_client, tunnel);

tokio::select! {
Expand All @@ -201,8 +200,7 @@ impl Client {
shutdown.clone(),
rpc_client,
register_resp,
local_endpoint,
is_udp,
dial,
hook,
).await
}
Expand All @@ -216,8 +214,7 @@ impl Client {
shutdown: ShutdownSignal<i8>,
rpc_client: TunnelServiceClient<Channel>,
register_resp: tonic::Response<Streaming<ControlCommand>>,
local_endpoint: SocketAddr,
is_udp: bool,
dialer: Dialer,
mut hook: Option<impl FnOnce(Vec<String>) + Send + 'static>,
) -> Result<()> {
let mut control_stream = register_resp.into_inner();
Expand All @@ -229,24 +226,18 @@ impl Client {
hook(entrypoint);
}

self.start_streaming(
shutdown,
&mut control_stream,
rpc_client,
local_endpoint,
is_udp,
)
.await
self.start_streaming(shutdown, &mut control_stream, rpc_client, dialer)
.await
}

async fn start_streaming(
&self,
shutdown: ShutdownSignal<i8>,
control_stream: &mut Streaming<ControlCommand>,
rpc_client: TunnelServiceClient<Channel>,
local_endpoint: SocketAddr,
is_udp: bool,
dialer: Dialer,
) -> Result<()> {
let dialer = Arc::new(dialer);
loop {
tokio::select! {
result = control_stream.next() => {
Expand All @@ -268,8 +259,7 @@ impl Client {
if let Err(err) = handle_work_traffic(
rpc_client.clone() /* cheap clone operation */,
&work.connection_id,
local_endpoint,
is_udp,
dialer.clone(),
).await {
error!(err = ?err, "failed to handle work traffic");
} else {
Expand Down Expand Up @@ -307,8 +297,7 @@ async fn new_rpc_client(control_addr: SocketAddr) -> Result<TunnelServiceClient<
async fn handle_work_traffic(
mut rpc_client: TunnelServiceClient<Channel>,
connection_id: &str,
local_endpoint: SocketAddr,
is_udp: bool,
dialer: Arc<Dialer>,
) -> Result<()> {
// write response to the streaming_tx
// rpc_client sends the data from reading the streaming_rx
Expand All @@ -331,7 +320,7 @@ async fn handle_work_traffic(

// write the data streaming response to transfer_tx,
// then forward_traffic_to_local can read the data from transfer_rx
let (transfer_tx, mut transfer_rx) = mpsc::channel::<TrafficToClient>(64);
let (transfer_tx, transfer_rx) = mpsc::channel::<TrafficToClient>(64);

let (local_conn_established_tx, local_conn_established_rx) = mpsc::channel::<()>(1);
let mut local_conn_established_rx = Some(local_conn_established_rx);
Expand Down Expand Up @@ -375,72 +364,24 @@ async fn handle_work_traffic(
});

let wrapper = TrafficToServerWrapper::new(connection_id.clone());
let mut writer = StreamingWriter::new(streaming_tx.clone(), wrapper);

if is_udp {
tokio::spawn(async move {
let local_addr: SocketAddr = if local_endpoint.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
}
.parse()
.unwrap();
let socket = UdpSocket::bind(local_addr).await;
if socket.is_err() {
error!(err = ?socket.err(), "failed to init udp socket, so let's notify the server to close the user connection");

streaming_tx
.send(TrafficToServer {
connection_id: connection_id.to_string(),
action: traffic_to_server::Action::Close as i32,
..Default::default()
})
.await
.context("terrible, the server may be crashed")
.unwrap();
return;
}

let socket = socket.unwrap();
let result = socket.connect(local_endpoint).await;
if let Err(err) = result {
error!(err = ?err, "failed to connect to local endpoint, so let's notify the server to close the user connection");
}
let writer = StreamingWriter::new(streaming_tx.clone(), wrapper);

local_conn_established_tx.send(()).await.unwrap();

let read_transfer_send_to_local = async {
while let Some(buf) = transfer_rx.recv().await {
socket.send(&buf.data).await.unwrap();
}
};

let read_local_send_to_server = async {
loop {
let mut buf = vec![0u8; 65507];
let result = socket.recv(&mut buf).await;
match result {
Ok(n) => {
writer.write_all(&buf[..n]).await.unwrap();
}
Err(err) => {
error!(err = ?err, "failed to read from local endpoint");
break;
}
}
tokio::spawn(async move {
match dialer.dial().await {
Ok((local_r, local_w)) => {
local_conn_established_tx.send(()).await.unwrap();
if let Err(err) =
transfer(local_r, local_w, StreamingReader::new(transfer_rx), writer).await
{
debug!("failed to forward traffic to local: {:?}", err);
}
writer.shutdown().await.unwrap();
};

tokio::join!(read_transfer_send_to_local, read_local_send_to_server);
});
} else {
tokio::spawn(async move {
// TODO(sword): use a connection pool to reuse the tcp connection
let local_conn = TcpStream::connect(local_endpoint).await;
if local_conn.is_err() {
error!("failed to connect to local endpoint {}, so let's notify the server to close the user connection", local_endpoint);
}
Err(err) => {
error!(
local_endpoint = ?dialer.addr(),
?err,
"failed to connect to local endpoint, so let's notify the server to close the user connection",
);

streaming_tx
.send(TrafficToServer {
Expand All @@ -451,42 +392,26 @@ async fn handle_work_traffic(
.await
.context("terrible, the server may be crashed")
.unwrap();
return;
}

let mut local_conn = local_conn.unwrap();
let (local_r, local_w) = local_conn.split();
local_conn_established_tx.send(()).await.unwrap();

if let Err(err) = forward_traffic_to_local(
local_r,
local_w,
StreamingReader::new(transfer_rx),
writer,
)
.await
{
debug!("failed to forward traffic to local: {:?}", err);
}
});
}
}
});

Ok(())
}

/// Forwards the traffic from the server to the local endpoint.
/// transfer the traffic from the server to the local endpoint.
///
/// Try to imagine the current client is yourself,
/// your mission is to forward the traffic from the server to the local,
/// then write the original response back to the server.
/// in this process, there are two underlying connections:
/// 1. remote <=> me
/// 2. me <=> local
async fn forward_traffic_to_local(
async fn transfer(
local_r: impl AsyncRead + Unpin,
mut local_w: impl AsyncWrite + Unpin,
remote_r: StreamingReader<TrafficToClient>,
mut remote_w: StreamingWriter<TrafficToServer>,
remote_r: impl AsyncRead + Unpin,
mut remote_w: impl AsyncWrite + Unpin,
) -> Result<()> {
let remote_to_me_to_local = async {
// read from remote, write to local
Expand Down
16 changes: 13 additions & 3 deletions src/client/tunnel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ use std::net::SocketAddr;

use bytes::Bytes;

use crate::pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig};
use crate::{
pb::{self, tunnel, HttpConfig, TcpConfig, UdpConfig},
socket::{dial_tcp, dial_udp, Dialer},
};

/// Tunnel configuration for the client.
#[derive(Debug)]
pub struct Tunnel<'a> {
pub(crate) name: &'a str,
pub(crate) local_endpoint: SocketAddr,
pub(crate) dialer: Dialer,
pub(crate) config: RemoteConfig<'a>,
}

Expand All @@ -19,7 +22,14 @@ impl<'a> Tunnel<'a> {
pub fn new(name: &'a str, local_endpoint: SocketAddr, config: RemoteConfig<'a>) -> Self {
Self {
name,
local_endpoint,
dialer: Dialer::new(
match config {
RemoteConfig::Tcp(_) => |endpoint| Box::pin(dial_tcp(endpoint)),
RemoteConfig::Udp(_) => |endpoint| Box::pin(dial_udp(endpoint)),
RemoteConfig::Http(_) => |endpoint| Box::pin(dial_tcp(endpoint)),
},
local_endpoint,
),
config,
}
}
Expand Down
29 changes: 2 additions & 27 deletions src/helper.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,8 @@
use crate::pb::RegisterReq;

use tokio::net::{TcpListener, UdpSocket};
use tonic::Status;
use tracing::error;

pub(crate) async fn create_tcp_listener(port: u16) -> Result<TcpListener, Status> {
TcpListener::bind(("0.0.0.0", port))
.await
.map_err(map_bind_error)
}

pub(crate) async fn create_udp_socket(port: u16) -> Result<UdpSocket, Status> {
UdpSocket::bind(("0.0.0.0", port))
.await
.map_err(map_bind_error)
}

fn map_bind_error(err: std::io::Error) -> Status {
match err.kind() {
std::io::ErrorKind::AddrInUse => Status::already_exists("port already in use"),
std::io::ErrorKind::PermissionDenied => Status::permission_denied("permission denied"),
_ => {
error!("failed to bind port: {}", err);
Status::internal("failed to bind port")
}
}
}
use crate::pb::RegisterReq;

pub fn validate_register_req(req: &RegisterReq) -> Option<Status> {
pub(crate) fn validate_register_req(req: &RegisterReq) -> Option<Status> {
if req.tunnel.is_none() {
return Some(Status::invalid_argument("tunnel is required"));
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub(crate) mod constant;
pub(crate) mod event;
pub(crate) mod helper;
pub(crate) mod io;
pub(crate) mod socket;

pub mod pb {
include!("gen/message.rs");
Expand Down
2 changes: 1 addition & 1 deletion src/server/control_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type DataStream = Pin<Box<dyn Stream<Item = GrpcResult<TrafficToClient>> + Send>
///
/// We treat the control server is grpc server as well, in the concept,
/// they are same thing.
/// Although the grpc server provides a [`crate::protocol::pb::tunnel_service_server::TunnelService::data`],
/// Although the grpc server provides a [`crate::pb::tunnel_service_server::TunnelService::data`],
/// it's similar to the data server(a little), but in the `data` function body,
/// the most of work is to forward the data from client to data server.
/// We can understand this is a tunnel between the client and the data server.
Expand Down
2 changes: 1 addition & 1 deletion src/server/tunnel/http.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::bridge::BridgeData;
use crate::event::{self, IncomingEventSender};
use crate::helper::create_tcp_listener;
use crate::socket::create_tcp_listener;

use super::{init_data_sender_bridge, BridgeResult};
use anyhow::{Context as _, Result};
Expand Down
2 changes: 1 addition & 1 deletion src/server/tunnel/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::{
event,
helper::create_tcp_listener,
io::{StreamingReader, StreamingWriter, VecWrapper},
server::tunnel::BridgeResult,
socket::create_tcp_listener,
};
use anyhow::Context as _;
use tokio::{
Expand Down
2 changes: 1 addition & 1 deletion src/server/tunnel/udp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::SocketAddr, sync::Arc};

use crate::{bridge::BridgeData, event, helper::create_udp_socket, server::tunnel::BridgeResult};
use crate::{bridge::BridgeData, event, server::tunnel::BridgeResult, socket::create_udp_socket};
use dashmap::DashMap;
use tokio::{net::UdpSocket, select, sync::mpsc};
use tokio_util::sync::CancellationToken;
Expand Down
Loading
Loading