Skip to content

Commit

Permalink
feat(core): add batched programmable boostraping
Browse files Browse the repository at this point in the history
  • Loading branch information
soonum committed Oct 25, 2024
1 parent a88597b commit e9af460
Show file tree
Hide file tree
Showing 5 changed files with 556 additions and 2 deletions.
122 changes: 122 additions & 0 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,127 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
}
}

fn mem_optimized_batched_pbs<Scalar: UnsignedTorus + CastInto<usize> + Serialize>(
c: &mut Criterion,
parameters: &[(String, CryptoParametersRecord<Scalar>)],
) {
let bench_name = "core_crypto::batched_pbs_mem_optimized";
let mut bench_group = c.benchmark_group(bench_name);
bench_group
.sample_size(15)
.measurement_time(std::time::Duration::from_secs(10));

// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());

for (name, params) in parameters.iter() {
// Create the LweSecretKey
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
params.lwe_dimension.unwrap(),
&mut secret_generator,
);
let output_glwe_secret_key: GlweSecretKeyOwned<Scalar> =
allocate_and_generate_new_binary_glwe_secret_key(
params.glwe_dimension.unwrap(),
params.polynomial_size.unwrap(),
&mut secret_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key();

// Create the empty bootstrapping key in the Fourier domain
let fourier_bsk = FourierLweBootstrapKey::new(
params.lwe_dimension.unwrap(),
params.glwe_dimension.unwrap().to_glwe_size(),
params.polynomial_size.unwrap(),
params.pbs_base_log.unwrap(),
params.pbs_level.unwrap(),
);

let count = 10; // FIXME Is it a representative value (big enough?)

// Allocate a new LweCiphertext and encrypt our plaintext
let mut lwe_ciphertext_in = LweCiphertextListOwned::<Scalar>::new(
Scalar::ZERO,
input_lwe_secret_key.lwe_dimension().to_lwe_size(),
LweCiphertextCount(count),
params.ciphertext_modulus.unwrap(),
);

encrypt_lwe_ciphertext_list(
&input_lwe_secret_key,
&mut lwe_ciphertext_in,
&PlaintextList::from_container(vec![Scalar::ZERO; count]),
params.lwe_noise_distribution.unwrap(),
&mut encryption_generator,
);

let accumulator = GlweCiphertextList::new(
Scalar::ZERO,
params.glwe_dimension.unwrap().to_glwe_size(),
params.polynomial_size.unwrap(),
GlweCiphertextCount(count),
params.ciphertext_modulus.unwrap(),
);

// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct = LweCiphertextList::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
LweCiphertextCount(count),
params.ciphertext_modulus.unwrap(),
);

let mut buffers = ComputationBuffers::new();

let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();

buffers.resize(
batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<Scalar>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
CiphertextCount(count),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);

let id = format!("{bench_name}::{name}");
{
bench_group.bench_function(&id, |b| {
b.iter(|| {
batch_programmable_bootstrap_lwe_ciphertext_mem_optimized(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator,
&fourier_bsk,
fft,
buffers.stack(),
);
black_box(&mut out_pbs_ct);
})
});
}

let bit_size = (params.message_modulus.unwrap_or(2) as u32).ilog2();
write_to_json(
&id,
*params,
name,
"pbs",
&OperatorType::Atomic,
bit_size,
vec![bit_size],
);
}
}

fn multi_bit_pbs<
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Default + Sync + Serialize,
>(
Expand Down Expand Up @@ -1305,6 +1426,7 @@ pub fn pbs_group() {
mem_optimized_pbs(&mut criterion, &benchmark_parameters_64bits());
mem_optimized_pbs(&mut criterion, &benchmark_parameters_32bits());
mem_optimized_pbs_ntt(&mut criterion);
mem_optimized_batched_pbs(&mut criterion, &benchmark_parameters_64bits());
}

pub fn multi_bit_pbs_group() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::{
blind_rotate_assign_scratch, bootstrap_scratch,
batch_bootstrap_scratch, blind_rotate_assign_scratch, bootstrap_scratch,
};
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
add_external_product_assign as impl_add_external_product_assign,
Expand Down Expand Up @@ -1072,3 +1072,91 @@ pub fn programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputSca
) -> Result<StackReq, SizeOverflow> {
bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, fft)
}

/// This function takes list as input and output and computes the programmable bootstrap for each
/// slot progressively loading the bootstrapping key only once. The caller must provide
/// a properly configured [`FftView`] object and a `PodStack` used as a memory buffer having a
/// capacity at least as large as the result of
/// [`batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement`].
pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized<
InputScalar,
OutputScalar,
InputCont,
OutputCont,
AccCont,
KeyCont,
>(
input: &LweCiphertextList<InputCont>,
output: &mut LweCiphertextList<OutputCont>,
accumulator: &GlweCiphertextList<AccCont>,
fourier_bsk: &FourierLweBootstrapKey<KeyCont>,
fft: FftView<'_>,
stack: PodStack<'_>,
) where
// CastInto required for PBS modulus switch which returns a usize
InputScalar: UnsignedTorus + CastInto<usize>,
OutputScalar: UnsignedTorus,
InputCont: Container<Element = InputScalar>,
OutputCont: ContainerMut<Element = OutputScalar>,
AccCont: Container<Element = OutputScalar>,
KeyCont: Container<Element = c64>,
{
assert_eq!(
accumulator.ciphertext_modulus(),
output.ciphertext_modulus(),
"Mismatched moduli between accumulator ({:?}) and output ({:?})",
accumulator.ciphertext_modulus(),
output.ciphertext_modulus()
);

assert_eq!(
fourier_bsk.input_lwe_dimension(),
input.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
FourierLweBootstrapKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
fourier_bsk.input_lwe_dimension(),
input.lwe_size().to_lwe_dimension(),
);
assert_eq!(
fourier_bsk.output_lwe_dimension(),
output.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
FourierLweBootstrapKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
fourier_bsk.output_lwe_dimension(),
output.lwe_size().to_lwe_dimension(),
);
assert_eq!(
input.lwe_ciphertext_count().0,
output.lwe_ciphertext_count().0,
"Mismatched list length. \
input LweCiphertextList length: {:?}, output LweCiphertextList length {:?}.",
input.lwe_ciphertext_count().0,
output.lwe_ciphertext_count().0,
);
assert_eq!(
input.lwe_ciphertext_count().0,
accumulator.glwe_ciphertext_count().0,
"Mismatched list length. \
input LweCiphertextList length: {:?}, accumulator GlweCiphertextList length {:?}.",
input.lwe_ciphertext_count().0,
accumulator.glwe_ciphertext_count().0,
);

fourier_bsk.as_view().batch_bootstrap(
output.as_mut_view(),
input.as_view(),
&accumulator.as_view(),
fft,
stack,
);
}

/// Return the required memory for [`batch_programmable_bootstrap_lwe_ciphertext_mem_optimized`].
pub fn batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement<OutputScalar>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
ciphertext_count: CiphertextCount,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
batch_bootstrap_scratch::<OutputScalar>(glwe_size, polynomial_size, ciphertext_count, fft)
}
143 changes: 143 additions & 0 deletions tfhe/src/core_crypto/algorithms/test/lwe_programmable_bootstrapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,149 @@ where

create_parametrized_test!(lwe_encrypt_pbs_decrypt_custom_mod);

fn lwe_encrypt_batch_pbs_decrypt_custom_mod<Scalar>(params: ClassicTestParams<Scalar>)
where
Scalar: UnsignedTorus
+ Sync
+ Send
+ CastFrom<usize>
+ CastInto<usize>
+ Serialize
+ DeserializeOwned,
ClassicTestParams<Scalar>: KeyCacheAccess<Keys = ClassicBootstrapKeys<Scalar>>,
{
let lwe_noise_distribution = params.lwe_noise_distribution;
let ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let glwe_dimension = params.glwe_dimension;
let polynomial_size = params.polynomial_size;

let ciphertext_count = 2;

let mut rsc = TestResources::new();

let f = |x: Scalar| x;

let delta: Scalar = encoding_with_padding / msg_modulus;
let mut msg = msg_modulus;

let accumulator = generate_programmable_bootstrap_glwe_lut(
polynomial_size,
glwe_dimension.to_glwe_size(),
msg_modulus.cast_into(),
ciphertext_modulus,
delta,
f,
);

assert!(check_encrypted_content_respects_mod(
&accumulator,
ciphertext_modulus
));

while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);

let mut keys_gen = |params| generate_keys(params, &mut rsc);
let keys = gen_keys_or_get_from_cache_if_enabled(params, &mut keys_gen);
let (input_lwe_secret_key, output_lwe_secret_key, fbsk) =
(keys.small_lwe_sk, keys.big_lwe_sk, keys.fbsk);

for _ in 0..NB_TESTS {
let plaintext = msg * delta;

let mut lwe_ciphertext_in = LweCiphertextListOwned::<Scalar>::new(
Scalar::ZERO,
input_lwe_secret_key.lwe_dimension().to_lwe_size(),
LweCiphertextCount(ciphertext_count),
ciphertext_modulus,
);

encrypt_lwe_ciphertext_list(
&input_lwe_secret_key,
&mut lwe_ciphertext_in,
&PlaintextList::from_container(vec![plaintext; ciphertext_count]),
lwe_noise_distribution,
&mut rsc.encryption_random_generator,
);

assert!(lwe_ciphertext_in
.iter()
.all(|ct| check_encrypted_content_respects_mod(&ct, ciphertext_modulus)));

let mut accumulator_list = GlweCiphertextList::new(
Scalar::ZERO,
glwe_dimension.to_glwe_size(),
polynomial_size,
GlweCiphertextCount(ciphertext_count),
ciphertext_modulus,
);

for mut glwe in accumulator_list.iter_mut() {
glwe.as_mut().copy_from_slice(accumulator.as_ref());
}

// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct = LweCiphertextList::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
LweCiphertextCount(ciphertext_count),
ciphertext_modulus,
);

let mut buffers = ComputationBuffers::new();

let fft = Fft::new(fbsk.polynomial_size());
let fft = fft.as_view();

buffers.resize(
batch_programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<Scalar>(
fbsk.glwe_size(),
fbsk.polynomial_size(),
CiphertextCount(ciphertext_count),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);

batch_programmable_bootstrap_lwe_ciphertext_mem_optimized(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator_list,
&fbsk,
fft,
buffers.stack(),
);

assert!(out_pbs_ct
.iter()
.all(|ct| check_encrypted_content_respects_mod(&ct, ciphertext_modulus)));

let mut decrypted_list =
PlaintextList::new(Scalar::ZERO, PlaintextCount(ciphertext_count));

decrypt_lwe_ciphertext_list(&output_lwe_secret_key, &out_pbs_ct, &mut decrypted_list);

let decoded_list = decrypted_list
.iter()
.map(|ct| round_decode(*ct.0, delta) % msg_modulus)
.collect::<Vec<Scalar>>();

assert!(decoded_list.iter().all(|ct| *ct == f(msg)));
}

// In coverage, we break after one while loop iteration, changing message values does not
// yield higher coverage
#[cfg(tarpaulin)]
break;
}
}

create_parametrized_test!(lwe_encrypt_batch_pbs_decrypt_custom_mod);

// Here we will define a helper function to generate a many lut accumulator for a PBS
fn generate_accumulator_many_lut<Scalar: UnsignedTorus + CastFrom<usize>>(
polynomial_size: PolynomialSize,
Expand Down
Loading

0 comments on commit e9af460

Please sign in to comment.