Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(NODE-5550): set AWS region from environment variable for STSClient #3851

Merged
merged 2 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 52 additions & 12 deletions src/cmap/auth/mongodb_aws.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as crypto from 'crypto';
import * as process from 'process';
import { promisify } from 'util';

import type { Binary, BSONSerializeOptions } from '../../bson';
Expand All @@ -15,6 +16,28 @@ import { type AuthContext, AuthProvider } from './auth_provider';
import { MongoCredentials } from './mongo_credentials';
import { AuthMechanism } from './providers';

/**
* The following regions use the global AWS STS endpoint, sts.amazonaws.com, by default
* https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html
*/
const LEGACY_REGIONS = new Set([
'ap-northeast-1',
'ap-south-1',
'ap-southeast-1',
'ap-southeast-2',
'aws-global',
'ca-central-1',
'eu-central-1',
'eu-north-1',
'eu-west-1',
'eu-west-2',
'eu-west-3',
'sa-east-1',
'us-east-1',
'us-east-2',
'us-west-1',
'us-west-2'
]);
const ASCII_N = 110;
const AWS_RELATIVE_URI = 'http://169.254.170.2';
const AWS_EC2_URI = 'http://169.254.169.254';
Expand All @@ -34,6 +57,7 @@ interface AWSSaslContinuePayload {
}

export class MongoDBAWS extends AuthProvider {
static credentialProvider: ReturnType<typeof getAwsCredentialProvider> | null = null;
randomBytesAsync: (size: number) => Promise<Buffer>;

constructor() {
Expand Down Expand Up @@ -157,14 +181,6 @@ interface AWSTempCredentials {
Expiration?: Date;
}

/* @internal */
export interface AWSCredentials {
accessKeyId?: string;
secretAccessKey?: string;
sessionToken?: string;
expiration?: Date;
}

async function makeTempCredentials(credentials: MongoCredentials): Promise<MongoCredentials> {
function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) {
if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) {
Expand All @@ -182,11 +198,11 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
});
}

const credentialProvider = getAwsCredentialProvider();
MongoDBAWS.credentialProvider ??= getAwsCredentialProvider();

// Check if the AWS credential provider from the SDK is present. If not,
// use the old method.
if ('kModuleError' in credentialProvider) {
if ('kModuleError' in MongoDBAWS.credentialProvider) {
// If the environment variable AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
// is set then drivers MUST assume that it was set by an AWS ECS agent
if (process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI) {
Expand Down Expand Up @@ -217,6 +233,32 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo

return makeMongoCredentialsFromAWSTemp(creds);
} else {
let { AWS_STS_REGIONAL_ENDPOINTS = '', AWS_REGION = '' } = process.env;
AWS_STS_REGIONAL_ENDPOINTS = AWS_STS_REGIONAL_ENDPOINTS.toLowerCase();
AWS_REGION = AWS_REGION.toLowerCase();

/** The option setting should work only for users who have explicit settings in their environment, the driver should not encode "defaults" */
const awsRegionSettingsExist =
AWS_REGION.length !== 0 && AWS_STS_REGIONAL_ENDPOINTS.length !== 0;

/**
* If AWS_STS_REGIONAL_ENDPOINTS is set to regional, users are opting into the new behavior of respecting the region settings
*
* If AWS_STS_REGIONAL_ENDPOINTS is set to legacy, then "old" regions need to keep using the global setting.
* Technically the SDK gets this wrong, it reaches out to 'sts.us-east-1.amazonaws.com' when it should be 'sts.amazonaws.com'.
* That is not our bug to fix here. We leave that up to the SDK.
*/
const useRegionalSts =
AWS_STS_REGIONAL_ENDPOINTS === 'regional' ||
(AWS_STS_REGIONAL_ENDPOINTS === 'legacy' && !LEGACY_REGIONS.has(AWS_REGION));

const provider =
awsRegionSettingsExist && useRegionalSts
? MongoDBAWS.credentialProvider.fromNodeProviderChain({
clientConfig: { region: AWS_REGION }
})
: MongoDBAWS.credentialProvider.fromNodeProviderChain();

/*
* Creates a credential provider that will attempt to find credentials from the
* following sources (listed in order of precedence):
Expand All @@ -227,8 +269,6 @@ async function makeTempCredentials(credentials: MongoCredentials): Promise<Mongo
* - Shared credentials and config ini files
* - The EC2/ECS Instance Metadata Service
*/
const { fromNodeProviderChain } = credentialProvider;
const provider = fromNodeProviderChain();
try {
const creds = await provider();
return makeMongoCredentialsFromAWSTemp({
Expand Down
17 changes: 16 additions & 1 deletion src/deps.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/* eslint-disable @typescript-eslint/no-var-requires */
import type { Document } from './bson';
import type { AWSCredentials } from './cmap/auth/mongodb_aws';
import type { ProxyOptions } from './cmap/connection';
import { MongoMissingDependencyError } from './error';
import type { MongoClient } from './mongo_client';
Expand Down Expand Up @@ -76,7 +75,23 @@ export function getZstdLibrary(): typeof ZStandard | { kModuleError: MongoMissin
}
}

/**
* @internal
* Copy of the AwsCredentialIdentityProvider interface from [`smithy/types`](https://socket.dev/npm/package/\@smithy/types/files/1.1.1/dist-types/identity/awsCredentialIdentity.d.ts),
* the return type of the aws-sdk's `fromNodeProviderChain().provider()`.
*/
export interface AWSCredentials {
accessKeyId: string;
secretAccessKey: string;
sessionToken: string;
expiration?: Date;
}

type CredentialProvider = {
fromNodeProviderChain(
this: void,
options: { clientConfig: { region: string } }
): () => Promise<AWSCredentials>;
fromNodeProviderChain(this: void): () => Promise<AWSCredentials>;
};

Expand Down
150 changes: 149 additions & 1 deletion test/integration/auth/mongodb_aws.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as http from 'http';
import { performance } from 'perf_hooks';
import * as sinon from 'sinon';

import { MongoAWSError, type MongoClient, MongoServerError } from '../../mongodb';
import { MongoAWSError, type MongoClient, MongoDBAWS, MongoServerError } from '../../mongodb';

describe('MONGODB-AWS', function () {
let client: MongoClient;
Expand Down Expand Up @@ -88,4 +88,152 @@ describe('MONGODB-AWS', function () {
expect(timeTaken).to.be.below(12000);
});
});

describe('when using AssumeRoleWithWebIdentity', () => {
const tests = [
{
ctx: 'when no AWS region settings are set',
title: 'uses the default region',
env: {
AWS_STS_REGIONAL_ENDPOINTS: undefined,
AWS_REGION: undefined
},
calledWith: []
},
{
ctx: 'when only AWS_STS_REGIONAL_ENDPOINTS is set',
title: 'uses the default region',
env: {
AWS_STS_REGIONAL_ENDPOINTS: 'regional',
AWS_REGION: undefined
},
calledWith: []
},
{
ctx: 'when only AWS_REGION is set',
title: 'uses the default region',
env: {
AWS_STS_REGIONAL_ENDPOINTS: undefined,
AWS_REGION: 'us-west-2'
},
calledWith: []
},

{
ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to regional and region is legacy',
title: 'uses the region from the environment',
env: {
AWS_STS_REGIONAL_ENDPOINTS: 'regional',
AWS_REGION: 'us-west-2'
},
calledWith: [{ clientConfig: { region: 'us-west-2' } }]
},
{
ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to regional and region is new',
title: 'uses the region from the environment',
env: {
AWS_STS_REGIONAL_ENDPOINTS: 'regional',
AWS_REGION: 'sa-east-1'
},
calledWith: [{ clientConfig: { region: 'sa-east-1' } }]
},

{
ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to legacy and region is legacy',
title: 'uses the region from the environment',
env: {
AWS_STS_REGIONAL_ENDPOINTS: 'legacy',
AWS_REGION: 'us-west-2'
},
calledWith: []
},
{
ctx: 'when AWS_STS_REGIONAL_ENDPOINTS is set to legacy and region is new',
title: 'uses the default region',
env: {
AWS_STS_REGIONAL_ENDPOINTS: 'legacy',
AWS_REGION: 'sa-east-1'
},
calledWith: []
}
];

for (const test of tests) {
context(test.ctx, () => {
let credentialProvider;
let storedEnv;
let calledArguments;
let shouldSkip = false;

const envCheck = () => {
const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env;
credentialProvider = (() => {
try {
return require('@aws-sdk/credential-providers');
} catch {
return null;
}
})();
return AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || credentialProvider == null;
};

beforeEach(function () {
shouldSkip = envCheck();
if (shouldSkip) {
this.skipReason = 'only relevant to AssumeRoleWithWebIdentity with SDK installed';
return this.skip();
}

client = this.configuration.newClient(process.env.MONGODB_URI);

storedEnv = process.env;
if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) {
delete process.env.AWS_STS_REGIONAL_ENDPOINTS;
} else {
process.env.AWS_STS_REGIONAL_ENDPOINTS = test.env.AWS_STS_REGIONAL_ENDPOINTS;
}
if (test.env.AWS_REGION === undefined) {
delete process.env.AWS_REGION;
} else {
process.env.AWS_REGION = test.env.AWS_REGION;
}

calledArguments = [];
MongoDBAWS.credentialProvider = {
fromNodeProviderChain(...args) {
calledArguments = args;
return credentialProvider.fromNodeProviderChain(...args);
}
};
});

afterEach(() => {
if (shouldSkip) {
return;
}
if (typeof storedEnv.AWS_STS_REGIONAL_ENDPOINTS === 'string') {
process.env.AWS_STS_REGIONAL_ENDPOINTS = storedEnv.AWS_STS_REGIONAL_ENDPOINTS;
}
if (typeof storedEnv.AWS_STS_REGIONAL_ENDPOINTS === 'string') {
process.env.AWS_REGION = storedEnv.AWS_REGION;
}
MongoDBAWS.credentialProvider = credentialProvider;
calledArguments = [];
});

it(test.title, async function () {
const result = await client
.db('aws')
.collection('aws_test')
.estimatedDocumentCount()
.catch(error => error);

expect(result).to.not.be.instanceOf(MongoServerError);
expect(result).to.be.a('number');

expect(calledArguments).to.deep.equal(test.calledWith);
});
});
}
});
});