Skip to content

Commit

Permalink
feat(ws): custom workers (#9004)
Browse files Browse the repository at this point in the history
* feat(ws): custom workers

* chore: typo

* refactor(WebSocketShard): expose shard id

* chore: remove outdated readme comment

* chore: nits

* chore: remove unnecessary mutation

* feat: fancier resolution

* chore: remove unnecessary exports

* chore: apply suggestions

* refactor: use range errors

Co-authored-by: Aura Román <kyradiscord@gmail.com>
  • Loading branch information
didinele and kyranet authored Jan 10, 2023
1 parent 39c4de2 commit 828a13b
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 159 deletions.
47 changes: 47 additions & 0 deletions packages/ws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ You can also have the shards spawn in worker threads:

```ts
import { WebSocketManager, WorkerShardingStrategy } from '@discordjs/ws';
import { REST } from '@discordjs/rest';

const rest = new REST().setToken(process.env.DISCORD_TOKEN);
const manager = new WebSocketManager({
token: process.env.DISCORD_TOKEN,
intents: 0,
Expand All @@ -113,6 +115,51 @@ manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 2 }))
manager.setStrategy(new WorkerShardingStrategy(manager, { shardsPerWorker: 'all' }));
```

**Note**: By default, this will cause the workers to effectively only be responsible for the WebSocket connection, they simply pass up all the events back to the main process for the manager to emit. If you want to have the workers handle events as well, you can pass in a `workerPath` option to the `WorkerShardingStrategy` constructor:

```ts
import { WebSocketManager, WorkerShardingStrategy } from '@discordjs/ws';
import { REST } from '@discordjs/rest';

const rest = new REST().setToken(process.env.DISCORD_TOKEN);
const manager = new WebSocketManager({
token: process.env.DISCORD_TOKEN,
intents: 0,
rest,
});

manager.setStrategy(
new WorkerShardingStrategy(manager, {
shardsPerWorker: 2,
workerPath: './worker.js',
}),
);
```

And your `worker.ts` file:

```ts
import { WorkerBootstrapper, WebSocketShardEvents } from '@discordjs/ws';

const bootstrapper = new WorkerBootstrapper();
void bootstrapper.bootstrap({
// Those will be sent to the main thread for the manager to emit
forwardEvents: [
WebSocketShardEvents.Closed,
WebSocketShardEvents.Debug,
WebSocketShardEvents.Hello,
WebSocketShardEvents.Ready,
WebSocketShardEvents.Resumed,
],
shardCallback: (shard) => {
shard.on(WebSocketShardEvents.Dispatch, (event) => {
// Process gateway events here however you want (e.g. send them through a message broker)
// You also have access to shard.id if you need it
});
},
});
```

## Links

- [Website][website] ([source][website-source])
Expand Down
33 changes: 21 additions & 12 deletions packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,54 @@ vi.mock('node:worker_threads', async () => {
super();
mockConstructor(...args);
// need to delay this by an event loop cycle to allow the strategy to attach a listener
setImmediate(() => this.emit('online'));
setImmediate(() => {
this.emit('online');
// same deal here
setImmediate(() => {
const message = {
op: WorkerRecievePayloadOp.WorkerReady,
} satisfies WorkerRecievePayload;
this.emit('message', message);
});
});
}

public postMessage(message: WorkerSendPayload) {
switch (message.op) {
case WorkerSendPayloadOp.Connect: {
const response: WorkerRecievePayload = {
const response = {
op: WorkerRecievePayloadOp.Connected,
shardId: message.shardId,
};
} satisfies WorkerRecievePayload;
this.emit('message', response);
break;
}

case WorkerSendPayloadOp.Destroy: {
const response: WorkerRecievePayload = {
const response = {
op: WorkerRecievePayloadOp.Destroyed,
shardId: message.shardId,
};
} satisfies WorkerRecievePayload;
this.emit('message', response);
break;
}

case WorkerSendPayloadOp.Send: {
if (message.payload.op === GatewayOpcodes.RequestGuildMembers) {
const response: WorkerRecievePayload = {
const response = {
op: WorkerRecievePayloadOp.Event,
shardId: message.shardId,
event: WebSocketShardEvents.Dispatch,
data: memberChunkData,
};
} satisfies WorkerRecievePayload;
this.emit('message', response);

// Fetch session info
const sessionFetch: WorkerRecievePayload = {
const sessionFetch = {
op: WorkerRecievePayloadOp.RetrieveSessionInfo,
shardId: message.shardId,
nonce: Math.random(),
};
} satisfies WorkerRecievePayload;
this.emit('message', sessionFetch);
}

Expand All @@ -102,11 +111,11 @@ vi.mock('node:worker_threads', async () => {
case WorkerSendPayloadOp.SessionInfoResponse: {
message.session ??= sessionInfo;

const session: WorkerRecievePayload = {
const session = {
op: WorkerRecievePayloadOp.UpdateSessionInfo,
shardId: message.session.shardId,
session: { ...message.session, sequence: message.session.sequence + 1 },
};
} satisfies WorkerRecievePayload;
this.emit('message', session);
break;
}
Expand Down Expand Up @@ -186,7 +195,7 @@ test('spawn, connect, send a message, session info, and destroy', async () => {

await manager.connect();
expect(mockConstructor).toHaveBeenCalledWith(
expect.stringContaining('worker.js'),
expect.stringContaining('defaultWorker.js'),
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
);

Expand Down
12 changes: 9 additions & 3 deletions packages/ws/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
"module": "./dist/index.mjs",
"typings": "./dist/index.d.ts",
"exports": {
"import": "./dist/index.mjs",
"require": "./dist/index.js",
"types": "./dist/index.d.ts"
".": {
"import": "./dist/index.mjs",
"require": "./dist/index.js",
"types": "./dist/index.d.ts"
},
"./defaultWorker": {
"import": "./dist/defaultWorker.mjs",
"require": "./dist/defaultWorker.js"
}
},
"directories": {
"lib": "src",
Expand Down
1 change: 1 addition & 0 deletions packages/ws/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export * from './strategies/sharding/WorkerShardingStrategy.js';

export * from './utils/constants.js';
export * from './utils/IdentifyThrottler.js';
export * from './utils/WorkerBootstrapper.js';

export * from './ws/WebSocketManager.js';
export * from './ws/WebSocketShard.js';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ export class SimpleShardingStrategy implements IShardingStrategy {
*/
public async send(shardId: number, payload: GatewaySendPayload) {
const shard = this.shards.get(shardId);
if (!shard) throw new Error(`Shard ${shardId} not found`);
if (!shard) {
throw new RangeError(`Shard ${shardId} not found`);
}

return shard.send(payload);
}

Expand Down
99 changes: 77 additions & 22 deletions packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { once } from 'node:events';
import { join } from 'node:path';
import { join, isAbsolute, resolve } from 'node:path';
import { Worker } from 'node:worker_threads';
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
Expand Down Expand Up @@ -38,6 +38,7 @@ export enum WorkerRecievePayloadOp {
UpdateSessionInfo,
WaitForIdentify,
FetchStatusResponse,
WorkerReady,
}

export type WorkerRecievePayload =
Expand All @@ -48,7 +49,8 @@ export type WorkerRecievePayload =
| { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify }
| { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number };
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number }
| { op: WorkerRecievePayloadOp.WorkerReady };

/**
* Options for a {@link WorkerShardingStrategy}
Expand All @@ -58,6 +60,10 @@ export interface WorkerShardingStrategyOptions {
* Dictates how many shards should be spawned per worker thread.
*/
shardsPerWorker: number | 'all';
/**
* Path to the worker file to use. The worker requires quite a bit of setup, it is recommended you leverage the {@link WorkerBootstrapper} class.
*/
workerPath?: string;
}

/**
Expand Down Expand Up @@ -93,32 +99,20 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const shardsPerWorker = this.options.shardsPerWorker === 'all' ? shardIds.length : this.options.shardsPerWorker;
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);

let shards = 0;
while (shards !== shardIds.length) {
const slice = shardIds.slice(shards, shardsPerWorker + shards);
const loops = Math.ceil(shardIds.length / shardsPerWorker);
const promises: Promise<void>[] = [];

for (let idx = 0; idx < loops; idx++) {
const slice = shardIds.slice(idx * shardsPerWorker, (idx + 1) * shardsPerWorker);
const workerData: WorkerData = {
...strategyOptions,
shardIds: slice,
};

const worker = new Worker(join(__dirname, 'worker.js'), { workerData });
await once(worker, 'online');
worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload));

this.#workers.push(worker);
for (const shardId of slice) {
this.#workerByShardId.set(shardId, worker);
}

shards += slice.length;
promises.push(this.setupWorker(workerData));
}

await Promise.all(promises);
}

/**
Expand Down Expand Up @@ -210,6 +204,63 @@ export class WorkerShardingStrategy implements IShardingStrategy {
return statuses;
}

private async setupWorker(workerData: WorkerData) {
const worker = new Worker(this.resolveWorkerPath(), { workerData });

await once(worker, 'online');
// We do this in case the user has any potentially long running code in their worker
await this.waitForWorkerReady(worker);

worker
.on('error', (err) => {
throw err;
})
.on('messageerror', (err) => {
throw err;
})
.on('message', async (payload: WorkerRecievePayload) => this.onMessage(worker, payload));

this.#workers.push(worker);
for (const shardId of workerData.shardIds) {
this.#workerByShardId.set(shardId, worker);
}
}

private resolveWorkerPath(): string {
const path = this.options.workerPath;

if (!path) {
return join(__dirname, 'defaultWorker.js');
}

if (isAbsolute(path)) {
return path;
}

if (/^\.\.?[/\\]/.test(path)) {
return resolve(path);
}

try {
return require.resolve(path);
} catch {
return resolve(path);
}
}

private async waitForWorkerReady(worker: Worker): Promise<void> {
return new Promise((resolve) => {
const handler = (payload: WorkerRecievePayload) => {
if (payload.op === WorkerRecievePayloadOp.WorkerReady) {
resolve();
worker.off('message', handler);
}
};

worker.on('message', handler);
});
}

private async onMessage(worker: Worker, payload: WorkerRecievePayload) {
switch (payload.op) {
case WorkerRecievePayloadOp.Connected: {
Expand Down Expand Up @@ -260,6 +311,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
this.fetchStatusPromises.delete(payload.nonce);
break;
}

case WorkerRecievePayloadOp.WorkerReady: {
break;
}
}
}
}
4 changes: 4 additions & 0 deletions packages/ws/src/strategies/sharding/defaultWorker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import { WorkerBootstrapper } from '../../utils/WorkerBootstrapper.js';

const bootstrapper = new WorkerBootstrapper();
void bootstrapper.bootstrap();
Loading

2 comments on commit 828a13b

@vercel
Copy link

@vercel vercel bot commented on 828a13b Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vercel
Copy link

@vercel vercel bot commented on 828a13b Jan 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.