Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: ensure the SessionStore is cleared when regenerating the OlmMachine #3338

Merged
merged 5 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
322 changes: 321 additions & 1 deletion crates/matrix-sdk-base/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ impl BaseClient {
tracing::debug!("regenerating OlmMachine");
let session_meta = self.session_meta().ok_or(Error::OlmError(OlmError::MissingSession))?;

// Recreate it.
// Recreate the `OlmMachine` and wipe the in-memory cache in the store
// because we suspect it has stale data.
self.crypto_store.clear_caches().await;
let olm_machine = OlmMachine::with_store(
&session_meta.user_id,
&session_meta.device_id,
Expand Down Expand Up @@ -1695,6 +1697,36 @@ mod tests {
client.get_room(room_id).expect("Just-created room not found!")
}

#[cfg(feature = "e2e-encryption")]
#[async_test]
async fn test_regerating_olm_clears_store_caches() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this test is really useful? This is using a fake crypto store impl, and just checking that a method has been called, which is better tested in the other test. How do y'all feel about removing this test instead and all the bootstrapping that it requires, which is a bit noisey to me?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agreed as much in my final review, it is a bit useful since it's the only place which checks that the method gets called.

Porting the complement crypto test is still in the pinkie promise so that might get rid of this noisy code.

// See https://github.com/matrix-org/matrix-rust-sdk/issues/3110
// We must clear the store cache when we regenerate the OlmMachine
// to ensure we really get the new state.

use ruma::{owned_device_id, owned_user_id};

use crate::store::StoreConfig;

// Given a client using a fake store
let user_id = owned_user_id!("@u:m.o");
let device_id = owned_device_id!("DEVICE");
let fake_store = fake_crypto_store::FakeCryptoStore::default();
let store_config = StoreConfig::new().crypto_store(fake_store.clone());
let client = BaseClient::with_store_config(store_config);
client.set_session_meta(SessionMeta { user_id, device_id }).await.unwrap();
fake_store.clear_method_calls();

// When we regenerate the OlmMachine
client.regenerate_olm().await.expect("Failed to regenerate olm");

// Then we cleared the store cache
assert!(
fake_store.method_calls().contains(&"clear_caches".to_owned()),
"No clear_caches call!"
);
}

#[async_test]
async fn test_deserialization_failure() {
let user_id = user_id!("@alice:example.org");
Expand Down Expand Up @@ -1884,4 +1916,292 @@ mod tests {
assert_eq!(member.display_name().unwrap(), "Invited Alice");
assert_eq!(member.avatar_url().unwrap().to_string(), "mxc://localhost/fewjilfewjil42");
}

#[cfg(feature = "e2e-encryption")]
mod fake_crypto_store {
use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, Mutex},
};

use async_trait::async_trait;
use matrix_sdk_crypto::{
olm::{InboundGroupSession, OutboundGroupSession, PrivateCrossSigningIdentity},
store::{
BackupKeys, Changes, CryptoStore, PendingChanges, RoomKeyCounts, RoomSettings,
},
types::events::room_key_withheld::RoomKeyWithheldEvent,
Account, GossipRequest, GossippedSecret, ReadOnlyDevice, ReadOnlyUserIdentities,
SecretInfo, Session, TrackedUser,
};
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId, TransactionId,
UserId,
};

#[derive(Clone, Debug, Default)]
pub(crate) struct FakeCryptoStore {
pub method_calls: Arc<Mutex<Vec<String>>>,
}

impl FakeCryptoStore {
pub fn method_calls(&self) -> Vec<String> {
self.method_calls.lock().unwrap().clone()
}

pub fn clear_method_calls(&self) {
self.method_calls.lock().unwrap().clear();
}

fn call(&self, method_name: &str) {
self.method_calls.lock().unwrap().push(method_name.to_owned());
}
}

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl CryptoStore for FakeCryptoStore {
type Error = Infallible;

async fn clear_caches(&self) {
self.call("clear_caches");
}

async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
self.call("load_account");
Ok(None)
}

async fn load_identity(
&self,
) -> Result<Option<PrivateCrossSigningIdentity>, Self::Error> {
self.call("load_identity");
Ok(None)
}

async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
self.call("next_batch_token");
Ok(None)
}

async fn save_pending_changes(
&self,
_changes: PendingChanges,
) -> Result<(), Self::Error> {
self.call("save_pending_changes");
Ok(())
}

async fn save_changes(&self, _changes: Changes) -> Result<(), Self::Error> {
self.call("save_changes");
Ok(())
}

async fn get_sessions(
&self,
_sender_key: &str,
) -> Result<Option<Arc<tokio::sync::Mutex<Vec<Session>>>>, Self::Error> {
self.call("get_sessions");
Ok(None)
}

async fn get_inbound_group_session(
&self,
_room_id: &RoomId,
_session_id: &str,
) -> Result<Option<InboundGroupSession>, Self::Error> {
self.call("get_inbound_group_session");
Ok(None)
}

async fn get_withheld_info(
&self,
_room_id: &RoomId,
_session_id: &str,
) -> Result<Option<RoomKeyWithheldEvent>, Self::Error> {
self.call("get_withheld_info");
Ok(None)
}

async fn get_inbound_group_sessions(
&self,
) -> Result<Vec<InboundGroupSession>, Self::Error> {
self.call("get_inbound_group_sessions");
Ok(vec![])
}

async fn inbound_group_session_counts(
&self,
_backup_version: Option<&str>,
) -> Result<RoomKeyCounts, Self::Error> {
self.call("inbound_group_session_counts");
Ok(RoomKeyCounts { total: 0, backed_up: 0 })
}

async fn inbound_group_sessions_for_backup(
&self,
_backup_version: &str,
_limit: usize,
) -> Result<Vec<InboundGroupSession>, Self::Error> {
self.call("inbound_group_sessions_for_backup");
Ok(vec![])
}

async fn mark_inbound_group_sessions_as_backed_up(
&self,
_backup_version: &str,
_room_and_session_ids: &[(&RoomId, &str)],
) -> Result<(), Self::Error> {
self.call("mark_inbound_group_sessions_as_backed_up");
Ok(())
}

async fn reset_backup_state(&self) -> Result<(), Self::Error> {
self.call("reset_backup_state");
Ok(())
}

async fn load_backup_keys(&self) -> Result<BackupKeys, Self::Error> {
self.call("load_backup_keys");
Ok(BackupKeys::default())
}

async fn get_outbound_group_session(
&self,
_room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>, Self::Error> {
self.call("get_outbound_group_session");
Ok(None)
}

async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error> {
self.call("load_tracked_users");
Ok(vec![])
}

async fn save_tracked_users(
&self,
_tracked_users: &[(&UserId, bool)],
) -> Result<(), Self::Error> {
self.call("save_tracked_users");
Ok(())
}

async fn get_device(
&self,
_user_id: &UserId,
_device_id: &DeviceId,
) -> Result<Option<ReadOnlyDevice>, Self::Error> {
self.call("get_device");
Ok(None)
}

async fn get_user_devices(
&self,
_user_id: &UserId,
) -> Result<HashMap<OwnedDeviceId, ReadOnlyDevice>, Self::Error> {
self.call("get_user_devices");
Ok(HashMap::default())
}

async fn get_user_identity(
&self,
_user_id: &UserId,
) -> Result<Option<ReadOnlyUserIdentities>, Self::Error> {
self.call("get_user_identity");
Ok(None)
}

async fn is_message_known(
&self,
_message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
) -> Result<bool, Self::Error> {
self.call("is_message_known");
Ok(false)
}

async fn get_outgoing_secret_requests(
&self,
_request_id: &TransactionId,
) -> Result<Option<GossipRequest>, Self::Error> {
self.call("get_outgoing_secret_requests");
Ok(None)
}

async fn get_secret_request_by_info(
&self,
_key_info: &SecretInfo,
) -> Result<Option<GossipRequest>, Self::Error> {
self.call("get_secret_request_by_info");
Ok(None)
}

async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>, Self::Error> {
self.call("get_unsent_secret_requests");
Ok(vec![])
}

async fn delete_outgoing_secret_requests(
&self,
_request_id: &TransactionId,
) -> Result<(), Self::Error> {
self.call("delete_outgoing_secret_requests");
Ok(())
}

async fn get_secrets_from_inbox(
&self,
_secret_name: &SecretName,
) -> Result<Vec<GossippedSecret>, Self::Error> {
self.call("get_secrets_from_inbox");
Ok(vec![])
}

async fn delete_secrets_from_inbox(
&self,
_secret_name: &SecretName,
) -> Result<(), Self::Error> {
self.call("delete_secrets_from_inbox");
Ok(())
}

async fn get_room_settings(
&self,
_room_id: &RoomId,
) -> Result<Option<RoomSettings>, Self::Error> {
self.call("get_room_settings");
Ok(None)
}

async fn get_custom_value(&self, _key: &str) -> Result<Option<Vec<u8>>, Self::Error> {
self.call("get_custom_value");
Ok(None)
}

async fn set_custom_value(
&self,
_key: &str,
_value: Vec<u8>,
) -> Result<(), Self::Error> {
self.call("set_custom_value");
Ok(())
}

async fn remove_custom_value(&self, _key: &str) -> Result<(), Self::Error> {
self.call("remove_custom_value");
Ok(())
}

async fn try_take_leased_lock(
&self,
_lease_duration_ms: u32,
_key: &str,
_holder: &str,
) -> Result<bool, Self::Error> {
self.call("try_take_leased_lock");
Ok(true)
}
}
}
}
3 changes: 3 additions & 0 deletions crates/matrix-sdk-crypto/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Breaking changes:

Additions:

- Expose new method `CryptoStore::clear_caches`.
([#3338](https://github.com/matrix-org/matrix-rust-sdk/pull/3338))

- Expose new method `OlmMachine::device_creation_time`.
([#3275](https://github.com/matrix-org/matrix-rust-sdk/pull/3275))

Expand Down
7 changes: 7 additions & 0 deletions crates/matrix-sdk-crypto/src/store/caches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ impl SessionStore {
Self::default()
}

/// Clear all entries in the session store.
///
/// This is intended to be used when regenerating olm machines.
pub fn clear(&self) {
self.entries.write().unwrap().clear()
}

/// Add a session to the store.
///
/// Returns true if the session was added, false if the session was
Expand Down
2 changes: 1 addition & 1 deletion crates/matrix-sdk-crypto/src/store/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ macro_rules! cryptostore_integration_tests {
Account::with_device_id(alice_id(), alice_device_id())
}

async fn get_account_and_session() -> (Account, Session) {
pub(crate) async fn get_account_and_session() -> (Account, Session) {
let alice = Account::with_device_id(alice_id(), alice_device_id());
let mut bob = Account::with_device_id(bob_id(), bob_device_id());

Expand Down
12 changes: 12 additions & 0 deletions crates/matrix-sdk-crypto/src/store/memorystore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ type Result<T> = std::result::Result<T, Infallible>;
impl CryptoStore for MemoryStore {
type Error = Infallible;

async fn clear_caches(&self) {
// no-op: it makes no sense to delete fields here as we would forget our
// identity, etc Effectively we have no caches as the fields
// *are* the underlying store. Calling this method only makes
// sense if there is some other layer (e.g disk) persistence
// happening.
}

async fn load_account(&self) -> Result<Option<Account>> {
Ok(self.account.read().unwrap().as_ref().map(|acc| acc.deep_clone()))
}
Expand Down Expand Up @@ -718,6 +726,10 @@ mod integration_tests {
impl CryptoStore for PersistentMemoryStore {
type Error = <MemoryStore as CryptoStore>::Error;

async fn clear_caches(&self) {
self.0.clear_caches().await
}

async fn load_account(&self) -> Result<Option<Account>, Self::Error> {
self.0.load_account().await
}
Expand Down
Loading
Loading