Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sending user-defined encrypted to-device messages #2528

Merged
merged 12 commits into from
Aug 3, 2022
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"jest-localstorage-mock": "^2.4.6",
"jest-sonar-reporter": "^2.0.0",
"jsdoc": "^3.6.6",
"matrix-mock-request": "^2.1.1",
"matrix-mock-request": "^2.1.2",
"rimraf": "^3.0.2",
"terser": "^5.5.1",
"tsify": "^5.0.2",
Expand Down
128 changes: 117 additions & 11 deletions spec/unit/crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import '../olm-loader';
// eslint-disable-next-line no-restricted-imports
import { EventEmitter } from "events";

import { MatrixClient } from "../../src/client";
import { Crypto } from "../../src/crypto";
import { MemoryCryptoStore } from "../../src/crypto/store/memory-crypto-store";
import { MockStorageApi } from "../MockStorageApi";
Expand Down Expand Up @@ -225,8 +226,8 @@ describe("Crypto", function() {
});

describe('Key requests', function() {
let aliceClient;
let bobClient;
let aliceClient: MatrixClient;
let bobClient: MatrixClient;

beforeEach(async function() {
aliceClient = (new TestClient(
Expand Down Expand Up @@ -313,7 +314,7 @@ describe("Crypto", function() {
expect(events[0].getContent().msgtype).toBe("m.bad.encrypted");
expect(events[1].getContent().msgtype).not.toBe("m.bad.encrypted");

const cryptoStore = bobClient.cryptoStore;
const cryptoStore = bobClient.crypto.cryptoStore;
const eventContent = events[0].getWireContent();
const senderKey = eventContent.sender_key;
const sessionId = eventContent.session_id;
Expand Down Expand Up @@ -383,9 +384,9 @@ describe("Crypto", function() {

const ksEvent = await keyshareEventForEvent(aliceClient, event, 1);
ksEvent.getContent().sender_key = undefined; // test
bobClient.crypto.addInboundGroupSession = jest.fn();
bobClient.crypto.olmDevice.addInboundGroupSession = jest.fn();
await bobDecryptor.onRoomKeyEvent(ksEvent);
expect(bobClient.crypto.addInboundGroupSession).not.toHaveBeenCalled();
expect(bobClient.crypto.olmDevice.addInboundGroupSession).not.toHaveBeenCalled();
});

it("creates a new keyshare request if we request a keyshare", async function() {
Expand All @@ -401,7 +402,7 @@ describe("Crypto", function() {
},
});
await aliceClient.cancelAndResendEventRoomKeyRequest(event);
const cryptoStore = aliceClient.cryptoStore;
const cryptoStore = aliceClient.crypto.cryptoStore;
const roomKeyRequestBody = {
algorithm: olmlib.MEGOLM_ALGORITHM,
room_id: "!someroom",
Expand All @@ -425,7 +426,8 @@ describe("Crypto", function() {
},
});
// replace Alice's sendToDevice function with a mock
aliceClient.sendToDevice = jest.fn().mockResolvedValue(undefined);
const aliceSendToDevice = jest.fn().mockResolvedValue(undefined);
aliceClient.sendToDevice = aliceSendToDevice;
aliceClient.startClient();

// make a room key request, and record the transaction ID for the
Expand All @@ -434,11 +436,12 @@ describe("Crypto", function() {
// key requests get queued until the sync has finished, but we don't
// let the client set up enough for that to happen, so gut-wrench a bit
// to force it to send now.
// @ts-ignore
aliceClient.crypto.outgoingRoomKeyRequestManager.sendQueuedRequests();
jest.runAllTimers();
await Promise.resolve();
expect(aliceClient.sendToDevice).toBeCalledTimes(1);
const txnId = aliceClient.sendToDevice.mock.calls[0][2];
expect(aliceSendToDevice).toBeCalledTimes(1);
const txnId = aliceSendToDevice.mock.calls[0][2];

// give the room key request manager time to update the state
// of the request
Expand All @@ -451,8 +454,10 @@ describe("Crypto", function() {
// cancelAndResend will call sendToDevice twice:
// the first call to sendToDevice will be the cancellation
// the second call to sendToDevice will be the key request
expect(aliceClient.sendToDevice).toBeCalledTimes(3);
expect(aliceClient.sendToDevice.mock.calls[2][2]).not.toBe(txnId);
expect(aliceSendToDevice).toBeCalledTimes(3);
expect(aliceSendToDevice.mock.calls[2][2]).not.toBe(txnId);

jest.useRealTimers();
robintown marked this conversation as resolved.
Show resolved Hide resolved
});
});

Expand Down Expand Up @@ -480,4 +485,105 @@ describe("Crypto", function() {
client.stopClient();
});
});

describe("encryptAndSendToDevices", () => {
let client: TestClient;
let ensureOlmSessionsForDevices: jest.SpiedFunction<typeof olmlib.ensureOlmSessionsForDevices>;
let encryptMessageForDevice: jest.SpiedFunction<typeof olmlib.encryptMessageForDevice>;
const payload = { hello: "world" };
let encryptedPayload: object;

beforeEach(async () => {
ensureOlmSessionsForDevices = jest.spyOn(olmlib, "ensureOlmSessionsForDevices");
ensureOlmSessionsForDevices.mockResolvedValue({});
encryptMessageForDevice = jest.spyOn(olmlib, "encryptMessageForDevice");
encryptMessageForDevice.mockImplementation(async (...[result,,,,,, payload]) => {
result.plaintext = JSON.stringify(payload);
});

client = new TestClient("@alice:example.org", "aliceweb");
await client.client.initCrypto();

encryptedPayload = {
algorithm: "m.olm.v1.curve25519-aes-sha2",
sender_key: client.client.crypto.olmDevice.deviceCurve25519Key,
ciphertext: { plaintext: JSON.stringify(payload) },
};
});

afterEach(async () => {
ensureOlmSessionsForDevices.mockRestore();
encryptMessageForDevice.mockRestore();
await client.stop();
});

it("encrypts and sends to devices", async () => {
client.httpBackend
.when("PUT", "/sendToDevice/m.room.encrypted", {
messages: {
"@bob:example.org": {
bobweb: encryptedPayload,
bobmobile: encryptedPayload,
},
"@carol:example.org": {
caroldesktop: encryptedPayload,
},
},
})
.respond(200, {});

await Promise.all([
client.client.encryptAndSendToDevices(
[
{ userId: "@bob:example.org", deviceInfo: new DeviceInfo("bobweb") },
{ userId: "@bob:example.org", deviceInfo: new DeviceInfo("bobmobile") },
{ userId: "@carol:example.org", deviceInfo: new DeviceInfo("caroldesktop") },
],
payload,
),
client.httpBackend.flushAllExpected(),
]);
});

it("sends nothing to devices that couldn't be encrypted to", async () => {
encryptMessageForDevice.mockImplementation(async (...[result,,,, userId, device, payload]) => {
// Refuse to encrypt to Carol's desktop device
if (userId === "@carol:example.org" && device.deviceId === "caroldesktop") return;
result.plaintext = JSON.stringify(payload);
});

client.httpBackend
.when("PUT", "/sendToDevice/m.room.encrypted", {
// Carol is nowhere to be seen
messages: { "@bob:example.org": { bobweb: encryptedPayload } },
})
.respond(200, {});

await Promise.all([
client.client.encryptAndSendToDevices(
[
{ userId: "@bob:example.org", deviceInfo: new DeviceInfo("bobweb") },
{ userId: "@carol:example.org", deviceInfo: new DeviceInfo("caroldesktop") },
],
payload,
),
client.httpBackend.flushAllExpected(),
]);
});

it("no-ops if no devices can be encrypted to", async () => {
// Refuse to encrypt to anybody
encryptMessageForDevice.mockResolvedValue(undefined);

// Get the room keys version request out of the way
client.httpBackend.when("GET", "/room_keys/version").respond(404, {});
await client.httpBackend.flush("/room_keys/version", 1);

await client.client.encryptAndSendToDevices(
[{ userId: "@bob:example.org", deviceInfo: new DeviceInfo("bobweb") }],
payload,
);
client.httpBackend.verifyNoOutstandingRequests();
});
});
});
10 changes: 10 additions & 0 deletions spec/unit/crypto/algorithms/megolm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,16 @@ describe("MegolmDecryption", function() {
rotation_period_ms: rotationPeriodMs,
},
});

// Splice the real method onto the mock object as megolm uses this method
// on the crypto class in order to encrypt / start sessions
// @ts-ignore Mock
mockCrypto.encryptAndSendToDevices = Crypto.prototype.encryptAndSendToDevices;
// @ts-ignore Mock
mockCrypto.olmDevice = olmDevice;
// @ts-ignore Mock
mockCrypto.baseApis = mockBaseApis;

mockRoom = {
getEncryptionTargetMembers: jest.fn().mockReturnValue(
[{ userId: "@alice:home.server" }],
Expand Down
16 changes: 16 additions & 0 deletions spec/unit/matrix-client.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
UNSTABLE_MSC3089_TREE_SUBTYPE,
} from "../../src/@types/event";
import { MEGOLM_ALGORITHM } from "../../src/crypto/olmlib";
import { Crypto } from "../../src/crypto";
import { EventStatus, MatrixEvent } from "../../src/models/event";
import { Preset } from "../../src/@types/partials";
import { ReceiptType } from "../../src/@types/read_receipts";
Expand Down Expand Up @@ -1297,4 +1298,19 @@ describe("MatrixClient", function() {
expect(result!.aliases).toEqual(response.aliases);
});
});

describe("encryptAndSendToDevices", () => {
it("throws an error if crypto is unavailable", () => {
client.crypto = undefined;
expect(() => client.encryptAndSendToDevices([], {})).toThrow();
});

it("is an alias for the crypto method", async () => {
client.crypto = testUtils.mock(Crypto, "Crypto");
const deviceInfos = [];
const payload = {};
await client.encryptAndSendToDevices(deviceInfos, payload);
expect(client.crypto.encryptAndSendToDevices).toHaveBeenLastCalledWith(deviceInfos, payload);
});
});
});
32 changes: 29 additions & 3 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ import { sleep } from './utils';
import { Direction, EventTimeline } from "./models/event-timeline";
import { IActionsObject, PushProcessor } from "./pushprocessor";
import { AutoDiscovery, AutoDiscoveryAction } from "./autodiscovery";
import { IEncryptAndSendToDevicesResult } from "./crypto";
import * as olmlib from "./crypto/olmlib";
import { decodeBase64, encodeBase64 } from "./crypto/olmlib";
import { IExportedDevice as IOlmDevice } from "./crypto/OlmDevice";
import { IExportedDevice as IExportedOlmDevice } from "./crypto/OlmDevice";
import { IOlmDevice } from "./crypto/algorithms/megolm";
import { TypedReEmitter } from './ReEmitter';
import { IRoomEncryption, RoomList } from './crypto/RoomList';
import { logger } from './logger';
Expand Down Expand Up @@ -208,7 +210,7 @@ const CAPABILITIES_CACHE_MS = 21600000; // 6 hours - an arbitrary value
const TURN_CHECK_INTERVAL = 10 * 60 * 1000; // poll for turn credentials every 10 minutes

interface IExportedDevice {
olmDevice: IOlmDevice;
olmDevice: IExportedOlmDevice;
userId: string;
deviceId: string;
}
Expand Down Expand Up @@ -936,7 +938,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
protected turnServers: ITurnServer[] = [];
protected turnServersExpiry = 0;
protected checkTurnServersIntervalID: ReturnType<typeof setInterval>;
protected exportedOlmDeviceToImport: IOlmDevice;
protected exportedOlmDeviceToImport: IExportedOlmDevice;
protected txnCtr = 0;
protected mediaHandler = new MediaHandler(this);
protected pendingEventEncryption = new Map<string, Promise<void>>();
Expand Down Expand Up @@ -2558,6 +2560,30 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
return this.roomList.isRoomEncrypted(roomId);
}

/**
* Encrypts and sends a given object via Olm to-device messages to a given
* set of devices.
*
* @param {object[]} userDeviceInfoArr
* mapping from userId to deviceInfo
*
* @param {object} payload fields to include in the encrypted payload
* *
* @return {Promise<{contentMap, deviceInfoByDeviceId}>} Promise which
* resolves once the message has been encrypted and sent to the given
* userDeviceMap, and returns the { contentMap, deviceInfoByDeviceId }
* of the successfully sent messages.
*/
public encryptAndSendToDevices(
userDeviceInfoArr: IOlmDevice<DeviceInfo>[],
payload: object,
): Promise<IEncryptAndSendToDevicesResult> {
if (!this.crypto) {
throw new Error("End-to-End encryption disabled");
}
return this.crypto.encryptAndSendToDevices(userDeviceInfoArr, payload);
}

/**
* Forces the current outbound group session to be discarded such
* that another one will be created next time an event is sent.
Expand Down
Loading