Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protocol context is a trait #253

Merged
1 change: 1 addition & 0 deletions src/helpers/messaging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ impl Debug for ReceiveRequest {
mod tests {
use crate::ff::Fp31;
use crate::helpers::Role;
use crate::protocol::context::Context;
use crate::protocol::{QueryId, RecordId};
use crate::test_fixture::{make_contexts, make_world_with_config, TestWorldConfig};

Expand Down
7 changes: 4 additions & 3 deletions src/protocol/attribution/accumulate_credit.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use super::{AccumulateCreditInputRow, AccumulateCreditOutputRow, AttributionInputRow};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::mul::SecureMul;
use crate::protocol::IterStep;
use crate::{
error::Error,
ff::Field,
protocol::{
batch::{Batch, RecordIndex},
context::ProtocolContext,
context::Context,
RecordId,
},
secret_sharing::Replicated,
Expand Down Expand Up @@ -55,7 +56,7 @@ impl<'a, F: Field> AccumulateCredit<'a, F> {
#[allow(dead_code)]
pub async fn execute(
&self,
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
) -> Result<Batch<AccumulateCreditOutputRow<F>>, Error> {
#[allow(clippy::cast_possible_truncation)]
let num_rows = self.input.len() as RecordIndex;
Expand Down Expand Up @@ -161,7 +162,7 @@ impl<'a, F: Field> AccumulateCredit<'a, F> {
}

async fn get_accumulated_credit(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
current: AccumulateCreditInputRow<F>,
successor: AccumulateCreditInputRow<F>,
Expand Down
12 changes: 7 additions & 5 deletions src/protocol/boolean/bitwise_lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use super::xor::xor;
use super::BitOpStep;
use crate::error::Error;
use crate::ff::Field;
use crate::protocol::{context::ProtocolContext, mul::SecureMul, RecordId};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::{context::Context, mul::SecureMul, RecordId};
use crate::secret_sharing::Replicated;
use futures::future::try_join_all;
use std::iter::{repeat, zip};
Expand Down Expand Up @@ -39,7 +40,7 @@ impl BitwiseLessThan {
async fn step1<F: Field>(
a: &[Replicated<F>],
b: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let xor = zip(a, b).enumerate().map(|(i, (a_bit, b_bit))| {
Expand Down Expand Up @@ -67,7 +68,7 @@ impl BitwiseLessThan {
/// ```
async fn step2<F: Field>(
e: &mut [Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
e.reverse();
Expand Down Expand Up @@ -107,7 +108,7 @@ impl BitwiseLessThan {
async fn step5<F: Field>(
g: &[Replicated<F>],
b: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let mul = zip(repeat(ctx), zip(g, b))
Expand All @@ -130,7 +131,7 @@ impl BitwiseLessThan {
#[allow(dead_code)]
#[allow(clippy::many_single_char_names)]
pub async fn execute<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
a: &[Replicated<F>],
b: &[Replicated<F>],
Expand Down Expand Up @@ -167,6 +168,7 @@ impl AsRef<str> for Step {
#[cfg(test)]
mod tests {
use super::BitwiseLessThan;
use crate::protocol::context::Context;
use crate::{
error::Error,
ff::{Field, Fp31, Fp32BitPrime},
Expand Down
6 changes: 4 additions & 2 deletions src/protocol/boolean/bitwise_sum.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::carries::Carries;
use crate::error::BoxError;
use crate::ff::Field;
use crate::protocol::{context::ProtocolContext, RecordId};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::{context::Context, RecordId};
use crate::secret_sharing::Replicated;

/// This is an implementation of Bitwise Sum on bitwise-shared numbers.
Expand Down Expand Up @@ -31,7 +32,7 @@ impl BitwiseSum {
#[allow(dead_code)]
#[allow(clippy::many_single_char_names)]
pub async fn execute<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
a: &[Replicated<F>],
b: &[Replicated<F>],
Expand Down Expand Up @@ -82,6 +83,7 @@ impl AsRef<str> for Step {
#[cfg(test)]
mod tests {
use super::BitwiseSum;
use crate::protocol::context::Context;
use crate::{
error::BoxError,
ff::{Field, Fp31, Fp32BitPrime},
Expand Down
12 changes: 7 additions & 5 deletions src/protocol/boolean/carries.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::BitOpStep;
use crate::error::Error;
use crate::ff::Field;
use crate::protocol::{context::ProtocolContext, mul::SecureMul, RecordId};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::{context::Context, mul::SecureMul, RecordId};
use crate::secret_sharing::Replicated;
use futures::future::try_join_all;
use std::iter::{repeat, zip};
Expand Down Expand Up @@ -38,7 +39,7 @@ impl Carries {
#[allow(dead_code)]
#[allow(clippy::many_single_char_names)]
pub async fn execute<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
a: &[Replicated<F>],
b: &[Replicated<F>],
Expand Down Expand Up @@ -66,7 +67,7 @@ impl Carries {
async fn step1<F: Field>(
a: &[Replicated<F>],
b: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let mul = zip(repeat(ctx), zip(a, b))
Expand Down Expand Up @@ -123,7 +124,7 @@ impl Carries {
/// computes `([f_0],...,[f_l]) ← PRE○([e_0],...,[e_l])`.
async fn step3<F: Field>(
e: &[CarryPropagationShares<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<CarryPropagationShares<F>>, Error> {
let l = e.len();
Expand All @@ -149,7 +150,7 @@ impl Carries {
#[allow(clippy::many_single_char_names)]
async fn fan_in_carry_propagation<F: Field>(
e: &[CarryPropagationShares<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<CarryPropagationShares<F>, Error> {
let l = e.len();
Expand Down Expand Up @@ -231,6 +232,7 @@ mod tests {
use super::Carries;
use crate::error::Error;
use crate::ff::{Field, Fp31, Fp32BitPrime};
use crate::protocol::context::Context;
use crate::protocol::{QueryId, RecordId};
use crate::secret_sharing::Replicated;
use crate::test_fixture::{
Expand Down
6 changes: 4 additions & 2 deletions src/protocol/boolean/or.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::error::Error;
use crate::ff::Field;
use crate::protocol::{context::ProtocolContext, mul::SecureMul, RecordId};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::{context::Context, mul::SecureMul, RecordId};
use crate::secret_sharing::Replicated;

/// Secure XOR protocol with two inputs, `a, b ∈ {0,1} ⊆ F_p`.
/// It computes `[a] + [b] - 2[ab]`
pub async fn or<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
a: &Replicated<F>,
b: &Replicated<F>,
Expand All @@ -19,6 +20,7 @@ pub async fn or<F: Field>(
#[cfg(test)]
mod tests {
use super::or;
use crate::protocol::context::Context;
use crate::{
error::Error,
ff::{Field, Fp31},
Expand Down
18 changes: 10 additions & 8 deletions src/protocol/boolean/prefix_or.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use super::{or::or, BitOpStep};
use crate::error::Error;
use crate::ff::Field;
use crate::protocol::{context::ProtocolContext, mul::SecureMul, RecordId};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::{context::Context, mul::SecureMul, RecordId};
use crate::secret_sharing::Replicated;
use futures::future::try_join_all;
use std::iter::{repeat, zip};
Expand All @@ -27,7 +28,7 @@ impl PrefixOr {
async fn block_or<F: Field>(
a: &[Replicated<F>],
k: usize,
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Replicated<F>, Error> {
#[allow(clippy::cast_possible_truncation)]
Expand All @@ -53,7 +54,7 @@ impl PrefixOr {
async fn step1<F: Field>(
a: &[Replicated<F>],
lambda: usize,
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let mut futures = Vec::with_capacity(lambda);
Expand All @@ -74,7 +75,7 @@ impl PrefixOr {
/// ```
async fn step2<F: Field>(
x: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let lambda = x.len();
Expand Down Expand Up @@ -116,7 +117,7 @@ impl PrefixOr {
async fn step5<F: Field>(
f: &[Replicated<F>],
a: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let lambda = f.len();
Expand Down Expand Up @@ -160,7 +161,7 @@ impl PrefixOr {
/// ```
async fn step7<F: Field>(
c: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let lambda = c.len();
Expand All @@ -186,7 +187,7 @@ impl PrefixOr {
async fn step8<F: Field>(
f: &[Replicated<F>],
b: &[Replicated<F>],
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
let lambda = f.len();
Expand Down Expand Up @@ -227,7 +228,7 @@ impl PrefixOr {
#[allow(dead_code)]
#[allow(clippy::many_single_char_names)]
pub async fn execute<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
input: &[Replicated<F>],
) -> Result<Vec<Replicated<F>>, Error> {
Expand Down Expand Up @@ -303,6 +304,7 @@ impl AsRef<str> for Step {
#[cfg(test)]
mod tests {
use super::PrefixOr;
use crate::protocol::context::Context;
use crate::{
error::Error,
ff::{Field, Fp2, Fp31},
Expand Down
21 changes: 11 additions & 10 deletions src/protocol/boolean/solved_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use crate::error::Error;
use crate::ff::{Field, Int};
use crate::helpers::Role;
use crate::protocol::boolean::BitOpStep;
use crate::protocol::context::SemiHonestContext;
use crate::protocol::modulus_conversion::convert_shares::{ConvertShares, XorShares};
use crate::protocol::reveal::Reveal;
use crate::protocol::{context::ProtocolContext, RecordId};
use crate::protocol::{context::Context, RecordId};
use crate::secret_sharing::Replicated;
use futures::future::try_join_all;
use std::iter::repeat;
Expand Down Expand Up @@ -41,7 +42,7 @@ impl SolvedBits {
// 1/2^3 =~ 13% (3 high bits being all 0's).
#[allow(dead_code)]
pub async fn execute<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Option<RandomBitsShare<F>>, Error> {
//
Expand Down Expand Up @@ -74,7 +75,7 @@ impl SolvedBits {

/// Generates a sequence of `l` random bit sharings in the target field `F`.
async fn generate_random_bits<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
) -> Result<Vec<Replicated<F>>, Error> {
// Calculate the number of bits we need to form a random number that
Expand Down Expand Up @@ -120,7 +121,7 @@ impl SolvedBits {
}

async fn is_less_than_p<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
b_b: &[Replicated<F>],
) -> Result<bool, Error> {
Expand Down Expand Up @@ -180,19 +181,19 @@ impl AsRef<str> for Step {
#[cfg(test)]
mod tests {
use super::SolvedBits;
use crate::protocol::context::SemiHonestContext;
use crate::{
error::Error,
ff::{Field, Fp31, Fp32BitPrime},
protocol::{context::ProtocolContext, QueryId, RecordId},
secret_sharing::Replicated,
protocol::{QueryId, RecordId},
test_fixture::{
bits_to_value, join3, make_contexts, make_world, validate_and_reconstruct, TestWorld,
},
};
use rand::{distributions::Standard, prelude::Distribution};

async fn random_bits<F: Field>(
ctx: [ProtocolContext<'_, Replicated<F>, F>; 3],
ctx: [SemiHonestContext<'_, F>; 3],
record_id: RecordId,
) -> Result<Option<(Vec<F>, F)>, Error>
where
Expand All @@ -202,9 +203,9 @@ mod tests {

// Execute
let [result0, result1, result2] = join3(
SolvedBits::execute(c0.bind(record_id), record_id),
SolvedBits::execute(c1.bind(record_id), record_id),
SolvedBits::execute(c2.bind(record_id), record_id),
SolvedBits::execute(c0, record_id),
SolvedBits::execute(c1, record_id),
SolvedBits::execute(c2, record_id),
)
.await;

Expand Down
6 changes: 4 additions & 2 deletions src/protocol/boolean/xor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::error::Error;
use crate::ff::Field;
use crate::protocol::{context::ProtocolContext, mul::SecureMul, RecordId};
use crate::protocol::context::SemiHonestContext;
use crate::protocol::{mul::SecureMul, RecordId};
use crate::secret_sharing::Replicated;

/// Secure XOR protocol with two inputs, `a, b ∈ {0,1} ⊆ F_p`.
/// It computes `[a] + [b] - 2[ab]`
pub async fn xor<F: Field>(
ctx: ProtocolContext<'_, Replicated<F>, F>,
ctx: SemiHonestContext<'_, F>,
record_id: RecordId,
a: &Replicated<F>,
b: &Replicated<F>,
Expand All @@ -18,6 +19,7 @@ pub async fn xor<F: Field>(
#[cfg(test)]
mod tests {
use super::xor;
use crate::protocol::context::Context;
use crate::{
error::Error,
ff::{Field, Fp31},
Expand Down
Loading