Skip to content

Commit

Permalink
Introduce DisconnectReason enum
Browse files Browse the repository at this point in the history
The enum replaces the need for multiple `AtomicBool`'s to maintain the
disconnection reason. This makes the code easier to read and more
ergonomic to maintain the state.
  • Loading branch information
rageshkrishna committed Mar 22, 2024
1 parent 795cdbd commit 210fe95
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions socketio/src/asynchronous/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ use crate::{
Event, Payload,
};

#[derive(Default)]
enum DisconnectReason {
/// There is no known reason for the disconnect; likely a network error
#[default]
Unknown,
/// The user disconnected manually
Manual,
/// The server disconnected
Server,
}

/// A socket which handles communication with the server. It's initialized with
/// a specific address as well as an optional namespace to connect to. If `None`
/// is given the client will connect to the default namespace `"/"`.
Expand All @@ -42,8 +53,7 @@ pub struct Client {
// Data send in the opening packet (commonly used as for auth)
auth: Option<serde_json::Value>,
builder: Arc<RwLock<ClientBuilder>>,
manually_disconnected: Arc<AtomicBool>,
server_disconnected: Arc<AtomicBool>,
disconnect_reason: Arc<RwLock<DisconnectReason>>,
}

impl Client {
Expand All @@ -58,8 +68,7 @@ impl Client {
outstanding_acks: Arc::new(RwLock::new(Vec::new())),
auth: builder.auth.clone(),
builder: Arc::new(RwLock::new(builder)),
manually_disconnected: Arc::new(AtomicBool::new(false)),
server_disconnected: Arc::new(AtomicBool::new(false)),
disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())),
})
}

Expand All @@ -85,6 +94,9 @@ impl Client {
// New inner socket that can be connected
let mut client_socket = self.socket.write().await;
*client_socket = socket;

// Now that we have replaced `self.socket`, we drop the write lock
// because the `connect` method we call below will need to use it
drop(client_socket);

self.connect().await?;
Expand All @@ -98,6 +110,8 @@ impl Client {
let reconnect_delay_min = builder.reconnect_delay_min;
let reconnect_delay_max = builder.reconnect_delay_max;
let max_reconnect_attempts = builder.max_reconnect_attempts;
let reconnect = builder.reconnect;
let reconnect_on_disconnect = builder.reconnect_on_disconnect;
drop(builder);

let mut client_clone = self.clone();
Expand All @@ -115,7 +129,13 @@ impl Client {
// Drop the stream so we can once again use `socket_clone` as mutable
drop(stream);

if client_clone.should_reconnect().await {
let should_reconnect = match *(client_clone.disconnect_reason.read().await) {
DisconnectReason::Unknown => reconnect,
DisconnectReason::Manual => false,
DisconnectReason::Server => reconnect_on_disconnect,
};

if should_reconnect {
let mut reconnect_attempts = 0;
let mut backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(reconnect_delay_min))
Expand Down Expand Up @@ -233,7 +253,7 @@ impl Client {
/// }
/// ```
pub async fn disconnect(&self) -> Result<()> {
self.manually_disconnected.store(true, Ordering::Release);
*(self.disconnect_reason.write().await) = DisconnectReason::Manual;

let disconnect_packet =
Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None);
Expand Down Expand Up @@ -456,11 +476,11 @@ impl Client {
}
}
PacketId::Connect => {
self.server_disconnected.store(false, Ordering::Release);
*(self.disconnect_reason.write().await) = DisconnectReason::default();
self.callback(&Event::Connect, "").await?;
}
PacketId::Disconnect => {
self.server_disconnected.store(true, Ordering::Release);
*(self.disconnect_reason.write().await) = DisconnectReason::Server;
self.callback(&Event::Close, "").await?;
}
PacketId::ConnectError => {
Expand All @@ -484,31 +504,15 @@ impl Client {
Ok(())
}

/// Indicates whether the client should try to reconnect
pub(crate) async fn should_reconnect(&self) -> bool {
let manually_disconnected = self.manually_disconnected.load(Ordering::Acquire);
let server_disconnected = self.server_disconnected.load(Ordering::Acquire);

if server_disconnected {
self.builder.read().await.reconnect_on_disconnect
} else {
!manually_disconnected
}
}

/// Returns the packet stream for the client.
pub(crate) fn as_stream<'a>(
&'a self,
) -> Pin<Box<dyn Stream<Item = Result<Packet>> + Send + 'a>> {
let socket_clone = self.socket.clone();
let socket_clone = (*self.socket.blocking_read()).clone();

stream::unfold(socket_clone, |socket| async {
let mut socket_read = {
let s = socket.read().await;
(*s).clone()
};
stream::unfold(socket_clone, |mut socket| async {
// wait for the next payload
let packet: Option<std::result::Result<Packet, Error>> = socket_read.next().await;
let packet: Option<std::result::Result<Packet, Error>> = socket.next().await;
match packet {
// end the stream if the underlying one is closed
None => None,
Expand Down

0 comments on commit 210fe95

Please sign in to comment.