Skip to content

Commit

Permalink
test(optimizer): mix tfhers and low_precision partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Aug 9, 2024
1 parent 8b36b06 commit b81e13d
Showing 1 changed file with 48 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ pub mod tests {

use super::*;
use crate::config;
use crate::dag::operator::{FunctionTable, Shape, Weights};
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
use crate::dag::unparametrized;
use crate::optimization::dag::multi_parameters::optimize::MacroParameters;
use crate::optimization::dag::multi_parameters::partition_cut::ExternalPartition;
Expand Down Expand Up @@ -466,6 +466,53 @@ pub mod tests {
);
}

#[test]
fn test_tfhers_in_out_lut_compute_mix_external() {
let variance = get_tfhers_noise_br();
let tfhers_partition = ExternalPartition {
name: String::from("tfhers"),
macro_params: TFHERS_MACRO_PARAMS,
max_variance: variance * 6.0,
variance,
};
let mut dag = unparametrized::Dag::new();
let external_input = dag.add_input(16, Shape::number());
let other_input = dag.add_input(16, Shape::number());
let other_lut = dag.add_lut(other_input, FunctionTable::UNKWOWN, 16);
let mix_add = dag.add_levelled_op(
[external_input, other_lut],
LevelledComplexity::ADDITION,
[1.0, 1.0],
Shape::number(),
"add",
);
let change_part1 = dag.add_change_partition(mix_add, Some(tfhers_partition.clone()), None);
let lut = dag.add_lut(change_part1, FunctionTable::UNKWOWN, 16);
let change_part2 = dag.add_change_partition(lut, None, Some(tfhers_partition));

let partitions = partitionning(&dag);
assert!(partitions.nb_partitions == 2);
let tfhers_partition_index = PartitionIndex(1);
let instrs_partition = partitions.instrs_partition;
show_partitionning(&dag, &instrs_partition);

let consider = |op_i: OperatorIndex| &instrs_partition[op_i.0];
assert!(consider(external_input).instruction_partition == tfhers_partition_index);
assert!(consider(other_input).instruction_partition == LOW_PRECISION_PARTITION);
assert!(consider(other_lut).alternative_output_representation.len() == 1);
assert!(consider(other_lut)
.alternative_output_representation
.contains(&tfhers_partition_index));
assert!(consider(change_part1).instruction_partition == tfhers_partition_index);
assert!(consider(change_part1).inputs_transition == [None]);
assert!(consider(lut).instruction_partition == LOW_PRECISION_PARTITION);
assert!(consider(lut).alternative_output_representation.len() == 1);
assert!(consider(lut)
.alternative_output_representation
.contains(&tfhers_partition_index));
assert!(consider(change_part2).instruction_partition == tfhers_partition_index);
}

#[test]
fn test_tfhers_different_in_out_lut_compute() {
let variance = get_tfhers_noise_br();
Expand Down

0 comments on commit b81e13d

Please sign in to comment.