Skip to content

Commit

Permalink
Fix BAD_MESSAGE_FORMAT issue with e2ee web notifs
Browse files Browse the repository at this point in the history
Summary: This differential fixes this [bug](https://linear.app/comm/issue/ENG-9266/olmbad-message-format-notification-on-web). All decryption methods take message type as argument. The only exception are the top level methods that decrypt notifs from the keyserver

Test Plan: Execute steps [here](https://linear.app/comm/issue/ENG-9266/olmbad-message-format-notification-on-web#comment-6a15ca6c). Ensure the issue does not reproduce.

Reviewers: tomek, kamil, ashoat

Reviewed By: tomek, kamil

Differential Revision: https://phab.comm.dev/D13353
  • Loading branch information
marcinwasowicz committed Sep 18, 2024
1 parent 189c1fd commit 819abce
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
7 changes: 6 additions & 1 deletion lib/types/crypto-types.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// @flow

import t, { type TInterface } from 'tcomb';
import t, { type TInterface, type TEnums } from 'tcomb';

import type { OlmSessionInitializationInfo } from './olm-session-types.js';
import { type AuthMetadata } from '../shared/identity-client-context.js';
import { values } from '../utils/objects.js';
import { tShape } from '../utils/validation-utils.js';

export type OLMIdentityKeys = {
Expand Down Expand Up @@ -114,6 +115,10 @@ export const olmEncryptedMessageTypes = Object.freeze({

export type OlmEncryptedMessageTypes = $Values<typeof olmEncryptedMessageTypes>;

export const olmEncryptedMessageTypesValidator: TEnums = t.enums.of(
values(olmEncryptedMessageTypes),
);

export type EncryptedData = {
+message: string,
+messageType: OlmEncryptedMessageTypes,
Expand Down
55 changes: 33 additions & 22 deletions web/push-notif/notif-crypto-utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import {
olmEncryptedMessageTypes,
type NotificationsOlmDataType,
type PickledOLMAccount,
type OlmEncryptedMessageTypes,
} from 'lib/types/crypto-types.js';
import { olmEncryptedMessageTypesValidator } from 'lib/types/crypto-types.js';
import type {
PlainTextWebNotification,
EncryptedWebNotification,
Expand Down Expand Up @@ -94,6 +96,15 @@ const INDEXED_DB_UNREAD_THICK_THREAD_IDS_ENCRYPTION_KEY_DB_LABEL =
'unreadThickThreadIDsEncryptionKey';
const INDEXED_DB_UNREAD_THICK_THREADS_SYNC_KEY = 'unreadThickThreadIDsSyncKey';

function stringToOlmEncryptedMessageType(
messageType: string,
): OlmEncryptedMessageTypes {
const messageTypeAsNumber = Number(messageType);
return assertWithValidator(
messageTypeAsNumber,
olmEncryptedMessageTypesValidator,
);
}
async function deserializeEncryptedData<T>(
encryptedData: EncryptedData,
encryptionKey: CryptoKey,
Expand Down Expand Up @@ -346,11 +357,11 @@ async function decryptWebNotification(
const {
id,
encryptedPayload,
type: messageType,
type: rawMessageType,
...rest
} = encryptedNotification;
const senderDeviceDescriptor: SenderDeviceDescriptor = rest;

const messageType = stringToOlmEncryptedMessageType(rawMessageType);
const utilsData = await localforage.getItem<WebNotifsServiceUtilsData>(
WEB_NOTIFS_SERVICE_UTILS_KEY,
);
Expand Down Expand Up @@ -414,6 +425,7 @@ async function decryptWebNotification(
} = await commonDecrypt<PlainTextWebNotification>(
notificationsOlmData,
encryptedPayload,
messageType,
);

decryptedNotification = resultDecryptedNotification;
Expand Down Expand Up @@ -495,12 +507,12 @@ async function decryptWebNotification(

async function decryptDesktopNotification(
encryptedPayload: string,
messageType: string,
rawMessageType: string,
staffCanSee: boolean,
senderDeviceDescriptor: SenderDeviceDescriptor,
): Promise<{ +[string]: mixed }> {
const { keyserverID, senderDeviceID } = senderDeviceDescriptor;

const messageType = stringToOlmEncryptedMessageType(rawMessageType);
let notifsAccountWithOlmData;
try {
[notifsAccountWithOlmData] = await Promise.all([
Expand Down Expand Up @@ -544,7 +556,7 @@ async function decryptDesktopNotification(

const { decryptedNotification, updatedOlmData } = await commonDecrypt<{
+[string]: mixed,
}>(notificationsOlmData, encryptedPayload);
}>(notificationsOlmData, encryptedPayload, olmEncryptedMessageTypes.TEXT);

const updatedOlmDataPersistencePromise = persistNotifsAccountWithOlmData({
olmDataKey,
Expand Down Expand Up @@ -634,6 +646,7 @@ async function decryptDesktopNotification(
async function commonDecrypt<T>(
notificationsOlmData: NotificationsOlmDataType,
encryptedPayload: string,
type: OlmEncryptedMessageTypes,
): Promise<{
+decryptedNotification: T,
+updatedOlmData: NotificationsOlmDataType,
Expand All @@ -655,6 +668,7 @@ async function commonDecrypt<T>(
pendingSessionUpdate,
picklingKey,
encryptedPayload,
type,
);

if (decryptionWithPendingSessionResult.decryptedNotification) {
Expand All @@ -675,7 +689,7 @@ async function commonDecrypt<T>(
const {
newUpdateCreationTimestamp,
decryptedNotification: notifDecryptedWithMainSession,
} = decryptWithSession<T>(mainSession, picklingKey, encryptedPayload);
} = decryptWithSession<T>(mainSession, picklingKey, encryptedPayload, type);

decryptedNotification = notifDecryptedWithMainSession;
updatedOlmData = {
Expand All @@ -693,22 +707,13 @@ async function commonPeerDecrypt<T>(
senderDeviceID: string,
notificationsOlmData: ?NotificationsOlmDataType,
notificationAccount: PickledOLMAccount,
messageType: string,
messageType: OlmEncryptedMessageTypes,
encryptedPayload: string,
): Promise<{
+decryptedNotification: T,
+updatedOlmData?: NotificationsOlmDataType,
+updatedNotifsAccount?: PickledOLMAccount,
}> {
if (
messageType !== olmEncryptedMessageTypes.PREKEY.toString() &&
messageType !== olmEncryptedMessageTypes.TEXT.toString()
) {
throw new Error(
`Received message of invalid type from device: ${senderDeviceID}`,
);
}

let isSenderChainEmpty = true;
let hasReceivedMessage = false;
const sessionExists = !!notificationsOlmData;
Expand All @@ -726,17 +731,20 @@ async function commonPeerDecrypt<T>(

// regular message
const isRegularMessage =
!!notificationsOlmData &&
messageType === olmEncryptedMessageTypes.TEXT.toString();
!!notificationsOlmData && messageType === olmEncryptedMessageTypes.TEXT;

const isRegularPrekeyMessage =
!!notificationsOlmData &&
messageType === olmEncryptedMessageTypes.PREKEY.toString() &&
messageType === olmEncryptedMessageTypes.PREKEY &&
isSenderChainEmpty &&
hasReceivedMessage;

if (!!notificationsOlmData && (isRegularMessage || isRegularPrekeyMessage)) {
return await commonDecrypt<T>(notificationsOlmData, encryptedPayload);
return await commonDecrypt<T>(
notificationsOlmData,
encryptedPayload,
messageType,
);
}

// At this point we either face race condition or session reset attempt or
Expand Down Expand Up @@ -772,7 +780,7 @@ async function commonPeerDecrypt<T>(
);

const decryptedNotification: T = JSON.parse(
session.decrypt(Number(messageType), encryptedPayload),
session.decrypt(messageType, encryptedPayload),
);

// session reset attempt or session initialization - handled the same
Expand Down Expand Up @@ -820,12 +828,13 @@ function decryptWithSession<T>(
pickledSession: string,
picklingKey: string,
encryptedPayload: string,
type: OlmEncryptedMessageTypes,
): DecryptionResult<T> {
const session = new olm.Session();

session.unpickle(picklingKey, pickledSession);
const decryptedNotification: T = JSON.parse(
session.decrypt(olmEncryptedMessageTypes.TEXT, encryptedPayload),
session.decrypt(type, encryptedPayload),
);

const newPendingSessionUpdate = session.pickle(picklingKey);
Expand All @@ -842,6 +851,7 @@ function decryptWithPendingSession<T>(
pendingSessionUpdate: string,
picklingKey: string,
encryptedPayload: string,
type: OlmEncryptedMessageTypes,
): DecryptionResult<T> | { +error: string } {
try {
const {
Expand All @@ -852,6 +862,7 @@ function decryptWithPendingSession<T>(
pendingSessionUpdate,
picklingKey,
encryptedPayload,
type,
);
return {
newPendingSessionUpdate,
Expand Down

0 comments on commit 819abce

Please sign in to comment.