Skip to content

Commit

Permalink
Improve custom proposal filtering (#205)
Browse files Browse the repository at this point in the history
* Filter invalid custom proposals before passing them to MlsRules

* Hard-fail sending invalid local proposals

* Fixup

* Bump version

---------

Co-authored-by: Marta Mularczyk <mulmarta@amazon.com>
  • Loading branch information
mulmarta and Marta Mularczyk authored Nov 6, 2024
1 parent 87c5dce commit df87607
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 45 deletions.
2 changes: 1 addition & 1 deletion mls-rs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mls-rs"
version = "0.42.0"
version = "0.42.1"
edition = "2021"
description = "An implementation of Messaging Layer Security (RFC 9420)"
homepage = "https://github.com/awslabs/mls-rs"
Expand Down
59 changes: 34 additions & 25 deletions mls-rs/src/group/proposal_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use alloc::vec::Vec;
use super::{
message_processor::ProvisionalState,
mls_rules::{CommitDirection, CommitSource, MlsRules},
proposal_filter::prepare_proposals_for_mls_rules,
GroupState, ProposalOrRef,
};
use crate::{
Expand All @@ -20,10 +21,7 @@ use crate::{

#[cfg(feature = "by_ref_proposal")]
use crate::{
group::{
message_hash::MessageHash, proposal_filter::FilterStrategy, ProposalMessageDescription,
ProposalRef, ProtocolVersion,
},
group::{message_hash::MessageHash, ProposalMessageDescription, ProposalRef, ProtocolVersion},
MlsMessage,
};

Expand Down Expand Up @@ -286,6 +284,8 @@ impl GroupState {
)),
}?;

prepare_proposals_for_mls_rules(&mut proposals, direction, &self.public_tree)?;

proposals = user_rules
.filter_proposals(direction, origin, &roster, group_extensions, proposals)
.await
Expand All @@ -304,18 +304,9 @@ impl GroupState {
);

#[cfg(feature = "by_ref_proposal")]
let applier_output = match direction {
CommitDirection::Send => {
applier
.apply_proposals(FilterStrategy::IgnoreByRef, &sender, proposals, commit_time)
.await?
}
CommitDirection::Receive => {
applier
.apply_proposals(FilterStrategy::IgnoreNone, &sender, proposals, commit_time)
.await?
}
};
let applier_output = applier
.apply_proposals(direction.into(), &sender, proposals, commit_time)
.await?;

#[cfg(not(feature = "by_ref_proposal"))]
let applier_output = applier
Expand Down Expand Up @@ -3973,7 +3964,7 @@ mod tests {
}

struct InjectMlsRules {
to_inject: Proposal,
to_inject: Vec<Proposal>,
source: ProposalSource,
}

Expand All @@ -3990,11 +3981,10 @@ mod tests {
_: &ExtensionList,
mut proposals: ProposalBundle,
) -> Result<ProposalBundle, Self::Error> {
proposals.add(
self.to_inject.clone(),
Sender::Member(0),
self.source.clone(),
);
for proposal in self.to_inject.iter().cloned() {
proposals.add(proposal, Sender::Member(0), self.source.clone());
}

Ok(proposals)
}

Expand Down Expand Up @@ -4027,7 +4017,7 @@ mod tests {
let (committed, _) =
CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules {
to_inject: test_proposal.clone(),
to_inject: vec![test_proposal.clone()],
source: ProposalSource::ByValue,
})
.send()
Expand All @@ -4049,7 +4039,7 @@ mod tests {
let (committed, _) =
CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules {
to_inject: test_proposal.clone(),
to_inject: vec![test_proposal.clone()],
source: ProposalSource::Local,
})
.send()
Expand All @@ -4069,7 +4059,7 @@ mod tests {

let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules {
to_inject: test_proposal.clone(),
to_inject: vec![test_proposal.clone()],
source: ProposalSource::ByValue,
})
.send()
Expand All @@ -4078,6 +4068,25 @@ mod tests {
assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender { .. }))
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn sending_invalid_local_proposal_fails() {
let (alice, tree) = new_tree("alice").await;
let gce_proposal = Proposal::GroupContextExtensions(Default::default());

let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
.with_user_rules(InjectMlsRules {
to_inject: vec![gce_proposal.clone(), gce_proposal],
source: ProposalSource::Local,
})
.send()
.await;

assert_matches!(
res,
Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
);
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn user_defined_filter_can_refuse_to_send_commit() {
let (alice, tree) = new_tree("alice").await;
Expand Down
5 changes: 1 addition & 4 deletions mls-rs/src/group/proposal_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ use filtering_lite as filtering;

pub use bundle::{ProposalBundle, ProposalInfo, ProposalSource};

#[cfg(feature = "by_ref_proposal")]
pub(crate) use filtering::FilterStrategy;

pub(crate) use filtering_common::ProposalApplier;
pub(crate) use filtering_common::{prepare_proposals_for_mls_rules, ProposalApplier};

#[cfg(all(feature = "by_ref_proposal", test))]
pub(crate) use filtering::proposer_can_propose;
11 changes: 9 additions & 2 deletions mls-rs/src/group/proposal_filter/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,19 @@ impl<T> ProposalInfo<T> {

#[inline(always)]
pub fn is_by_value(&self) -> bool {
self.source == ProposalSource::ByValue
!self.is_by_reference()
}

#[cfg(feature = "by_ref_proposal")]
#[inline(always)]
pub fn is_by_reference(&self) -> bool {
matches!(self.source, ProposalSource::ByReference(_))
}

#[cfg(not(feature = "by_ref_proposal"))]
#[inline(always)]
pub fn is_by_reference(&self) -> bool {
!self.is_by_value()
false
}

/// The [`ProposalRef`] of this proposal if its source is [`ProposalSource::ByReference`]
Expand Down
10 changes: 10 additions & 0 deletions mls-rs/src/group/proposal_filter/filtering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
AddProposal, ProposalType, RemoveProposal, Sender, UpdateProposal,
},
iter::wrap_iter,
mls_rules::CommitDirection,
protocol_version::ProtocolVersion,
time::MlsTime,
tree_kem::{
Expand Down Expand Up @@ -249,6 +250,15 @@ pub enum FilterStrategy {
IgnoreNone,
}

impl From<CommitDirection> for FilterStrategy {
fn from(value: CommitDirection) -> Self {
match value {
CommitDirection::Send => FilterStrategy::IgnoreByRef,
CommitDirection::Receive => FilterStrategy::IgnoreNone,
}
}
}

impl FilterStrategy {
pub(super) fn ignore(self, by_ref: bool) -> bool {
match self {
Expand Down
41 changes: 28 additions & 13 deletions mls-rs/src/group/proposal_filter/filtering_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
client::MlsError,
group::{proposal_filter::ProposalBundle, Sender},
key_package::{validate_key_package_properties, KeyPackage},
mls_rules::CommitDirection,
protocol_version::ProtocolVersion,
time::MlsTime,
tree_kem::{
Expand Down Expand Up @@ -130,19 +131,6 @@ where
Sender::NewMemberProposal => Err(MlsError::ExternalSenderCannotCommit),
}?;

#[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
let mut output = output;

#[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
filter_out_unsupported_custom_proposals(
&mut output.applied_proposals,
&output.new_tree,
strategy,
)?;

#[cfg(all(not(feature = "by_ref_proposal"), feature = "custom_proposal"))]
filter_out_unsupported_custom_proposals(proposals, &output.new_tree)?;

Ok(output)
}

Expand Down Expand Up @@ -359,6 +347,33 @@ where
}
}

#[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
pub(crate) fn prepare_proposals_for_mls_rules(
proposals: &mut ProposalBundle,
direction: CommitDirection,
tree: &TreeKemPublic,
) -> Result<(), MlsError> {
filter_out_unsupported_custom_proposals(proposals, tree, direction.into())
}

#[cfg(all(feature = "custom_proposal", not(feature = "by_ref_proposal")))]
pub(crate) fn prepare_proposals_for_mls_rules(
proposals: &mut ProposalBundle,
_direction: CommitDirection,
tree: &TreeKemPublic,
) -> Result<(), MlsError> {
filter_out_unsupported_custom_proposals(&proposals, tree)
}

#[cfg(not(feature = "custom_proposal"))]
pub(crate) fn prepare_proposals_for_mls_rules(
_: &mut ProposalBundle,
_: CommitDirection,
_: &TreeKemPublic,
) -> Result<(), MlsError> {
Ok(())
}

#[cfg(feature = "psk")]
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub(crate) async fn filter_out_invalid_psks<P, CP>(
Expand Down

0 comments on commit df87607

Please sign in to comment.