diff --git a/src/mongo_client.ts b/src/mongo_client.ts index 184df3a3bc..b35d2811bf 100644 --- a/src/mongo_client.ts +++ b/src/mongo_client.ts @@ -597,7 +597,14 @@ export class MongoClient extends TypedEventEmitter { return client.connect(); } - /** Starts a new session on the server */ + /** + * Creates a new ClientSession. When using the returned session in an operation + * a corresponding ServerSession will be created. + * + * @remarks + * A ClientSession instance may only be passed to operations being performed on the same + * MongoClient it was started from. + */ startSession(options?: ClientSessionOptions): ClientSession { const session = new ClientSession( this, diff --git a/src/operations/execute_operation.ts b/src/operations/execute_operation.ts index 50f77648f9..e815f091a8 100644 --- a/src/operations/execute_operation.ts +++ b/src/operations/execute_operation.ts @@ -7,6 +7,7 @@ import { MongoError, MongoErrorLabel, MongoExpiredSessionError, + MongoInvalidArgumentError, MongoNetworkError, MongoNotConnectedError, MongoRuntimeError, @@ -118,6 +119,8 @@ async function executeOperationAsync< throw new MongoExpiredSessionError('Use of expired sessions is not permitted'); } else if (session.snapshotEnabled && !topology.capabilities.supportsSnapshotReads) { throw new MongoCompatibilityError('Snapshot reads require MongoDB 5.0 or later'); + } else if (session.client !== client) { + throw new MongoInvalidArgumentError('ClientSession must be from the same MongoClient'); } const readPreference = operation.readPreference ?? ReadPreference.primary; diff --git a/test/integration/sessions/sessions.prose.test.ts b/test/integration/sessions/sessions.prose.test.ts index ee851e5f31..3fc41a5de4 100644 --- a/test/integration/sessions/sessions.prose.test.ts +++ b/test/integration/sessions/sessions.prose.test.ts @@ -6,11 +6,64 @@ import { type Collection, type CommandStartedEvent, MongoClient, - MongoDriverError + MongoDriverError, + MongoInvalidArgumentError } from '../../mongodb'; import { sleep } from '../../tools/utils'; describe('Sessions Prose Tests', () => { + describe('5. Session argument is for the right client', () => { + let client1: MongoClient; + let client2: MongoClient; + beforeEach(async function () { + client1 = this.configuration.newClient(); + client2 = this.configuration.newClient(); + }); + + afterEach(async function () { + await client1?.close(); + await client2?.close(); + }); + + /** + * Steps: + * - Create client1 and client2 + * - Get database from client1 + * - Get collection from database + * - Start session from client2 + * - Call collection.insertOne(session,...) + * - Assert that an error was reported because session was not started from client1 + * + * This validation lives in our executeOperation layer so it applies universally. + * A find and an insert provide enough coverage, we determined we do not need to enumerate every possible operation. + */ + context( + 'when session is started from a different client than operation is being run on', + () => { + it('insertOne operation throws a MongoInvalidArgumentError', async () => { + const db = client1.db(); + const collection = db.collection('test'); + const session = client2.startSession(); + const error = await collection.insertOne({}, { session }).catch(error => error); + expect(error).to.be.instanceOf(MongoInvalidArgumentError); + expect(error).to.match(/ClientSession must be from the same MongoClient/i); + }); + + it('find operation throws a MongoInvalidArgumentError', async () => { + const db = client1.db(); + const collection = db.collection('test'); + const session = client2.startSession(); + const error = await collection + .find({}, { session }) + .toArray() + .catch(error => error); + expect(error).to.be.instanceOf(MongoInvalidArgumentError); + expect(error).to.match(/ClientSession must be from the same MongoClient/i); + }); + } + ); + }); + describe('14. Implicit sessions only allocate their server session after a successful connection checkout', () => { let client: MongoClient; let testCollection: Collection<{ _id: number; a?: number }>; diff --git a/test/integration/sessions/sessions.test.ts b/test/integration/sessions/sessions.test.ts index a9629a47e5..3bb00e181c 100644 --- a/test/integration/sessions/sessions.test.ts +++ b/test/integration/sessions/sessions.test.ts @@ -1,4 +1,5 @@ import { expect } from 'chai'; +import { MongoClient as LegacyMongoClient } from 'mongodb-legacy'; import type { CommandStartedEvent, CommandSucceededEvent, MongoClient } from '../../mongodb'; import { LEGACY_HELLO_COMMAND, MongoServerError } from '../../mongodb'; @@ -422,4 +423,31 @@ describe('Sessions Spec', function () { expect(new Set(events.map(ev => ev.command.lsid.id.toString('hex'))).size).to.equal(2); }); }); + + context('when using a LegacyMongoClient', () => { + let legacyClient; + beforeEach(async function () { + const options = this.configuration.serverApi + ? { serverApi: this.configuration.serverApi } + : {}; + legacyClient = new LegacyMongoClient(this.configuration.url(), options); + }); + + afterEach(async function () { + await legacyClient?.close(); + }); + + it('insertOne accepts session started by legacy client', async () => { + const db = legacyClient.db(); + const collection = db.collection('test'); + const session = legacyClient.startSession(); + const error = await collection.insertOne({}, { session }).catch(error => error); + expect(error).to.not.be.instanceOf(Error); + }); + + it('session returned by legacy startSession has reference to legacyClient', async () => { + const session = legacyClient.startSession(); + expect(session).to.have.property('client', legacyClient); + }); + }); }); diff --git a/test/tools/runner/config.ts b/test/tools/runner/config.ts index 2f190fe3b1..375b909a74 100644 --- a/test/tools/runner/config.ts +++ b/test/tools/runner/config.ts @@ -7,6 +7,7 @@ import { type AuthMechanism, HostAddress, MongoClient, + type ServerApi, TopologyType, type WriteConcernSettings } from '../../mongodb'; @@ -71,7 +72,7 @@ export class TestConfiguration { auth?: { username: string; password: string; authSource?: string }; proxyURIParams?: ProxyParams; }; - serverApi: string; + serverApi: ServerApi; constructor(uri: string, context: Record) { const url = new ConnectionString(uri);