Skip to content

Commit

Permalink
udp会话添加超时机制
Browse files Browse the repository at this point in the history
  • Loading branch information
tkzcfc committed Jul 23, 2024
1 parent a29e96d commit fcea957
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 78 deletions.
7 changes: 2 additions & 5 deletions np_base/src/net/udp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,8 @@ pub async fn run_server(
};

select! {
_= recv_task => {
},
_ = shutdown => {
info!("UDP Server shutting down");
}
_= recv_task => {},
_ = shutdown => { info!("UDP Server shutting down"); }
};

// 销毁notify_shutdown 是为了触发 udp_session即将停止服务,立即停止其他操作
Expand Down
33 changes: 30 additions & 3 deletions np_base/src/net/udp_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@ use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::broadcast;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
use tokio::sync::{broadcast, RwLock};
use tokio::task::yield_now;
use tokio::time::sleep;
use tokio::time::{Duration, Instant};

async fn poll_read(
addr: SocketAddr,
delegate: &mut Box<dyn SessionDelegate>,
mut udp_recv_receiver: UnboundedReceiver<Vec<u8>>,
last_active_time: Arc<RwLock<Instant>>,
) {
let write_timeout = Duration::from_secs(1);
while let Some(data) = udp_recv_receiver.recv().await {
if last_active_time.read().await.elapsed() >= write_timeout {
let mut instant_write = last_active_time.write().await;
*instant_write = Instant::now();
}
if let Err(err) = delegate.on_recv_frame(data).await {
error!("[{addr}] on_recv_frame error: {err}");
break;
Expand All @@ -28,8 +35,15 @@ async fn poll_write(
addr: SocketAddr,
mut delegate_receiver: UnboundedReceiver<WriterMessage>,
socket: Arc<UdpSocket>,
last_active_time: Arc<RwLock<Instant>>,
) {
let write_timeout = Duration::from_secs(1);
while let Some(message) = delegate_receiver.recv().await {
if last_active_time.read().await.elapsed() >= write_timeout {
let mut instant_write = last_active_time.write().await;
*instant_write = Instant::now();
}

match message {
WriterMessage::Close => break,
WriterMessage::CloseDelayed(duration) => {
Expand All @@ -54,6 +68,16 @@ async fn poll_write(
delegate_receiver.close();
}

async fn poll_timeout(last_active_time: Arc<RwLock<Instant>>) {
let timeout = Duration::from_secs(10);
loop {
sleep(Duration::from_secs(1)).await;
if last_active_time.read().await.elapsed() > timeout {
break;
}
}
}

/// run
///
/// [`session_id`] 会话id
Expand Down Expand Up @@ -85,9 +109,12 @@ pub async fn run(
return;
}

let last_active_time = Arc::new(RwLock::new(Instant::now()));

select! {
_= poll_read(addr, &mut delegate, udp_recv_receiver) => {},
_= poll_write(addr, delegate_receiver, socket) => {},
_= poll_read(addr, &mut delegate, udp_recv_receiver, last_active_time.clone()) => {},
_= poll_write(addr, delegate_receiver, socket, last_active_time.clone()) => {},
_= poll_timeout(last_active_time) => {},
_ = shutdown.recv() => {}
}

Expand Down
4 changes: 3 additions & 1 deletion np_base/src/proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ mod tests {
})
});

let mut inlet = Inlet::new(InletProxyType::TCP, "".into());
let mut inlet = Inlet::new("".into());
inlet
.start(
InletProxyType::TCP,
"0.0.0.0:4000".into(),
"www.baidu.com:80".into(),
false,
Expand All @@ -73,6 +74,7 @@ mod tests {
inlet.stop().await;
inlet
.start(
InletProxyType::TCP,
"0.0.0.0:4000".into(),
"www.baidu.com:80".into(),
false,
Expand Down
120 changes: 80 additions & 40 deletions np_base/src/proxy/outlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ use bytes::BytesMut;
use log::{debug, error, info};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, WriteHalf};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::Mutex;
use tokio::select;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{sleep, Instant};

type TcpWriter = Mutex<WriteHalf<TcpStream>>;

enum ClientType {
TCP(TcpWriter, bool, EncryptionMethod, Vec<u8>),
UDP(Arc<UdpSocket>, bool, EncryptionMethod, Vec<u8>),
UDP(Arc<UdpSocket>, Arc<RwLock<Instant>>, bool, EncryptionMethod, Vec<u8>),
}

pub struct Outlet {
Expand Down Expand Up @@ -47,7 +50,7 @@ impl Outlet {
addr,
encryption_method,
encryption_key,
client_addr
client_addr,
) => {
let output_callback = self.on_output_callback.clone();
if let Err(err) = self
Expand All @@ -61,7 +64,12 @@ impl Outlet {
)
.await
{
error!("Failed to connect to {}, error: {}, remote client addr {}", addr, err.to_string(), client_addr);
error!(
"Failed to connect to {}, error: {}, remote client addr {}",
addr,
err.to_string(),
client_addr
);
tokio::spawn(async move {
output_callback(ProxyMessage::O2iConnect(
session_id,
Expand All @@ -71,7 +79,10 @@ impl Outlet {
.await
});
} else {
info!("Successfully connected to {}, remote client addr {}", addr, client_addr);
info!(
"Successfully connected to {}, remote client addr {}",
addr, client_addr
);
tokio::spawn(async move {
output_callback(ProxyMessage::O2iConnect(session_id, true, "".into()))
.await;
Expand Down Expand Up @@ -124,7 +135,7 @@ impl Outlet {
ClientType::TCP(client, is_compressed, encryption_method, encryption_key),
);
} else {
let client = udp_connect(
let (client, last_active_time) = udp_connect(
addr,
session_id,
is_compressed,
Expand All @@ -133,11 +144,11 @@ impl Outlet {
self.on_output_callback.clone(),
self.client_map.clone(),
)
.await?;
.await?;

self.client_map.lock().await.insert(
session_id,
ClientType::UDP(client, is_compressed, encryption_method, encryption_key),
ClientType::UDP(client, last_active_time, is_compressed, encryption_method, encryption_key),
);
}

Expand All @@ -156,21 +167,26 @@ impl Outlet {
)?;
writer.lock().await.write_all(&data).await?;
}
ClientType::UDP(socket, is_compressed, encryption_method, encryption_key) => {
ClientType::UDP(socket, last_active_time, is_compressed, encryption_method, encryption_key) => {
data = decode_data(
data,
is_compressed.clone(),
encryption_method,
encryption_key,
)?;
socket.send(&data).await?;

if last_active_time.read().await.elapsed() >= Duration::from_secs(1) {
let mut instant_write = last_active_time.write().await;
*instant_write = Instant::now();
}
}
}
}
Ok(())
}

async fn on_i2o_disconnect(&self, session_id: u32) -> anyhow::Result<()>{
async fn on_i2o_disconnect(&self, session_id: u32) -> anyhow::Result<()> {
info!("disconnect session: {session_id}");

if let Some(client_type) = self.client_map.lock().await.remove(&session_id) {
Expand Down Expand Up @@ -236,7 +252,6 @@ async fn tcp_connect(
Ok(Mutex::new(writer))
}


async fn udp_connect(
addr: String,
session_id: u32,
Expand All @@ -245,53 +260,78 @@ async fn udp_connect(
encryption_key: Vec<u8>,
on_output_callback: OutputFuncType,
client_map: Arc<Mutex<HashMap<u32, ClientType>>>,
) -> anyhow::Result<Arc<UdpSocket>> {
) -> anyhow::Result<(Arc<UdpSocket>, Arc<RwLock<Instant>>)> {
let socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
socket.connect(addr).await?;

let last_active_time = Arc::new(RwLock::new(Instant::now()));
let last_active_time_cloned = last_active_time.clone();

let socket_cloned = socket.clone();
tokio::spawn(async move {
let mut buffer = vec![0u8; 65536];
loop {
match socket.recv_buf(&mut buffer).await {
Ok(size)=>{
let received_data = &buffer[..size];
if size == 0 {
break;
}
match encode_data(
received_data.to_vec(),
is_compressed,
&encryption_method,
&encryption_key,
) {
Ok(data) => {
on_output_callback(ProxyMessage::O2iRecvData(
session_id,
data,
))
.await;
}
Err(err) => {
error!("Data encryption error: {}", err);
let last_active_time_1 = last_active_time.clone();
let recv_task = async {
let write_timeout = Duration::from_secs(1);
let mut buffer = vec![0u8; 65536];
loop {
match socket.recv_buf(&mut buffer).await {
Ok(size) => {
let received_data = &buffer[..size];
if size == 0 {
break;
}

if last_active_time_1.read().await.elapsed() >= write_timeout {
let mut instant_write = last_active_time_1.write().await;
*instant_write = Instant::now();
}

match encode_data(
received_data.to_vec(),
is_compressed,
&encryption_method,
&encryption_key,
) {
Ok(data) => {
on_output_callback(ProxyMessage::O2iRecvData(session_id, data))
.await;
}
Err(err) => {
error!("Data encryption error: {}", err);
break;
}
}
}
},
Err(err)=>{
error!("Udp recv error: {}", err);
Err(err) => {
error!("Udp recv error: {}", err);
break;
}
}
}
};

let timeout_task = async {
let timeout = Duration::from_secs(10);
loop {
sleep(Duration::from_secs(1)).await;
if last_active_time.read().await.elapsed() > timeout {
break;
}
}
};

select! {
_=recv_task => {},
_=timeout_task => {},
}

on_output_callback(ProxyMessage::O2iDisconnect(session_id)).await;
client_map.lock().await.remove(&session_id);
});

Ok(socket_cloned)
Ok((socket_cloned, last_active_time_cloned))
}


fn decode_data(
mut data: Vec<u8>,
is_compressed: bool,
Expand Down
Loading

0 comments on commit fcea957

Please sign in to comment.