Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] support external partitions #915

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::iter::{empty, once};
use std::ops::Deref;

use crate::dag::operator::tensor::{ClearTensor, Shape};
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;

use super::DotKind;

Expand Down Expand Up @@ -104,6 +105,11 @@ pub enum Operator {
input: OperatorIndex,
out_precision: Precision,
},
ChangePartition {
input: OperatorIndex,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
Comment on lines +110 to +111
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really possible to have None, None ? Or Some(a), Some(a) ?

If so, maybe it would be worth having an assert when adding the op !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that (None, None), doesn't make sense, but having them both is something we decided not to restrict for now, but we will see. I will add an assert for the first case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm actually it is not clear for me how this is supposed to work. Only one of them is expected to be something (though you support having both set) ? And the other (none) is supposed to signal the fact that it is free to have any partition ?

My original point was checking that both were something different. Maybe we could have a quick chat, so that I am sure I got everything right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the "change partition" is just an informative flag of the origin/destination partition for circuit inputs and results. We choosed to keep it more generic and compatible with eventual future use in the middle of a circuit.
Current used form are
ChangePartition(Some, None) on inputs
ChangePartition(None, Some) on results
E.g. future use
ChangePartition(None, None) could hint that if a conflict exists the partition change should happen here and not somewhere else.
ChangePartition(None, Some), could be use to express manually where a transition occurs, which is forced automatically today by the optimizer policy but better manual decision could be taken.
Also it's translated from MLIR, so it could be used there to have multi partition annoted code instead of relying on optimizer choosing partition and parameters.

},
}

impl Operator {
Expand All @@ -114,7 +120,8 @@ impl Operator {
Self::LevelledOp { inputs, .. } | Self::Dot { inputs, .. } => Box::new(inputs.iter()),
Self::UnsafeCast { input, .. }
| Self::Lut { input, .. }
| Self::Round { input, .. } => Box::new(once(input)),
| Self::Round { input, .. }
| Self::ChangePartition { input, .. } => Box::new(once(input)),
}
}
}
Expand Down Expand Up @@ -190,6 +197,23 @@ impl fmt::Display for Operator {
} => {
write!(f, "ROUND[%{}] : u{out_precision}", input.0)?;
}
Self::ChangePartition {
input,
src_partition,
dst_partition,
} => {
write!(f, "CHANGE_PARTITION[%{}] : {{", input.0)?;
if let Some(partition) = src_partition {
write!(f, "src_partition: {}", partition.name)?;
}
if let Some(partition) = dst_partition {
if src_partition.is_some() {
write!(f, ", ")?;
}
write!(f, "dst_partition: {}", partition.name)?;
}
write!(f, "}}")?;
}
}
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
Operator::Input { .. } => (),
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. } => input.0 = old_index_to_new[input.0],
| Operator::Round { input, .. }
| Operator::ChangePartition { input, .. } => input.0 = old_index_to_new[input.0],
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
for input in inputs {
input.0 = old_index_to_new[input.0];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::dag::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
};
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;
use std::{
collections::{HashMap, HashSet},
fmt,
Expand Down Expand Up @@ -259,6 +260,27 @@ impl<'dag> DagBuilder<'dag> {
)
}

pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
location: Location,
) -> OperatorIndex {
assert!(
src_partition.is_some() || dst_partition.is_some(),
"change_partition: src or dest partition need to be set"
);
self.add_operator(
Operator::ChangePartition {
input,
src_partition,
dst_partition,
},
location,
)
}

pub fn add_round_op(
&mut self,
input: OperatorIndex,
Expand Down Expand Up @@ -416,7 +438,8 @@ impl<'dag> DagBuilder<'dag> {
}
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. } => self.dag.out_shapes[input.0].clone(),
| Operator::Round { input, .. }
| Operator::ChangePartition { input, .. } => self.dag.out_shapes[input.0].clone(),
Operator::Dot {
kind: DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor,
..
Expand Down Expand Up @@ -451,6 +474,7 @@ impl<'dag> DagBuilder<'dag> {
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
self.dag.out_precisions[inputs[0].0]
}
Operator::ChangePartition { input, .. } => self.dag.out_precisions[input.0],
}
}
}
Expand Down Expand Up @@ -577,6 +601,20 @@ impl Dag {
.add_lut(input, table, out_precision, Location::Unknown)
}

pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
) -> OperatorIndex {
self.builder(DEFAULT_CIRCUIT).add_change_partition(
input,
src_partition,
dst_partition,
Location::Unknown,
)
}

pub fn add_dot(
&mut self,
inputs: impl Into<Vec<OperatorIndex>>,
Expand Down Expand Up @@ -834,8 +872,20 @@ impl Dag {

#[cfg(test)]
mod tests {
use crate::{
optimization::dag::multi_parameters::optimize::MacroParameters, parameters::GlweParameters,
};

use super::*;

const DUMMY_MACRO_PARAM: MacroParameters = MacroParameters {
glwe_params: GlweParameters {
log2_polynomial_size: 0,
glwe_dimension: 0,
},
internal_dim: 0,
};

#[test]
fn output_marking() {
let mut graph = Dag::new();
Expand All @@ -852,16 +902,25 @@ mod tests {
#[allow(clippy::many_single_char_names)]
fn graph_builder() {
let mut graph = Dag::new();
let tfhers_part = ExternalPartition {
youben11 marked this conversation as resolved.
Show resolved Hide resolved
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
max_variance: 0.0_f64,
variance: 0.0_f64,
};
let mut builder = graph.builder("main1");
let a = builder.add_input(1, Shape::number(), Location::Unknown);
let b = builder.add_input(1, Shape::number(), Location::Unknown);
let c = builder.add_dot([a, b], [1, 1], Location::Unknown);
let _d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown);
let d = builder.add_lut(c, FunctionTable::UNKWOWN, 1, Location::Unknown);
let _d =
builder.add_change_partition(d, Some(tfhers_part.clone()), None, Location::Unknown);
let mut builder = graph.builder("main2");
let e = builder.add_input(2, Shape::number(), Location::Unknown);
let f = builder.add_input(2, Shape::number(), Location::Unknown);
let g = builder.add_dot([e, f], [2, 2], Location::Unknown);
let _h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown);
let h = builder.add_lut(g, FunctionTable::UNKWOWN, 2, Location::Unknown);
let _h = builder.add_change_partition(h, None, Some(tfhers_part), Location::Unknown);
graph.tag_operator_as_output(c);
}

Expand Down Expand Up @@ -897,7 +956,16 @@ mod tests {

let lut2 = builder.add_lut(dot, FunctionTable::UNKWOWN, 2, Location::Unknown);

let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2];
let tfhers_part = ExternalPartition {
youben11 marked this conversation as resolved.
Show resolved Hide resolved
name: String::from("tfhers"),
macro_params: DUMMY_MACRO_PARAM,
max_variance: 0.0_f64,
variance: 0.0_f64,
};
let change_part =
builder.add_change_partition(lut2, Some(tfhers_part.clone()), None, Location::Unknown);

let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2, change_part];
for (expected_i, op_index) in ops_index.iter().enumerate() {
assert_eq!(expected_i, op_index.0);
}
Expand Down Expand Up @@ -944,6 +1012,11 @@ mod tests {
input: dot,
table: FunctionTable::UNKWOWN,
out_precision: 2,
},
Operator::ChangePartition {
input: lut2,
src_partition: Some(tfhers_part.clone()),
dst_partition: None
}
]
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub struct VariancedDag {
pub(crate) dag: Dag,
pub(crate) partitions: Partitions,
pub(crate) variances: Variances,
pub(crate) external_variance_constraints: Vec<VarianceConstraint>,
}

impl VariancedDag {
Expand All @@ -133,9 +134,11 @@ impl VariancedDag {
dag,
partitions,
variances,
external_variance_constraints: vec![],
};

// We forward the noise once to verify the composability.
varianced.apply_external_partition_input_variance();
let _ = varianced.forward_noise();
varianced.check_composability()?;
varianced.apply_composition_rules();
Expand All @@ -145,6 +148,8 @@ impl VariancedDag {
// The noise gets computed from inputs down to outputs.
if varianced.forward_noise() {
// Noise settled, we return the varianced dag.
varianced.collect_external_input_constraint();
varianced.collect_external_output_constraint();
return Ok(varianced);
}
// The noise of the inputs gets updated following the composition rules
Expand Down Expand Up @@ -181,6 +186,106 @@ impl VariancedDag {
}
}

fn apply_external_partition_input_variance(&mut self) {
youben11 marked this conversation as resolved.
Show resolved Hide resolved
let p_cut = self.partitions.p_cut.clone();
for (i, op) in self.dag.operators.clone().iter().enumerate() {
if let Operator::Input { .. } = op {
let partition_index = self.partitions.instrs_partition[i].instruction_partition;
if p_cut.is_external_partition(&partition_index) {
let partitions = self.partitions.clone();
youben11 marked this conversation as resolved.
Show resolved Hide resolved
let external_partition =
&p_cut.external_partitions[p_cut.external_partition_index(partition_index)];
let max_variance = external_partition.max_variance;
let variance = external_partition.variance;

let mut input = self.get_operator_mut(OperatorIndex(i));
let mut variances = input.variance().clone();
variances.vars[partition_index.0] = SymbolicVariance::from_external_partition(
partitions.nb_partitions,
partition_index,
max_variance / variance,
);
*(input.variance_mut()) = variances;
}
}
}
}

fn collect_external_input_constraint(&mut self) {
let p_cut = &self.partitions.p_cut;
for (i, op) in self.dag.operators.clone().iter().enumerate() {
if let Operator::Input {
out_precision,
out_shape,
} = op
{
let partition_index = self.partitions.instrs_partition[i].instruction_partition;
youben11 marked this conversation as resolved.
Show resolved Hide resolved
if !p_cut.is_external_partition(&partition_index) {
continue;
}

let max_variance = p_cut.external_partitions
[p_cut.external_partition_index(partition_index)]
.max_variance;

let variances = &self.get_operator(OperatorIndex(i)).variance().vars.clone();
for (i, variance) in variances.iter().enumerate() {
if variance.coeffs.is_nan() {
assert!(i != partition_index.0);
continue;
}
let constraint = VarianceConstraint {
precision: *out_precision,
partition: partition_index,
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
variance: variance.clone(),
};
self.external_variance_constraints.push(constraint);
}
}
}
}

fn collect_external_output_constraint(&mut self) {
let p_cut = self.partitions.p_cut.clone();
for dag_op in self.dag.get_output_operators_iter() {
let DagOperator {
youben11 marked this conversation as resolved.
Show resolved Hide resolved
id: op_index,
shape: out_shape,
precision: out_precision,
..
} = dag_op;
let optional_partition_index = p_cut.partition(&self.dag, op_index);
if optional_partition_index.is_none() {
continue;
}
let partition_index = optional_partition_index.unwrap();
if !p_cut.is_external_partition(&partition_index) {
continue;
}
let max_variance = p_cut.external_partitions
[p_cut.external_partition_index(partition_index)]
.max_variance;

let variances = &self.get_operator(op_index).variance().vars.clone();
for (i, variance) in variances.iter().enumerate() {
if variance.coeffs.is_nan() {
assert!(i != partition_index.0);
continue;
}
let constraint = VarianceConstraint {
precision: *out_precision,
partition: partition_index,
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
variance: variance.clone(),
};
self.external_variance_constraints.push(constraint);
}
}
}

/// Propagates the noise downward in the graph.
fn forward_noise(&mut self) -> bool {
// We save the old variance to compute the diff at the end.
Expand Down Expand Up @@ -252,7 +357,7 @@ impl VariancedDag {
acc + var[operator.partition().instruction_partition].clone()
* square(*weight as f64)
}),
Operator::UnsafeCast { .. } => {
Operator::UnsafeCast { .. } | Operator::ChangePartition { .. } => {
operator.get_inputs_iter().next().unwrap().variance()
[operator.partition().instruction_partition]
.clone()
Expand Down Expand Up @@ -343,7 +448,9 @@ pub fn analyze(
let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition);
let partitioned_dag = PartitionedDag { dag, partitions };
let varianced_dag = VariancedDag::try_from_partitioned(partitioned_dag)?;
let variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config);
let mut variance_constraints = collect_all_variance_constraints(&varianced_dag, noise_config);
// add external variance constraints
variance_constraints.extend_from_slice(varianced_dag.external_variance_constraints.as_slice());
let undominated_variance_constraints =
VarianceConstraint::remove_dominated(&variance_constraints);
let operations_count_per_instrs = collect_operations_count(&varianced_dag);
Expand Down Expand Up @@ -560,6 +667,7 @@ fn collect_all_variance_constraints(
dag,
partitions,
variances,
..
} = dag;
let mut constraints = vec![];
for op in dag.get_operators_iter() {
Expand Down Expand Up @@ -1322,6 +1430,6 @@ pub mod tests {
let p_cut = PartitionCut::from_precisions(&precisions);
let dag =
super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION).unwrap();
assert!(dag.nb_partitions == p_cut.p_cut.len() + 1);
assert!(dag.nb_partitions == p_cut.n_partitions());
}
}
Loading
Loading