diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 1213e158ad..e1ad9a0293 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -62,7 +62,11 @@ import type { ClientMetadata } from './handshake/client_metadata'; import { StreamDescription, type StreamDescriptionOptions } from './stream_description'; import { type CompressorName, decompressResponse } from './wire_protocol/compression'; import { onData } from './wire_protocol/on_data'; -import { MongoDBResponse, type MongoDBResponseConstructor } from './wire_protocol/responses'; +import { + isErrorResponse, + MongoDBResponse, + type MongoDBResponseConstructor +} from './wire_protocol/responses'; import { getReadPreference, isSharded } from './wire_protocol/shared'; /** @internal */ @@ -443,7 +447,12 @@ export class Connection extends TypedEventEmitter { this.socket.setTimeout(0); const bson = response.parse(); - const document = new (responseType ?? MongoDBResponse)(bson, 0, false); + const document = + responseType == null + ? new MongoDBResponse(bson) + : isErrorResponse(bson) + ? new MongoDBResponse(bson) + : new responseType(bson); yield document; this.throwIfAborted(); @@ -739,7 +748,7 @@ export class CryptoConnection extends Connection { ns: MongoDBNamespace, cmd: Document, options?: CommandOptions, - responseType?: T | undefined + _responseType?: T | undefined ): Promise { const { autoEncrypter } = this; if (!autoEncrypter) { @@ -753,7 +762,7 @@ export class CryptoConnection extends Connection { const serverWireVersion = maxWireVersion(this); if (serverWireVersion === 0) { // This means the initial handshake hasn't happened yet - return await super.command(ns, cmd, options, responseType); + return await super.command(ns, cmd, options, undefined); } if (serverWireVersion < 8) { @@ -787,7 +796,7 @@ export class CryptoConnection extends Connection { } } - const response = await super.command(ns, encrypted, options, responseType); + const response = await super.command(ns, encrypted, options, undefined); return await autoEncrypter.decrypt(response, options); } diff --git a/src/cmap/wire_protocol/on_demand/document.ts b/src/cmap/wire_protocol/on_demand/document.ts index ba8804fc6c..638946d647 100644 --- a/src/cmap/wire_protocol/on_demand/document.ts +++ b/src/cmap/wire_protocol/on_demand/document.ts @@ -58,7 +58,7 @@ export class OnDemandDocument { private readonly indexFound: Record = Object.create(null); /** All bson elements in this document */ - private readonly elements: BSONElement[]; + private readonly elements: ReadonlyArray; constructor( /** BSON bytes, this document begins at offset */ @@ -97,7 +97,7 @@ export class OnDemandDocument { * @param name - a basic latin string name of a BSON element * @returns */ - private getElement(name: string): CachedBSONElement | null { + private getElement(name: string | number): CachedBSONElement | null { const cachedElement = this.cache[name]; if (cachedElement === false) return null; @@ -105,6 +105,22 @@ export class OnDemandDocument { return cachedElement; } + if (typeof name === 'number') { + if (this.isArray) { + if (name < this.elements.length) { + const element = this.elements[name]; + const cachedElement = { element, value: undefined }; + this.cache[name] = cachedElement; + this.indexFound[name] = true; + return cachedElement; + } else { + return null; + } + } else { + return null; + } + } + for (let index = 0; index < this.elements.length; index++) { const element = this.elements[index]; @@ -197,6 +213,13 @@ export class OnDemandDocument { } } + /** + * Returns the number of elements in this BSON document + */ + public size() { + return this.elements.length; + } + /** * Checks for the existence of an element by name. * @@ -222,16 +245,20 @@ export class OnDemandDocument { * @param required - whether or not the element is expected to exist, if true this function will throw if it is not present */ public get( - name: string, + name: string | number, as: T, required?: false | undefined ): JSTypeOf[T] | null; /** `required` will make `get` throw if name does not exist or is null/undefined */ - public get(name: string, as: T, required: true): JSTypeOf[T]; + public get( + name: string | number, + as: T, + required: true + ): JSTypeOf[T]; public get( - name: string, + name: string | number, as: T, required?: boolean ): JSTypeOf[T] | null { @@ -303,21 +330,9 @@ export class OnDemandDocument { }); } - /** - * Iterates through the elements of a document reviving them using the `as` BSONType. - * - * @param as - The type to revive all elements as - */ - public *valuesAs(as: T): Generator { - if (!this.isArray) { - throw new BSONError('Unexpected conversion of non-array value to array'); - } - let counter = 0; - for (const element of this.elements) { - const value = this.toJSValue(element, as); - this.cache[counter] = { element, value }; - yield value; - counter += 1; - } + /** Returns this document's bytes only */ + toBytes() { + const size = getInt32LE(this.bson, this.offset); + return this.bson.subarray(this.offset, this.offset + size); } } diff --git a/src/cmap/wire_protocol/responses.ts b/src/cmap/wire_protocol/responses.ts index b776a4de56..65515cbb31 100644 --- a/src/cmap/wire_protocol/responses.ts +++ b/src/cmap/wire_protocol/responses.ts @@ -1,7 +1,62 @@ -import { type BSONSerializeOptions, BSONType, type Document, type Timestamp } from '../../bson'; +import { + type BSONSerializeOptions, + BSONType, + type Document, + Long, + parseToElementsToArray, + type Timestamp +} from '../../bson'; +import { MongoUnexpectedServerResponseError } from '../../error'; import { type ClusterTime } from '../../sdam/common'; +import { type MongoDBNamespace, ns } from '../../utils'; import { OnDemandDocument } from './on_demand/document'; +// eslint-disable-next-line no-restricted-syntax +const enum BSONElementOffset { + type = 0, + nameOffset = 1, + nameLength = 2, + offset = 3, + length = 4 +} +/** + * Accepts a BSON payload and checks for na "ok: 0" element. + * This utility is intended to prevent calling response class constructors + * that expect the result to be a success and demand certain properties to exist. + * + * For example, a cursor response always expects a cursor embedded document. + * In order to write the class such that the properties reflect that assertion (non-null) + * we cannot invoke the subclass constructor if the BSON represents an error. + * + * @param bytes - BSON document returned from the server + */ +export function isErrorResponse(bson: Uint8Array): boolean { + const elements = parseToElementsToArray(bson, 0); + for (let eIdx = 0; eIdx < elements.length; eIdx++) { + const element = elements[eIdx]; + + if (element[BSONElementOffset.nameLength] === 2) { + const nameOffset = element[BSONElementOffset.nameOffset]; + + // 111 == "o", 107 == "k" + if (bson[nameOffset] === 111 && bson[nameOffset + 1] === 107) { + const valueOffset = element[BSONElementOffset.offset]; + const valueLength = element[BSONElementOffset.length]; + + // If any byte in the length of the ok number (works for any type) is non zero, + // then it is considered "ok: 1" + for (let i = valueOffset; i < valueOffset + valueLength; i++) { + if (bson[i] !== 0x00) return false; + } + + return true; + } + } + } + + return true; +} + /** @internal */ export type MongoDBResponseConstructor = { new (bson: Uint8Array, offset?: number, isArray?: boolean): MongoDBResponse; @@ -9,6 +64,10 @@ export type MongoDBResponseConstructor = { /** @internal */ export class MongoDBResponse extends OnDemandDocument { + static is(value: unknown): value is MongoDBResponse { + return value instanceof MongoDBResponse; + } + // {ok:1} static empty = new MongoDBResponse(new Uint8Array([13, 0, 0, 0, 16, 111, 107, 0, 1, 0, 0, 0, 0])); @@ -83,27 +142,96 @@ export class MongoDBResponse extends OnDemandDocument { return this.clusterTime ?? null; } - public override toObject(options: BSONSerializeOptions = {}): Record { + public override toObject(options?: BSONSerializeOptions): Record { const exactBSONOptions = { - useBigInt64: options.useBigInt64, - promoteLongs: options.promoteLongs, - promoteValues: options.promoteValues, - promoteBuffers: options.promoteBuffers, - bsonRegExp: options.bsonRegExp, - raw: options.raw ?? false, - fieldsAsRaw: options.fieldsAsRaw ?? {}, + useBigInt64: options?.useBigInt64, + promoteLongs: options?.promoteLongs, + promoteValues: options?.promoteValues, + promoteBuffers: options?.promoteBuffers, + bsonRegExp: options?.bsonRegExp, + raw: options?.raw ?? false, + fieldsAsRaw: options?.fieldsAsRaw ?? {}, validation: this.parseBsonSerializationOptions(options) }; return super.toObject(exactBSONOptions); } - private parseBsonSerializationOptions({ enableUtf8Validation }: BSONSerializeOptions): { + private parseBsonSerializationOptions(options?: { enableUtf8Validation?: boolean }): { utf8: { writeErrors: false } | false; } { + const enableUtf8Validation = options?.enableUtf8Validation; if (enableUtf8Validation === false) { return { utf8: false }; } - return { utf8: { writeErrors: false } }; } } + +/** @internal */ +export class CursorResponse extends MongoDBResponse { + /** + * This supports a feature of the FindCursor. + * It is an optimization to avoid an extra getMore when the limit has been reached + */ + static emptyGetMore = { id: new Long(0), length: 0, shift: () => null }; + + static override is(value: unknown): value is CursorResponse { + return value instanceof CursorResponse || value === CursorResponse.emptyGetMore; + } + + public id: Long; + public ns: MongoDBNamespace | null = null; + public batchSize = 0; + + private batch: OnDemandDocument; + private iterated = 0; + + constructor(bytes: Uint8Array, offset?: number, isArray?: boolean) { + super(bytes, offset, isArray); + + const cursor = this.get('cursor', BSONType.object, true); + + const id = cursor.get('id', BSONType.long, true); + this.id = new Long(Number(id & 0xffff_ffffn), Number((id >> 32n) & 0xffff_ffffn)); + + const namespace = cursor.get('ns', BSONType.string); + if (namespace != null) this.ns = ns(namespace); + + if (cursor.has('firstBatch')) this.batch = cursor.get('firstBatch', BSONType.array, true); + else if (cursor.has('nextBatch')) this.batch = cursor.get('nextBatch', BSONType.array, true); + else throw new MongoUnexpectedServerResponseError('Cursor document did not contain a batch'); + + this.batchSize = this.batch.size(); + } + + get length() { + return Math.max(this.batchSize - this.iterated, 0); + } + + shift(options?: BSONSerializeOptions): any { + if (this.iterated >= this.batchSize) { + return null; + } + + const result = this.batch.get(this.iterated, BSONType.object, true) ?? null; + this.iterated += 1; + + if (options?.raw) { + return result.toBytes(); + } else { + return result.toObject(options); + } + } + + clear() { + this.iterated = this.batchSize; + } + + pushMany() { + throw new Error('pushMany Unsupported method'); + } + + push() { + throw new Error('push Unsupported method'); + } +} diff --git a/src/cursor/abstract_cursor.ts b/src/cursor/abstract_cursor.ts index 10aa5eea5f..c4f349500a 100644 --- a/src/cursor/abstract_cursor.ts +++ b/src/cursor/abstract_cursor.ts @@ -1,6 +1,7 @@ import { Readable, Transform } from 'stream'; import { type BSONSerializeOptions, type Document, Long, pluckBSONSerializeOptions } from '../bson'; +import { CursorResponse } from '../cmap/wire_protocol/responses'; import { type AnyError, MongoAPIError, @@ -144,7 +145,13 @@ export abstract class AbstractCursor< /** @internal */ [kNamespace]: MongoDBNamespace; /** @internal */ - [kDocuments]: List; + [kDocuments]: { + length: number; + shift(bsonOptions?: any): TSchema | null; + clear(): void; + pushMany(many: Iterable): void; + push(item: TSchema): void; + }; /** @internal */ [kClient]: MongoClient; /** @internal */ @@ -286,7 +293,7 @@ export abstract class AbstractCursor< const documentsToRead = Math.min(number ?? this[kDocuments].length, this[kDocuments].length); for (let count = 0; count < documentsToRead; count++) { - const document = this[kDocuments].shift(); + const document = this[kDocuments].shift(this[kOptions]); if (document != null) { bufferedDocs.push(document); } @@ -382,14 +389,7 @@ export abstract class AbstractCursor< return true; } - const doc = await next(this, { blocking: true, transform: false }); - - if (doc) { - this[kDocuments].unshift(doc); - return true; - } - - return false; + return await next(this, { blocking: true, transform: false, shift: false }); } /** Get the next available document from the cursor, returns null if no more documents are available. */ @@ -398,7 +398,7 @@ export abstract class AbstractCursor< throw new MongoCursorExhaustedError(); } - return await next(this, { blocking: true, transform: true }); + return await next(this, { blocking: true, transform: true, shift: true }); } /** @@ -409,7 +409,7 @@ export abstract class AbstractCursor< throw new MongoCursorExhaustedError(); } - return await next(this, { blocking: false, transform: true }); + return await next(this, { blocking: false, transform: true, shift: true }); } /** @@ -633,12 +633,13 @@ export abstract class AbstractCursor< protected abstract _initialize(session: ClientSession | undefined): Promise; /** @internal */ - async getMore(batchSize: number): Promise { + async getMore(batchSize: number, useCursorResponse = false): Promise { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const getMoreOperation = new GetMoreOperation(this[kNamespace], this[kId]!, this[kServer]!, { ...this[kOptions], session: this[kSession], - batchSize + batchSize, + useCursorResponse }); return await executeOperation(this[kClient], getMoreOperation); @@ -656,7 +657,11 @@ export abstract class AbstractCursor< const state = await this._initialize(this[kSession]); const response = state.response; this[kServer] = state.server; - if (response.cursor) { + if (CursorResponse.is(response)) { + this[kId] = response.id; + if (response.ns) this[kNamespace] = response.ns; + this[kDocuments] = response; + } else if (response.cursor) { // TODO(NODE-2674): Preserve int64 sent from MongoDB this[kId] = typeof response.cursor.id === 'number' @@ -713,13 +718,42 @@ async function next( cursor: AbstractCursor, { blocking, - transform + transform, + shift }: { blocking: boolean; transform: boolean; + shift: false; } -): Promise { +): Promise; + +async function next( + cursor: AbstractCursor, + { + blocking, + transform, + shift + }: { + blocking: boolean; + transform: boolean; + shift: true; + } +): Promise; + +async function next( + cursor: AbstractCursor, + { + blocking, + transform, + shift + }: { + blocking: boolean; + transform: boolean; + shift: boolean; + } +): Promise { if (cursor.closed) { + if (!shift) return false; return null; } @@ -730,7 +764,8 @@ async function next( } if (cursor[kDocuments].length !== 0) { - const doc = cursor[kDocuments].shift(); + if (!shift) return true; + const doc = cursor[kDocuments].shift(cursor[kOptions]); if (doc != null && transform && cursor[kTransform]) { try { @@ -754,6 +789,7 @@ async function next( // cleanupCursor should never throw, but if it does it indicates a bug in the driver // and we should surface the error await cleanupCursor(cursor, {}); + if (!shift) return false; return null; } @@ -762,8 +798,10 @@ async function next( try { const response = await cursor.getMore(batchSize); - - if (response) { + if (CursorResponse.is(response)) { + cursor[kId] = response.id; + cursor[kDocuments] = response; + } else if (response) { const cursorId = typeof response.cursor.id === 'number' ? Long.fromNumber(response.cursor.id) @@ -796,10 +834,12 @@ async function next( } if (cursor[kDocuments].length === 0 && blocking === false) { + if (!shift) return false; return null; } } while (!cursor.isDead || cursor[kDocuments].length !== 0); + if (!shift) return false; return null; } @@ -921,7 +961,7 @@ class ReadableCursorStream extends Readable { private _readNext() { // eslint-disable-next-line github/no-then - next(this._cursor, { blocking: true, transform: true }).then( + next(this._cursor, { blocking: true, transform: true, shift: true }).then( result => { if (result == null) { this.push(null); diff --git a/src/cursor/find_cursor.ts b/src/cursor/find_cursor.ts index b76af197e1..19c599fe8c 100644 --- a/src/cursor/find_cursor.ts +++ b/src/cursor/find_cursor.ts @@ -1,4 +1,5 @@ -import { type Document, Long } from '../bson'; +import { type Document } from '../bson'; +import { CursorResponse } from '../cmap/wire_protocol/responses'; import { MongoInvalidArgumentError, MongoTailableCursorError } from '../error'; import { type ExplainVerbosityLike } from '../explain'; import type { MongoClient } from '../mongo_client'; @@ -34,7 +35,7 @@ export class FindCursor extends AbstractCursor { /** @internal */ [kFilter]: Document; /** @internal */ - [kNumReturned]?: number; + [kNumReturned] = 0; /** @internal */ [kBuiltOptions]: FindOptions; @@ -69,7 +70,7 @@ export class FindCursor extends AbstractCursor { /** @internal */ async _initialize(session: ClientSession): Promise { - const findOperation = new FindOperation(undefined, this.namespace, this[kFilter], { + const findOperation = new FindOperation(this.namespace, this[kFilter], { ...this[kBuiltOptions], // NOTE: order matters here, we may need to refine this ...this.cursorOptions, session @@ -78,7 +79,9 @@ export class FindCursor extends AbstractCursor { const response = await executeOperation(this.client, findOperation); // the response is not a cursor when `explain` is enabled - this[kNumReturned] = response.cursor?.firstBatch?.length; + if (CursorResponse.is(response)) { + this[kNumReturned] = response.batchSize; + } // TODO: NODE-2882 return { server: findOperation.server, session, response }; @@ -107,14 +110,14 @@ export class FindCursor extends AbstractCursor { // instead, if we determine there are no more documents to request from the server, we preemptively // close the cursor } - return { cursor: { id: Long.ZERO, nextBatch: [] } }; + return CursorResponse.emptyGetMore; } } - const response = await super.getMore(batchSize); + const response = await super.getMore(batchSize, this.client.autoEncrypter ? false : true); // TODO: wrap this in some logic to prevent it from happening if we don't need this support - if (response) { - this[kNumReturned] = this[kNumReturned] + response.cursor.nextBatch.length; + if (CursorResponse.is(response)) { + this[kNumReturned] = this[kNumReturned] + response.batchSize; } return response; @@ -145,7 +148,7 @@ export class FindCursor extends AbstractCursor { async explain(verbosity?: ExplainVerbosityLike): Promise { return await executeOperation( this.client, - new FindOperation(undefined, this.namespace, this[kFilter], { + new FindOperation(this.namespace, this[kFilter], { ...this[kBuiltOptions], // NOTE: order matters here, we may need to refine this ...this.cursorOptions, explain: verbosity ?? true diff --git a/src/cursor/run_command_cursor.ts b/src/cursor/run_command_cursor.ts index 4f88dc2db5..553041492f 100644 --- a/src/cursor/run_command_cursor.ts +++ b/src/cursor/run_command_cursor.ts @@ -125,7 +125,8 @@ export class RunCommandCursor extends AbstractCursor { const getMoreOperation = new GetMoreOperation(this.namespace, this.id!, this.server!, { ...this.cursorOptions, session: this.session, - ...this.getMoreOptions + ...this.getMoreOptions, + useCursorResponse: false }); return await executeOperation(this.client, getMoreOperation); diff --git a/src/index.ts b/src/index.ts index bb83a774bf..812d045ba6 100644 --- a/src/index.ts +++ b/src/index.ts @@ -290,7 +290,11 @@ export type { ConnectionPoolMetrics } from './cmap/metrics'; export type { StreamDescription, StreamDescriptionOptions } from './cmap/stream_description'; export type { CompressorName } from './cmap/wire_protocol/compression'; export type { JSTypeOf, OnDemandDocument } from './cmap/wire_protocol/on_demand/document'; -export type { MongoDBResponse, MongoDBResponseConstructor } from './cmap/wire_protocol/responses'; +export type { + CursorResponse, + MongoDBResponse, + MongoDBResponseConstructor +} from './cmap/wire_protocol/responses'; export type { CollectionOptions, CollectionPrivate, ModifyResult } from './collection'; export type { COMMAND_FAILED, diff --git a/src/operations/execute_operation.ts b/src/operations/execute_operation.ts index 6e1b569a7d..4faf4fd95a 100644 --- a/src/operations/execute_operation.ts +++ b/src/operations/execute_operation.ts @@ -1,4 +1,5 @@ import type { Document } from '../bson'; +import { type CursorResponse } from '../cmap/wire_protocol/responses'; import { isRetryableReadError, isRetryableWriteError, @@ -44,7 +45,7 @@ export interface ExecutionResult { /** The session used for this operation, may be implicitly created */ session?: ClientSession; /** The raw server response for the operation */ - response: Document; + response: Document | CursorResponse; } /** diff --git a/src/operations/find.ts b/src/operations/find.ts index 3841142c4e..1c2ccdb1ca 100644 --- a/src/operations/find.ts +++ b/src/operations/find.ts @@ -1,5 +1,5 @@ import type { Document } from '../bson'; -import type { Collection } from '../collection'; +import { CursorResponse } from '../cmap/wire_protocol/responses'; import { MongoInvalidArgumentError } from '../error'; import { ReadConcern } from '../read_concern'; import type { Server } from '../sdam/server'; @@ -77,13 +77,8 @@ export class FindOperation extends CommandOperation { override options: FindOptions & { writeConcern?: never }; filter: Document; - constructor( - collection: Collection | undefined, - ns: MongoDBNamespace, - filter: Document = {}, - options: FindOptions = {} - ) { - super(collection, options); + constructor(ns: MongoDBNamespace, filter: Document = {}, options: FindOptions = {}) { + super(undefined, options); this.options = { ...options }; delete this.options.writeConcern; @@ -111,12 +106,17 @@ export class FindOperation extends CommandOperation { findCommand = decorateWithExplain(findCommand, this.explain); } - return await server.command(this.ns, findCommand, { - ...this.options, - ...this.bsonOptions, - documentsReturnedIn: 'firstBatch', - session - }); + return await server.command( + this.ns, + findCommand, + { + ...this.options, + ...this.bsonOptions, + documentsReturnedIn: 'firstBatch', + session + }, + this.explain ? undefined : CursorResponse + ); } } diff --git a/src/operations/get_more.ts b/src/operations/get_more.ts index ada371c956..05f54b0b57 100644 --- a/src/operations/get_more.ts +++ b/src/operations/get_more.ts @@ -1,4 +1,5 @@ import type { Document, Long } from '../bson'; +import { CursorResponse } from '../cmap/wire_protocol/responses'; import { MongoRuntimeError } from '../error'; import type { Server } from '../sdam/server'; import type { ClientSession } from '../sessions'; @@ -19,6 +20,8 @@ export interface GetMoreOptions extends OperationOptions { maxTimeMS?: number; /** TODO(NODE-4413): Address bug with maxAwaitTimeMS not being passed in from the cursor correctly */ maxAwaitTimeMS?: number; + + useCursorResponse: boolean; } /** @@ -96,7 +99,12 @@ export class GetMoreOperation extends AbstractOperation { ...this.options }; - return await server.command(this.ns, getMoreCmd, commandOptions); + return await server.command( + this.ns, + getMoreCmd, + commandOptions, + this.options.useCursorResponse ? CursorResponse : undefined + ); } } diff --git a/src/sdam/server.ts b/src/sdam/server.ts index 6dbc31df7d..8ea91815c6 100644 --- a/src/sdam/server.ts +++ b/src/sdam/server.ts @@ -7,6 +7,7 @@ import { type ConnectionPoolOptions } from '../cmap/connection_pool'; import { PoolClearedError } from '../cmap/errors'; +import { type MongoDBResponseConstructor } from '../cmap/wire_protocol/responses'; import { APM_EVENTS, CLOSED, @@ -262,11 +263,25 @@ export class Server extends TypedEventEmitter { } } - /** - * Execute a command - * @internal - */ - async command(ns: MongoDBNamespace, cmd: Document, options: CommandOptions): Promise { + public async command( + ns: MongoDBNamespace, + command: Document, + options: CommandOptions | undefined, + responseType: T | undefined + ): Promise>; + + public async command( + ns: MongoDBNamespace, + command: Document, + options?: CommandOptions + ): Promise; + + public async command( + ns: MongoDBNamespace, + cmd: Document, + options: CommandOptions, + responseType?: MongoDBResponseConstructor + ): Promise { if (ns.db == null || typeof ns === 'string') { throw new MongoInvalidArgumentError('Namespace must not be a string'); } @@ -308,7 +323,7 @@ export class Server extends TypedEventEmitter { try { try { - return await conn.command(ns, cmd, finalOptions); + return await conn.command(ns, cmd, finalOptions, responseType); } catch (commandError) { throw this.decorateCommandError(conn, cmd, finalOptions, commandError); } @@ -319,7 +334,7 @@ export class Server extends TypedEventEmitter { ) { await this.pool.reauthenticate(conn); try { - return await conn.command(ns, cmd, finalOptions); + return await conn.command(ns, cmd, finalOptions, responseType); } catch (commandError) { throw this.decorateCommandError(conn, cmd, finalOptions, commandError); } diff --git a/src/utils.ts b/src/utils.ts index bf34a3d519..57079b1f63 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -64,6 +64,19 @@ export const ByteUtils = { } }; +/** + * Returns true if value is a Uint8Array or a Buffer + * @param value - any value that may be a Uint8Array + */ +export function isUint8Array(value: unknown): value is Uint8Array { + return ( + value != null && + typeof value === 'object' && + Symbol.toStringTag in value && + value[Symbol.toStringTag] === 'Uint8Array' + ); +} + /** * Determines if a connection's address matches a user provided list * of domain wildcards. diff --git a/test/integration/client-side-encryption/driver.test.ts b/test/integration/client-side-encryption/driver.test.ts index d22033c890..342f5adf1c 100644 --- a/test/integration/client-side-encryption/driver.test.ts +++ b/test/integration/client-side-encryption/driver.test.ts @@ -297,4 +297,111 @@ describe('Client Side Encryption Functional', function () { }); }); }); + + describe( + 'when @@mdb.decorateDecryptionResult is set on autoEncrypter', + { requires: { clientSideEncryption: true, mongodb: '>=4.4' } }, + () => { + let client: MongoClient; + let encryptedClient: MongoClient; + + beforeEach(async function () { + client = this.configuration.newClient(); + + const encryptSchema = (keyId: unknown, bsonType: string) => ({ + encrypt: { + bsonType, + algorithm: 'AEAD_AES_256_CBC_HMAC_SHA_512-Random', + keyId: [keyId] + } + }); + + const kmsProviders = this.configuration.kmsProviders(crypto.randomBytes(96)); + + await client.connect(); + + const encryption = new ClientEncryption(client, { + keyVaultNamespace, + kmsProviders, + extraOptions: getEncryptExtraOptions() + }); + + const dataDb = client.db(dataDbName); + const keyVaultDb = client.db(keyVaultDbName); + + await dataDb.dropCollection(dataCollName).catch(() => null); + await keyVaultDb.dropCollection(keyVaultCollName).catch(() => null); + await keyVaultDb.createCollection(keyVaultCollName); + const dataKey = await encryption.createDataKey('local'); + + const $jsonSchema = { + bsonType: 'object', + properties: { + a: encryptSchema(dataKey, 'int'), + b: encryptSchema(dataKey, 'string'), + c: { + bsonType: 'object', + properties: { + d: { + encrypt: { + keyId: [dataKey], + algorithm: 'AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic', + bsonType: 'string' + } + } + } + } + } + }; + + await dataDb.createCollection(dataCollName, { + validator: { $jsonSchema } + }); + + encryptedClient = this.configuration.newClient( + {}, + { + autoEncryption: { + keyVaultNamespace, + kmsProviders, + extraOptions: getEncryptExtraOptions() + } + } + ); + + encryptedClient.autoEncrypter[Symbol.for('@@mdb.decorateDecryptionResult')] = true; + await encryptedClient.connect(); + }); + + afterEach(async function () { + await encryptedClient?.close(); + await client?.close(); + }); + + it('adds decrypted keys to result at @@mdb.decryptedKeys', async function () { + const coll = encryptedClient.db(dataDbName).collection(dataCollName); + + const data = { + _id: new BSON.ObjectId(), + a: 1, + b: 'abc', + c: { d: 'def' } + }; + + const result = await coll.insertOne(data); + const decrypted = await coll.findOne({ _id: result.insertedId }); + + expect(decrypted).to.deep.equal(data); + expect(decrypted) + .to.have.property(Symbol.for('@@mdb.decryptedKeys')) + .that.deep.equals(['a', 'b']); + + // Nested + expect(decrypted).to.have.property('c'); + expect(decrypted.c) + .to.have.property(Symbol.for('@@mdb.decryptedKeys')) + .that.deep.equals(['d']); + }); + } + ); }); diff --git a/test/integration/crud/abstract_operation.test.ts b/test/integration/crud/abstract_operation.test.ts index fcac3e6ffe..77c45890e2 100644 --- a/test/integration/crud/abstract_operation.test.ts +++ b/test/integration/crud/abstract_operation.test.ts @@ -102,7 +102,7 @@ describe('abstract operation', function () { correctCommandName: 'count' }, { - subclassCreator: () => new mongodb.FindOperation(collection, collection.fullNamespace), + subclassCreator: () => new mongodb.FindOperation(collection.fullNamespace), subclassType: mongodb.FindOperation, correctCommandName: 'find' }, diff --git a/test/unit/assorted/collations.test.js b/test/unit/assorted/collations.test.js index 0c597a2ce7..124d020afc 100644 --- a/test/unit/assorted/collations.test.js +++ b/test/unit/assorted/collations.test.js @@ -55,7 +55,7 @@ describe('Collation', function () { request.reply(primary[0]); } else if (doc.aggregate) { commandResult = doc; - request.reply({ ok: 1, cursor: { id: 0, firstBatch: [], ns: 'collation_test' } }); + request.reply({ ok: 1, cursor: { id: 0n, firstBatch: [], ns: 'collation_test' } }); } else if (doc.endSessions) { request.reply({ ok: 1 }); } @@ -183,7 +183,7 @@ describe('Collation', function () { request.reply(primary[0]); } else if (doc.find) { commandResult = doc; - request.reply({ ok: 1, cursor: { id: 0, firstBatch: [] } }); + request.reply({ ok: 1, cursor: { id: 0n, firstBatch: [] } }); } else if (doc.endSessions) { request.reply({ ok: 1 }); } @@ -215,7 +215,7 @@ describe('Collation', function () { request.reply(primary[0]); } else if (doc.find) { commandResult = doc; - request.reply({ ok: 1, cursor: { id: 0, firstBatch: [] } }); + request.reply({ ok: 1, cursor: { id: 0n, firstBatch: [] } }); } else if (doc.endSessions) { request.reply({ ok: 1 }); } @@ -249,7 +249,7 @@ describe('Collation', function () { request.reply(primary[0]); } else if (doc.find) { commandResult = doc; - request.reply({ ok: 1, cursor: { id: 0, firstBatch: [] } }); + request.reply({ ok: 1, cursor: { id: 0n, firstBatch: [] } }); } else if (doc.endSessions) { request.reply({ ok: 1 }); } diff --git a/test/unit/assorted/sessions_collection.test.js b/test/unit/assorted/sessions_collection.test.js index eee1a76ec9..409d818d13 100644 --- a/test/unit/assorted/sessions_collection.test.js +++ b/test/unit/assorted/sessions_collection.test.js @@ -27,7 +27,7 @@ describe('Sessions - unit/sessions', function () { request.reply({ ok: 1, operationTime: insertOperationTime }); } else if (doc.find) { findCommand = doc; - request.reply({ ok: 1, cursor: { id: 0, firstBatch: [] } }); + request.reply({ ok: 1, cursor: { id: 0n, firstBatch: [] } }); } else if (doc.endSessions) { request.reply({ ok: 1 }); } diff --git a/test/unit/cmap/wire_protocol/on_demand/document.test.ts b/test/unit/cmap/wire_protocol/on_demand/document.test.ts index 6a7d5bb10c..82ed4040f6 100644 --- a/test/unit/cmap/wire_protocol/on_demand/document.test.ts +++ b/test/unit/cmap/wire_protocol/on_demand/document.test.ts @@ -73,6 +73,7 @@ describe('class OnDemandDocument', () => { context('get()', () => { let document: OnDemandDocument; + let array: OnDemandDocument; const input = { int: 1, double: 1.2, @@ -86,12 +87,27 @@ describe('class OnDemandDocument', () => { date: new Date(0), object: { a: 1 }, array: [1, 2], - unsupportedType: /abc/ + unsupportedType: /abc/, + [233]: 3 }; beforeEach(async function () { const bytes = BSON.serialize(input); document = new OnDemandDocument(bytes); + array = new OnDemandDocument( + BSON.serialize(Object.fromEntries(Object.values(input).entries())), + 0, + true + ); + }); + + it('supports access by number for arrays', () => { + expect(array.get(1, BSONType.int)).to.equal(1); + }); + + it('does not support access by number for objects', () => { + expect(document.get(233, BSONType.int)).to.be.null; + expect(document.get('233', BSONType.int)).to.equal(3); }); it('returns null if the element does not exist', () => { @@ -277,39 +293,4 @@ describe('class OnDemandDocument', () => { expect(document.getNumber('boolTrue')).to.equal(1); }); }); - - context('*valuesAs()', () => { - let array: OnDemandDocument; - beforeEach(async function () { - const bytes = BSON.serialize( - Object.fromEntries(Array.from({ length: 10 }, () => 1).entries()) - ); - array = new OnDemandDocument(bytes, 0, true); - }); - - it('throws if document is not an array', () => { - const bytes = BSON.serialize( - Object.fromEntries(Array.from({ length: 10 }, () => 1).entries()) - ); - array = new OnDemandDocument(bytes, 0, false); - expect(() => array.valuesAs(BSONType.int).next()).to.throw(); - }); - - it('returns a generator that yields values matching the as BSONType parameter', () => { - let didRun = false; - for (const item of array.valuesAs(BSONType.int)) { - didRun = true; - expect(item).to.equal(1); - } - expect(didRun).to.be.true; - }); - - it('caches the results of array', () => { - const generator = array.valuesAs(BSONType.int); - generator.next(); - generator.next(); - expect(array).to.have.nested.property('cache.0.value', 1); - expect(array).to.have.nested.property('cache.1.value', 1); - }); - }); }); diff --git a/test/unit/cmap/wire_protocol/responses.test.ts b/test/unit/cmap/wire_protocol/responses.test.ts index fc5ee88ae1..91e052da84 100644 --- a/test/unit/cmap/wire_protocol/responses.test.ts +++ b/test/unit/cmap/wire_protocol/responses.test.ts @@ -1,7 +1,15 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; -import { BSON, MongoDBResponse, OnDemandDocument } from '../../../mongodb'; +import { + BSON, + BSONError, + CursorResponse, + Int32, + MongoDBResponse, + MongoUnexpectedServerResponseError, + OnDemandDocument +} from '../../../mongodb'; describe('class MongoDBResponse', () => { it('is a subclass of OnDemandDocument', () => { @@ -76,3 +84,90 @@ describe('class MongoDBResponse', () => { }); }); }); + +describe('class CursorResponse', () => { + describe('constructor()', () => { + it('throws if input does not contain cursor embedded document', () => { + expect(() => new CursorResponse(BSON.serialize({ ok: 1 }))).to.throw(BSONError); + }); + + it('throws if input does not contain cursor.id int64', () => { + expect(() => new CursorResponse(BSON.serialize({ ok: 1, cursor: {} }))).to.throw(BSONError); + }); + + it('sets namespace to null if input does not contain cursor.ns', () => { + expect(new CursorResponse(BSON.serialize({ ok: 1, cursor: { id: 0n, firstBatch: [] } })).ns) + .to.be.null; + }); + + it('throws if input does not contain firstBatch nor nextBatch', () => { + expect( + () => new CursorResponse(BSON.serialize({ ok: 1, cursor: { id: 0n, batch: [] } })) + ).to.throw(MongoUnexpectedServerResponseError); + }); + + it('reports a length equal to the batch', () => { + expect( + new CursorResponse(BSON.serialize({ ok: 1, cursor: { id: 0n, nextBatch: [1, 2, 3] } })) + ).to.have.lengthOf(3); + }); + }); + + describe('shift()', () => { + let response; + + beforeEach(async function () { + response = new CursorResponse( + BSON.serialize({ + ok: 1, + cursor: { id: 0n, nextBatch: [{ _id: 1 }, { _id: 2 }, { _id: 3 }] } + }) + ); + }); + + it('returns a document from the batch', () => { + expect(response.shift()).to.deep.equal({ _id: 1 }); + expect(response.shift()).to.deep.equal({ _id: 2 }); + expect(response.shift()).to.deep.equal({ _id: 3 }); + expect(response.shift()).to.deep.equal(null); + }); + + it('passes BSON options to deserialization', () => { + expect(response.shift({ promoteValues: false })).to.deep.equal({ _id: new Int32(1) }); + expect(response.shift({ promoteValues: true })).to.deep.equal({ _id: 2 }); + expect(response.shift({ promoteValues: false })).to.deep.equal({ _id: new Int32(3) }); + expect(response.shift()).to.deep.equal(null); + }); + }); + + describe('clear()', () => { + let response; + + beforeEach(async function () { + response = new CursorResponse( + BSON.serialize({ + ok: 1, + cursor: { id: 0n, nextBatch: [{ _id: 1 }, { _id: 2 }, { _id: 3 }] } + }) + ); + }); + + it('makes length equal to 0', () => { + expect(response.clear()).to.be.undefined; + expect(response).to.have.lengthOf(0); + }); + + it('makes shift return null', () => { + expect(response.clear()).to.be.undefined; + expect(response.shift()).to.be.null; + }); + }); + + describe('pushMany()', () => + it('throws unsupported error', () => + expect(CursorResponse.prototype.pushMany).to.throw(/Unsupported/i))); + + describe('push()', () => + it('throws unsupported error', () => + expect(CursorResponse.prototype.push).to.throw(/Unsupported/i))); +}); diff --git a/test/unit/operations/find.test.ts b/test/unit/operations/find.test.ts index bfb67d8e81..f208636d23 100644 --- a/test/unit/operations/find.test.ts +++ b/test/unit/operations/find.test.ts @@ -18,7 +18,7 @@ describe('FindOperation', function () { }); describe('#constructor', function () { - const operation = new FindOperation(undefined, namespace, filter, options); + const operation = new FindOperation(namespace, filter, options); it('sets the namespace', function () { expect(operation.ns).to.deep.equal(namespace); @@ -40,7 +40,7 @@ describe('FindOperation', function () { const server = new Server(topology, new ServerDescription('a:1'), {} as any); it('should build basic find command with filter', async () => { - const findOperation = new FindOperation(undefined, namespace, filter); + const findOperation = new FindOperation(namespace, filter); const stub = sinon.stub(server, 'command').resolves({}); await findOperation.execute(server, undefined); expect(stub).to.have.been.calledOnceWith(namespace, { @@ -53,7 +53,7 @@ describe('FindOperation', function () { const options = { oplogReplay: true }; - const findOperation = new FindOperation(undefined, namespace, {}, options); + const findOperation = new FindOperation(namespace, {}, options); const stub = sinon.stub(server, 'command').resolves({}); await findOperation.execute(server, undefined); expect(stub).to.have.been.calledOnceWith( diff --git a/test/unit/utils.test.ts b/test/unit/utils.test.ts index 802b9bc564..0184d44c5b 100644 --- a/test/unit/utils.test.ts +++ b/test/unit/utils.test.ts @@ -7,6 +7,7 @@ import { HostAddress, hostMatchesWildcards, isHello, + isUint8Array, LEGACY_HELLO_COMMAND, List, matchesParentDomain, @@ -981,4 +982,29 @@ describe('driver utils', function () { }); }); }); + + describe('isUint8Array()', () => { + describe('when given a UintArray', () => + it('returns true', () => expect(isUint8Array(Uint8Array.from([1]))).to.be.true)); + + describe('when given a Buffer', () => + it('returns true', () => expect(isUint8Array(Buffer.from([1]))).to.be.true)); + + describe('when given a value that does not have `Uint8Array` at Symbol.toStringTag', () => { + it('returns false', () => { + const weirdArray = Uint8Array.from([1]); + Object.defineProperty(weirdArray, Symbol.toStringTag, { value: 'blah' }); + expect(isUint8Array(weirdArray)).to.be.false; + }); + }); + + describe('when given null', () => + it('returns false', () => expect(isUint8Array(null)).to.be.false)); + + describe('when given a non object', () => + it('returns false', () => expect(isUint8Array('')).to.be.false)); + + describe('when given an object that does not respond to Symbol.toStringTag', () => + it('returns false', () => expect(isUint8Array(Object.create(null))).to.be.false)); + }); });