Skip to content

Commit

Permalink
feat: implement connection state recovery
Browse files Browse the repository at this point in the history
Connection state recovery allows a client to reconnect after a
temporary disconnection and restore its state:

- id
- rooms
- data
- missed packets

Usage:

```js
import { Server } from "socket.io";

const io = new Server({
  connectionStateRecovery: {
    // default values
    maxDisconnectionDuration: 2 * 60 * 1000,
    skipMiddlewares: true,
  },
});

io.on("connection", (socket) => {
  console.log(socket.recovered); // whether the state was recovered or not
});
```

Here's how it works:

- the server sends a session ID during the handshake (which is
different from the current `id` attribute, which is public and can be
freely shared)

- the server also includes an offset in each packet (added at the end
of the data array, for backward compatibility)

- upon temporary disconnection, the server stores the client state for
a given delay (implemented at the adapter level)

- upon reconnection, the client sends both the session ID and the last
offset it has processed, and the server tries to restore the state

A few notes:

- the base adapter exposes two additional methods, persistSession() and
restoreSession(), that must be implemented by the other adapters in
order to allow the feature to work within a cluster

See: socketio/socket.io-adapter@f529412

- acknowledgements are not affected, because it won't work if the
client reconnects on another server (as the ack id is local)

- any disconnection that lasts longer than the
`maxDisconnectionDuration` value will result in a new session, so users
will still need to care for the state reconciliation between the server
and the client

Related: #4510
  • Loading branch information
darrachequesne committed Jan 12, 2023
1 parent da2b542 commit 54d5ee0
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 81 deletions.
10 changes: 5 additions & 5 deletions lib/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ export class Client<
* @param {Object} auth - the auth parameters
* @private
*/
private connect(name: string, auth: object = {}): void {
private connect(name: string, auth: Record<string, unknown> = {}): void {
if (this.server._nsps.has(name)) {
debug("connecting to namespace %s", name);
return this.doConnect(name, auth);
Expand Down Expand Up @@ -152,10 +152,10 @@ export class Client<
*
* @private
*/
private doConnect(name: string, auth: object): void {
private doConnect(name: string, auth: Record<string, unknown>): void {
const nsp = this.server.of(name);

const socket = nsp._add(this, auth, () => {
nsp._add(this, auth, (socket) => {
this.sockets.set(socket.id, socket);
this.nsps.set(nsp.name, socket);

Expand Down Expand Up @@ -228,7 +228,7 @@ export class Client<
}

private writeToEngine(
encodedPackets: Array<String | Buffer>,
encodedPackets: Array<string | Buffer>,
opts: WriteOptions
): void {
if (opts.volatile && !this.conn.transport.writable) {
Expand Down Expand Up @@ -267,7 +267,7 @@ export class Client<
*/
private ondecoded(packet: Packet): void {
let namespace: string;
let authPayload;
let authPayload: Record<string, unknown>;
if (this.conn.protocol === 3) {
const parsed = url.parse(packet.nsp, true);
namespace = parsed.pathname!;
Expand Down
43 changes: 39 additions & 4 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ import { Client } from "./client";
import { EventEmitter } from "events";
import { ExtendedError, Namespace, ServerReservedEventsMap } from "./namespace";
import { ParentNamespace } from "./parent-namespace";
import { Adapter, Room, SocketId } from "socket.io-adapter";
import {
Adapter,
SessionAwareAdapter,
Room,
SocketId,
} from "socket.io-adapter";
import * as parser from "socket.io-parser";
import type { Encoder } from "socket.io-parser";
import debugModule from "debug";
Expand Down Expand Up @@ -72,6 +77,25 @@ interface ServerOptions extends EngineOptions, AttachOptions {
* @default 45000
*/
connectTimeout: number;
/**
* Whether to enable the recovery of connection state when a client temporarily disconnects.
*
* The connection state includes the missed packets, the rooms the socket was in and the `data` attribute.
*/
connectionStateRecovery: {
/**
* The backup duration of the sessions and the packets.
*
* @default 120000 (2 minutes)
*/
maxDisconnectionDuration?: number;
/**
* Whether to skip middlewares upon successful connection state recovery.
*
* @default true
*/
skipMiddlewares?: boolean;
};
}

/**
Expand Down Expand Up @@ -148,7 +172,7 @@ export class Server<
> = new Map();
private _adapter?: AdapterConstructor;
private _serveClient: boolean;
private opts: Partial<EngineOptions>;
private readonly opts: Partial<ServerOptions>;
private eio: Engine;
private _path: string;
private clientPathRegex: RegExp;
Expand Down Expand Up @@ -204,9 +228,20 @@ export class Server<
this.serveClient(false !== opts.serveClient);
this._parser = opts.parser || parser;
this.encoder = new this._parser.Encoder();
this.adapter(opts.adapter || Adapter);
this.sockets = this.of("/");
this.opts = opts;
if (opts.connectionStateRecovery) {
opts.connectionStateRecovery = Object.assign(
{
maxDisconnectionDuration: 2 * 60 * 1000,
skipMiddlewares: true,
},
opts.connectionStateRecovery
);
this.adapter(opts.adapter || SessionAwareAdapter);
} else {
this.adapter(opts.adapter || Adapter);
}
this.sockets = this.of("/");
if (srv || typeof srv == "number")
this.attach(
srv as http.Server | HTTPSServer | Http2SecureServer | number
Expand Down
81 changes: 62 additions & 19 deletions lib/namespace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,25 @@ export class Namespace<
* @return {Socket}
* @private
*/
_add(
async _add(
client: Client<ListenEvents, EmitEvents, ServerSideEvents>,
query,
fn?: () => void
): Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData> {
auth: Record<string, unknown>,
fn: (
socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>
) => void
) {
debug("adding socket to nsp %s", this.name);
const socket = new Socket(this, client, query);
const socket = await this._createSocket(client, auth);

if (
// @ts-ignore
this.server.opts.connectionStateRecovery?.skipMiddlewares &&
socket.recovered &&
client.conn.readyState === "open"
) {
return this._doConnect(socket, fn);
}

this.run(socket, (err) => {
process.nextTick(() => {
if ("open" !== client.conn.readyState) {
Expand All @@ -324,22 +336,53 @@ export class Namespace<
}
}

// track socket
this.sockets.set(socket.id, socket);

// it's paramount that the internal `onconnect` logic
// fires before user-set events to prevent state order
// violations (such as a disconnection before the connection
// logic is complete)
socket._onconnect();
if (fn) fn();

// fire user-set events
this.emitReserved("connect", socket);
this.emitReserved("connection", socket);
this._doConnect(socket, fn);
});
});
return socket;
}

private async _createSocket(
client: Client<ListenEvents, EmitEvents, ServerSideEvents>,
auth: Record<string, unknown>
) {
const sessionId = auth.pid;
const offset = auth.offset;
if (
// @ts-ignore
this.server.opts.connectionStateRecovery &&
typeof sessionId === "string" &&
typeof offset === "string"
) {
const session = await this.adapter.restoreSession(sessionId, offset);
if (session) {
debug("connection state recovered for sid %s", session.sid);
return new Socket(this, client, auth, session);
} else {
debug("unable to restore session state");
}
}
return new Socket(this, client, auth);
}

private _doConnect(
socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>,
fn: (
socket: Socket<ListenEvents, EmitEvents, ServerSideEvents, SocketData>
) => void
) {
// track socket
this.sockets.set(socket.id, socket);

// it's paramount that the internal `onconnect` logic
// fires before user-set events to prevent state order
// violations (such as a disconnection before the connection
// logic is complete)
socket._onconnect();
if (fn) fn(socket);

// fire user-set events
this.emitReserved("connect", socket);
this.emitReserved("connection", socket);
}

/**
Expand Down
86 changes: 75 additions & 11 deletions lib/socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import { Packet, PacketType } from "socket.io-parser";
import debugModule from "debug";
import type { Server } from "./index";
import {
EventParams,
DefaultEventsMap,
EventNames,
EventParams,
EventsMap,
StrictEventEmitter,
DefaultEventsMap,
} from "./typed-events";
import type { Client } from "./client";
import type { Namespace, NamespaceReservedEventsMap } from "./namespace";
import type { IncomingMessage, IncomingHttpHeaders } from "http";
import type { IncomingHttpHeaders, IncomingMessage } from "http";
import type {
Adapter,
BroadcastFlags,
PrivateSessionId,
Room,
Session,
SocketId,
} from "socket.io-adapter";
import base64id from "base64id";
Expand All @@ -39,6 +41,15 @@ export type DisconnectReason =
| "client namespace disconnect"
| "server namespace disconnect";

const RECOVERABLE_DISCONNECT_REASONS: ReadonlySet<DisconnectReason> = new Set([
"transport error",
"transport close",
"forced close",
"ping timeout",
"server shutting down",
"forced server close",
]);

export interface SocketReservedEventsMap {
disconnect: (reason: DisconnectReason) => void;
disconnecting: (reason: DisconnectReason) => void;
Expand Down Expand Up @@ -173,6 +184,11 @@ export class Socket<
* An unique identifier for the session.
*/
public readonly id: SocketId;
/**
* Whether the connection state was recovered after a temporary disconnection. In that case, any missed packets will
* be transmitted to the client, the data attribute and the rooms will be restored.
*/
public readonly recovered: boolean = false;
/**
* The handshake details.
*/
Expand All @@ -197,6 +213,14 @@ export class Socket<
*/
public connected: boolean = false;

/**
* The session ID, which must not be shared (unlike {@link id}).
*
* @private
*/
private readonly pid: PrivateSessionId;

// TODO: remove this unused reference
private readonly server: Server<
ListenEvents,
EmitEvents,
Expand All @@ -221,16 +245,32 @@ export class Socket<
constructor(
readonly nsp: Namespace<ListenEvents, EmitEvents, ServerSideEvents>,
readonly client: Client<ListenEvents, EmitEvents, ServerSideEvents>,
auth: object
auth: Record<string, unknown>,
previousSession?: Session
) {
super();
this.server = nsp.server;
this.adapter = this.nsp.adapter;
if (client.conn.protocol === 3) {
// @ts-ignore
this.id = nsp.name !== "/" ? nsp.name + "#" + client.id : client.id;
if (previousSession) {
this.id = previousSession.sid;
this.pid = previousSession.pid;
previousSession.rooms.forEach((room) => this.join(room));
this.data = previousSession.data as Partial<SocketData>;
previousSession.missedPackets.forEach((packet) => {
this.packet({
type: PacketType.EVENT,
data: packet,
});
});
this.recovered = true;
} else {
this.id = base64id.generateId(); // don't reuse the Engine.IO id because it's sensitive information
if (client.conn.protocol === 3) {
// @ts-ignore
this.id = nsp.name !== "/" ? nsp.name + "#" + client.id : client.id;
} else {
this.id = base64id.generateId(); // don't reuse the Engine.IO id because it's sensitive information
}
this.pid = base64id.generateId();
}
this.handshake = this.buildHandshake(auth);
}
Expand Down Expand Up @@ -299,8 +339,18 @@ export class Socket<
const flags = Object.assign({}, this.flags);
this.flags = {};

this.notifyOutgoingListeners(packet);
this.packet(packet, flags);
// @ts-ignore
if (this.nsp.server.opts.connectionStateRecovery) {
// this ensures the packet is stored and can be transmitted upon reconnection
this.adapter.broadcast(packet, {
rooms: new Set([this.id]),
except: new Set(),
flags,
});
} else {
this.notifyOutgoingListeners(packet);
this.packet(packet, flags);
}

return true;
}
Expand Down Expand Up @@ -508,7 +558,10 @@ export class Socket<
if (this.conn.protocol === 3) {
this.packet({ type: PacketType.CONNECT });
} else {
this.packet({ type: PacketType.CONNECT, data: { sid: this.id } });
this.packet({
type: PacketType.CONNECT,
data: { sid: this.id, pid: this.pid },
});
}
}

Expand Down Expand Up @@ -644,6 +697,17 @@ export class Socket<
if (!this.connected) return this;
debug("closing socket - reason %s", reason);
this.emitReserved("disconnecting", reason);

if (RECOVERABLE_DISCONNECT_REASONS.has(reason)) {
debug("connection state recovery is enabled for sid %s", this.id);
this.adapter.persistSession({
sid: this.id,
pid: this.pid,
rooms: [...this.rooms],
data: this.data,
});
}

this._cleanup();
this.nsp._remove(this);
this.client._remove(this);
Expand Down
Loading

0 comments on commit 54d5ee0

Please sign in to comment.