Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-9824 - feat: PubSub Connection state tracking for MQTT and IoT providers #10136

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import { Reachability, Credentials, Logger, Signer } from '@aws-amplify/core';
import { Auth } from '@aws-amplify/auth';
import Cache from '@aws-amplify/cache';

import { MESSAGE_TYPES } from '../src/Providers/AWSAppSyncRealTimeProvider/constants';
import * as constants from '../src/Providers/AWSAppSyncRealTimeProvider/constants';
import { MESSAGE_TYPES } from '../src/Providers/constants';
import * as constants from '../src/Providers/constants';

import { delay, FakeWebSocketInterface, replaceConstant } from './helpers';
import { ConnectionState as CS } from '../src';
Expand Down
127 changes: 127 additions & 0 deletions packages/pubsub/__tests__/PubSub-unit-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@ import {
// import Amplify from '../../src/';
import {
Credentials,
Hub,
INTERNAL_AWS_APPSYNC_PUBSUB_PROVIDER,
Logger,
Reachability,
} from '@aws-amplify/core';
import * as Paho from 'paho-mqtt';
import {
ConnectionState,
ConnectionState,
CONNECTION_STATE_CHANGE,
} from '../src';
import { HubConnectionListener } from './helpers';
import Observable from 'zen-observable-ts';

const pahoClientMockCache = {};

Expand Down Expand Up @@ -306,6 +316,123 @@ describe('PubSub', () => {
expect(originalProvider.publish).not.toHaveBeenCalled();
expect(newProvider.publish).toHaveBeenCalled();
});

describe('Hub connection state changes', () => {
let hubConnectionListener: HubConnectionListener;

let reachabilityObserver: ZenObservable.Observer<{ online: boolean }>;

beforeEach(() => {
// Maintain the Hub connection listener, used to monitor the connection messages sent through Hub
hubConnectionListener?.teardown();
hubConnectionListener = new HubConnectionListener('pubsub');

// Setup a mock of the reachability monitor where the initial value is online.
const spyon = jest
.spyOn(Reachability.prototype, 'networkMonitor')
.mockImplementationOnce(
() =>
new Observable(observer => {
reachabilityObserver = observer;
})
);
reachabilityObserver?.next?.({ online: true });
});

test('test happy case connect -> disconnect cycle', async () => {
const pubsub = new PubSub();

const awsIotProvider = new AWSIoTProvider({
aws_pubsub_region: 'region',
aws_pubsub_endpoint: 'wss://iot.mymockendpoint.org:443/notrealmqtt',
});
pubsub.addPluggable(awsIotProvider);

const sub = pubsub.subscribe('topic', { clientId: '123' }).subscribe({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why the subscribe calls are chained here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PubSub API uses the .subscribe keyword to setup an Observable where these zen-observable objects also have a .subscribe method that allows you to setup independent callback for next events, complete and error events. Both error and complete events close the observable, so will only ever be called once.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! thanks for both the explanations.

error: () => {},
});

await hubConnectionListener.waitUntilConnectionStateIn(['Connected']);
sub.unsubscribe();
awsIotProvider.onDisconnect({ errorCode: 1, clientId: '123' });
await hubConnectionListener.waitUntilConnectionStateIn([
'Disconnected',
]);
expect(hubConnectionListener.observedConnectionStates).toEqual([
'Disconnected',
'Connecting',
'Connected',
'ConnectedPendingDisconnect',
'Disconnected',
]);
});

test('test network disconnection and recovery', async () => {
const pubsub = new PubSub();

const awsIotProvider = new AWSIoTProvider({
aws_pubsub_region: 'region',
aws_pubsub_endpoint: 'wss://iot.mymockendpoint.org:443/notrealmqtt',
});
pubsub.addPluggable(awsIotProvider);

const sub = pubsub.subscribe('topic', { clientId: '123' }).subscribe({
error: () => {},
});

await hubConnectionListener.waitUntilConnectionStateIn(['Connected']);

reachabilityObserver?.next?.({ online: false });
await hubConnectionListener.waitUntilConnectionStateIn([
'ConnectedPendingNetwork',
]);

reachabilityObserver?.next?.({ online: true });
await hubConnectionListener.waitUntilConnectionStateIn(['Connected']);

expect(hubConnectionListener.observedConnectionStates).toEqual([
'Disconnected',
'Connecting',
'Connected',
'ConnectedPendingNetwork',
'Connected',
]);
});

test('test network disconnection followed by connection disruption', async () => {
const pubsub = new PubSub();

const awsIotProvider = new AWSIoTProvider({
aws_pubsub_region: 'region',
aws_pubsub_endpoint: 'wss://iot.mymockendpoint.org:443/notrealmqtt',
});
pubsub.addPluggable(awsIotProvider);

const sub = pubsub.subscribe('topic', { clientId: '123' }).subscribe({
error: () => {},
});

await hubConnectionListener.waitUntilConnectionStateIn(['Connected']);

reachabilityObserver?.next?.({ online: false });
await hubConnectionListener.waitUntilConnectionStateIn([
'ConnectedPendingNetwork',
]);

awsIotProvider.onDisconnect({ errorCode: 1, clientId: '123' });
await hubConnectionListener.waitUntilConnectionStateIn([
'Disconnected',
]);

expect(hubConnectionListener.observedConnectionStates).toEqual([
'Disconnected',
'Connecting',
'Connected',
'ConnectedPendingNetwork',
'Disconnected',
]);
});
});
});

describe('MqttOverWSProvider local testing config', () => {
Expand Down
93 changes: 65 additions & 28 deletions packages/pubsub/__tests__/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Hub } from '@aws-amplify/core';
import Observable from 'zen-observable-ts';
import { ConnectionState as CS, CONNECTION_STATE_CHANGE } from '../src';
import * as constants from '../src/Providers/AWSAppSyncRealTimeProvider/constants';
import * as constants from '../src/Providers/constants';

export function delay(timeout) {
return new Promise(resolve => {
Expand All @@ -11,28 +11,17 @@ export function delay(timeout) {
});
}

export class FakeWebSocketInterface {
readonly webSocket: FakeWebSocket;
readyForUse: Promise<void>;
hasClosed: Promise<undefined>;
export class HubConnectionListener {
teardownHubListener: () => void;
observedConnectionStates: CS[] = [];
currentConnectionState: CS;

private readyResolve: (value: PromiseLike<any>) => void;
private connectionStateObservers: ZenObservable.Observer<CS>[] = [];

constructor() {
this.readyForUse = new Promise((res, rej) => {
this.readyResolve = res;
});
constructor(channel: string) {
let closeResolver: (value: PromiseLike<any>) => void;
this.hasClosed = new Promise((res, rej) => {
closeResolver = res;
});
this.webSocket = new FakeWebSocket(() => closeResolver);

this.teardownHubListener = Hub.listen('api', (data: any) => {
this.teardownHubListener = Hub.listen(channel, (data: any) => {
const { payload } = data;
if (payload.event === CONNECTION_STATE_CHANGE) {
const connectionState = payload.data.connectionState as CS;
Expand Down Expand Up @@ -65,6 +54,7 @@ export class FakeWebSocketInterface {
this.connectionStateObservers.push(observer);
});
}

/**
* Tear down the Fake Socket state
*/
Expand All @@ -75,6 +65,62 @@ export class FakeWebSocketInterface {
});
}

async waitForConnectionState(connectionStates: CS[]) {
return new Promise<void>((res, rej) => {
this.connectionStateObserver().subscribe(value => {
if (connectionStates.includes(String(value) as CS)) {
res(undefined);
}
});
});
}

async waitUntilConnectionStateIn(connectionStates: CS[]) {
return new Promise<void>((res, rej) => {
if (connectionStates.includes(this.currentConnectionState)) {
res(undefined);
}
res(this.waitForConnectionState(connectionStates));
});
}
}

export class FakeWebSocketInterface {
readonly webSocket: FakeWebSocket;
readyForUse: Promise<void>;
hasClosed: Promise<undefined>;
hubConnectionListener: HubConnectionListener;

private readyResolve: (value: PromiseLike<any>) => void;

constructor() {
this.hubConnectionListener = new HubConnectionListener('api');
this.readyForUse = new Promise((res, rej) => {
this.readyResolve = res;
});
let closeResolver: (value: PromiseLike<any>) => void;
this.hasClosed = new Promise((res, rej) => {
closeResolver = res;
});
this.webSocket = new FakeWebSocket(() => closeResolver);
}

get observedConnectionStates() {
return this.hubConnectionListener.observedConnectionStates;
}

allConnectionStateObserver() {
return this.hubConnectionListener.allConnectionStateObserver();
}

connectionStateObserver() {
return this.hubConnectionListener.connectionStateObserver();
}

teardown() {
this.hubConnectionListener.teardown();
}

/**
* Once ready for use, send onOpen and the connection_ack
*/
Expand Down Expand Up @@ -207,25 +253,16 @@ export class FakeWebSocketInterface {
* @returns a Promise that will wait for one of the provided states to be observed
*/
async waitForConnectionState(connectionStates: CS[]) {
return new Promise<void>((res, rej) => {
this.connectionStateObserver().subscribe(value => {
if (connectionStates.includes(String(value) as CS)) {
res(undefined);
}
});
});
return this.hubConnectionListener.waitForConnectionState(connectionStates);
}

/**
* @returns a Promise that will wait until the current state is one of the provided states
*/
async waitUntilConnectionStateIn(connectionStates: CS[]) {
return new Promise<void>((res, rej) => {
if (connectionStates.includes(this.currentConnectionState)) {
res(undefined);
}
res(this.waitForConnectionState(connectionStates));
});
return this.hubConnectionListener.waitUntilConnectionStateIn(
connectionStates
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import {
SOCKET_STATUS,
START_ACK_TIMEOUT,
SUBSCRIPTION_STATUS,
} from './constants';
} from '../constants';
import {
ConnectionStateMonitor,
CONNECTION_CHANGE,
Expand Down
Loading