diff --git a/packages/pubsub/__tests__/AWSAppSyncRealTimeProvider.test.ts b/packages/pubsub/__tests__/AWSAppSyncRealTimeProvider.test.ts index 3cd152cc32c..d30f4383981 100644 --- a/packages/pubsub/__tests__/AWSAppSyncRealTimeProvider.test.ts +++ b/packages/pubsub/__tests__/AWSAppSyncRealTimeProvider.test.ts @@ -14,8 +14,8 @@ import { Reachability, Credentials, Logger, Signer } from '@aws-amplify/core'; import { Auth } from '@aws-amplify/auth'; import Cache from '@aws-amplify/cache'; -import { MESSAGE_TYPES } from '../src/Providers/AWSAppSyncRealTimeProvider/constants'; -import * as constants from '../src/Providers/AWSAppSyncRealTimeProvider/constants'; +import { MESSAGE_TYPES } from '../src/Providers/constants'; +import * as constants from '../src/Providers/constants'; import { delay, FakeWebSocketInterface, replaceConstant } from './helpers'; import { ConnectionState as CS } from '../src'; diff --git a/packages/pubsub/__tests__/PubSub-unit-test.ts b/packages/pubsub/__tests__/PubSub-unit-test.ts index 622ceefcdf9..c24f3a62646 100644 --- a/packages/pubsub/__tests__/PubSub-unit-test.ts +++ b/packages/pubsub/__tests__/PubSub-unit-test.ts @@ -19,9 +19,19 @@ import { // import Amplify from '../../src/'; import { Credentials, + Hub, INTERNAL_AWS_APPSYNC_PUBSUB_PROVIDER, + Logger, + Reachability, } from '@aws-amplify/core'; import * as Paho from 'paho-mqtt'; +import { + ConnectionState, + ConnectionState, + CONNECTION_STATE_CHANGE, +} from '../src'; +import { HubConnectionListener } from './helpers'; +import Observable from 'zen-observable-ts'; const pahoClientMockCache = {}; @@ -306,6 +316,123 @@ describe('PubSub', () => { expect(originalProvider.publish).not.toHaveBeenCalled(); expect(newProvider.publish).toHaveBeenCalled(); }); + + describe('Hub connection state changes', () => { + let hubConnectionListener: HubConnectionListener; + + let reachabilityObserver: ZenObservable.Observer<{ online: boolean }>; + + beforeEach(() => { + // Maintain the Hub connection listener, used to monitor the connection messages sent through Hub + hubConnectionListener?.teardown(); + hubConnectionListener = new HubConnectionListener('pubsub'); + + // Setup a mock of the reachability monitor where the initial value is online. + const spyon = jest + .spyOn(Reachability.prototype, 'networkMonitor') + .mockImplementationOnce( + () => + new Observable(observer => { + reachabilityObserver = observer; + }) + ); + reachabilityObserver?.next?.({ online: true }); + }); + + test('test happy case connect -> disconnect cycle', async () => { + const pubsub = new PubSub(); + + const awsIotProvider = new AWSIoTProvider({ + aws_pubsub_region: 'region', + aws_pubsub_endpoint: 'wss://iot.mymockendpoint.org:443/notrealmqtt', + }); + pubsub.addPluggable(awsIotProvider); + + const sub = pubsub.subscribe('topic', { clientId: '123' }).subscribe({ + error: () => {}, + }); + + await hubConnectionListener.waitUntilConnectionStateIn(['Connected']); + sub.unsubscribe(); + awsIotProvider.onDisconnect({ errorCode: 1, clientId: '123' }); + await hubConnectionListener.waitUntilConnectionStateIn([ + 'Disconnected', + ]); + expect(hubConnectionListener.observedConnectionStates).toEqual([ + 'Disconnected', + 'Connecting', + 'Connected', + 'ConnectedPendingDisconnect', + 'Disconnected', + ]); + }); + + test('test network disconnection and recovery', async () => { + const pubsub = new PubSub(); + + const awsIotProvider = new AWSIoTProvider({ + aws_pubsub_region: 'region', + aws_pubsub_endpoint: 'wss://iot.mymockendpoint.org:443/notrealmqtt', + }); + pubsub.addPluggable(awsIotProvider); + + const sub = pubsub.subscribe('topic', { clientId: '123' }).subscribe({ + error: () => {}, + }); + + await hubConnectionListener.waitUntilConnectionStateIn(['Connected']); + + reachabilityObserver?.next?.({ online: false }); + await hubConnectionListener.waitUntilConnectionStateIn([ + 'ConnectedPendingNetwork', + ]); + + reachabilityObserver?.next?.({ online: true }); + await hubConnectionListener.waitUntilConnectionStateIn(['Connected']); + + expect(hubConnectionListener.observedConnectionStates).toEqual([ + 'Disconnected', + 'Connecting', + 'Connected', + 'ConnectedPendingNetwork', + 'Connected', + ]); + }); + + test('test network disconnection followed by connection disruption', async () => { + const pubsub = new PubSub(); + + const awsIotProvider = new AWSIoTProvider({ + aws_pubsub_region: 'region', + aws_pubsub_endpoint: 'wss://iot.mymockendpoint.org:443/notrealmqtt', + }); + pubsub.addPluggable(awsIotProvider); + + const sub = pubsub.subscribe('topic', { clientId: '123' }).subscribe({ + error: () => {}, + }); + + await hubConnectionListener.waitUntilConnectionStateIn(['Connected']); + + reachabilityObserver?.next?.({ online: false }); + await hubConnectionListener.waitUntilConnectionStateIn([ + 'ConnectedPendingNetwork', + ]); + + awsIotProvider.onDisconnect({ errorCode: 1, clientId: '123' }); + await hubConnectionListener.waitUntilConnectionStateIn([ + 'Disconnected', + ]); + + expect(hubConnectionListener.observedConnectionStates).toEqual([ + 'Disconnected', + 'Connecting', + 'Connected', + 'ConnectedPendingNetwork', + 'Disconnected', + ]); + }); + }); }); describe('MqttOverWSProvider local testing config', () => { diff --git a/packages/pubsub/__tests__/helpers.ts b/packages/pubsub/__tests__/helpers.ts index 959522341be..f4226c6bfdf 100644 --- a/packages/pubsub/__tests__/helpers.ts +++ b/packages/pubsub/__tests__/helpers.ts @@ -1,7 +1,7 @@ import { Hub } from '@aws-amplify/core'; import Observable from 'zen-observable-ts'; import { ConnectionState as CS, CONNECTION_STATE_CHANGE } from '../src'; -import * as constants from '../src/Providers/AWSAppSyncRealTimeProvider/constants'; +import * as constants from '../src/Providers/constants'; export function delay(timeout) { return new Promise(resolve => { @@ -11,28 +11,17 @@ export function delay(timeout) { }); } -export class FakeWebSocketInterface { - readonly webSocket: FakeWebSocket; - readyForUse: Promise; - hasClosed: Promise; +export class HubConnectionListener { teardownHubListener: () => void; observedConnectionStates: CS[] = []; currentConnectionState: CS; - private readyResolve: (value: PromiseLike) => void; private connectionStateObservers: ZenObservable.Observer[] = []; - constructor() { - this.readyForUse = new Promise((res, rej) => { - this.readyResolve = res; - }); + constructor(channel: string) { let closeResolver: (value: PromiseLike) => void; - this.hasClosed = new Promise((res, rej) => { - closeResolver = res; - }); - this.webSocket = new FakeWebSocket(() => closeResolver); - this.teardownHubListener = Hub.listen('api', (data: any) => { + this.teardownHubListener = Hub.listen(channel, (data: any) => { const { payload } = data; if (payload.event === CONNECTION_STATE_CHANGE) { const connectionState = payload.data.connectionState as CS; @@ -65,6 +54,7 @@ export class FakeWebSocketInterface { this.connectionStateObservers.push(observer); }); } + /** * Tear down the Fake Socket state */ @@ -75,6 +65,62 @@ export class FakeWebSocketInterface { }); } + async waitForConnectionState(connectionStates: CS[]) { + return new Promise((res, rej) => { + this.connectionStateObserver().subscribe(value => { + if (connectionStates.includes(String(value) as CS)) { + res(undefined); + } + }); + }); + } + + async waitUntilConnectionStateIn(connectionStates: CS[]) { + return new Promise((res, rej) => { + if (connectionStates.includes(this.currentConnectionState)) { + res(undefined); + } + res(this.waitForConnectionState(connectionStates)); + }); + } +} + +export class FakeWebSocketInterface { + readonly webSocket: FakeWebSocket; + readyForUse: Promise; + hasClosed: Promise; + hubConnectionListener: HubConnectionListener; + + private readyResolve: (value: PromiseLike) => void; + + constructor() { + this.hubConnectionListener = new HubConnectionListener('api'); + this.readyForUse = new Promise((res, rej) => { + this.readyResolve = res; + }); + let closeResolver: (value: PromiseLike) => void; + this.hasClosed = new Promise((res, rej) => { + closeResolver = res; + }); + this.webSocket = new FakeWebSocket(() => closeResolver); + } + + get observedConnectionStates() { + return this.hubConnectionListener.observedConnectionStates; + } + + allConnectionStateObserver() { + return this.hubConnectionListener.allConnectionStateObserver(); + } + + connectionStateObserver() { + return this.hubConnectionListener.connectionStateObserver(); + } + + teardown() { + this.hubConnectionListener.teardown(); + } + /** * Once ready for use, send onOpen and the connection_ack */ @@ -207,25 +253,16 @@ export class FakeWebSocketInterface { * @returns a Promise that will wait for one of the provided states to be observed */ async waitForConnectionState(connectionStates: CS[]) { - return new Promise((res, rej) => { - this.connectionStateObserver().subscribe(value => { - if (connectionStates.includes(String(value) as CS)) { - res(undefined); - } - }); - }); + return this.hubConnectionListener.waitForConnectionState(connectionStates); } /** * @returns a Promise that will wait until the current state is one of the provided states */ async waitUntilConnectionStateIn(connectionStates: CS[]) { - return new Promise((res, rej) => { - if (connectionStates.includes(this.currentConnectionState)) { - res(undefined); - } - res(this.waitForConnectionState(connectionStates)); - }); + return this.hubConnectionListener.waitUntilConnectionStateIn( + connectionStates + ); } } diff --git a/packages/pubsub/src/Providers/AWSAppSyncRealTimeProvider/index.ts b/packages/pubsub/src/Providers/AWSAppSyncRealTimeProvider/index.ts index 0049a79c9ac..da02532739e 100644 --- a/packages/pubsub/src/Providers/AWSAppSyncRealTimeProvider/index.ts +++ b/packages/pubsub/src/Providers/AWSAppSyncRealTimeProvider/index.ts @@ -44,7 +44,7 @@ import { SOCKET_STATUS, START_ACK_TIMEOUT, SUBSCRIPTION_STATUS, -} from './constants'; +} from '../constants'; import { ConnectionStateMonitor, CONNECTION_CHANGE, diff --git a/packages/pubsub/src/Providers/MqttOverWSProvider.ts b/packages/pubsub/src/Providers/MqttOverWSProvider.ts index 4b92252e0f4..0a63cc1a9c0 100644 --- a/packages/pubsub/src/Providers/MqttOverWSProvider.ts +++ b/packages/pubsub/src/Providers/MqttOverWSProvider.ts @@ -16,7 +16,13 @@ import Observable from 'zen-observable-ts'; import { AbstractPubSubProvider } from './PubSubProvider'; import { ProviderOptions, SubscriptionObserver } from '../types'; -import { ConsoleLogger as Logger } from '@aws-amplify/core'; +import { ConsoleLogger as Logger, Hub } from '@aws-amplify/core'; +import { + ConnectionStateMonitor, + CONNECTION_CHANGE, +} from '../utils/ConnectionStateMonitor'; +import { AMPLIFY_SYMBOL } from './constants'; +import { CONNECTION_STATE_CHANGE } from '..'; const logger = new Logger('MqttOverWSProvider'); @@ -72,13 +78,32 @@ class ClientsQueue { } } +const dispatchPubSubEvent = (event: string, data: any, message: string) => { + Hub.dispatch('pubsub', { event, data, message }, 'PubSub', AMPLIFY_SYMBOL); +}; + const topicSymbol = typeof Symbol !== 'undefined' ? Symbol('topic') : '@@topic'; export class MqttOverWSProvider extends AbstractPubSubProvider { private _clientsQueue = new ClientsQueue(); + private readonly connectionStateMonitor = new ConnectionStateMonitor(); constructor(options: MqttProviderOptions = {}) { super({ ...options, clientId: options.clientId || uuid() }); + + // Monitor the connection health state and pass changes along to Hub + this.connectionStateMonitor.connectionStateObservable.subscribe( + connectionStateChange => { + dispatchPubSubEvent( + CONNECTION_STATE_CHANGE, + { + provider: this, + connectionState: connectionStateChange, + }, + `Connection state is ${connectionStateChange}` + ); + } + ); } protected get clientId() { @@ -149,6 +174,7 @@ export class MqttOverWSProvider extends AbstractPubSubProvider { public async newClient({ url, clientId }: MqttProviderOptions): Promise { logger.debug('Creating new MQTT client', clientId); + this.connectionStateMonitor.record(CONNECTION_CHANGE.OPENING_CONNECTION); // @ts-ignore const client = new Paho.Client(url, clientId); // client.trace = (args) => logger.debug(clientId, JSON.stringify(args, null, 2)); @@ -168,6 +194,7 @@ export class MqttOverWSProvider extends AbstractPubSubProvider { errorCode: number; }) => { this.onDisconnect({ clientId, errorCode, ...args }); + this.connectionStateMonitor.record(CONNECTION_CHANGE.CLOSED); }; await new Promise((resolve, reject) => { @@ -175,10 +202,19 @@ export class MqttOverWSProvider extends AbstractPubSubProvider { useSSL: this.isSSLEnabled, mqttVersion: 3, onSuccess: () => resolve(client), - onFailure: reject, + onFailure: () => { + reject(); + this.connectionStateMonitor.record( + CONNECTION_CHANGE.CONNECTION_FAILED + ); + }, }); }); + this.connectionStateMonitor.record( + CONNECTION_CHANGE.CONNECTION_ESTABLISHED + ); + return client; } @@ -196,6 +232,7 @@ export class MqttOverWSProvider extends AbstractPubSubProvider { if (client && client.isConnected()) { client.disconnect(); + this.connectionStateMonitor.record(CONNECTION_CHANGE.CLOSED); } this.clientsQueue.remove(clientId); } @@ -293,6 +330,10 @@ export class MqttOverWSProvider extends AbstractPubSubProvider { this._clientIdObservers.get(clientId)?.delete(observer); // No more observers per client => client not needed anymore if (this._clientIdObservers.get(clientId)?.size === 0) { + this.connectionStateMonitor.record( + CONNECTION_CHANGE.CLOSING_CONNECTION + ); + this.disconnect(clientId); this._clientIdObservers.delete(clientId); } diff --git a/packages/pubsub/src/Providers/AWSAppSyncRealTimeProvider/constants.ts b/packages/pubsub/src/Providers/constants.ts similarity index 100% rename from packages/pubsub/src/Providers/AWSAppSyncRealTimeProvider/constants.ts rename to packages/pubsub/src/Providers/constants.ts