Skip to content

Commit

Permalink
feat(api-graphql): pass authToken via subprotocol (aws-amplify#13727)
Browse files Browse the repository at this point in the history
  • Loading branch information
iartemiev authored Aug 27, 2024
1 parent 798c135 commit ced891c
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 53 deletions.
88 changes: 82 additions & 6 deletions packages/api-graphql/__tests__/AWSAppSyncRealTimeProvider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ describe('AWSAppSyncRealTimeProvider', () => {

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'ws://localhost:8080/realtime?header=&payload=e30=',
'graphql-ws',
'ws://localhost:8080/realtime',
['graphql-ws', 'header-'],
);
});

Expand All @@ -271,8 +271,8 @@ describe('AWSAppSyncRealTimeProvider', () => {

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://localhost:8080/realtime?header=&payload=e30=',
'graphql-ws',
'wss://localhost:8080/realtime',
['graphql-ws', 'header-'],
);
});

Expand All @@ -298,8 +298,84 @@ describe('AWSAppSyncRealTimeProvider', () => {

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=&payload=e30=',
'graphql-ws',
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql',
['graphql-ws', 'header-'],
);
});

test('subscription generates expected auth token', async () => {
expect.assertions(1);

const newSocketSpy = jest
.spyOn(provider, 'getNewWebSocket')
.mockImplementation(() => {
fakeWebSocketInterface.newWebSocket();
return fakeWebSocketInterface.webSocket;
});

provider
.subscribe({
appSyncGraphqlEndpoint:
'https://testaccounturl123456789123.appsync-api.us-east-1.amazonaws.com/graphql',
// using custom auth instead of apiKey, because the latter inserts a timestamp header => expected value changes
authenticationType: 'lambda',
additionalHeaders: {
Authorization: 'my-custom-auth-token',
},
})
.subscribe({ error: () => {} });

// Wait for the socket to be initialize
await fakeWebSocketInterface.readyForUse;

/*
Regular base64 encoding of auth header {"Authorization":"my-custom-auth-token","host":"testaccounturl123456789123.appsync-api.us-east-1.amazonaws.com"}
Is: `eyJBdXRob3JpemF0aW9uIjoibXktY3VzdG9tLWF1dGgtdG9rZW4iLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ==`
(note `==` at the end of the string)
base64url encoding is expected to drop padding chars `=`
*/

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql',
[
'graphql-ws',
'header-eyJBdXRob3JpemF0aW9uIjoibXktY3VzdG9tLWF1dGgtdG9rZW4iLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ',
],
);
});

test('subscription generates expected auth token - custom domain', async () => {
expect.assertions(1);

const newSocketSpy = jest
.spyOn(provider, 'getNewWebSocket')
.mockImplementation(() => {
fakeWebSocketInterface.newWebSocket();
return fakeWebSocketInterface.webSocket;
});

provider
.subscribe({
appSyncGraphqlEndpoint: 'https://unit-test.testurl.com',
// using custom auth instead of apiKey, because the latter inserts a timestamp header => expected value changes
authenticationType: 'lambda',
additionalHeaders: {
Authorization: 'my-custom-auth-token',
},
})
.subscribe({ error: () => {} });

// Wait for the socket to be initialize
await fakeWebSocketInterface.readyForUse;

expect(newSocketSpy).toHaveBeenNthCalledWith(
1,
'wss://unit-test.testurl.com/realtime',
[
'graphql-ws',
'header-eyJBdXRob3JpemF0aW9uIjoibXktY3VzdG9tLWF1dGgtdG9rZW4iLCJob3N0IjoidW5pdC10ZXN0LnRlc3R1cmwuY29tIn0',
],
);
});

Expand Down
2 changes: 1 addition & 1 deletion packages/api-graphql/__tests__/GraphQLAPI.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1594,7 +1594,7 @@ describe('API test', () => {
`;

const resolvedUrl =
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?header=eyJBdXRob3JpemF0aW9uIjoiYWJjMTIzNDUiLCJob3N0IjoidGVzdGFjY291bnR1cmwxMjM0NTY3ODkxMjMuYXBwc3luYy1hcGkudXMtZWFzdC0xLmFtYXpvbmF3cy5jb20ifQ==&payload=e30=&x-amz-user-agent=aws-amplify%2F6.4.0%20api%2F1%20framework%2F2&ex-machina=is%20a%20good%20movie';
'wss://testaccounturl123456789123.appsync-realtime-api.us-east-1.amazonaws.com/graphql?x-amz-user-agent=aws-amplify%2F6.4.0+api%2F1+framework%2F2&ex-machina=is+a+good+movie';

(
client.graphql(
Expand Down
123 changes: 85 additions & 38 deletions packages/api-graphql/src/Providers/AWSAppSyncRealTimeProvider/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
import { signRequest } from '@aws-amplify/core/internals/aws-client-utils';
import {
AmplifyUrl,
AmplifyUrlSearchParams,
CustomUserAgentDetails,
DocumentType,
GraphQLAuthMode,
Expand Down Expand Up @@ -181,7 +182,7 @@ export class AWSAppSyncRealTimeProvider {
this.reconnectionMonitor.close();
}

getNewWebSocket(url: string, protocol: string) {
getNewWebSocket(url: string, protocol: string[]) {
return new WebSocket(url, protocol);
}

Expand Down Expand Up @@ -734,20 +735,63 @@ export class AWSAppSyncRealTimeProvider {
/**
*
* @param headers - http headers
* @returns query string of uri-encoded parameters derived from custom headers
* @returns uri-encoded query parameters derived from custom headers
*/
private _queryStringFromCustomHeaders(
private _queryParamsFromCustomHeaders(
headers?: AWSAppSyncRealTimeProviderOptions['additionalCustomHeaders'],
): string {
): URLSearchParams {
const nonAuthHeaders = this._extractNonAuthHeaders(headers);

const queryParams: string[] = Object.entries(nonAuthHeaders).map(
([key, val]) => `${encodeURIComponent(key)}=${encodeURIComponent(val)}`,
const params = new AmplifyUrlSearchParams();

Object.entries(nonAuthHeaders).forEach(([k, v]) => {
params.append(k, v);
});

return params;
}

/**
* Normalizes AppSync realtime endpoint URL
*
* @param appSyncGraphqlEndpoint - AppSync endpointUri from config
* @param urlParams - URLSearchParams
* @returns fully resolved string realtime endpoint URL
*/
private _realtimeUrlWithQueryString(
appSyncGraphqlEndpoint: string | undefined,
urlParams: URLSearchParams,
): string {
const protocol = 'wss://';

let realtimeEndpoint = appSyncGraphqlEndpoint ?? '';

if (this.isCustomDomain(realtimeEndpoint)) {
realtimeEndpoint = realtimeEndpoint.concat(customDomainPath);
} else {
realtimeEndpoint = realtimeEndpoint
.replace('appsync-api', 'appsync-realtime-api')
.replace('gogi-beta', 'grt-beta');
}

realtimeEndpoint = realtimeEndpoint
.replace('https://', protocol)
.replace('http://', protocol);

const realtimeEndpointUrl = new AmplifyUrl(realtimeEndpoint);

// preserves any query params a customer might manually set in the configuration
const existingParams = new AmplifyUrlSearchParams(
realtimeEndpointUrl.search,
);

const queryString = queryParams.join('&');
for (const [k, v] of urlParams.entries()) {
existingParams.append(k, v);
}

return queryString;
realtimeEndpointUrl.search = existingParams.toString();

return realtimeEndpointUrl.toString();
}

private _initializeWebSocketConnection({
Expand Down Expand Up @@ -783,38 +827,27 @@ export class AWSAppSyncRealTimeProvider {
});

const headerString = authHeader ? JSON.stringify(authHeader) : '';
const headerQs = base64Encoder.convert(headerString);
// base64url-encoded string
const encodedHeader = base64Encoder.convert(headerString, {
urlSafe: true,
skipPadding: true,
});

const payloadQs = base64Encoder.convert(payloadString);
const authTokenSubprotocol = `header-${encodedHeader}`;

const queryString = this._queryStringFromCustomHeaders(
const queryParams = this._queryParamsFromCustomHeaders(
additionalCustomHeaders,
);

let discoverableEndpoint = appSyncGraphqlEndpoint ?? '';

if (this.isCustomDomain(discoverableEndpoint)) {
discoverableEndpoint =
discoverableEndpoint.concat(customDomainPath);
} else {
discoverableEndpoint = discoverableEndpoint
.replace('appsync-api', 'appsync-realtime-api')
.replace('gogi-beta', 'grt-beta');
}

// Creating websocket url with required query strings
const protocol = 'wss://';
discoverableEndpoint = discoverableEndpoint
.replace('https://', protocol)
.replace('http://', protocol);

let awsRealTimeUrl = `${discoverableEndpoint}?header=${headerQs}&payload=${payloadQs}`;

if (queryString !== '') {
awsRealTimeUrl += `&${queryString}`;
}
const awsRealTimeUrl = this._realtimeUrlWithQueryString(
appSyncGraphqlEndpoint,
queryParams,
);

await this._initializeRetryableHandshake(awsRealTimeUrl);
await this._initializeRetryableHandshake(
awsRealTimeUrl,
authTokenSubprotocol,
);

this.promiseArray.forEach(({ res }) => {
logger.debug('Notifying connection successful');
Expand All @@ -841,23 +874,37 @@ export class AWSAppSyncRealTimeProvider {
});
}

private async _initializeRetryableHandshake(awsRealTimeUrl: string) {
private async _initializeRetryableHandshake(
awsRealTimeUrl: string,
subprotocol: string,
) {
logger.debug(`Initializaling retryable Handshake`);
await jitteredExponentialRetry(
this._initializeHandshake.bind(this),
[awsRealTimeUrl],
[awsRealTimeUrl, subprotocol],
MAX_DELAY_MS,
);
}

private async _initializeHandshake(awsRealTimeUrl: string) {
/**
*
* @param subprotocol -
*/
private async _initializeHandshake(
awsRealTimeUrl: string,
subprotocol: string,
) {
logger.debug(`Initializing handshake ${awsRealTimeUrl}`);
// Because connecting the socket is async, is waiting until connection is open
// Step 1: connect websocket
try {
await (() => {
return new Promise<void>((resolve, reject) => {
const newSocket = this.getNewWebSocket(awsRealTimeUrl, 'graphql-ws');
const newSocket = this.getNewWebSocket(awsRealTimeUrl, [
'graphql-ws',
subprotocol,
]);

newSocket.onerror = () => {
logger.debug(`WebSocket connection error`);
};
Expand Down
2 changes: 1 addition & 1 deletion packages/aws-amplify/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@
"name": "[API] generateClient (AppSync)",
"path": "./dist/esm/api/index.mjs",
"import": "{ generateClient }",
"limit": "41 kB"
"limit": "41.5 kB"
},
{
"name": "[API] REST API handlers",
Expand Down
8 changes: 8 additions & 0 deletions packages/core/__tests__/utils/convert/base64Encoder.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,12 @@ describe('base64Encoder (non-native)', () => {
'test-test_test',
);
});

it('makes the result a base64url string with no padding chars', () => {
const mockResult = 'test+test/test=='; // = is the base64 padding char
mockBtoa.mockReturnValue(mockResult);
expect(
base64Encoder.convert('test', { urlSafe: true, skipPadding: true }),
).toBe('test-test_test');
});
});
33 changes: 26 additions & 7 deletions packages/core/src/utils/convert/base64/base64Encoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,37 @@
// SPDX-License-Identifier: Apache-2.0

import { getBtoa } from '../../globalHelpers';
import { Base64Encoder } from '../types';
import type { Base64Encoder, Base64EncoderConvertOptions } from '../types';

import { bytesToString } from './bytesToString';

export const base64Encoder: Base64Encoder = {
convert(input, { urlSafe } = { urlSafe: false }) {
/**
* Convert input to base64-encoded string
* @param input - string to convert to base64
* @param options - encoding options that can optionally produce a base64url string
* @returns base64-encoded string
*/
convert(
input,
options: Base64EncoderConvertOptions = {
urlSafe: false,
skipPadding: false,
},
) {
const inputStr = typeof input === 'string' ? input : bytesToString(input);
const encodedStr = getBtoa()(inputStr);
let encodedStr = getBtoa()(inputStr);

// see details about the char replacing at https://datatracker.ietf.org/doc/html/rfc4648#section-5
return urlSafe
? encodedStr.replace(/\+/g, '-').replace(/\//g, '_')
: encodedStr;
// urlSafe char replacement and skipPadding options conform to the base64url spec
// https://datatracker.ietf.org/doc/html/rfc4648#section-5
if (options.urlSafe) {
encodedStr = encodedStr.replace(/\+/g, '-').replace(/\//g, '_');
}

if (options.skipPadding) {
encodedStr = encodedStr.replace(/=/g, '');
}

return encodedStr;
},
};
1 change: 1 addition & 0 deletions packages/core/src/utils/convert/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

export interface Base64EncoderConvertOptions {
urlSafe: boolean;
skipPadding?: boolean;
}

export interface Base64Encoder {
Expand Down

0 comments on commit ced891c

Please sign in to comment.