diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs index f3fce4a406..367f2a21ad 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/operator.rs @@ -3,6 +3,7 @@ use std::iter::{empty, once}; use std::ops::Deref; use crate::dag::operator::tensor::{ClearTensor, Shape}; +use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition; use super::DotKind; @@ -104,6 +105,11 @@ pub enum Operator { input: OperatorIndex, out_precision: Precision, }, + ChangePartition { + input: OperatorIndex, + src_partition: Option, + dst_partition: Option, + }, } impl Operator { @@ -114,7 +120,8 @@ impl Operator { Self::LevelledOp { inputs, .. } | Self::Dot { inputs, .. } => Box::new(inputs.iter()), Self::UnsafeCast { input, .. } | Self::Lut { input, .. } - | Self::Round { input, .. } => Box::new(once(input)), + | Self::Round { input, .. } + | Self::ChangePartition { input, .. } => Box::new(once(input)), } } } @@ -190,6 +197,23 @@ impl fmt::Display for Operator { } => { write!(f, "ROUND[%{}] : u{out_precision}", input.0)?; } + Self::ChangePartition { + input, + src_partition, + dst_partition, + } => { + write!(f, "CHANGE_PARTITION[%{}] : {{", input.0)?; + if let Some(partition) = src_partition { + write!(f, "src_partition: {}", partition.name)?; + } + if let Some(partition) = dst_partition { + if src_partition.is_some() { + write!(f, ", ")?; + } + write!(f, "dst_partition: {}", partition.name)?; + } + write!(f, "}}")?; + } } Ok(()) } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index dc9403dcbd..24c38a08c9 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -8,7 +8,8 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { Operator::Input { .. } => (), Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } - | Operator::Round { input, .. } => input.0 = old_index_to_new[input.0], + | Operator::Round { input, .. } + | Operator::ChangePartition { input, .. } => input.0 = old_index_to_new[input.0], Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { for input in inputs { input.0 = old_index_to_new[input.0]; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index 4877253464..a59e0bd4c7 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -1,6 +1,7 @@ use crate::dag::operator::{ FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, }; +use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition; use std::{ collections::{HashMap, HashSet}, fmt, @@ -259,6 +260,27 @@ impl<'dag> DagBuilder<'dag> { ) } + pub fn add_change_partition( + &mut self, + input: OperatorIndex, + src_partition: Option, + dst_partition: Option, + location: Location, + ) -> OperatorIndex { + assert!( + src_partition.is_some() || dst_partition.is_some(), + "change_partition: src or dest partition need to be set" + ); + self.add_operator( + Operator::ChangePartition { + input, + src_partition, + dst_partition, + }, + location, + ) + } + pub fn add_round_op( &mut self, input: OperatorIndex, @@ -416,7 +438,8 @@ impl<'dag> DagBuilder<'dag> { } Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } - | Operator::Round { input, .. } => self.dag.out_shapes[input.0].clone(), + | Operator::Round { input, .. } + | Operator::ChangePartition { input, .. } => self.dag.out_shapes[input.0].clone(), Operator::Dot { kind: DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor, .. @@ -451,6 +474,7 @@ impl<'dag> DagBuilder<'dag> { Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { self.dag.out_precisions[inputs[0].0] } + Operator::ChangePartition { input, .. } => self.dag.out_precisions[input.0], } } } @@ -577,6 +601,20 @@ impl Dag { .add_lut(input, table, out_precision, Location::Unknown) } + pub fn add_change_partition( + &mut self, + input: OperatorIndex, + src_partition: Option, + dst_partition: Option, + ) -> OperatorIndex { + self.builder(DEFAULT_CIRCUIT).add_change_partition( + input, + src_partition, + dst_partition, + Location::Unknown, + ) + } + pub fn add_dot( &mut self, inputs: impl Into>, @@ -834,8 +872,20 @@ impl Dag { #[cfg(test)] mod tests { + use crate::{ + optimization::dag::multi_parameters::optimize::MacroParameters, parameters::GlweParameters, + }; + use super::*; + const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters { + glwe_params: GlweParameters { + log2_polynomial_size: 0, + glwe_dimension: 0, + }, + internal_dim: 0, + }; + #[test] fn output_marking() { let mut graph = Dag::new(); @@ -852,16 +902,25 @@ mod tests { #[allow(clippy::many_single_char_names)] fn graph_builder() { let mut graph = Dag::new(); + let tfhers_part = ExternalPartition { + name: String::from("tfhers"), + macro_params: DUMMY_MACRO_PARAM, + max_variance: 0.0_f64, + variance: 0.0_f64, + }; let mut builder = graph.builder("main1"); let a = builder.add_input(1, Shape::number(), Location::Unknown); let b = builder.add_input(1, Shape::number(), Location::Unknown); let c = builder.add_dot([a, b], [1, 1], Location::Unknown); - let _d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown); + let d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown); + let _d = + builder.add_change_partition(d, Some(tfhers_part.clone()), None, Location::Unknown); let mut builder = graph.builder("main2"); let e = builder.add_input(2, Shape::number(), Location::Unknown); let f = builder.add_input(2, Shape::number(), Location::Unknown); let g = builder.add_dot([e, f], [2, 2], Location::Unknown); - let _h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown); + let h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown); + let _h = builder.add_change_partition(h, None, Some(tfhers_part), Location::Unknown); graph.tag_operator_as_output(c); } @@ -897,7 +956,16 @@ mod tests { let lut2 = builder.add_lut(dot, FunctionTable::UNKWOWN, 2, Location::Unknown); - let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2]; + let tfhers_part = ExternalPartition { + name: String::from("tfhers"), + macro_params: DUMMY_MACRO_PARAM, + max_variance: 0.0_f64, + variance: 0.0_f64, + }; + let change_part = + builder.add_change_partition(lut2, Some(tfhers_part.clone()), None, Location::Unknown); + + let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2, change_part]; for (expected_i, op_index) in ops_index.iter().enumerate() { assert_eq!(expected_i, op_index.0); } @@ -944,6 +1012,11 @@ mod tests { input: dot, table: FunctionTable::UNKWOWN, out_precision: 2, + }, + Operator::ChangePartition { + input: lut2, + src_partition: Some(tfhers_part.clone()), + dst_partition: None } ] ); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index ce66c2c68a..6e7662edf2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -121,6 +121,7 @@ pub struct VariancedDag { pub(crate) dag: Dag, pub(crate) partitions: Partitions, pub(crate) variances: Variances, + pub(crate) external_variance_constraints: Vec, } impl VariancedDag { @@ -133,9 +134,11 @@ impl VariancedDag { dag, partitions, variances, + external_variance_constraints: vec![], }; // We forward the noise once to verify the composability. + varianced.apply_external_partition_input_variance(); let _ = varianced.forward_noise(); varianced.check_composability()?; varianced.apply_composition_rules(); @@ -145,6 +148,8 @@ impl VariancedDag { // The noise gets computed from inputs down to outputs. if varianced.forward_noise() { // Noise settled, we return the varianced dag. + varianced.collect_external_input_constraint(); + varianced.collect_external_output_constraint(); return Ok(varianced); } // The noise of the inputs gets updated following the composition rules @@ -181,6 +186,106 @@ impl VariancedDag { } } + fn apply_external_partition_input_variance(&mut self) { + let p_cut = self.partitions.p_cut.clone(); + for (i, op) in self.dag.operators.clone().iter().enumerate() { + if let Operator::Input { .. } = op { + let partition_index = self.partitions.instrs_partition[i].instruction_partition; + if p_cut.is_external_partition(&partition_index) { + let partitions = self.partitions.clone(); + let external_partition = + &p_cut.external_partitions[p_cut.external_partition_index(partition_index)]; + let max_variance = external_partition.max_variance; + let variance = external_partition.variance; + + let mut input = self.get_operator_mut(OperatorIndex(i)); + let mut variances = input.variance().clone(); + variances.vars[partition_index.0] = SymbolicVariance::from_external_partition( + partitions.nb_partitions, + partition_index, + max_variance / variance, + ); + *(input.variance_mut()) = variances; + } + } + } + } + + fn collect_external_input_constraint(&mut self) { + let p_cut = &self.partitions.p_cut; + for (i, op) in self.dag.operators.clone().iter().enumerate() { + if let Operator::Input { + out_precision, + out_shape, + } = op + { + let partition_index = self.partitions.instrs_partition[i].instruction_partition; + if !p_cut.is_external_partition(&partition_index) { + continue; + } + + let max_variance = p_cut.external_partitions + [p_cut.external_partition_index(partition_index)] + .max_variance; + + let variances = &self.get_operator(OperatorIndex(i)).variance().vars.clone(); + for (i, variance) in variances.iter().enumerate() { + if variance.coeffs.is_nan() { + assert!(i != partition_index.0); + continue; + } + let constraint = VarianceConstraint { + precision: *out_precision, + partition: partition_index, + nb_constraints: out_shape.flat_size(), + safe_variance_bound: max_variance, + variance: variance.clone(), + }; + self.external_variance_constraints.push(constraint); + } + } + } + } + + fn collect_external_output_constraint(&mut self) { + let p_cut = self.partitions.p_cut.clone(); + for dag_op in self.dag.get_output_operators_iter() { + let DagOperator { + id: op_index, + shape: out_shape, + precision: out_precision, + .. + } = dag_op; + let optional_partition_index = p_cut.partition(&self.dag, op_index); + if optional_partition_index.is_none() { + continue; + } + let partition_index = optional_partition_index.unwrap(); + if !p_cut.is_external_partition(&partition_index) { + continue; + } + let max_variance = p_cut.external_partitions + [p_cut.external_partition_index(partition_index)] + .max_variance; + + let variances = &self.get_operator(op_index).variance().vars.clone(); + for (i, variance) in variances.iter().enumerate() { + if variance.coeffs.is_nan() { + assert!(i != partition_index.0); + continue; + } + let constraint = VarianceConstraint { + precision: *out_precision, + partition: partition_index, + nb_constraints: out_shape.flat_size(), + safe_variance_bound: max_variance, + variance: variance.clone(), + }; + self.external_variance_constraints.push(constraint); + } + } + } + /// Propagates the noise downward in the graph. fn forward_noise(&mut self) -> bool { // We save the old variance to compute the diff at the end. @@ -252,7 +357,7 @@ impl VariancedDag { acc + var[operator.partition().instruction_partition].clone() * square(*weight as f64) }), - Operator::UnsafeCast { .. } => { + Operator::UnsafeCast { .. } | Operator::ChangePartition { .. } => { operator.get_inputs_iter().next().unwrap().variance() [operator.partition().instruction_partition] .clone() @@ -343,7 +448,9 @@ pub fn analyze( let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition); let partitioned_dag = PartitionedDag { dag, partitions }; let varianced_dag = VariancedDag::try_from_partitioned(partitioned_dag)?; - let variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config); + let mut variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config); + // add external variance constraints + variance_constraints.extend_from_slice(varianced_dag.external_variance_constraints.as_slice()); let undominated_variance_constraints = VarianceConstraint::remove_dominated(&variance_constraints); let operations_count_per_instrs = collect_operations_count(&varianced_dag); @@ -560,6 +667,7 @@ fn collect_all_variance_constraints( dag, partitions, variances, + .. } = dag; let mut constraints = vec![]; for op in dag.get_operators_iter() { @@ -1322,6 +1430,6 @@ pub mod tests { let p_cut = PartitionCut::from_precisions(&precisions); let dag = super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION).unwrap(); - assert!(dag.nb_partitions == p_cut.p_cut.len() + 1); + assert!(dag.nb_partitions == p_cut.n_partitions()); } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs index e70b214266..7b0b90b268 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs @@ -193,6 +193,15 @@ impl OperationsValue { } } + pub fn is_nan(&self) -> bool { + for val in self.values.iter() { + if !val.is_nan() { + return false; + } + } + true + } + pub fn input(&mut self, partition: PartitionIndex) -> &mut f64 { &mut self.values[self.index.input(partition)] } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index c146803ab9..a5b2183848 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -45,7 +45,7 @@ struct PartialMicroParameters { complexity: f64, } -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Hash, Debug, Copy, Clone, Eq, PartialEq)] pub struct MacroParameters { pub glwe_params: GlweParameters, pub internal_dim: u64, @@ -920,6 +920,11 @@ pub fn optimize( ciphertext_modulus_log, }; + let dag_p_cut = p_cut.as_ref().map_or_else( + || PartitionCut::for_each_precision(dag), + std::clone::Clone::clone, + ); + let dag = analyze(dag, &noise_config, p_cut, default_partition)?; let kappa = error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); @@ -954,11 +959,27 @@ pub fn optimize( let mut best_params: Option = None; for iter in 0..=10 { for partition in PartitionIndex::range(0, nb_partitions).rev() { + // reduce search space to the parameters of external partitions + let partition_search_space = if dag_p_cut.is_external_partition(&partition) { + let external_part = + &dag_p_cut.external_partitions[partition.0 - dag_p_cut.n_internal_partitions()]; + let mut reduced_search_space = search_space.clone(); + reduced_search_space.glwe_dimensions = + [external_part.macro_params.glwe_params.glwe_dimension].to_vec(); + reduced_search_space.glwe_log_polynomial_sizes = + [external_part.macro_params.glwe_params.log2_polynomial_size].to_vec(); + reduced_search_space.internal_lwe_dimensions = + [external_part.macro_params.internal_dim].to_vec(); + reduced_search_space + } else { + search_space.clone() + }; + let new_params = optimize_macro( security_level, ciphertext_modulus_log, fft_precision, - search_space, + &partition_search_space, partition, &used_tlu_keyswitch, &used_conversion_keyswitch, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index e82683ca66..5e532874af 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -1,30 +1,17 @@ #![allow(clippy::float_cmp)] use once_cell::sync::Lazy; +use optimization::dag::multi_parameters::partition_cut::ExternalPartition; use super::*; use crate::computing_cost::cpu::CpuComplexity; -use crate::config; use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; use crate::dag::unparametrized; +use crate::optimization::dag::multi_parameters::partitionning::tests::{ + get_tfhers_noise_br, SHARED_CACHES, TFHERS_MACRO_PARAMS, +}; use crate::optimization::dag::solo_key; use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag}; -use crate::optimization::decomposition; - -const CIPHERTEXT_MODULUS_LOG: u32 = 64; -const FFT_PRECISION: u32 = 53; - -static SHARED_CACHES: Lazy = Lazy::new(|| { - let processing_unit = config::ProcessingUnit::Cpu; - decomposition::cache( - 128, - processing_unit, - None, - true, - CIPHERTEXT_MODULUS_LOG, - FFT_PRECISION, - ) -}); const _4_SIGMA: f64 = 0.000_063_342_483_999_973; @@ -489,7 +476,7 @@ fn test_partition_chain(decreasing: bool) { let sol = optimize(&dag, &Some(p_cut.clone()), PartitionIndex(0)).unwrap(); let nb_partitions = sol.macro_params.len(); assert!( - nb_partitions == (p_cut.p_cut.len() + 1), + nb_partitions == p_cut.n_partitions(), "bad nb partitions {} {p_cut}", sol.macro_params.len() ); @@ -851,3 +838,171 @@ fn test_bug_with_zero_noise() { let sol = optimize(&dag, &None, PartitionIndex(0)); assert!(sol.is_some()); } + +#[test] +fn test_optimize_tfhers_in_out_dot_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 2.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let input1 = dag.add_input(16, Shape::number()); + let change_part1 = dag.add_change_partition(input1, Some(tfhers_partition.clone()), None); + let dot = dag.add_dot([change_part1], [2]); + _ = dag.add_change_partition(dot, None, Some(tfhers_partition.clone())); + + let sol = optimize(&dag, &None, PartitionIndex(0)); + assert!(sol.is_some()); + println!("solution: {:?}", sol.unwrap()); +} + +#[test] +fn test_optimize_tfhers_2lut_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition_in = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let tfhers_partition_out = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let tfhers_precision = 11; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(tfhers_precision, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in), None); + let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); + let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition_out)); + + let sol = optimize(&dag, &None, PartitionIndex(0)); + assert!(sol.is_some()); +} + +#[test] +fn test_optimize_tfhers_different_in_out_2lut_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition_in = ExternalPartition { + name: String::from("tfhers_in"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, + }; + let tfhers_partition_out = ExternalPartition { + name: String::from("tfhers_out"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let tfhers_precision = 8; + let input = dag.add_input(tfhers_precision, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in), None); + let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); + let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition_out)); + + let sol = optimize(&dag, &None, PartitionIndex(0)); + assert!(sol.is_some()); +} + +#[test] +fn test_optimize_tfhers_input_constraints() { + let variances = [1.0, 6.14e-14, 2.14e-16]; + let dag_builder = |variance: f64| -> Dag { + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let tfhers_precision = 4; + let input = dag.add_input(tfhers_precision, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition), None); + let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, tfhers_precision); + let out = dag.add_dot([lut], [128]); + dag.add_composition(out, input); + dag + }; + + let sol = optimize(&dag_builder(variances[0]), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let mut last_complexity = sol.unwrap().complexity; + for variance in &variances[1..] { + let sol = optimize(&dag_builder(*variance), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let complexity = sol.unwrap().complexity; + assert!(complexity > last_complexity); + last_complexity = complexity; + } +} + +#[test] +fn test_optimize_tfhers_output_constraints() { + let variances = [1.0, 6.14e-14, 2.14e-16]; + let dag_builder = |variance: f64| -> Dag { + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let tfhers_precision = 4; + let input = dag.add_input(tfhers_precision, Shape::number()); + let lut = dag.add_lut(input, FunctionTable::UNKWOWN, tfhers_precision); + let dot = dag.add_dot([lut], [128]); + let out = dag.add_change_partition(dot, None, Some(tfhers_partition.clone())); + dag.add_composition(out, input); + dag + }; + + let sol = optimize(&dag_builder(variances[0]), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let mut last_complexity = sol.unwrap().complexity; + for variance in &variances[1..] { + let sol = optimize(&dag_builder(*variance), &None, PartitionIndex(0)); + assert!(sol.is_some()); + let complexity = sol.unwrap().complexity; + assert!(complexity > last_complexity); + last_complexity = complexity; + } +} + +#[test] +fn test_optimize_tfhers_to_concrete_and_back_example() { + let variance = get_tfhers_noise_br(); + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 8.0, + variance, + }; + let concrete_precision = 8; + let msg_width = 2; + let carry_width = 2; + let tfhers_precision = msg_width + carry_width; + + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input( + tfhers_precision, + Shape::vector((concrete_precision / msg_width).into()), + ); + // to concrete + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None); + let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, concrete_precision); + // from concrete + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, tfhers_precision); + let _ = dag.add_change_partition(lut2, None, Some(tfhers_partition.clone())); + + let sol = optimize(&dag, &None, PartitionIndex(0)); + assert!(sol.is_some()); +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index c3e607bb5e..dbd5edc22f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -9,21 +9,62 @@ use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; use crate::optimization::dag::solo_key::analyze::out_variances; use crate::optimization::dag::solo_key::symbolic_variance::SymbolicVariance; +use super::optimize::MacroParameters; + const ROUND_INNER_MULTI_PARAMETER: bool = false; const ROUND_EXTERNAL_MULTI_PARAMETER: bool = !ROUND_INNER_MULTI_PARAMETER && true; +#[derive(Clone, Debug)] +pub struct ExternalPartition { + pub name: String, + pub macro_params: MacroParameters, + pub max_variance: f64, + pub variance: f64, +} + +impl Eq for ExternalPartition {} + +impl PartialEq for ExternalPartition { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.macro_params == other.macro_params + && self.max_variance == other.max_variance + } +} + +impl std::hash::Hash for ExternalPartition { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.macro_params.hash(state); + } +} + +impl std::fmt::Display for ExternalPartition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{{ name: {} }}", self.name)?; + Ok(()) + } +} + // TODO: keep both precisions // TODO: rounding lut should have its own partition based on max norm2 and precisions #[derive(Clone, Debug)] pub struct PartitionCut { + // TODO: add name to partitions + // partition0 precision <= p_cut[0] < partition 1 precision <= p_cut[1] ... // precision are in the sens of Lut input precision and are sorted pub p_cut: Vec<(Precision, f64)>, + // Whether it has internal partitions or not + pub has_internal_partitions: bool, + // # TODO RELATIVE NORM2 // # HIGHER NORM2 MEANS HIGHER VARIANCE IN CONSTRAINT // norm2 * 2 ** (out precision - in precision) pub rnorm2: Vec, + + pub external_partitions: Vec, } impl PartitionCut { @@ -31,17 +72,63 @@ impl PartitionCut { Self { p_cut: vec![], rnorm2: vec![], + external_partitions: vec![], + has_internal_partitions: true, } } + pub fn n_partitions(&self) -> usize { + self.n_internal_partitions() + self.n_external_partitions() + } + + pub fn n_internal_partitions(&self) -> usize { + self.p_cut.len() + self.has_internal_partitions as usize + } + + pub fn n_external_partitions(&self) -> usize { + self.external_partitions.len() + } + + pub fn external_partition_index(&self, partition: PartitionIndex) -> usize { + partition.0 - self.n_internal_partitions() + } + + pub fn is_external_partition(&self, partition: &PartitionIndex) -> bool { + partition.0 >= self.n_internal_partitions() && partition.0 < self.n_partitions() + } + + pub fn is_internal_partition(&self, partition: &PartitionIndex) -> bool { + partition.0 < self.n_internal_partitions() + } + pub fn from_precisions(precisions: &[Precision]) -> Self { let mut precisions: Vec<_> = precisions.to_vec(); + let has_internal_partitions = !precisions.is_empty(); + precisions.sort_by(|a, b| a.partial_cmp(b).unwrap()); + _ = precisions.pop(); + + Self { + p_cut: precisions.iter().map(|p| (*p, f64::MAX)).collect(), + rnorm2: vec![], + external_partitions: vec![], + has_internal_partitions, + } + } + + pub fn from_precisions_and_external_partitions( + precisions: &[Precision], + external_partitions: &[ExternalPartition], + ) -> Self { + let mut precisions: Vec<_> = precisions.to_vec(); + let has_internal_partitions = !precisions.is_empty(); precisions.sort_by(|a, b| a.partial_cmp(b).unwrap()); _ = precisions.pop(); Self { p_cut: precisions.iter().map(|p| (*p, f64::MAX)).collect(), rnorm2: vec![], + external_partitions: external_partitions.to_vec(), + has_internal_partitions, } } @@ -61,7 +148,7 @@ impl PartitionCut { let op = &dag.operators[op_i.0]; match op { Operator::Lut { input, .. } => { - assert!(!self.p_cut.is_empty()); + assert!(self.has_internal_partitions); for (partition, &(precision_cut, norm2_cut)) in self.p_cut.iter().enumerate() { if dag.out_precisions[input.0] <= precision_cut && self.rnorm2(op_i) <= norm2_cut @@ -71,6 +158,23 @@ impl PartitionCut { } Some(PartitionIndex(self.p_cut.len())) } + Operator::ChangePartition { + src_partition: Some(partition), + dst_partition: None, + .. + } + | Operator::ChangePartition { + src_partition: None, + dst_partition: Some(partition), + .. + } => { + for (i, external_partition) in self.external_partitions.iter().enumerate() { + if partition == external_partition { + return Some(PartitionIndex(self.n_internal_partitions() + i)); + } + } + None + } _ => None, } } @@ -78,15 +182,33 @@ impl PartitionCut { pub fn for_each_precision(dag: &unparametrized::Dag) -> Self { let (dag, _) = expand_round_and_index_map(dag); let mut lut_in_precisions: HashSet<_> = HashSet::default(); + let mut partitions: HashSet = HashSet::default(); for op in &dag.operators { if let Operator::Lut { input, .. } = op { _ = lut_in_precisions.insert(dag.out_precisions[input.0]); } } + for op in &dag.operators { + if let Operator::ChangePartition { + src_partition, + dst_partition, + .. + } = op + { + if let Some(partition) = src_partition { + _ = partitions.insert(partition.clone()); + } + if let Some(partition) = dst_partition { + _ = partitions.insert(partition.clone()); + } + } + } let precisions: Vec<_> = lut_in_precisions.iter().copied().collect(); - Self::from_precisions(&precisions) + let external_partitions = Vec::from_iter(partitions); + Self::from_precisions_and_external_partitions(&precisions, &external_partitions) } + #[allow(clippy::too_many_lines)] pub fn maximal_partitionning(original_dag: &unparametrized::Dag) -> Self { // Note: only keep one 0-bits, partition as the compiler will not support multi-parameter round // partition based on input precision and output log norm2 @@ -103,6 +225,7 @@ impl PartitionCut { let out_variances: Vec = out_variances(&dag); let mut noise_origins: Vec> = vec![HashSet::default(); out_variances.len()]; let mut max_output_norm2 = vec![f64::NAN; out_variances.len()]; + let mut external_partitions: Vec = vec![]; assert!(out_variances.len() == dag.operators.len()); // Find input lut log norm2 and lut as origins @@ -128,6 +251,21 @@ impl PartitionCut { max_output_norm2[op_i] = 1.0; // initial value that can be maxed noise_origins[op_i] = std::iter::once(op_i).collect(); } + Operator::ChangePartition { + src_partition: Some(partition), + dst_partition: None, + .. + } + | Operator::ChangePartition { + src_partition: None, + dst_partition: Some(partition), + .. + } => { + external_partitions.push(partition.clone()); + } + Operator::ChangePartition { .. } => { + panic!("change_partition not supported when src and dest partition are both set or unset"); + } // unreachable Operator::Round { .. } => panic!("expand_round failed"), } @@ -181,12 +319,16 @@ impl PartitionCut { } } let mut p_cut: Vec<_> = lut_partition.iter().copied().collect(); + let has_internal_partitions = !lut_partition.is_empty(); p_cut.sort_by(|a, b| a.partial_cmp(b).unwrap()); _ = p_cut.pop(); let p_cut = p_cut.iter().map(|(p, n)| (*p, n.into_inner())).collect(); + Self { p_cut, rnorm2: max_output_norm2, + external_partitions, + has_internal_partitions, } } @@ -197,9 +339,15 @@ impl PartitionCut { p_cut.push(cut); } } + let has_internal_partitions = self.has_internal_partitions + && self.is_internal_partition(&PartitionIndex( + used.iter().map(|u| u.0).min().unwrap_or(usize::MAX), + )); Self { p_cut, rnorm2: self.rnorm2.clone(), + external_partitions: self.external_partitions.clone(), + has_internal_partitions, } } } @@ -229,6 +377,13 @@ impl std::fmt::Display for PartitionCut { "partition {}: {prev_precision_cut} bits and higher", self.p_cut.len() )?; + for (i, e_partition) in self.external_partitions.iter().enumerate() { + writeln!( + f, + "partition {} (external): {e_partition}", + i + self.n_internal_partitions() + )?; + } Ok(()) } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index fb4d7d66d1..e01648db12 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -55,7 +55,9 @@ fn extract_levelled_block(dag: &unparametrized::Dag) -> Blocks { // Block entry point and pre-exit point Op::Lut { .. } => (), // Connectors - Op::UnsafeCast { input, .. } => uf.union(input.0, op_i), + Op::UnsafeCast { input, .. } | Op::ChangePartition { input, .. } => { + uf.union(input.0, op_i); + } Op::LevelledOp { inputs, .. } | Op::Dot { inputs, .. } => { for input in inputs { uf.union(input.0, op_i); @@ -120,7 +122,7 @@ fn only_1_partition(dag: &unparametrized::Dag) -> Partitions { Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { instrs_partition[op_i].inputs_transition = vec![None; inputs.len()]; } - Op::Lut { .. } | Op::UnsafeCast { .. } => { + Op::Lut { .. } | Op::UnsafeCast { .. } | Operator::ChangePartition { .. } => { instrs_partition[op_i].inputs_transition = vec![None]; } Op::Input { .. } => (), @@ -130,6 +132,7 @@ fn only_1_partition(dag: &unparametrized::Dag) -> Partitions { Partitions { nb_partitions: 1, instrs_partition, + p_cut: PartitionCut::empty(), } } @@ -146,7 +149,7 @@ fn resolve_by_levelled_block( .copied() .collect(); let nb_partitions = present_partitions.len().max(1); // no tlu = no constraints - if p_cut.p_cut.len() + 1 != nb_partitions { + if p_cut.n_partitions() != nb_partitions { return resolve_by_levelled_block( dag, &p_cut.delete_unused_cut(&present_partitions), @@ -169,7 +172,15 @@ fn resolve_by_levelled_block( 1 => get_singleton_value(&constraints.forced), _ => { let forced = constraints.forced; - if forced.contains(&default_partition) { + // in case of conflict, prioritize the external partition, then the default partition + let externals: Vec<_> = forced + .iter() + .filter(|p| p_cut.is_external_partition(p)) + .collect(); + assert!(externals.len() <= 1); + if externals.len() == 1 { + *externals[0] + } else if forced.contains(&default_partition) { default_partition } else { *forced.iter().min().unwrap() @@ -217,7 +228,7 @@ fn resolve_by_levelled_block( } } } - Op::UnsafeCast { input, .. } => { + Op::UnsafeCast { input, .. } | Op::ChangePartition { input, .. } => { instrs_p[op_i].instruction_partition = group_partition; let input_partition = instrs_p[input.0].instruction_partition; instrs_p[op_i].inputs_transition = if group_partition == input_partition { @@ -235,6 +246,7 @@ fn resolve_by_levelled_block( Partitions { nb_partitions, instrs_partition: instrs_p, + p_cut: p_cut.clone(), } // Now we can generate transitions // Input has no transtions @@ -247,7 +259,7 @@ pub fn partitionning_with_preferred( p_cut: &PartitionCut, default_partition: PartitionIndex, ) -> Partitions { - if p_cut.p_cut.is_empty() { + if p_cut.n_partitions() <= 1 { only_1_partition(dag) } else { resolve_by_levelled_block(dag, p_cut, default_partition) @@ -261,9 +273,32 @@ pub mod tests { pub const LOW_PRECISION_PARTITION: PartitionIndex = PartitionIndex(0); pub const HIGH_PRECISION_PARTITION: PartitionIndex = PartitionIndex(1); + use once_cell::sync::Lazy; + use super::*; - use crate::dag::operator::{FunctionTable, Shape, Weights}; + use crate::config; + use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights}; use crate::dag::unparametrized; + use crate::optimization::dag::multi_parameters::optimize::MacroParameters; + use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition; + use crate::optimization::decomposition::cmux::get_noise_br; + use crate::optimization::decomposition::{self, PersistDecompCaches}; + use crate::parameters::GlweParameters; + + const CIPHERTEXT_MODULUS_LOG: u32 = 64; + const FFT_PRECISION: u32 = 53; + + pub static SHARED_CACHES: Lazy = Lazy::new(|| { + let processing_unit = config::ProcessingUnit::Cpu; + decomposition::cache( + 128, + processing_unit, + None, + true, + CIPHERTEXT_MODULUS_LOG, + FFT_PRECISION, + ) + }); fn default_p_cut() -> PartitionCut { PartitionCut::from_precisions(&[2, 128]) @@ -333,6 +368,316 @@ pub mod tests { } } + fn get_external_partition_index_or( + external_partition: &ExternalPartition, + p_cut: &PartitionCut, + default: usize, + ) -> PartitionIndex { + for (i, e_partition) in p_cut.external_partitions.iter().enumerate() { + if external_partition == e_partition { + return PartitionIndex(i + p_cut.n_internal_partitions()); + } + } + PartitionIndex(default) + } + + pub const GLWE_PARAMS: GlweParameters = GlweParameters { + log2_polynomial_size: 11, + glwe_dimension: 1, + }; + + pub const TFHERS_PBS_LEVEL: u64 = 1; + + pub const TFHERS_MACRO_PARAMS: MacroParameters = MacroParameters { + glwe_params: GLWE_PARAMS, + internal_dim: 841, + }; + + pub fn get_tfhers_noise_br() -> f64 { + get_noise_br( + SHARED_CACHES.caches(), + GLWE_PARAMS.log2_polynomial_size, + GLWE_PARAMS.glwe_dimension, + TFHERS_MACRO_PARAMS.internal_dim, + TFHERS_PBS_LEVEL, + None, + ) + .unwrap() + } + + #[test] + fn test_tfhers_in_out_dot_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let input1 = dag.add_input(16, Shape::number()); + let change_part1 = dag.add_change_partition(input1, Some(tfhers_partition.clone()), None); + let dot = dag.add_dot([change_part1], [2]); + _ = dag.add_change_partition(dot, None, Some(tfhers_partition)); + + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 1); + let instrs_partition = partitions.instrs_partition; + show_partitionning(&dag, &instrs_partition); + assert!(instrs_partition[0].instruction_partition == LOW_PRECISION_PARTITION); + } + + #[test] + fn test_tfhers_in_out_lut_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None); + let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); + let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition)); + + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 2); + let tfhers_partition_index = PartitionIndex(1); + let instrs_partition = partitions.instrs_partition; + show_partitionning(&dag, &instrs_partition); + + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; + assert!(consider(input).instruction_partition == tfhers_partition_index); + assert!(consider(change_part1).instruction_partition == tfhers_partition_index); + assert!(consider(change_part1).inputs_transition == [None]); + assert!(consider(lut).instruction_partition == LOW_PRECISION_PARTITION); + assert!(consider(lut).alternative_output_representation.len() == 1); + assert!(consider(lut) + .alternative_output_representation + .contains(&tfhers_partition_index)); + assert!(consider(change_part2).instruction_partition == tfhers_partition_index); + assert!( + consider(change_part2).inputs_transition + == [Some(Transition::Additional { + src_partition: LOW_PRECISION_PARTITION + })] + ); + } + + #[test] + fn test_tfhers_in_out_lut_compute_mix_external() { + let variance = get_tfhers_noise_br(); + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 6.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let external_input = dag.add_input(16, Shape::number()); + let other_input = dag.add_input(16, Shape::number()); + let other_lut = dag.add_lut(other_input, FunctionTable::UNKWOWN, 16); + let mix_add = dag.add_levelled_op( + [external_input, other_lut], + LevelledComplexity::ADDITION, + [1.0, 1.0], + Shape::number(), + "add", + ); + let change_part1 = dag.add_change_partition(mix_add, Some(tfhers_partition.clone()), None); + let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); + let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition)); + + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 2); + let tfhers_partition_index = PartitionIndex(1); + let instrs_partition = partitions.instrs_partition; + show_partitionning(&dag, &instrs_partition); + + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; + assert!(consider(external_input).instruction_partition == tfhers_partition_index); + assert!(consider(other_input).instruction_partition == LOW_PRECISION_PARTITION); + assert!(consider(other_lut).alternative_output_representation.len() == 1); + assert!(consider(other_lut) + .alternative_output_representation + .contains(&tfhers_partition_index)); + assert!(consider(change_part1).instruction_partition == tfhers_partition_index); + assert!(consider(change_part1).inputs_transition == [None]); + assert!(consider(lut).instruction_partition == LOW_PRECISION_PARTITION); + assert!(consider(lut).alternative_output_representation.len() == 1); + assert!(consider(lut) + .alternative_output_representation + .contains(&tfhers_partition_index)); + assert!(consider(change_part2).instruction_partition == tfhers_partition_index); + } + + #[test] + fn test_tfhers_different_in_out_lut_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition_in = ExternalPartition { + name: String::from("tfhers_in"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 2.0, + variance, + }; + let tfhers_partition_out = ExternalPartition { + name: String::from("tfhers_out"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 2.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in.clone()), None); + let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16); + let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition_out.clone())); + + let p_cut = PartitionCut::for_each_precision(&dag); + let partitions = partitionning_with_preferred(&dag, &p_cut, LOW_PRECISION_PARTITION); + assert!(partitions.nb_partitions == 3); + let tfhers_partition_index_in = + get_external_partition_index_or(&tfhers_partition_in, &p_cut, 1); + let tfhers_partition_index_out = + get_external_partition_index_or(&tfhers_partition_out, &p_cut, 2); + let instrs_partition = partitions.instrs_partition; + show_partitionning(&dag, &instrs_partition); + + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; + assert!(consider(input).instruction_partition == tfhers_partition_index_in); + assert!(consider(change_part1).instruction_partition == tfhers_partition_index_in); + assert!(consider(change_part1).inputs_transition == [None]); + assert!(consider(lut).instruction_partition == LOW_PRECISION_PARTITION); + assert!(consider(lut).alternative_output_representation.len() == 1); + assert!(consider(lut) + .alternative_output_representation + .contains(&tfhers_partition_index_out)); + assert!(consider(change_part2).instruction_partition == tfhers_partition_index_out); + assert!( + consider(change_part2).inputs_transition + == [Some(Transition::Additional { + src_partition: LOW_PRECISION_PARTITION + })] + ); + } + + #[test] + fn test_tfhers_in_out_2lut_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition = ExternalPartition { + name: String::from("tfhers"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition.clone()), None); + let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16); + let change_part2 = dag.add_change_partition(lut2, None, Some(tfhers_partition)); + + let partitions = partitionning(&dag); + assert!(partitions.nb_partitions == 3); + let tfhers_partition_index = PartitionIndex(2); + let instrs_partition = partitions.instrs_partition; + show_partitionning(&dag, &instrs_partition); + + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; + assert!(consider(input).instruction_partition == tfhers_partition_index); + assert!(consider(change_part1).instruction_partition == tfhers_partition_index); + assert!(consider(change_part1).inputs_transition == [None]); + assert!(consider(lut1).instruction_partition == HIGH_PRECISION_PARTITION); + assert!( + consider(lut1).inputs_transition + == [Some(Transition::Internal { + src_partition: tfhers_partition_index + })] + ); + assert!(consider(lut2).alternative_output_representation.len() == 1); + assert!(consider(lut2) + .alternative_output_representation + .contains(&tfhers_partition_index)); + assert!(consider(lut2).instruction_partition == LOW_PRECISION_PARTITION); + assert!( + consider(lut2).inputs_transition + == [Some(Transition::Internal { + src_partition: HIGH_PRECISION_PARTITION + })] + ); + assert!(consider(change_part2).instruction_partition == tfhers_partition_index); + assert!( + consider(change_part2).inputs_transition + == [Some(Transition::Additional { + src_partition: LOW_PRECISION_PARTITION + })] + ); + } + + #[test] + fn test_tfhers_different_in_out_2lut_compute() { + let variance = get_tfhers_noise_br(); + let tfhers_partition_in = ExternalPartition { + name: String::from("tfhers_in"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let tfhers_partition_out = ExternalPartition { + name: String::from("tfhers_out"), + macro_params: TFHERS_MACRO_PARAMS, + max_variance: variance * 4.0, + variance, + }; + let mut dag = unparametrized::Dag::new(); + let input = dag.add_input(16, Shape::number()); + let change_part1 = dag.add_change_partition(input, Some(tfhers_partition_in.clone()), None); + let lut1 = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 4); + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 16); + let change_part2 = dag.add_change_partition(lut2, None, Some(tfhers_partition_out.clone())); + + let p_cut = PartitionCut::for_each_precision(&dag); + let partitions = partitionning_with_preferred(&dag, &p_cut, LOW_PRECISION_PARTITION); + assert!(partitions.nb_partitions == 4); + let tfhers_partition_index_in = + get_external_partition_index_or(&tfhers_partition_in, &p_cut, 2); + let tfhers_partition_index_out = + get_external_partition_index_or(&tfhers_partition_out, &p_cut, 3); + let instrs_partition = partitions.instrs_partition; + show_partitionning(&dag, &instrs_partition); + + let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0]; + assert!(consider(input).instruction_partition == tfhers_partition_index_in); + assert!(consider(change_part1).instruction_partition == tfhers_partition_index_in); + assert!(consider(change_part1).inputs_transition == [None]); + assert!(consider(lut1).instruction_partition == HIGH_PRECISION_PARTITION); + assert!( + consider(lut1).inputs_transition + == [Some(Transition::Internal { + src_partition: tfhers_partition_index_in + })] + ); + assert!(consider(lut2).alternative_output_representation.len() == 1); + assert!(consider(lut2) + .alternative_output_representation + .contains(&tfhers_partition_index_out)); + assert!(consider(lut2).instruction_partition == LOW_PRECISION_PARTITION); + assert!( + consider(lut2).inputs_transition + == [Some(Transition::Internal { + src_partition: HIGH_PRECISION_PARTITION + })] + ); + assert!(consider(change_part2).instruction_partition == tfhers_partition_index_out); + assert!( + consider(change_part2).inputs_transition + == [Some(Transition::Additional { + src_partition: LOW_PRECISION_PARTITION + })] + ); + } + #[test] fn test_1_input_2_partitions() { let mut dag = unparametrized::Dag::new(); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs index 28ee90c61b..18dd8050da 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitions.rs @@ -6,6 +6,8 @@ use std::{ use crate::dag::operator::OperatorIndex; +use super::partition_cut::PartitionCut; + #[derive(Clone, Debug, PartialEq, Eq, Default, PartialOrd, Ord, Hash, Copy)] pub struct PartitionIndex(pub(crate) usize); @@ -78,6 +80,7 @@ impl InstructionPartition { pub struct Partitions { pub nb_partitions: usize, pub instrs_partition: Vec, + pub p_cut: PartitionCut, } impl Index for Partitions { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs index 664088bda2..7b120a48a6 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -54,6 +54,20 @@ impl SymbolicVariance { r } + pub fn from_external_partition( + nb_partitions: usize, + partition: PartitionIndex, + max_variance: f64, + ) -> Self { + let mut r = Self { + partition, + coeffs: OperationsValue::zero(nb_partitions), + }; + // rust ..., offset cannot be inlined + *r.coeffs.pbs(partition) = max_variance; + r + } + pub fn coeff_input(&self, partition: PartitionIndex) -> f64 { self.coeffs[self.coeffs.index.input(partition)] } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 756d1b2b48..484338b123 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -47,7 +47,8 @@ fn assert_inputs_index(op: &Operator, first_bad_index: usize) { Operator::Input { .. } => true, Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } - | Operator::Round { input, .. } => input.0 < first_bad_index, + | Operator::Round { input, .. } + | Operator::ChangePartition { input, .. } => input.0 < first_bad_index, Operator::LevelledOp { inputs, .. } | Operator::Dot { inputs, .. } => { inputs.iter().all(|input| input.0 < first_bad_index) } @@ -73,6 +74,15 @@ pub fn has_round(dag: &Dag) -> bool { false } +pub fn has_change_partition(dag: &Dag) -> bool { + for op in &dag.operators { + if matches!(op, Operator::ChangePartition { .. }) { + return true; + } + } + false +} + pub fn has_unsafe_cast(dag: &Dag) -> bool { for op in &dag.operators { if matches!(op, Operator::UnsafeCast { .. }) { @@ -86,6 +96,10 @@ pub fn assert_no_round(dag: &Dag) { assert!(!has_round(dag)); } +pub fn assert_no_change_partition(dag: &Dag) { + assert!(!has_change_partition(dag)); +} + fn assert_valid_variances(dag: &SoloKeyDag) { for &out_variance in &dag.out_variances { assert!( @@ -176,7 +190,9 @@ fn out_variance( .fold(SymbolicVariance::ZERO, |acc, (weight, var)| { acc + var * square(*weight as f64) }), - Operator::UnsafeCast { input, .. } => out_variances[input.0], + Operator::UnsafeCast { input, .. } | Operator::ChangePartition { input, .. } => { + out_variances[input.0] + } Operator::Round { .. } => { unreachable!("Round should have been either expanded or integrated to a lut") } @@ -236,9 +252,10 @@ fn op_levelled_complexity(op: &Operator, out_shapes: &[Shape]) -> LevelledComple } Operator::LevelledOp { complexity, .. } => *complexity, - Operator::Input { .. } | Operator::Lut { .. } | Operator::UnsafeCast { .. } => { - LevelledComplexity::ZERO - } + Operator::Input { .. } + | Operator::Lut { .. } + | Operator::UnsafeCast { .. } + | Operator::ChangePartition { .. } => LevelledComplexity::ZERO, Operator::Round { .. } => { unreachable!("Round should have been either expanded or integrated to a lut") } @@ -374,6 +391,7 @@ pub fn analyze(dag: &Dag, noise_config: &NoiseBoundConfig) -> SoloKeyDag { assert_dag_correctness(dag); let dag = &expand_round(dag); assert_no_round(dag); + assert_no_change_partition(dag); let out_variances = out_variances(dag); let in_luts_variance = in_luts_variance(dag, &out_variances); let nb_luts = lut_count_from_dag(dag); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs index fc28b5f89e..41a44e6edf 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use std::sync::Arc; use super::common::VERSION; +use super::DecompCaches; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] pub struct CmuxComplexityNoise { @@ -147,3 +148,32 @@ pub fn cache( }; PersistentCacheHashMap::new_no_read(&path, VERSION, function) } + +#[derive(Debug)] +pub enum MaxVarianceError { + PbsBaseLogNotFound, + PbsLevelNotFound, +} + +pub fn get_noise_br( + mut cache: DecompCaches, + log2_polynomial_size: u64, + glwe_dimension: u64, + lwe_dim: u64, + pbs_level: u64, + pbs_log2_base: Option, +) -> Result { + let cmux_quantities = cache.cmux.pareto_quantities(GlweParameters { + log2_polynomial_size, + glwe_dimension, + }); + for cmux_quantity in cmux_quantities { + if cmux_quantity.decomp.level == pbs_level { + if pbs_log2_base.is_some() && cmux_quantity.decomp.log2_base == pbs_log2_base.unwrap() { + return Err(MaxVarianceError::PbsBaseLogNotFound); + } + return Ok(cmux_quantity.noise_br(lwe_dim)); + } + } + Err(MaxVarianceError::PbsLevelNotFound) +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs index ce5da8010a..a54068ede6 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs @@ -107,6 +107,9 @@ impl<'dag> Viz for crate::dag::unparametrized::DagOperator<'dag> { Operator::Round { out_precision, .. } => { format!("{index} [label = \"{{%{index} = Round({input_string}) |{{out_precision:|{out_precision:?}}}| {{loc:|{location}}}}}\" fillcolor={color}];",) } + Operator::ChangePartition { .. } => { + format!("{index} [label = \"{{%{index} = ChangePartition({input_string})}}\" fillcolor={color}];",) + } } }