Skip to content

Commit

Permalink
Improve TrainerBuilder::build (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
jw1912 authored Jan 18, 2025
1 parent cc122c3 commit dbc9c94
Showing 1 changed file with 78 additions and 92 deletions.
170 changes: 78 additions & 92 deletions src/trainer/default/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,136 +188,119 @@ impl<T: SparseInputType, U: OutputBuckets<T::RequiredDataType>, O: OptimiserType
self
}

fn push_saved_format(&self, layer: usize, saved_format: &mut Vec<SavedFormat>, net_quant: &mut i16) {
let w = format!("l{layer}w");
let b = format!("l{layer}b");

let layout = if self.allow_transpose && layer > 0 && U::BUCKETS > 1 {
Layout::Transposed
} else {
Layout::Normal
};

let (wquant, bquant) = if let Some(quants) = &self.quantisations {
let bquant = match quants[layer] {
QuantTarget::Float => {
*net_quant = 1;
QuantTarget::Float
}
QuantTarget::I16(q) => {
*net_quant = net_quant.checked_mul(q).expect("Bias quantisation factor overflowed!");
QuantTarget::I16(*net_quant)
}
QuantTarget::I8(q) => {
*net_quant = net_quant.checked_mul(q).expect("Bias quantisation factor overflowed!");
QuantTarget::I8(*net_quant)
}
QuantTarget::I32(_) => unimplemented!("i32 quant is not implemented for TrainerBuilder!"),
};

(quants[layer], bquant)
} else {
(QuantTarget::Float, QuantTarget::Float)
};

saved_format.push(SavedFormat { id: w, quant: wquant, layout });
saved_format.push(SavedFormat { id: b, quant: bquant, layout: Layout::Normal });
}

pub fn build(self) -> Trainer<O::Optimiser, T, U> {
let builder = NetworkBuilder::default();

let output_buckets = U::BUCKETS > 1;

let input_getter = self.input_getter.expect("Need to set the input features!");

let input_getter = self.input_getter.clone().expect("Need to set the input features!");
let input_size = input_getter.num_inputs();
let input_shape = Shape::new(input_size, 1);
let targets = builder.new_input("targets", Shape::new(1, 1));

let buckets = if output_buckets { Some(builder.new_input("buckets", Shape::new(U::BUCKETS, 1))) } else { None };
let mut out = builder.new_input("stm", input_shape);
let targets = builder.new_input("targets", Shape::new(1, 1));
let buckets = output_buckets.then(|| builder.new_input("buckets", Shape::new(U::BUCKETS, 1)));
let l0 = builder.new_affine("l0", input_size, self.ft_out_size);

let mut still_in_ft = true;

let mut saved_format = Vec::new();

if self.ft_out_size % 8 != 0 {
logger::set_colour("31");
println!("==================================");
println!(" Feature transformer size = {}", self.ft_out_size);
println!(" is not a multiple of 8.");
println!(" Why are you doing this?");
println!(" Please seek help.");
println!("==================================");
logger::clear_colours();
warning(|| {
println!("Feature transformer size = {}", self.ft_out_size);
println!(" is not a multiple of 8.");
println!(" Why are you doing this?");
println!(" Please seek help.");
});
}

let l0 = builder.new_affine("l0", input_size, self.ft_out_size);

let mut net_quant = 1i16;

//let input_buckets = self.input_getter.buckets();
let mut ft_desc = format!("{} -> {}", input_getter.shorthand(), self.ft_out_size);

if self.perspective {
ft_desc = format!("({ft_desc})x2");
}

let mut out = builder.new_input("stm", input_shape);

let pst = if self.psqt_subnet {
let pst = self.psqt_subnet.then(|| {
let pst = builder.new_weights("pst", Shape::new(1, input_size), InitSettings::Zeroed);
saved_format.push(SavedFormat { id: "pst".to_string(), quant: QuantTarget::Float, layout: Layout::Normal });
Some(pst.matmul(out))
} else {
None
};
pst.matmul(out)
});

let mut push_saved_format = |layer: usize| {
let w = format!("l{layer}w");
let b = format!("l{layer}b");

if let Some(quants) = &self.quantisations {
let layout = if self.allow_transpose && layer > 0 && output_buckets {
Layout::Transposed
} else {
Layout::Normal
};

saved_format.push(SavedFormat { id: w, quant: quants[layer], layout });

match quants[layer] {
QuantTarget::Float => {
net_quant = 1;
saved_format.push(SavedFormat { id: b, quant: QuantTarget::Float, layout: Layout::Normal });
}
QuantTarget::I16(q) => {
net_quant = net_quant.checked_mul(q).expect("Bias quantisation factor overflowed!");
saved_format.push(SavedFormat {
id: b,
quant: QuantTarget::I16(net_quant),
layout: Layout::Normal,
});
}
QuantTarget::I8(q) => {
net_quant = net_quant.checked_mul(q).expect("Bias quantisation factor overflowed!");
saved_format.push(SavedFormat {
id: b,
quant: QuantTarget::I8(net_quant),
layout: Layout::Normal,
});
}
QuantTarget::I32(_) => unimplemented!("i32 quant is not implemented for TrainerBuilder!"),
}
} else {
saved_format.push(SavedFormat { id: w, quant: QuantTarget::Float, layout: Layout::Normal });
saved_format.push(SavedFormat { id: b, quant: QuantTarget::Float, layout: Layout::Normal });
}
};

push_saved_format(0);
self.push_saved_format(0, &mut saved_format, &mut net_quant);

assert!(self.nodes.len() > 1, "Require at least 2 nodes for a working arch!");

let (skip, activation) = if self.perspective {
if let NodeType { op: OpType::Activate(act), .. } = self.nodes[0] {
let skip = if self.perspective {
let (skip, activation) = if let OpType::Activate(act) = self.nodes[0].op {
(1, act)
} else {
warning(|| {
println!("Feature transformer is not followed");
println!(" by an activation function,");
println!(" which is probably erreonous");
});
(0, Activation::Identity)
}
} else {
(0, Activation::Identity)
};
};

out = if self.perspective {
let ntm = builder.new_input("nstm", input_shape);
l0.forward_sparse_dual_with_activation(out, ntm, activation)
out = l0.forward_sparse_dual_with_activation(out, ntm, activation);
skip
} else {
l0.forward(out)
out = l0.forward(out);
0
};

let mut layer = 1;

let mut layer_sizes = Vec::new();

let mut prev_size = self.ft_out_size * if self.perspective { 2 } else { 1 };

for &NodeType { size, op } in self.nodes.iter().skip(skip) {
match op {
OpType::Activate(activation) => {
out = out.activate(activation);
}
OpType::Activate(activation) => out = out.activate(activation),
OpType::Affine => {
still_in_ft = false;
let raw_size = size * U::BUCKETS;

let l = builder.new_affine(&format!("l{layer}"), prev_size, raw_size);

push_saved_format(layer);
self.push_saved_format(layer, &mut saved_format, &mut net_quant);

layer += 1;

Expand Down Expand Up @@ -349,7 +332,6 @@ impl<T: SparseInputType, U: OutputBuckets<T::RequiredDataType>, O: OptimiserType
}

let output_node = out.node();

let predicted = out.activate(Activation::Sigmoid);

match self.loss {
Expand All @@ -375,17 +357,13 @@ impl<T: SparseInputType, U: OutputBuckets<T::RequiredDataType>, O: OptimiserType
}
}

let factorised_weights = if input_getter.is_factorised() {
let mut f = vec!["l0w".to_string()];

let factorised_weights = input_getter.is_factorised().then(|| {
if self.psqt_subnet {
f.push("pst".to_string());
vec!["l0w".to_string(), "pst".to_string()]
} else {
vec!["l0w".to_string()]
}

Some(f)
} else {
None
};
});

let mut trainer = Trainer {
optimiser: O::Optimiser::new(graph, Default::default()),
Expand Down Expand Up @@ -467,3 +445,11 @@ impl<T: SparseInputType, U: OutputBuckets<T::RequiredDataType>, O: OptimiserType
trainer
}
}

fn warning(mut f: impl FnMut()) {
logger::set_colour("31");
println!("==================================");
f();
println!("==================================");
logger::clear_colours();
}

0 comments on commit dbc9c94

Please sign in to comment.