Skip to content

Commit

Permalink
feat(optimizer): partition with external partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Jul 8, 2024
1 parent ce65cc6 commit 9f8d1b3
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::iter::{empty, once};
use std::ops::Deref;

use crate::dag::operator::tensor::{ClearTensor, Shape};
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;

use super::DotKind;

Expand Down Expand Up @@ -106,6 +107,8 @@ pub enum Operator {
},
ChangePartition {
input: OperatorIndex,
src_partition: Option<ExternalPartition>,
dst_partition: Option<ExternalPartition>,
},
}

Expand All @@ -118,7 +121,7 @@ impl Operator {
Self::UnsafeCast { input, .. }
| Self::Lut { input, .. }
| Self::Round { input, .. }
| Self::ChangePartition { input } => Box::new(once(input)),
| Self::ChangePartition { input, .. } => Box::new(once(input)),
}
}
}
Expand Down Expand Up @@ -194,8 +197,24 @@ impl fmt::Display for Operator {
} => {
write!(f, "ROUND[%{}] : u{out_precision}", input.0)?;
}
Self::ChangePartition { input } => {
write!(f, "ChangePartition[%{}]", input.0)?;
Self::ChangePartition {
input,
src_partition,
dst_partition,
} => {
write!(f, "CHANGE_PARTITION[%{}] : {{", input.0)?;
let mut src_printed = false;
if let Some(partition) = src_partition {
src_printed = true;
write!(f, "src_partition: {}", partition.name)?;
}
if let Some(partition) = dst_partition {
if src_printed {
write!(f, ", ")?;
}
write!(f, "dst_partition: {}", partition.name)?;
}
write!(f, "}}")?;
}
}
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. }
| Operator::ChangePartition { input } => input.0 = old_index_to_new[input.0],
| Operator::ChangePartition { input, .. } => input.0 = old_index_to_new[input.0],
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
for input in inputs {
input.0 = old_index_to_new[input.0];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::dag::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
use crate::{
dag::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
},
optimization::dag::multi_parameters::partition_cut::ExternalPartition,
};
use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -226,8 +229,21 @@ impl<'dag> DagBuilder<'dag> {
})
}

pub fn add_change_partition(&mut self, input: OperatorIndex) -> OperatorIndex {
self.add_operator(Operator::ChangePartition { input })
pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<&ExternalPartition>,
dst_partition: Option<&ExternalPartition>,
) -> OperatorIndex {
assert!(
src_partition.is_some() || dst_partition.is_some(),
"change_partition: src or dest partition need to be set"
);
self.add_operator(Operator::ChangePartition {
input,
src_partition: src_partition.cloned(),
dst_partition: dst_partition.cloned(),
})
}

pub fn add_round_op(
Expand Down Expand Up @@ -367,7 +383,7 @@ impl<'dag> DagBuilder<'dag> {
Operator::Lut { input, .. }
| Operator::UnsafeCast { input, .. }
| Operator::Round { input, .. }
| Operator::ChangePartition { input } => self.dag.out_shapes[input.0].clone(),
| Operator::ChangePartition { input, .. } => self.dag.out_shapes[input.0].clone(),
Operator::Dot {
kind: DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor,
..
Expand Down Expand Up @@ -402,7 +418,7 @@ impl<'dag> DagBuilder<'dag> {
Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => {
self.dag.out_precisions[inputs[0].0]
}
Operator::ChangePartition { input } => self.dag.out_precisions[input.0],
Operator::ChangePartition { input, .. } => self.dag.out_precisions[input.0],
}
}
}
Expand Down Expand Up @@ -526,6 +542,16 @@ impl Dag {
.add_lut(input, table, out_precision)
}

pub fn add_change_partition(
&mut self,
input: OperatorIndex,
src_partition: Option<&ExternalPartition>,
dst_partition: Option<&ExternalPartition>,
) -> OperatorIndex {
self.builder(DEFAULT_CIRCUIT)
.add_change_partition(input, src_partition, dst_partition)
}

pub fn add_dot(
&mut self,
inputs: impl Into<Vec<OperatorIndex>>,
Expand Down Expand Up @@ -782,18 +808,21 @@ mod tests {
#[allow(clippy::many_single_char_names)]
fn graph_builder() {
let mut graph = Dag::new();
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
};
let mut builder = graph.builder("main1");
let a = builder.add_input(1, Shape::number());
let b = builder.add_input(1, Shape::number());
let c = builder.add_dot([a, b], [1, 1]);
let d = builder.add_lut(c, FunctionTable::UNKWOWN, 1);
let _e = builder.add_change_partition(d);
let _e = builder.add_change_partition(d, Some(&tfhers_part), None);
let mut builder = graph.builder("main2");
let e = builder.add_input(2, Shape::number());
let f = builder.add_input(2, Shape::number());
let g = builder.add_dot([e, f], [2, 2]);
let h = builder.add_lut(g, FunctionTable::UNKWOWN, 2);
let _h = builder.add_change_partition(h);
let _h = builder.add_change_partition(h, None, Some(&tfhers_part));
graph.tag_operator_as_output(c);
}

Expand All @@ -816,7 +845,10 @@ mod tests {

let lut2 = builder.add_lut(dot, FunctionTable::UNKWOWN, 2);

let change_part = builder.add_change_partition(lut2);
let tfhers_part = ExternalPartition {
name: String::from("tfhers"),
};
let change_part = builder.add_change_partition(lut2, Some(&tfhers_part), None);

let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2, change_part];
for (expected_i, op_index) in ops_index.iter().enumerate() {
Expand Down Expand Up @@ -866,7 +898,11 @@ mod tests {
table: FunctionTable::UNKWOWN,
out_precision: 2,
},
Operator::ChangePartition { input: lut2 }
Operator::ChangePartition {
input: lut2,
src_partition: Some(tfhers_part.clone()),
dst_partition: None
}
]
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,18 +254,14 @@ impl VariancedDag {
acc + var[operator.partition().instruction_partition].clone()
* square(*weight as f64)
}),
Operator::UnsafeCast { .. } => {
Operator::UnsafeCast { .. } | Operator::ChangePartition { .. } => {
operator.get_inputs_iter().next().unwrap().variance()
[operator.partition().instruction_partition]
.clone()
}
Operator::Round { .. } => {
unreachable!("Round should have been either expanded or integrated to a lut")
}
Operator::ChangePartition { .. } => {
// TODO
todo!("TODO")
}
};
// We add the noise for the transitions to alternative representations
operator
Expand Down Expand Up @@ -1308,6 +1304,6 @@ pub mod tests {
let p_cut = PartitionCut::from_precisions(&precisions);
let dag =
super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION).unwrap();
assert!(dag.nb_partitions == p_cut.p_cut.len() + 1);
assert!(dag.nb_partitions == p_cut.n_partitions());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ fn test_partition_chain(decreasing: bool) {
let sol = optimize(&dag, &Some(p_cut.clone()), PartitionIndex(0)).unwrap();
let nb_partitions = sol.macro_params.len();
assert!(
nb_partitions == (p_cut.p_cut.len() + 1),
nb_partitions == p_cut.n_partitions(),
"bad nb partitions {} {p_cut}",
sol.macro_params.len()
);
Expand Down
Loading

0 comments on commit 9f8d1b3

Please sign in to comment.