Skip to content

Commit

Permalink
refactor!: use strong types for outputs of DispersionParameters trait…
Browse files Browse the repository at this point in the history
… fns
  • Loading branch information
IceTDrinker committed Dec 3, 2024
1 parent 38a7e4f commit c731615
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn lwe_encrypt_decrypt_noise_distribution_custom_mod<Scalar: UnsignedTorus + Cas
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);

let expected_variance = Variance(lwe_noise_distribution.gaussian_std_dev().get_variance());
let expected_variance = lwe_noise_distribution.gaussian_std_dev().get_variance();

let mut rsc = TestResources::new();

Expand Down Expand Up @@ -93,7 +93,7 @@ fn lwe_compact_public_key_encryption_expected_variance(
lwe_dimension: LweDimension,
) -> Variance {
let input_variance = input_noise.get_variance();
Variance(input_variance * (lwe_dimension.to_lwe_size().0 as f64))
Variance(input_variance.0 * (lwe_dimension.to_lwe_size().0 as f64))
}

#[test]
Expand All @@ -104,7 +104,8 @@ fn test_variance_increase_cpk_formula() {
);

assert!(
(predicted_variance.get_standard_dev().log2() - 44.000704097196405f64).abs() < f64::EPSILON
(predicted_variance.get_standard_dev().0.log2() - 44.000704097196405f64).abs()
< f64::EPSILON
);
}

Expand All @@ -119,7 +120,7 @@ fn lwe_compact_public_encrypt_noise_distribution_custom_mod<
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);

let glwe_variance = Variance(glwe_noise_distribution.gaussian_std_dev().get_variance());
let glwe_variance = glwe_noise_distribution.gaussian_std_dev().get_variance();

let expected_variance =
lwe_compact_public_key_encryption_expected_variance(glwe_variance, lwe_dimension);
Expand Down Expand Up @@ -208,7 +209,7 @@ fn random_noise_roundtrip<Scalar: UnsignedTorus + CastInto<usize>>(

assert!(matches!(noise, DynamicDistribution::Gaussian(_)));

let expected_variance = Variance(noise.gaussian_std_dev().get_variance());
let expected_variance = noise.gaussian_std_dev().get_variance();

let num_outputs = 100_000;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fn lwe_encrypt_ks_decrypt_noise_distribution_custom_mod<Scalar: UnsignedTorus +
ciphertext_modulus.get_custom_modulus() as f64
};

let encryption_variance = Variance(glwe_noise_distribution.gaussian_std_dev().get_variance());
let encryption_variance = glwe_noise_distribution.gaussian_std_dev().get_variance();
let expected_variance = Variance(
encryption_variance.0
+ keyswitch_additive_variance_132_bits_security_gaussian(
Expand Down
180 changes: 119 additions & 61 deletions tfhe/src/core_crypto/commons/dispersion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@ use crate::core_crypto::backward_compatibility::commons::dispersion::StandardDev
// Clone because f64 is itself Copy and stored in register.
pub trait DispersionParameter: Copy {
/// Return the standard deviation of the distribution, i.e. $\sigma = 2^p$.
fn get_standard_dev(&self) -> f64;
fn get_standard_dev(&self) -> StandardDev;
/// Return the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$.
fn get_variance(&self) -> f64;
fn get_variance(&self) -> Variance;
/// Return base 2 logarithm of the standard deviation of the distribution, i.e.
/// $\log\_2(\sigma)=p$
fn get_log_standard_dev(&self) -> f64;
fn get_log_standard_dev(&self) -> LogStandardDev;
/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$.
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64;
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev;

/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$.
fn get_modular_variance(&self, log2_modulus: u32) -> f64;
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance;

/// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$.
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64;
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev;
}

fn log2_modulus_to_modulus(log2_modulus: u32) -> f64 {
2.0f64.powi(log2_modulus as i32)
}

/// A distribution parameter that uses the base-2 logarithm of the standard deviation as
Expand All @@ -49,22 +53,31 @@ pub trait DispersionParameter: Copy {
/// ```rust
/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, LogStandardDev};
/// let params = LogStandardDev::from_log_standard_dev(-25.);
/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev(), -25.);
/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.));
/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.);
/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_variance(32),
/// params.get_modular_standard_dev(32).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
///
/// let modular_params = LogStandardDev::from_modular_log_standard_dev(22., 32);
/// assert_eq!(modular_params.get_standard_dev(), 2_f64.powf(-10.));
/// assert_eq!(modular_params.get_standard_dev().0, 2_f64.powf(-10.));
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct LogStandardDev(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularLogStandardDev {
pub value: f64,
pub modulus: f64,
}

impl LogStandardDev {
pub fn from_log_standard_dev(log_std: f64) -> Self {
Self(log_std)
Expand All @@ -76,23 +89,32 @@ impl LogStandardDev {
}

impl DispersionParameter for LogStandardDev {
fn get_standard_dev(&self) -> f64 {
f64::powf(2., self.0)
fn get_standard_dev(&self) -> StandardDev {
StandardDev(f64::powf(2., self.0))
}
fn get_variance(&self) -> f64 {
f64::powf(2., self.0 * 2.)
fn get_variance(&self) -> Variance {
Variance(f64::powf(2., self.0 * 2.))
}
fn get_log_standard_dev(&self) -> f64 {
self.0
fn get_log_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 {
f64::powf(2., log2_modulus as f64 + self.0)
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
ModularStandardDev {
value: f64::powf(2., log2_modulus as f64 + self.0),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> f64 {
f64::powf(2., (log2_modulus as f64 + self.0) * 2.)
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
ModularVariance {
value: f64::powf(2., (log2_modulus as f64 + self.0) * 2.),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 {
log2_modulus as f64 + self.0
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0,
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
}

Expand All @@ -103,20 +125,29 @@ impl DispersionParameter for LogStandardDev {
/// ```rust
/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, StandardDev};
/// let params = StandardDev::from_standard_dev(2_f64.powf(-25.));
/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev(), -25.);
/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.));
/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.);
/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32),
/// params.get_modular_variance(32).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Serialize, Deserialize, Versionize)]
#[versionize(StandardDevVersions)]
pub struct StandardDev(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularStandardDev {
pub value: f64,
pub modulus: f64,
}

impl StandardDev {
pub fn from_standard_dev(std: f64) -> Self {
Self(std)
Expand All @@ -128,23 +159,32 @@ impl StandardDev {
}

impl DispersionParameter for StandardDev {
fn get_standard_dev(&self) -> f64 {
self.0
fn get_standard_dev(&self) -> Self {
Self(self.0)
}
fn get_variance(&self) -> f64 {
self.0.powi(2)
fn get_variance(&self) -> Variance {
Variance(self.0.powi(2))
}
fn get_log_standard_dev(&self) -> f64 {
self.0.log2()
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.log2())
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 {
2_f64.powf(log2_modulus as f64 + self.0.log2())
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(log2_modulus as f64 + self.0.log2()),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> f64 {
2_f64.powf(2. * (log2_modulus as f64 + self.0.log2()))
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
ModularVariance {
value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 {
log2_modulus as f64 + self.0.log2()
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0.log2(),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
}

Expand All @@ -155,19 +195,28 @@ impl DispersionParameter for StandardDev {
/// ```rust
/// use tfhe::core_crypto::commons::dispersion::{DispersionParameter, Variance};
/// let params = Variance::from_variance(2_f64.powi(-50));
/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev(), -25.);
/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2));
/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.));
/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.);
/// assert_eq!(params.get_standard_dev().0, 2_f64.powf(-25.));
/// assert_eq!(params.get_log_standard_dev().0, -25.);
/// assert_eq!(params.get_variance().0, 2_f64.powf(-25.).powi(2));
/// assert_eq!(
/// params.get_modular_standard_dev(32).value,
/// 2_f64.powf(32. - 25.)
/// );
/// assert_eq!(params.get_modular_log_standard_dev(32).value, 32. - 25.);
/// assert_eq!(
/// params.get_modular_variance(32),
/// params.get_modular_variance(32).value,
/// 2_f64.powf(32. - 25.).powi(2)
/// );
/// ```
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct Variance(pub f64);

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
pub struct ModularVariance {
pub value: f64,
pub modulus: f64,
}

impl Variance {
pub fn from_variance(var: f64) -> Self {
Self(var)
Expand All @@ -179,22 +228,31 @@ impl Variance {
}

impl DispersionParameter for Variance {
fn get_standard_dev(&self) -> f64 {
self.0.sqrt()
fn get_standard_dev(&self) -> StandardDev {
StandardDev(self.0.sqrt())
}
fn get_variance(&self) -> f64 {
self.0
fn get_variance(&self) -> Self {
Self(self.0)
}
fn get_log_standard_dev(&self) -> f64 {
self.0.sqrt().log2()
fn get_log_standard_dev(&self) -> LogStandardDev {
LogStandardDev(self.0.sqrt().log2())
}
fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 {
2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2())
fn get_modular_standard_dev(&self, log2_modulus: u32) -> ModularStandardDev {
ModularStandardDev {
value: 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_variance(&self, log2_modulus: u32) -> f64 {
2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2()))
fn get_modular_variance(&self, log2_modulus: u32) -> ModularVariance {
ModularVariance {
value: 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 {
log2_modulus as f64 + self.0.sqrt().log2()
fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> ModularLogStandardDev {
ModularLogStandardDev {
value: log2_modulus as f64 + self.0.sqrt().log2(),
modulus: log2_modulus_to_modulus(log2_modulus),
}
}
}
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/commons/math/random/gaussian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl Gaussian<f64> {

pub fn from_dispersion_parameter(dispersion: impl DispersionParameter, mean: f64) -> Self {
Self {
std: dispersion.get_standard_dev(),
std: dispersion.get_standard_dev().0,
mean,
}
}
Expand Down
4 changes: 1 addition & 3 deletions tfhe/src/core_crypto/commons/math/random/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,7 @@ impl<T: UnsignedInteger> DynamicDistribution<T> {
#[track_caller]
pub fn gaussian_variance(&self) -> Variance {
match self {
Self::Gaussian(gaussian) => {
Variance(StandardDev::from_standard_dev(gaussian.std).get_variance())
}
Self::Gaussian(gaussian) => StandardDev::from_standard_dev(gaussian.std).get_variance(),
Self::TUniform(_) => {
panic!("Tried to get gaussian variance from a non gaussian distribution")
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/core_crypto/commons/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ pub mod test_tools {
{
for (x, y) in first.as_ref().iter().zip(second.as_ref().iter()) {
println!("{:?}, {:?}", *x, *y);
println!("{}", dist.get_standard_dev());
println!("{:?}", dist.get_standard_dev());
let distance: f64 = modular_distance(*x, *y).cast_into();
let torus_distance = distance / 2_f64.powi(Element::BITS as i32);
assert!(
torus_distance <= 5. * dist.get_standard_dev(),
torus_distance <= 5. * dist.get_standard_dev().0,
"{x} != {y} "
);
}
Expand Down

0 comments on commit c731615

Please sign in to comment.