diff --git a/packages/ws/README.md b/packages/ws/README.md index 6c9108cbf35c..cfeb39cd13fa 100644 --- a/packages/ws/README.md +++ b/packages/ws/README.md @@ -99,7 +99,9 @@ You can also have the shards spawn in worker threads: ```ts import { WebSocketManager, WorkerShardingStrategy } from '@discordjs/ws'; +import { REST } from '@discordjs/rest'; +const rest = new REST().setToken(process.env.DISCORD_TOKEN); const manager = new WebSocketManager({ token: process.env.DISCORD_TOKEN, intents: 0, @@ -113,6 +115,51 @@ manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 2 })) manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' })); ``` +**Note**: By default, this will cause the workers to effectively only be responsible for the WebSocket connection, they simply pass up all the events back to the main process for the manager to emit. If you want to have the workers handle events as well, you can pass in a `workerPath` option to the `WorkerShardingStrategy` constructor: + +```ts +import { WebSocketManager, WorkerShardingStrategy } from '@discordjs/ws'; +import { REST } from '@discordjs/rest'; + +const rest = new REST().setToken(process.env.DISCORD_TOKEN); +const manager = new WebSocketManager({ + token: process.env.DISCORD_TOKEN, + intents: 0, + rest, +}); + +manager.setStrategy( + new WorkerShardingStrategy(manager, { + shardsPerWorker: 2, + workerPath: './worker.js', + }), +); +``` + +And your `worker.ts` file: + +```ts +import { WorkerBootstrapper, WebSocketShardEvents } from '@discordjs/ws'; + +const bootstrapper = new WorkerBootstrapper(); +void bootstrapper.bootstrap({ + // Those will be sent to the main thread for the manager to emit + forwardEvents: [ + WebSocketShardEvents.Closed, + WebSocketShardEvents.Debug, + WebSocketShardEvents.Hello, + WebSocketShardEvents.Ready, + WebSocketShardEvents.Resumed, + ], + shardCallback: (shard) => { + shard.on(WebSocketShardEvents.Dispatch, (event) => { + // Process gateway events here however you want (e.g. send them through a message broker) + // You also have access to shard.id if you need it + }); + }, +}); +``` + ## Links - [Website][website] ([source][website-source]) diff --git a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts index 18809afe7eab..474de313aa59 100644 --- a/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts +++ b/packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts @@ -53,45 +53,54 @@ vi.mock('node:worker_threads', async () => { super(); mockConstructor(...args); // need to delay this by an event loop cycle to allow the strategy to attach a listener - setImmediate(() => this.emit('online')); + setImmediate(() => { + this.emit('online'); + // same deal here + setImmediate(() => { + const message = { + op: WorkerRecievePayloadOp.WorkerReady, + } satisfies WorkerRecievePayload; + this.emit('message', message); + }); + }); } public postMessage(message: WorkerSendPayload) { switch (message.op) { case WorkerSendPayloadOp.Connect: { - const response: WorkerRecievePayload = { + const response = { op: WorkerRecievePayloadOp.Connected, shardId: message.shardId, - }; + } satisfies WorkerRecievePayload; this.emit('message', response); break; } case WorkerSendPayloadOp.Destroy: { - const response: WorkerRecievePayload = { + const response = { op: WorkerRecievePayloadOp.Destroyed, shardId: message.shardId, - }; + } satisfies WorkerRecievePayload; this.emit('message', response); break; } case WorkerSendPayloadOp.Send: { if (message.payload.op === GatewayOpcodes.RequestGuildMembers) { - const response: WorkerRecievePayload = { + const response = { op: WorkerRecievePayloadOp.Event, shardId: message.shardId, event: WebSocketShardEvents.Dispatch, data: memberChunkData, - }; + } satisfies WorkerRecievePayload; this.emit('message', response); // Fetch session info - const sessionFetch: WorkerRecievePayload = { + const sessionFetch = { op: WorkerRecievePayloadOp.RetrieveSessionInfo, shardId: message.shardId, nonce: Math.random(), - }; + } satisfies WorkerRecievePayload; this.emit('message', sessionFetch); } @@ -102,11 +111,11 @@ vi.mock('node:worker_threads', async () => { case WorkerSendPayloadOp.SessionInfoResponse: { message.session ??= sessionInfo; - const session: WorkerRecievePayload = { + const session = { op: WorkerRecievePayloadOp.UpdateSessionInfo, shardId: message.session.shardId, session: { ...message.session, sequence: message.session.sequence + 1 }, - }; + } satisfies WorkerRecievePayload; this.emit('message', session); break; } @@ -186,7 +195,7 @@ test('spawn, connect, send a message, session info, and destroy', async () => { await manager.connect(); expect(mockConstructor).toHaveBeenCalledWith( - expect.stringContaining('worker.js'), + expect.stringContaining('defaultWorker.js'), expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }), ); diff --git a/packages/ws/package.json b/packages/ws/package.json index 34daef79e41c..a7d4d2c840b2 100644 --- a/packages/ws/package.json +++ b/packages/ws/package.json @@ -16,9 +16,15 @@ "module": "./dist/index.mjs", "typings": "./dist/index.d.ts", "exports": { - "import": "./dist/index.mjs", - "require": "./dist/index.js", - "types": "./dist/index.d.ts" + ".": { + "import": "./dist/index.mjs", + "require": "./dist/index.js", + "types": "./dist/index.d.ts" + }, + "./defaultWorker": { + "import": "./dist/defaultWorker.mjs", + "require": "./dist/defaultWorker.js" + } }, "directories": { "lib": "src", diff --git a/packages/ws/src/index.ts b/packages/ws/src/index.ts index 948cd2a3e6a5..1c1e05848653 100644 --- a/packages/ws/src/index.ts +++ b/packages/ws/src/index.ts @@ -8,6 +8,7 @@ export * from './strategies/sharding/WorkerShardingStrategy.js'; export * from './utils/constants.js'; export * from './utils/IdentifyThrottler.js'; +export * from './utils/WorkerBootstrapper.js'; export * from './ws/WebSocketManager.js'; export * from './ws/WebSocketShard.js'; diff --git a/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts b/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts index 2480f25dfe46..17b5ffe07c49 100644 --- a/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts +++ b/packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts @@ -67,7 +67,10 @@ export class SimpleShardingStrategy implements IShardingStrategy { */ public async send(shardId: number, payload: GatewaySendPayload) { const shard = this.shards.get(shardId); - if (!shard) throw new Error(`Shard ${shardId} not found`); + if (!shard) { + throw new RangeError(`Shard ${shardId} not found`); + } + return shard.send(payload); } diff --git a/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts b/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts index aade79ece6e6..d216c2a37604 100644 --- a/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts +++ b/packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts @@ -1,5 +1,5 @@ import { once } from 'node:events'; -import { join } from 'node:path'; +import { join, isAbsolute, resolve } from 'node:path'; import { Worker } from 'node:worker_threads'; import { Collection } from '@discordjs/collection'; import type { GatewaySendPayload } from 'discord-api-types/v10'; @@ -38,6 +38,7 @@ export enum WorkerRecievePayloadOp { UpdateSessionInfo, WaitForIdentify, FetchStatusResponse, + WorkerReady, } export type WorkerRecievePayload = @@ -48,7 +49,8 @@ export type WorkerRecievePayload = | { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify } | { op: WorkerRecievePayloadOp.Connected; shardId: number } | { op: WorkerRecievePayloadOp.Destroyed; shardId: number } - | { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }; + | { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number } + | { op: WorkerRecievePayloadOp.WorkerReady }; /** * Options for a {@link WorkerShardingStrategy} @@ -58,6 +60,10 @@ export interface WorkerShardingStrategyOptions { * Dictates how many shards should be spawned per worker thread. */ shardsPerWorker: number | 'all'; + /** + * Path to the worker file to use. The worker requires quite a bit of setup, it is recommended you leverage the {@link WorkerBootstrapper} class. + */ + workerPath?: string; } /** @@ -93,32 +99,20 @@ export class WorkerShardingStrategy implements IShardingStrategy { const shardsPerWorker = this.options.shardsPerWorker === 'all' ? shardIds.length : this.options.shardsPerWorker; const strategyOptions = await managerToFetchingStrategyOptions(this.manager); - let shards = 0; - while (shards !== shardIds.length) { - const slice = shardIds.slice(shards, shardsPerWorker + shards); + const loops = Math.ceil(shardIds.length / shardsPerWorker); + const promises: Promise[] = []; + + for (let idx = 0; idx < loops; idx++) { + const slice = shardIds.slice(idx * shardsPerWorker, (idx + 1) * shardsPerWorker); const workerData: WorkerData = { ...strategyOptions, shardIds: slice, }; - const worker = new Worker(join(__dirname, 'worker.js'), { workerData }); - await once(worker, 'online'); - worker - .on('error', (err) => { - throw err; - }) - .on('messageerror', (err) => { - throw err; - }) - .on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload)); - - this.#workers.push(worker); - for (const shardId of slice) { - this.#workerByShardId.set(shardId, worker); - } - - shards += slice.length; + promises.push(this.setupWorker(workerData)); } + + await Promise.all(promises); } /** @@ -210,6 +204,63 @@ export class WorkerShardingStrategy implements IShardingStrategy { return statuses; } + private async setupWorker(workerData: WorkerData) { + const worker = new Worker(this.resolveWorkerPath(), { workerData }); + + await once(worker, 'online'); + // We do this in case the user has any potentially long running code in their worker + await this.waitForWorkerReady(worker); + + worker + .on('error', (err) => { + throw err; + }) + .on('messageerror', (err) => { + throw err; + }) + .on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload)); + + this.#workers.push(worker); + for (const shardId of workerData.shardIds) { + this.#workerByShardId.set(shardId, worker); + } + } + + private resolveWorkerPath(): string { + const path = this.options.workerPath; + + if (!path) { + return join(__dirname, 'defaultWorker.js'); + } + + if (isAbsolute(path)) { + return path; + } + + if (/^\.\.?[/\\]/.test(path)) { + return resolve(path); + } + + try { + return require.resolve(path); + } catch { + return resolve(path); + } + } + + private async waitForWorkerReady(worker: Worker): Promise { + return new Promise((resolve) => { + const handler = (payload: WorkerRecievePayload) => { + if (payload.op === WorkerRecievePayloadOp.WorkerReady) { + resolve(); + worker.off('message', handler); + } + }; + + worker.on('message', handler); + }); + } + private async onMessage(worker: Worker, payload: WorkerRecievePayload) { switch (payload.op) { case WorkerRecievePayloadOp.Connected: { @@ -260,6 +311,10 @@ export class WorkerShardingStrategy implements IShardingStrategy { this.fetchStatusPromises.delete(payload.nonce); break; } + + case WorkerRecievePayloadOp.WorkerReady: { + break; + } } } } diff --git a/packages/ws/src/strategies/sharding/defaultWorker.ts b/packages/ws/src/strategies/sharding/defaultWorker.ts new file mode 100644 index 000000000000..324bd20bc016 --- /dev/null +++ b/packages/ws/src/strategies/sharding/defaultWorker.ts @@ -0,0 +1,4 @@ +import { WorkerBootstrapper } from '../../utils/WorkerBootstrapper.js'; + +const bootstrapper = new WorkerBootstrapper(); +void bootstrapper.bootstrap(); diff --git a/packages/ws/src/strategies/sharding/worker.ts b/packages/ws/src/strategies/sharding/worker.ts deleted file mode 100644 index 19345343cfb4..000000000000 --- a/packages/ws/src/strategies/sharding/worker.ts +++ /dev/null @@ -1,117 +0,0 @@ -import { isMainThread, workerData, parentPort } from 'node:worker_threads'; -import { Collection } from '@discordjs/collection'; -import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js'; -import { WorkerContextFetchingStrategy } from '../context/WorkerContextFetchingStrategy.js'; -import { - WorkerRecievePayloadOp, - WorkerSendPayloadOp, - type WorkerData, - type WorkerRecievePayload, - type WorkerSendPayload, -} from './WorkerShardingStrategy.js'; - -if (isMainThread) { - throw new Error('Expected worker script to not be ran within the main thread'); -} - -const data = workerData as WorkerData; -const shards = new Collection(); - -async function connect(shardId: number) { - const shard = shards.get(shardId); - if (!shard) { - throw new Error(`Shard ${shardId} does not exist`); - } - - await shard.connect(); -} - -async function destroy(shardId: number, options?: WebSocketShardDestroyOptions) { - const shard = shards.get(shardId); - if (!shard) { - throw new Error(`Shard ${shardId} does not exist`); - } - - await shard.destroy(options); -} - -for (const shardId of data.shardIds) { - const shard = new WebSocketShard(new WorkerContextFetchingStrategy(data), shardId); - for (const event of Object.values(WebSocketShardEvents)) { - // @ts-expect-error: Event types incompatible - shard.on(event, (data) => { - const payload = { - op: WorkerRecievePayloadOp.Event, - event, - data, - shardId, - } satisfies WorkerRecievePayload; - parentPort!.postMessage(payload); - }); - } - - shards.set(shardId, shard); -} - -parentPort! - .on('messageerror', (err) => { - throw err; - }) - .on('message', async (payload: WorkerSendPayload) => { - switch (payload.op) { - case WorkerSendPayloadOp.Connect: { - await connect(payload.shardId); - const response: WorkerRecievePayload = { - op: WorkerRecievePayloadOp.Connected, - shardId: payload.shardId, - }; - parentPort!.postMessage(response); - break; - } - - case WorkerSendPayloadOp.Destroy: { - await destroy(payload.shardId, payload.options); - const response: WorkerRecievePayload = { - op: WorkerRecievePayloadOp.Destroyed, - shardId: payload.shardId, - }; - - parentPort!.postMessage(response); - break; - } - - case WorkerSendPayloadOp.Send: { - const shard = shards.get(payload.shardId); - if (!shard) { - throw new Error(`Shard ${payload.shardId} does not exist`); - } - - await shard.send(payload.payload); - break; - } - - case WorkerSendPayloadOp.SessionInfoResponse: { - break; - } - - case WorkerSendPayloadOp.ShardCanIdentify: { - break; - } - - case WorkerSendPayloadOp.FetchStatus: { - const shard = shards.get(payload.shardId); - if (!shard) { - throw new Error(`Shard ${payload.shardId} does not exist`); - } - - const response = { - op: WorkerRecievePayloadOp.FetchStatusResponse, - status: shard.status, - nonce: payload.nonce, - } satisfies WorkerRecievePayload; - - parentPort!.postMessage(response); - break; - } - } - }); diff --git a/packages/ws/src/utils/WorkerBootstrapper.ts b/packages/ws/src/utils/WorkerBootstrapper.ts new file mode 100644 index 000000000000..033f52c618c6 --- /dev/null +++ b/packages/ws/src/utils/WorkerBootstrapper.ts @@ -0,0 +1,176 @@ +import { isMainThread, parentPort, workerData } from 'node:worker_threads'; +import { Collection } from '@discordjs/collection'; +import type { Awaitable } from '@discordjs/util'; +import { WorkerContextFetchingStrategy } from '../strategies/context/WorkerContextFetchingStrategy.js'; +import { + WorkerRecievePayloadOp, + WorkerSendPayloadOp, + type WorkerData, + type WorkerRecievePayload, + type WorkerSendPayload, +} from '../strategies/sharding/WorkerShardingStrategy.js'; +import type { WebSocketShardDestroyOptions } from '../ws/WebSocketShard.js'; +import { WebSocketShardEvents, WebSocketShard } from '../ws/WebSocketShard.js'; + +/** + * Options for bootstrapping the worker + */ +export interface BootstrapOptions { + /** + * Shard events to just arbitrarily forward to the parent thread for the manager to emit + * Note: By default, this will include ALL events + * you most likely want to handle dispatch within the worker itself + */ + forwardEvents?: WebSocketShardEvents[]; + /** + * Function to call when a shard is created for additional setup + */ + shardCallback?(shard: WebSocketShard): Awaitable; +} + +/** + * Utility class for bootstraping a worker thread to be used for sharding + */ +export class WorkerBootstrapper { + /** + * The data passed to the worker thread + */ + protected readonly data = workerData as WorkerData; + + /** + * The shards that are managed by this worker + */ + protected readonly shards = new Collection(); + + public constructor() { + if (isMainThread) { + throw new Error('Expected WorkerBootstrap to not be used within the main thread'); + } + } + + /** + * Helper method to initiate a shard's connection process + */ + protected async connect(shardId: number): Promise { + const shard = this.shards.get(shardId); + if (!shard) { + throw new RangeError(`Shard ${shardId} does not exist`); + } + + await shard.connect(); + } + + /** + * Helper method to destroy a shard + */ + protected async destroy(shardId: number, options?: WebSocketShardDestroyOptions): Promise { + const shard = this.shards.get(shardId); + if (!shard) { + throw new RangeError(`Shard ${shardId} does not exist`); + } + + await shard.destroy(options); + } + + /** + * Helper method to attach event listeners to the parentPort + */ + protected setupThreadEvents(): void { + parentPort! + .on('messageerror', (err) => { + throw err; + }) + .on('message', async (payload: WorkerSendPayload) => { + switch (payload.op) { + case WorkerSendPayloadOp.Connect: { + await this.connect(payload.shardId); + const response: WorkerRecievePayload = { + op: WorkerRecievePayloadOp.Connected, + shardId: payload.shardId, + }; + parentPort!.postMessage(response); + break; + } + + case WorkerSendPayloadOp.Destroy: { + await this.destroy(payload.shardId, payload.options); + const response: WorkerRecievePayload = { + op: WorkerRecievePayloadOp.Destroyed, + shardId: payload.shardId, + }; + + parentPort!.postMessage(response); + break; + } + + case WorkerSendPayloadOp.Send: { + const shard = this.shards.get(payload.shardId); + if (!shard) { + throw new RangeError(`Shard ${payload.shardId} does not exist`); + } + + await shard.send(payload.payload); + break; + } + + case WorkerSendPayloadOp.SessionInfoResponse: { + break; + } + + case WorkerSendPayloadOp.ShardCanIdentify: { + break; + } + + case WorkerSendPayloadOp.FetchStatus: { + const shard = this.shards.get(payload.shardId); + if (!shard) { + throw new Error(`Shard ${payload.shardId} does not exist`); + } + + const response = { + op: WorkerRecievePayloadOp.FetchStatusResponse, + status: shard.status, + nonce: payload.nonce, + } satisfies WorkerRecievePayload; + + parentPort!.postMessage(response); + break; + } + } + }); + } + + /** + * Bootstraps the worker thread with the provided options + */ + public async bootstrap(options: Readonly = {}): Promise { + // Start by initializing the shards + for (const shardId of this.data.shardIds) { + const shard = new WebSocketShard(new WorkerContextFetchingStrategy(this.data), shardId); + for (const event of options.forwardEvents ?? Object.values(WebSocketShardEvents)) { + // @ts-expect-error: Event types incompatible + shard.on(event, (data) => { + const payload = { + op: WorkerRecievePayloadOp.Event, + event, + data, + shardId, + } satisfies WorkerRecievePayload; + parentPort!.postMessage(payload); + }); + } + + // Any additional setup the user might want to do + await options.shardCallback?.(shard); + this.shards.set(shardId, shard); + } + + // Lastly, start listening to messages from the parent thread + this.setupThreadEvents(); + + const message = { + op: WorkerRecievePayloadOp.WorkerReady, + } satisfies WorkerRecievePayload; + parentPort!.postMessage(message); + } +} diff --git a/packages/ws/src/ws/WebSocketShard.ts b/packages/ws/src/ws/WebSocketShard.ts index 023d6182d4e6..9685667dd73f 100644 --- a/packages/ws/src/ws/WebSocketShard.ts +++ b/packages/ws/src/ws/WebSocketShard.ts @@ -81,8 +81,6 @@ export interface SendRateLimitState { export class WebSocketShard extends AsyncEventEmitter { private connection: WebSocket | null = null; - private readonly id: number; - private useIdentifyCompress = false; private inflate: Inflate | null = null; @@ -105,7 +103,9 @@ export class WebSocketShard extends AsyncEventEmitter { private readonly timeouts = new Collection(); - public readonly strategy: IContextFetchingStrategy; + private readonly strategy: IContextFetchingStrategy; + + public readonly id: number; #status: WebSocketShardStatus = WebSocketShardStatus.Idle; diff --git a/packages/ws/tsup.config.js b/packages/ws/tsup.config.js index 3245f14302b2..8020700fa1c9 100644 --- a/packages/ws/tsup.config.js +++ b/packages/ws/tsup.config.js @@ -4,7 +4,7 @@ import { createTsupConfig } from '../../tsup.config.js'; export default createTsupConfig({ entry: { index: 'src/index.ts', - worker: 'src/strategies/sharding/worker.ts', + defaultWorker: 'src/strategies/sharding/defaultWorker.ts', }, external: ['zlib-sync'], esbuildPlugins: [esbuildPluginVersionInjector()],