diff --git a/src/cmap/auth/mongodb_aws.ts b/src/cmap/auth/mongodb_aws.ts index 57e3a028ff..cfaf8e6f9b 100644 --- a/src/cmap/auth/mongodb_aws.ts +++ b/src/cmap/auth/mongodb_aws.ts @@ -1,4 +1,5 @@ import * as crypto from 'crypto'; +import * as process from 'process'; import { promisify } from 'util'; import type { Binary, BSONSerializeOptions } from '../../bson'; @@ -15,6 +16,28 @@ import { type AuthContext, AuthProvider } from './auth_provider'; import { MongoCredentials } from './mongo_credentials'; import { AuthMechanism } from './providers'; +/** + * The following regions use the global AWS STS endpoint, sts.amazonaws.com, by default + * https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html + */ +const LEGACY_REGIONS = new Set([ + 'ap-northeast-1', + 'ap-south-1', + 'ap-southeast-1', + 'ap-southeast-2', + 'aws-global', + 'ca-central-1', + 'eu-central-1', + 'eu-north-1', + 'eu-west-1', + 'eu-west-2', + 'eu-west-3', + 'sa-east-1', + 'us-east-1', + 'us-east-2', + 'us-west-1', + 'us-west-2' +]); const ASCII_N = 110; const AWS_RELATIVE_URI = 'http://169.254.170.2'; const AWS_EC2_URI = 'http://169.254.169.254'; @@ -34,6 +57,7 @@ interface AWSSaslContinuePayload { } export class MongoDBAWS extends AuthProvider { + static credentialProvider: ReturnType | null = null; randomBytesAsync: (size: number) => Promise; constructor() { @@ -157,14 +181,6 @@ interface AWSTempCredentials { Expiration?: Date; } -/* @internal */ -export interface AWSCredentials { - accessKeyId?: string; - secretAccessKey?: string; - sessionToken?: string; - expiration?: Date; -} - async function makeTempCredentials(credentials: MongoCredentials): Promise { function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) { if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) { @@ -182,11 +198,11 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise Promise; fromNodeProviderChain(this: void): () => Promise; }; diff --git a/test/integration/auth/mongodb_aws.test.ts b/test/integration/auth/mongodb_aws.test.ts index 81eb0373c5..287c9b50c6 100644 --- a/test/integration/auth/mongodb_aws.test.ts +++ b/test/integration/auth/mongodb_aws.test.ts @@ -3,7 +3,7 @@ import * as http from 'http'; import { performance } from 'perf_hooks'; import * as sinon from 'sinon'; -import { MongoAWSError, type MongoClient, MongoServerError } from '../../mongodb'; +import { MongoAWSError, type MongoClient, MongoDBAWS, MongoServerError } from '../../mongodb'; describe('MONGODB-AWS', function () { let client: MongoClient; @@ -88,4 +88,152 @@ describe('MONGODB-AWS', function () { expect(timeTaken).to.be.below(12000); }); }); + + describe('when using AssumeRoleWithWebIdentity', () => { + const tests = [ + { + ctx: 'when no AWS region settings are set', + title: 'uses the default region', + env: { + AWS_STS_REGIONAL_ENDPOINTS: undefined, + AWS_REGION: undefined + }, + calledWith: [] + }, + { + ctx: 'when only AWS_STS_REGIONAL_ENDPOINTS is set', + title: 'uses the default region', + env: { + AWS_STS_REGIONAL_ENDPOINTS: 'regional', + AWS_REGION: undefined + }, + calledWith: [] + }, + { + ctx: 'when only AWS_REGION is set', + title: 'uses the default region', + env: { + AWS_STS_REGIONAL_ENDPOINTS: undefined, + AWS_REGION: 'us-west-2' + }, + calledWith: [] + }, + + { + ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to regional and region is legacy', + title: 'uses the region from the environment', + env: { + AWS_STS_REGIONAL_ENDPOINTS: 'regional', + AWS_REGION: 'us-west-2' + }, + calledWith: [{ clientConfig: { region: 'us-west-2' } }] + }, + { + ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to regional and region is new', + title: 'uses the region from the environment', + env: { + AWS_STS_REGIONAL_ENDPOINTS: 'regional', + AWS_REGION: 'sa-east-1' + }, + calledWith: [{ clientConfig: { region: 'sa-east-1' } }] + }, + + { + ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to legacy and region is legacy', + title: 'uses the region from the environment', + env: { + AWS_STS_REGIONAL_ENDPOINTS: 'legacy', + AWS_REGION: 'us-west-2' + }, + calledWith: [] + }, + { + ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to legacy and region is new', + title: 'uses the default region', + env: { + AWS_STS_REGIONAL_ENDPOINTS: 'legacy', + AWS_REGION: 'sa-east-1' + }, + calledWith: [] + } + ]; + + for (const test of tests) { + context(test.ctx, () => { + let credentialProvider; + let storedEnv; + let calledArguments; + let shouldSkip = false; + + const envCheck = () => { + const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env; + credentialProvider = (() => { + try { + return require('@aws-sdk/credential-providers'); + } catch { + return null; + } + })(); + return AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || credentialProvider == null; + }; + + beforeEach(function () { + shouldSkip = envCheck(); + if (shouldSkip) { + this.skipReason = 'only relevant to AssumeRoleWithWebIdentity with SDK installed'; + return this.skip(); + } + + client = this.configuration.newClient(process.env.MONGODB_URI); + + storedEnv = process.env; + if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) { + delete process.env.AWS_STS_REGIONAL_ENDPOINTS; + } else { + process.env.AWS_STS_REGIONAL_ENDPOINTS = test.env.AWS_STS_REGIONAL_ENDPOINTS; + } + if (test.env.AWS_REGION === undefined) { + delete process.env.AWS_REGION; + } else { + process.env.AWS_REGION = test.env.AWS_REGION; + } + + calledArguments = []; + MongoDBAWS.credentialProvider = { + fromNodeProviderChain(...args) { + calledArguments = args; + return credentialProvider.fromNodeProviderChain(...args); + } + }; + }); + + afterEach(() => { + if (shouldSkip) { + return; + } + if (typeof storedEnv.AWS_STS_REGIONAL_ENDPOINTS === 'string') { + process.env.AWS_STS_REGIONAL_ENDPOINTS = storedEnv.AWS_STS_REGIONAL_ENDPOINTS; + } + if (typeof storedEnv.AWS_STS_REGIONAL_ENDPOINTS === 'string') { + process.env.AWS_REGION = storedEnv.AWS_REGION; + } + MongoDBAWS.credentialProvider = credentialProvider; + calledArguments = []; + }); + + it(test.title, async function () { + const result = await client + .db('aws') + .collection('aws_test') + .estimatedDocumentCount() + .catch(error => error); + + expect(result).to.not.be.instanceOf(MongoServerError); + expect(result).to.be.a('number'); + + expect(calledArguments).to.deep.equal(test.calledWith); + }); + }); + } + }); });