From f4e6f8bbd666de8ff63cd11471501e4175e21232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hiraku=20=F0=9F=8E=A9?= Date: Mon, 23 Sep 2024 02:24:06 +0200 Subject: [PATCH] fix: prevent reconnecting of client after invoking disconnect (#374) --- socketio/src/client/client.rs | 82 ++++++++++++++++++++++--------- socketio/src/client/raw_client.rs | 50 +++++++++++++++++-- 2 files changed, 104 insertions(+), 28 deletions(-) diff --git a/socketio/src/client/client.rs b/socketio/src/client/client.rs index fe924307..80301031 100644 --- a/socketio/src/client/client.rs +++ b/socketio/src/client/client.rs @@ -3,7 +3,7 @@ use std::{ time::Duration, }; -use super::{ClientBuilder, RawClient}; +use super::{raw_client::DisconnectReason, ClientBuilder, RawClient}; use crate::{ error::Result, packet::{Packet, PacketId}, @@ -165,6 +165,11 @@ impl Client { client.disconnect() } + fn do_disconnect(&self) -> Result<()> { + let client = self.client.read()?; + client.do_disconnect() + } + fn reconnect(&mut self) -> Result<()> { let mut reconnect_attempts = 0; let (reconnect, max_reconnect_attempts) = { @@ -174,6 +179,17 @@ impl Client { if reconnect { loop { + // Check if disconnect_reason is Manual + { + let disconnect_reason = { + let client = self.client.read()?; + client.get_disconnect_reason() + }; + if disconnect_reason == DisconnectReason::Manual { + // Exit the loop, stop reconnecting + break; + } + } if let Some(max_reconnect_attempts) = max_reconnect_attempts { reconnect_attempts += 1; if reconnect_attempts > max_reconnect_attempts { @@ -186,6 +202,12 @@ impl Client { } if self.do_reconnect().is_ok() { + // Reset disconnect_reason to Unknown after successful reconnection + { + let client = self.client.read()?; + let mut reason = client.disconnect_reason.write()?; + *reason = DisconnectReason::Unknown; + } break; } } @@ -213,29 +235,43 @@ impl Client { let mut self_clone = self.clone(); // Use thread to consume items in iterator in order to call callbacks std::thread::spawn(move || { - // tries to restart a poll cycle whenever a 'normal' error occurs, - // it just panics on network errors, in case the poll cycle returned - // `Result::Ok`, the server receives a close frame so it's safe to - // terminate - for packet in self_clone.iter() { - let should_reconnect = match packet { - Err(Error::IncompleteResponseFromEngineIo(_)) => { - //TODO: 0.3.X handle errors - //TODO: logging error - true + loop { + let next_item = self_clone.iter().next(); + match next_item { + Some(Ok(_packet)) => { + // Process packet normally + continue; + } + Some(Err(_)) => { + let should_reconnect = { + let disconnect_reason = { + let client = self_clone.client.read().unwrap(); + client.get_disconnect_reason() + }; + match disconnect_reason { + DisconnectReason::Unknown => { + let builder = self_clone.builder.lock().unwrap(); + builder.reconnect + } + DisconnectReason::Manual => false, + DisconnectReason::Server => { + let builder = self_clone.builder.lock().unwrap(); + builder.reconnect_on_disconnect + } + } + }; + if should_reconnect { + let _ = self_clone.do_disconnect(); + let _ = self_clone.reconnect(); + } else { + // No reconnection needed, exit the loop + break; + } + } + None => { + // Iterator has ended, exit the loop + break; } - Ok(Packet { - packet_type: PacketId::Disconnect, - .. - }) => match self_clone.builder.lock() { - Ok(builder) => builder.reconnect_on_disconnect, - Err(_) => false, - }, - _ => false, - }; - if should_reconnect { - let _ = self_clone.disconnect(); - let _ = self_clone.reconnect(); } } }); diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index 0686683f..fe6b154a 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -9,12 +9,23 @@ use crate::client::callback::{SocketAnyCallback, SocketCallback}; use crate::error::Result; use std::collections::HashMap; use std::ops::DerefMut; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; use std::time::Duration; use std::time::Instant; use crate::socket::Socket as InnerSocket; +#[derive(Default, Clone, Copy, PartialEq)] +pub enum DisconnectReason { + /// There is no known reason for the disconnect; likely a network error + #[default] + Unknown, + /// The user disconnected manually + Manual, + /// The server disconnected + Server, +} + /// Represents an `Ack` as given back to the caller. Holds the internal `id` as /// well as the current ack'ed state. Holds data which will be accessible as /// soon as the ack'ed state is set to true. An `Ack` that didn't get ack'ed @@ -41,6 +52,7 @@ pub struct RawClient { nsp: String, // Data send in the opening packet (commonly used as for auth) auth: Option, + pub(crate) disconnect_reason: Arc>, } impl RawClient { @@ -62,6 +74,7 @@ impl RawClient { on_any, outstanding_acks: Arc::new(Mutex::new(Vec::new())), auth, + disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())), }) } @@ -142,7 +155,14 @@ impl RawClient { /// /// ``` pub fn disconnect(&self) -> Result<()> { - let disconnect_packet = + *(self.disconnect_reason.write()?) = DisconnectReason::Manual; + self.do_disconnect() + } + + /// Disconnects this client the same way as `disconnect()` but + /// without setting the `DisconnectReason` to `DisconnectReason::Manual` + pub fn do_disconnect(&self) -> Result<()> { + let disconnect_packet = Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None); // TODO: logging @@ -153,6 +173,10 @@ impl RawClient { Ok(()) } + pub fn get_disconnect_reason(&self) -> DisconnectReason { + *self.disconnect_reason.read().unwrap() + } + /// Sends a message to the server but `alloc`s an `ack` to check whether the /// server responded in a given time span. This message takes an event, which /// could either be one of the common events like "message" or "error" or a @@ -222,18 +246,32 @@ impl RawClient { } pub(crate) fn poll(&self) -> Result> { + { + let disconnect_reason = *self.disconnect_reason.read()?; + if disconnect_reason == DisconnectReason::Manual { + // If disconnected manually, return Ok(None) to end iterator + return Ok(None); + } + } loop { match self.socket.poll() { Err(err) => { - self.callback(&Event::Error, err.to_string())?; - return Err(err); + // Check if the disconnection was manual + let disconnect_reason = *self.disconnect_reason.read()?; + if disconnect_reason == DisconnectReason::Manual { + // Return Ok(None) to signal the end of the iterator + return Ok(None); + } else { + self.callback(&Event::Error, err.to_string())?; + return Err(err); + } } Ok(Some(packet)) => { if packet.nsp == self.nsp { self.handle_socketio_packet(&packet)?; return Ok(Some(packet)); } else { - // Not our namespace continue polling + // Not our namespace, continue polling } } Ok(None) => return Ok(None), @@ -369,9 +407,11 @@ impl RawClient { } } PacketId::Connect => { + *(self.disconnect_reason.write()?) = DisconnectReason::default(); self.callback(&Event::Connect, "")?; } PacketId::Disconnect => { + *(self.disconnect_reason.write()?) = DisconnectReason::Server; self.callback(&Event::Close, "")?; } PacketId::ConnectError => {