Skip to content

Commit

Permalink
Make sure we use the provided CSPRNG everywhere (#342)
Browse files Browse the repository at this point in the history
This includes a revert of commit b0fcd8b to use a different approach.
  • Loading branch information
gferon authored Nov 5, 2024
1 parent f4313db commit 6401dc3
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 187 deletions.
76 changes: 39 additions & 37 deletions src/account_manager.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use base64::prelude::*;
use phonenumber::PhoneNumber;
use rand::Rng;
use rand::{CryptoRng, Rng};
use reqwest::Method;
use std::collections::HashMap;
use std::convert::{TryFrom, TryInto};
Expand Down Expand Up @@ -88,16 +88,13 @@ impl AccountManager {
///
/// Equivalent to Java's RefreshPreKeysJob
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip(self, protocol_store, csprng))]
pub async fn update_pre_key_bundle<
R: rand::Rng + rand::CryptoRng,
P: PreKeysStore,
>(
#[tracing::instrument(skip(self, csprng, protocol_store))]
pub async fn update_pre_key_bundle<R: Rng + CryptoRng, P: PreKeysStore>(
&mut self,
protocol_store: &mut P,
service_id_kind: ServiceIdKind,
csprng: &mut R,
use_last_resort_key: bool,
csprng: &mut R,
) -> Result<(), ServiceError> {
let prekey_status = match self
.service
Expand Down Expand Up @@ -154,8 +151,8 @@ impl AccountManager {
let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
crate::pre_keys::replenish_pre_keys(
protocol_store,
&identity_key_pair,
csprng,
&identity_key_pair,
use_last_resort_key && !has_last_resort_key,
PRE_KEY_BATCH_SIZE,
PRE_KEY_BATCH_SIZE,
Expand Down Expand Up @@ -281,8 +278,9 @@ impl AccountManager {
/// ```java
/// TextSecurePreferences.setIsUnidentifiedDeliveryEnabled(context, false);
/// ```
pub async fn link_device(
pub async fn link_device<R: Rng + CryptoRng>(
&mut self,
csprng: &mut R,
url: url::Url,
aci_identity_store: &dyn IdentityKeyStore,
pni_identity_store: &dyn IdentityKeyStore,
Expand Down Expand Up @@ -344,7 +342,7 @@ impl AccountManager {

let cipher = ProvisioningCipher::from_public(pub_key);

let encrypted = cipher.encrypt(msg)?;
let encrypted = cipher.encrypt(csprng, msg)?;
self.send_provisioning_message(ephemeral_id, encrypted)
.await?;
Ok(())
Expand Down Expand Up @@ -380,7 +378,7 @@ impl AccountManager {
}

pub async fn register_account<
R: rand::Rng + rand::CryptoRng,
R: Rng + CryptoRng,
Aci: PreKeysStore + IdentityKeyStore,
Pni: PreKeysStore + IdentityKeyStore,
>(
Expand Down Expand Up @@ -408,8 +406,8 @@ impl AccountManager {
aci_last_resort_kyber_prekey,
) = crate::pre_keys::replenish_pre_keys(
aci_protocol_store,
&aci_identity_key_pair,
csprng,
&aci_identity_key_pair,
true,
0,
0,
Expand All @@ -423,8 +421,8 @@ impl AccountManager {
pni_last_resort_kyber_prekey,
) = crate::pre_keys::replenish_pre_keys(
pni_protocol_store,
&pni_identity_key_pair,
csprng,
&pni_identity_key_pair,
true,
0,
0,
Expand Down Expand Up @@ -470,15 +468,19 @@ impl AccountManager {
/// ```
/// in which the `retain_avatar` parameter sets whether to remove (`false`) or retain (`true`) the
/// currently set avatar.
pub async fn upload_versioned_profile_without_avatar<S: AsRef<str>>(
pub async fn upload_versioned_profile_without_avatar<
R: Rng + CryptoRng,
S: AsRef<str>,
>(
&mut self,
aci: libsignal_protocol::Aci,
name: ProfileName<S>,
about: Option<String>,
about_emoji: Option<String>,
retain_avatar: bool,
csprng: &mut R,
) -> Result<(), ProfileManagerError> {
self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _>(
self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
aci,
name,
about,
Expand All @@ -488,6 +490,7 @@ impl AccountManager {
} else {
AvatarWrite::NoAvatar
},
csprng,
)
.await?;
Ok(())
Expand All @@ -505,8 +508,8 @@ impl AccountManager {
.retrieve_profile_by_id(address, Some(profile_key))
.await?;

let profile_cipher = ProfileCipher::from(profile_key);
Ok(encrypted_profile.decrypt(profile_cipher)?)
let profile_cipher = ProfileCipher::new(profile_key);
Ok(profile_cipher.decrypt(encrypted_profile)?)
}

/// Upload a profile
Expand All @@ -517,6 +520,7 @@ impl AccountManager {
pub async fn upload_versioned_profile<
's,
C: std::io::Read + Send + 's,
R: Rng + CryptoRng,
S: AsRef<str>,
>(
&mut self,
Expand All @@ -525,17 +529,18 @@ impl AccountManager {
about: Option<String>,
about_emoji: Option<String>,
avatar: AvatarWrite<&'s mut C>,
csprng: &mut R,
) -> Result<Option<String>, ProfileManagerError> {
let profile_key =
self.profile_key.expect("set profile key in AccountManager");
let profile_cipher = ProfileCipher::from(profile_key);
let profile_cipher = ProfileCipher::new(profile_key);

// Profile encryption
let name = profile_cipher.encrypt_name(name.as_ref())?;
let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
let about = about.unwrap_or_default();
let about = profile_cipher.encrypt_about(about)?;
let about = profile_cipher.encrypt_about(about, csprng)?;
let about_emoji = about_emoji.unwrap_or_default();
let about_emoji = profile_cipher.encrypt_emoji(about_emoji)?;
let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;

// If avatar -> upload
if matches!(avatar, AvatarWrite::NewAvatar(_)) {
Expand Down Expand Up @@ -572,16 +577,14 @@ impl AccountManager {
}

/// Update (encrypted) device name
pub async fn update_device_name(
pub async fn update_device_name<R: Rng + CryptoRng>(
&mut self,
device_name: &str,
public_key: &IdentityKey,
csprng: &mut R,
) -> Result<(), ServiceError> {
let encrypted_device_name = encrypt_device_name(
&mut rand::thread_rng(),
device_name,
public_key,
)?;
let encrypted_device_name =
encrypt_device_name(csprng, device_name, public_key)?;

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
Expand Down Expand Up @@ -642,22 +645,21 @@ impl AccountManager {
/// Should be called as the primary device to migrate from pre-PNI to PNI.
///
/// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
#[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci), fields(local_aci = %local_aci.service_id_string()))]
#[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = local_aci.service_id_string()))]
pub async fn pnp_initialize_devices<
// XXX So many constraints here, all imposed by the MessageSender
R: rand::Rng + rand::CryptoRng,
R: Rng + CryptoRng,
AciStore: PreKeysStore + SessionStoreExt,
PniStore: PreKeysStore,
AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
>(
&mut self,
aci_protocol_store: &mut AciStore,
pni_protocol_store: &mut PniStore,
mut sender: MessageSender<AciOrPni>,
mut sender: MessageSender<AciOrPni, R>,
local_aci: Aci,
e164: PhoneNumber,
csprng: &mut R,
) -> Result<(), MessageSenderError> {
let mut csprng = rand::thread_rng();
let pni_identity_key_pair =
pni_protocol_store.get_identity_key_pair().await?;

Expand Down Expand Up @@ -713,21 +715,21 @@ impl AccountManager {
) = if local_device_id == DEFAULT_DEVICE_ID {
crate::pre_keys::replenish_pre_keys(
pni_protocol_store,
csprng,
&pni_identity_key_pair,
&mut csprng,
true,
0,
0,
)
.await?
} else {
// Generate a signed prekey
let signed_pre_key_pair = KeyPair::generate(&mut csprng);
let signed_pre_key_pair = KeyPair::generate(csprng);
let signed_pre_key_public = signed_pre_key_pair.public_key;
let signed_pre_key_signature =
pni_identity_key_pair.private_key().calculate_signature(
&signed_pre_key_public.serialize(),
&mut csprng,
csprng,
)?;

let signed_prekey_record = SignedPreKeyRecord::new(
Expand Down Expand Up @@ -755,7 +757,7 @@ impl AccountManager {
pni_protocol_store.get_local_registration_id().await?
} else {
loop {
let regid = generate_registration_id(&mut csprng);
let regid = generate_registration_id(csprng);
if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
break regid;
}
Expand Down Expand Up @@ -803,7 +805,7 @@ impl AccountManager {
e164.format().mode(phonenumber::Mode::E164).to_string(),
),
}),
padding: Some(random_length_padding(&mut csprng, 512)),
padding: Some(random_length_padding(csprng, 512)),
..SyncMessage::default()
};
let content: ContentBody = msg.into();
Expand Down
25 changes: 14 additions & 11 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use libsignal_protocol::{
SignalMessage, SignalProtocolError, SignedPreKeyStore, Timestamp,
};
use prost::Message;
use rand::{CryptoRng, Rng};
use uuid::Uuid;

use crate::{
Expand All @@ -39,7 +40,6 @@ impl<S> fmt::Debug for ServiceCipher<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServiceCipher")
.field("protocol_store", &"...")
.field("csprng", &"...")
.field("trust_root", &"...")
.field("local_uuid", &self.local_uuid)
.field("local_device_id", &self.local_device_id)
Expand Down Expand Up @@ -89,13 +89,14 @@ where
/// Opens ("decrypts") an envelope.
///
/// Envelopes may be empty, in which case this method returns `Ok(None)`
#[tracing::instrument(skip(envelope), fields(envelope = debug_envelope(&envelope)))]
pub async fn open_envelope(
#[tracing::instrument(skip(envelope, csprng), fields(envelope = debug_envelope(&envelope)))]
pub async fn open_envelope<R: Rng + CryptoRng>(
&mut self,
envelope: Envelope,
csprng: &mut R,
) -> Result<Option<Content>, ServiceError> {
if envelope.content.is_some() {
let plaintext = self.decrypt(&envelope).await?;
let plaintext = self.decrypt(&envelope, csprng).await?;
let message =
crate::proto::Content::decode(plaintext.data.as_slice())?;
if let Some(bytes) = message.sender_key_distribution_message {
Expand All @@ -121,10 +122,11 @@ where
/// Triage of legacy messages happens inside this method, as opposed to the
/// Java implementation, because it makes the borrow checker and the
/// author happier.
#[tracing::instrument(skip(envelope), fields(envelope = debug_envelope(envelope)))]
async fn decrypt(
#[tracing::instrument(skip(envelope, csprng), fields(envelope = debug_envelope(envelope)))]
async fn decrypt<R: Rng + CryptoRng>(
&mut self,
envelope: &Envelope,
csprng: &mut R,
) -> Result<Plaintext, ServiceError> {
let ciphertext = if let Some(msg) = envelope.content.as_ref() {
msg
Expand Down Expand Up @@ -175,7 +177,7 @@ where
&mut self.protocol_store.clone(),
&self.protocol_store.clone(),
&mut self.protocol_store.clone(),
&mut rand::thread_rng(),
csprng,
)
.await?
.as_slice()
Expand Down Expand Up @@ -231,7 +233,7 @@ where
&sender,
&mut self.protocol_store.clone(),
&mut self.protocol_store.clone(),
&mut rand::thread_rng(),
csprng,
)
.await?
.as_slice()
Expand Down Expand Up @@ -317,18 +319,19 @@ where
}

#[tracing::instrument(
skip(address, unidentified_access, content),
skip(address, unidentified_access, content, csprng),
fields(
address = %address,
with_unidentified_access = unidentified_access.is_some(),
content_length = content.len(),
)
)]
pub(crate) async fn encrypt(
pub(crate) async fn encrypt<R: Rng + CryptoRng>(
&mut self,
address: &ProtocolAddress,
unidentified_access: Option<&SenderCertificate>,
content: &[u8],
csprng: &mut R,
) -> Result<OutgoingPushMessage, ServiceError> {
let session_record = self
.protocol_store
Expand All @@ -352,7 +355,7 @@ where
&mut self.protocol_store.clone(),
&mut self.protocol_store,
SystemTime::now(),
&mut rand::thread_rng(),
csprng,
)
.await?;

Expand Down
Loading

0 comments on commit 6401dc3

Please sign in to comment.