Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: manual disconnect not preventing reconnection #466

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions socketio/src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
time::Duration,
};

use super::{ClientBuilder, RawClient};
use super::{raw_client::DisconnectReason, ClientBuilder, RawClient};
use crate::{
error::Result,
packet::{Packet, PacketId},
Expand Down Expand Up @@ -165,6 +165,11 @@ impl Client {
client.disconnect()
}

fn do_disconnect(&self) -> Result<()> {
let client = self.client.read()?;
client.do_disconnect()
}

fn reconnect(&mut self) -> Result<()> {
let mut reconnect_attempts = 0;
let (reconnect, max_reconnect_attempts) = {
Expand All @@ -174,6 +179,17 @@ impl Client {

if reconnect {
loop {
// Check if disconnect_reason is Manual
{
let disconnect_reason = {
let client = self.client.read()?;
client.get_disconnect_reason()
};
if disconnect_reason == DisconnectReason::Manual {
// Exit the loop, stop reconnecting
break;
}
}
if let Some(max_reconnect_attempts) = max_reconnect_attempts {
reconnect_attempts += 1;
if reconnect_attempts > max_reconnect_attempts {
Expand All @@ -186,6 +202,12 @@ impl Client {
}

if self.do_reconnect().is_ok() {
// Reset disconnect_reason to Unknown after successful reconnection
{
let client = self.client.read()?;
let mut reason = client.disconnect_reason.write()?;
*reason = DisconnectReason::Unknown;
}
break;
}
}
Expand Down Expand Up @@ -213,29 +235,43 @@ impl Client {
let mut self_clone = self.clone();
// Use thread to consume items in iterator in order to call callbacks
std::thread::spawn(move || {
// tries to restart a poll cycle whenever a 'normal' error occurs,
// it just panics on network errors, in case the poll cycle returned
// `Result::Ok`, the server receives a close frame so it's safe to
// terminate
for packet in self_clone.iter() {
let should_reconnect = match packet {
Err(Error::IncompleteResponseFromEngineIo(_)) => {
//TODO: 0.3.X handle errors
//TODO: logging error
true
loop {
let next_item = self_clone.iter().next();
match next_item {
Some(Ok(_packet)) => {
// Process packet normally
continue;
}
Some(Err(_)) => {
let should_reconnect = {
let disconnect_reason = {
let client = self_clone.client.read().unwrap();
client.get_disconnect_reason()
};
match disconnect_reason {
DisconnectReason::Unknown => {
let builder = self_clone.builder.lock().unwrap();
builder.reconnect
}
DisconnectReason::Manual => false,
DisconnectReason::Server => {
let builder = self_clone.builder.lock().unwrap();
builder.reconnect_on_disconnect
}
}
};
if should_reconnect {
let _ = self_clone.do_disconnect();
let _ = self_clone.reconnect();
} else {
// No reconnection needed, exit the loop
break;
}
}
None => {
// Iterator has ended, exit the loop
break;
}
Ok(Packet {
packet_type: PacketId::Disconnect,
..
}) => match self_clone.builder.lock() {
Ok(builder) => builder.reconnect_on_disconnect,
Err(_) => false,
},
_ => false,
};
if should_reconnect {
let _ = self_clone.disconnect();
let _ = self_clone.reconnect();
}
}
});
Expand Down
50 changes: 45 additions & 5 deletions socketio/src/client/raw_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@ use crate::client::callback::{SocketAnyCallback, SocketCallback};
use crate::error::Result;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
use std::time::Instant;

use crate::socket::Socket as InnerSocket;

#[derive(Default, Clone, Copy, PartialEq)]
pub enum DisconnectReason {
/// There is no known reason for the disconnect; likely a network error
#[default]
Unknown,
/// The user disconnected manually
Manual,
/// The server disconnected
Server,
}

/// Represents an `Ack` as given back to the caller. Holds the internal `id` as
/// well as the current ack'ed state. Holds data which will be accessible as
/// soon as the ack'ed state is set to true. An `Ack` that didn't get ack'ed
Expand All @@ -41,6 +52,7 @@ pub struct RawClient {
nsp: String,
// Data send in the opening packet (commonly used as for auth)
auth: Option<Value>,
pub(crate) disconnect_reason: Arc<RwLock<DisconnectReason>>,
}

impl RawClient {
Expand All @@ -62,6 +74,7 @@ impl RawClient {
on_any,
outstanding_acks: Arc::new(Mutex::new(Vec::new())),
auth,
disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())),
})
}

Expand Down Expand Up @@ -142,7 +155,14 @@ impl RawClient {
///
/// ```
pub fn disconnect(&self) -> Result<()> {
let disconnect_packet =
*(self.disconnect_reason.write()?) = DisconnectReason::Manual;
self.do_disconnect()
}

/// Disconnects this client the same way as `disconnect()` but
/// without setting the `DisconnectReason` to `DisconnectReason::Manual`
pub fn do_disconnect(&self) -> Result<()> {
let disconnect_packet =
Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None);

// TODO: logging
Expand All @@ -153,6 +173,10 @@ impl RawClient {
Ok(())
}

pub fn get_disconnect_reason(&self) -> DisconnectReason {
*self.disconnect_reason.read().unwrap()
}

/// Sends a message to the server but `alloc`s an `ack` to check whether the
/// server responded in a given time span. This message takes an event, which
/// could either be one of the common events like "message" or "error" or a
Expand Down Expand Up @@ -222,18 +246,32 @@ impl RawClient {
}

pub(crate) fn poll(&self) -> Result<Option<Packet>> {
{
let disconnect_reason = *self.disconnect_reason.read()?;
if disconnect_reason == DisconnectReason::Manual {
// If disconnected manually, return Ok(None) to end iterator
return Ok(None);
}
}
loop {
match self.socket.poll() {
Err(err) => {
self.callback(&Event::Error, err.to_string())?;
return Err(err);
// Check if the disconnection was manual
let disconnect_reason = *self.disconnect_reason.read()?;
if disconnect_reason == DisconnectReason::Manual {
// Return Ok(None) to signal the end of the iterator
return Ok(None);
} else {
self.callback(&Event::Error, err.to_string())?;
return Err(err);
}
}
Ok(Some(packet)) => {
if packet.nsp == self.nsp {
self.handle_socketio_packet(&packet)?;
return Ok(Some(packet));
} else {
// Not our namespace continue polling
// Not our namespace, continue polling
}
}
Ok(None) => return Ok(None),
Expand Down Expand Up @@ -369,9 +407,11 @@ impl RawClient {
}
}
PacketId::Connect => {
*(self.disconnect_reason.write()?) = DisconnectReason::default();
self.callback(&Event::Connect, "")?;
}
PacketId::Disconnect => {
*(self.disconnect_reason.write()?) = DisconnectReason::Server;
self.callback(&Event::Close, "")?;
}
PacketId::ConnectError => {
Expand Down