Skip to content

Commit

Permalink
fix(NODE-5993): memory leak in the Connection class (#4022)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbbeeken authored Mar 7, 2024
1 parent 28b7040 commit 69de253
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 59 deletions.
68 changes: 25 additions & 43 deletions src/cmap/connection.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { type Readable, Transform, type TransformCallback } from 'stream';
import { clearTimeout, setTimeout } from 'timers';
import { promisify } from 'util';

import type { BSONSerializeOptions, Document, ObjectId } from '../bson';
import type { AutoEncrypter } from '../client-side-encryption/auto_encrypter';
Expand Down Expand Up @@ -37,7 +36,7 @@ import {
maxWireVersion,
type MongoDBNamespace,
now,
promiseWithResolvers,
once,
uuidV4
} from '../utils';
import type { WriteConcern } from '../write_concern';
Expand Down Expand Up @@ -182,18 +181,18 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
* Once connection is established, command logging can log events (if enabled)
*/
public established: boolean;
/** Indicates that the connection (including underlying TCP socket) has been closed. */
public closed = false;

private lastUseTime: number;
private clusterTime: Document | null = null;
private error: Error | null = null;
private dataEvents: AsyncGenerator<Buffer, void, void> | null = null;

private readonly socketTimeoutMS: number;
private readonly monitorCommands: boolean;
private readonly socket: Stream;
private readonly controller: AbortController;
private readonly signal: AbortSignal;
private readonly messageStream: Readable;
private readonly socketWrite: (buffer: Uint8Array) => Promise<void>;
private readonly aborted: Promise<never>;

/** @event */
static readonly COMMAND_STARTED = COMMAND_STARTED;
Expand All @@ -213,6 +212,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
constructor(stream: Stream, options: ConnectionOptions) {
super();

this.socket = stream;
this.id = options.id;
this.address = streamIdentifier(stream, options);
this.socketTimeoutMS = options.socketTimeoutMS ?? 0;
Expand All @@ -225,39 +225,12 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
this.generation = options.generation;
this.lastUseTime = now();

this.socket = stream;

// TODO: Remove signal from connection layer
this.controller = new AbortController();
const { signal } = this.controller;
this.signal = signal;
const { promise: aborted, reject } = promiseWithResolvers<never>();
aborted.then(undefined, () => null); // Prevent unhandled rejection
this.signal.addEventListener(
'abort',
function onAbort() {
reject(signal.reason);
},
{ once: true }
);
this.aborted = aborted;

this.messageStream = this.socket
.on('error', this.onError.bind(this))
.pipe(new SizedMessageTransform({ connection: this }))
.on('error', this.onError.bind(this));
this.socket.on('close', this.onClose.bind(this));
this.socket.on('timeout', this.onTimeout.bind(this));

const socketWrite = promisify(this.socket.write.bind(this.socket));
this.socketWrite = async buffer => {
return Promise.race([socketWrite(buffer), this.aborted]);
};
}

/** Indicates that the connection (including underlying TCP socket) has been closed. */
public get closed(): boolean {
return this.signal.aborted;
}

public get hello() {
Expand Down Expand Up @@ -308,7 +281,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
this.lastUseTime = now();
}

public onError(error?: Error) {
public onError(error: Error) {
this.cleanup(error);
}

Expand Down Expand Up @@ -351,13 +324,15 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
*
* This method does nothing if the connection is already closed.
*/
private cleanup(error?: Error): void {
private cleanup(error: Error): void {
if (this.closed) {
return;
}

this.socket.destroy();
this.controller.abort(error);
this.error = error;
this.dataEvents?.throw(error).then(undefined, () => null); // squash unhandled rejection
this.closed = true;
this.emit(Connection.CLOSE);
}

Expand Down Expand Up @@ -598,7 +573,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
}

private throwIfAborted() {
this.signal.throwIfAborted();
if (this.error) throw this.error;
}

/**
Expand All @@ -621,7 +596,8 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {

const buffer = Buffer.concat(await finalCommand.toBin());

return this.socketWrite(buffer);
if (this.socket.write(buffer)) return;
return once(this.socket, 'drain');
}

/**
Expand All @@ -634,13 +610,19 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
* Note that `for-await` loops call `return` automatically when the loop is exited.
*/
private async *readMany(): AsyncGenerator<OpMsgResponse | OpQueryResponse> {
for await (const message of onData(this.messageStream, { signal: this.signal })) {
const response = await decompressResponse(message);
yield response;
try {
this.dataEvents = onData(this.messageStream);
for await (const message of this.dataEvents) {
const response = await decompressResponse(message);
yield response;

if (!response.moreToCome) {
return;
if (!response.moreToCome) {
return;
}
}
} finally {
this.dataEvents = null;
this.throwIfAborted();
}
}
}
Expand Down
18 changes: 2 additions & 16 deletions src/cmap/wire_protocol/on_data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ type PendingPromises = Omit<
* https://nodejs.org/api/events.html#eventsonemitter-eventname-options
*
* Returns an AsyncIterator that iterates each 'data' event emitted from emitter.
* It will reject upon an error event or if the provided signal is aborted.
* It will reject upon an error event.
*/
export function onData(emitter: EventEmitter, options: { signal: AbortSignal }) {
const signal = options.signal;

export function onData(emitter: EventEmitter) {
// Setup pending events and pending promise lists
/**
* When the caller has not yet called .next(), we store the
Expand Down Expand Up @@ -89,19 +87,8 @@ export function onData(emitter: EventEmitter, options: { signal: AbortSignal })
emitter.on('data', eventHandler);
emitter.on('error', errorHandler);

if (signal.aborted) {
// If the signal is aborted, set up the first .next() call to be a rejection
queueMicrotask(abortListener);
} else {
signal.addEventListener('abort', abortListener, { once: true });
}

return iterator;

function abortListener() {
errorHandler(signal.reason);
}

function eventHandler(value: Buffer) {
const promise = unconsumedPromises.shift();
if (promise != null) promise.resolve({ value, done: false });
Expand All @@ -119,7 +106,6 @@ export function onData(emitter: EventEmitter, options: { signal: AbortSignal })
// Adding event handlers
emitter.off('data', eventHandler);
emitter.off('error', errorHandler);
signal.removeEventListener('abort', abortListener);
finished = true;
const doneResult = { value: undefined, done: finished } as const;

Expand Down
25 changes: 25 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as crypto from 'crypto';
import type { SrvRecord } from 'dns';
import { type EventEmitter } from 'events';
import * as http from 'http';
import { clearTimeout, setTimeout } from 'timers';
import * as url from 'url';
Expand Down Expand Up @@ -1295,3 +1296,27 @@ export function promiseWithResolvers<T>() {
}

export const randomBytes = promisify(crypto.randomBytes);

/**
* Replicates the events.once helper.
*
* Removes unused signal logic and It **only** supports 0 or 1 argument events.
*
* @param ee - An event emitter that may emit `ev`
* @param name - An event name to wait for
*/
export async function once<T>(ee: EventEmitter, name: string): Promise<T> {
const { promise, resolve, reject } = promiseWithResolvers<T>();
const onEvent = (data: T) => resolve(data);
const onError = (error: Error) => reject(error);

ee.once(name, onEvent).once('error', onError);
try {
const res = await promise;
ee.off('error', onError);
return res;
} catch (error) {
ee.off(name, onEvent);
throw error;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { expect } from 'chai';
import { type EventEmitter, once } from 'events';
import * as sinon from 'sinon';
import { setTimeout } from 'timers';

import {
addContainerMetadata,
Binary,
connect,
Connection,
type ConnectionOptions,
Expand All @@ -15,7 +19,9 @@ import {
ServerHeartbeatStartedEvent,
Topology
} from '../../mongodb';
import * as mock from '../../tools/mongodb-mock/index';
import { skipBrokenAuthTestBeforeEachHook } from '../../tools/runner/hooks/configuration';
import { getSymbolFrom, sleep } from '../../tools/utils';
import { assert as test, setupDatabase } from '../shared';

const commonConnectOptions = {
Expand Down Expand Up @@ -200,6 +206,84 @@ describe('Connection', function () {
client.connect();
});

context(
'when a large message is written to the socket',
{ requires: { topology: 'single', auth: 'disabled' } },
() => {
let client, mockServer: import('../../tools/mongodb-mock/src/server').MockServer;

beforeEach(async function () {
mockServer = await mock.createServer();

mockServer
.addMessageHandler('insert', req => {
setTimeout(() => {
req.reply({ ok: 1 });
}, 800);
})
.addMessageHandler('hello', req => {
req.reply(Object.assign({}, mock.HELLO));
})
.addMessageHandler(LEGACY_HELLO_COMMAND, req => {
req.reply(Object.assign({}, mock.HELLO));
});

client = new MongoClient(`mongodb://${mockServer.uri()}`, {
minPoolSize: 1,
maxPoolSize: 1
});
});

afterEach(async function () {
await client.close();
mockServer.destroy();
sinon.restore();
});

it('waits for an async drain event because the write was buffered', async () => {
const connectionReady = once(client, 'connectionReady');
await client.connect();
await connectionReady;

// Get the only connection
const pool = [...client.topology.s.servers.values()][0].pool;

const connections = pool[getSymbolFrom(pool, 'connections')];
expect(connections).to.have.lengthOf(1);

const connection = connections.first();
const socket: EventEmitter = connection.socket;

// Spy on the socket event listeners
const addedListeners: string[] = [];
const removedListeners: string[] = [];
socket
.on('removeListener', name => removedListeners.push(name))
.on('newListener', name => addedListeners.push(name));

// Make server sockets block
for (const s of mockServer.sockets) s.pause();

const insert = client
.db('test')
.collection('test')
// Anything above 16Kb should work I think (10mb to be extra sure)
.insertOne({ a: new Binary(Buffer.alloc(10 * (2 ** 10) ** 2), 127) });

// Sleep a bit and unblock server sockets
await sleep(10);
for (const s of mockServer.sockets) s.resume();

// Let the operation finish
await insert;

// Ensure that we used the drain event for this write
expect(addedListeners).to.deep.equal(['drain', 'error']);
expect(removedListeners).to.deep.equal(['drain', 'error']);
});
}
);

context('when connecting with a username and password', () => {
let utilClient: MongoClient;
let client: MongoClient;
Expand Down
33 changes: 33 additions & 0 deletions test/integration/node-specific/resource_clean_up.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import * as v8 from 'node:v8';

import { expect } from 'chai';

import { sleep } from '../../tools/utils';
import { runScript } from './resource_tracking_script_builder';

/**
Expand Down Expand Up @@ -86,4 +89,34 @@ describe('Driver Resources', () => {
});
});
});

context('when 100s of operations are executed and complete', () => {
beforeEach(function () {
if (this.currentTest && typeof v8.queryObjects !== 'function') {
this.currentTest.skipReason = 'Test requires v8.queryObjects API to count Promises';
this.currentTest?.skip();
}
});

let client;
beforeEach(async function () {
client = this.configuration.newClient();
});

afterEach(async function () {
await client.close();
});

it('does not leave behind additional promises', async () => {
const test = client.db('test').collection('test');
const promiseCountBefore = v8.queryObjects(Promise, { format: 'count' });
for (let i = 0; i < 100; i++) {
await test.findOne();
}
await sleep(10);
const promiseCountAfter = v8.queryObjects(Promise, { format: 'count' });

expect(promiseCountAfter).to.be.within(promiseCountBefore - 5, promiseCountBefore + 5);
});
});
});

0 comments on commit 69de253

Please sign in to comment.