diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs new file mode 100644 index 000000000..cd07633cb --- /dev/null +++ b/src/nn/generalized_residual.rs @@ -0,0 +1,385 @@ +use crate::prelude::*; + +/// A residual connection `R` around `F`: `F(x) + R(x)`, +/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). +/// +/// # Generics +/// - `F`: The underlying module to do a skip connection around. +/// - `R`: The underlying residual module +/// +/// # Examples +/// ```rust +/// # use dfdx::prelude::*; +/// let module: GeneralizedResidual = Default::default(); +/// let x = Tensor1D::new([-2.0, -1.0, 0.0, 1.0, 2.0]); +/// let y = module.forward(x); +/// assert_eq!(y.data(), &[4.0, 1.0, 0.0, 2.0, 6.0]); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct GeneralizedResidual(F, R); + +impl CanUpdateWithGradients + for GeneralizedResidual +{ + /// Pass through to `F`'s [CanUpdateWithGradients]. + fn update(&mut self, grads: &mut G) { + self.0.update(grads); + self.1.update(grads); + } +} + +impl ResetParams for GeneralizedResidual { + /// Pass through to `F`'s [ResetParams]. + fn reset_params(&mut self, rng: &mut RNG) { + self.0.reset_params(rng); + self.1.reset_params(rng); + } +} + +impl Module for GeneralizedResidual +where + T: Tensor, + O: Tensor, + F: Module, + R: Module, +{ + type Output = O; + + /// Calls forward on `F` and `R` and then sums their result: `F(x) + R(x)` + fn forward(&self, x: T) -> Self::Output { + let (x, tape) = x.split_tape(); + + // do R(x) on the tape + let (r_x, tape) = self.1.forward(x.duplicate().put_tape(tape)).split_tape(); + + // do F(x) on the tape + let f_x = self.0.forward(x.put_tape(tape)); + + add(f_x, &r_x) + } +} + +impl SaveToNpz for GeneralizedResidual { + /// Pass through to `F`/`R`'s [SaveToNpz]. + fn write( + &self, + filename_prefix: &str, + w: &mut zip::ZipWriter, + ) -> zip::result::ZipResult<()> + where + W: std::io::Write + std::io::Seek, + { + self.0.write(&format!("{}_main", filename_prefix), w)?; + self.1.write(&format!("{}_residual", filename_prefix), w)?; + Ok(()) + } +} + +impl LoadFromNpz for GeneralizedResidual { + /// Pass through to `F`/`R`'s [LoadFromNpz]. + fn read( + &mut self, + filename_prefix: &str, + r: &mut zip::ZipArchive, + ) -> Result<(), NpzError> + where + READ: std::io::Read + std::io::Seek, + { + self.0.read(&format!("{}_main", filename_prefix), r)?; + self.1.read(&format!("{}_residual", filename_prefix), r)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::assert_close; + use rand::{prelude::StdRng, SeedableRng}; + use std::fs::File; + use tempfile::NamedTempFile; + use zip::ZipArchive; + + #[test] + fn test_reset() { + let mut rng = StdRng::seed_from_u64(0); + let mut model: GeneralizedResidual, Linear<2, 5>> = Default::default(); + assert_eq!(model.0.weight.data(), &[[0.0; 2]; 5]); + assert_eq!(model.0.bias.data(), &[0.0; 5]); + assert_eq!(model.1.weight.data(), &[[0.0; 2]; 5]); + assert_eq!(model.1.bias.data(), &[0.0; 5]); + + model.reset_params(&mut rng); + assert_ne!(model.0.weight.data(), &[[0.0; 2]; 5]); + assert_ne!(model.0.bias.data(), &[0.0; 5]); + assert_ne!(model.1.weight.data(), &[[0.0; 2]; 5]); + assert_ne!(model.1.bias.data(), &[0.0; 5]); + } + + const W0: [[f32; 2]; 5] = [ + [0.63315326, 0.3361526], + [0.60201937, 0.30927354], + [0.39831632, 0.29526848], + [-0.4730785, -0.10664469], + [0.5074884, -0.08458644], + ]; + const B0: [f32; 5] = [-0.7014593, 0.01725882, 0.67181975, -0.61593556, 0.27809095]; + + const W2: [[f32; 5]; 2] = [ + [0.37967658, -0.30938417, -0.4046409, 0.34131002, -0.36532], + [0.01010674, 0.2922417, -0.28791183, 0.09316397, 0.00722069], + ]; + const B2: [f32; 2] = [-0.01353309, 0.19437504]; + const X: [[f32; 2]; 10] = [ + [0.9706649, -0.50246257], + [0.36609784, 0.22519696], + [-0.26957038, -2.4395447], + [0.729607, 0.06136635], + [1.0758572, -0.6158074], + [1.844528, -0.7769507], + [-0.83232504, 0.26263165], + [-0.18690403, 0.5396985], + [-1.0891576, 0.09805013], + [-0.63034505, 2.4173584], + ]; + const Y: [[f32; 2]; 10] = [ + [0.15374291, -0.43383744], + [-0.26277426, 0.25803787], + [-0.41010314, -2.2426596], + [-0.062764645, 0.117026225], + [0.2237711, -0.54089284], + [0.69048953, -0.6508272], + [-1.0149324, 0.33670622], + [-0.57907265, 0.53813595], + [-1.2107061, 0.21556953], + [-1.2221863, 2.3977249], + ]; + + const W0G: [[f32; 2]; 5] = [ + [0.035948314, -0.015142122], + [-0.0035737813, -0.001155745], + [-0.07784372, -0.059181444], + [0.0, 0.0], + [-0.081114516, 0.06281963], + ]; + const B0G: [f32; 5] = [0.019489167, -0.005999865, -0.3116488, 0.0, -0.12533475]; + const W2G: [[f32; 5]; 2] = [[0.010261777, 0.15239798, 0.37232202, 0.0, 0.22712366]; 2]; + const B2G: [f32; 2] = [0.50000006; 2]; + + #[test] + fn test_residual_forward_backward_resadd_as_main() { + type SubModel = (Linear<2, 5>, ReLU, Linear<5, 2>); + type Model = GeneralizedResidual; + + let mut model: Model = Default::default(); + *model.0 .0.weight.mut_data() = W0; + *model.0 .0.bias.mut_data() = B0; + *model.0 .2.weight.mut_data() = W2; + *model.0 .2.bias.mut_data() = B2; + + let x = Tensor2D::new(X); + let y = model.forward(x.traced()); + // Y = s(x) + x, including negative x + // Y2 would be s(x) + r(x) [r == ReLU] + // Y2 = s(x) + r(x) = Y - x + r(x) + // Y2 = Y - (x - r(x)) + // x - r(x) = {0, if x >= 0, because r(x) = x + // x - r(x) = {x, if x < 0, because r(x) = 0 => x - 0 = x + // this is r(-x), since this returns x if x < 0 and 0 elsewhere + // => Y2 = Y - r(-x) + assert_close( + y.data(), + add(Tensor2D::new(Y), &(-Tensor2D::new(X)).relu()).data(), + ); + + let gradients = y.mean().backward(); + + assert_close(gradients.ref_gradient(&model.0 .0.weight), &W0G); + assert_close(gradients.ref_gradient(&model.0 .0.bias), &B0G); + assert_close(gradients.ref_gradient(&model.0 .2.weight), &W2G); + assert_close(gradients.ref_gradient(&model.0 .2.bias), &B2G); + } + + #[test] + fn test_residual_forward_backward_with_update() { + type SubModel = (Linear<2, 5>, ReLU, Linear<5, 2>); + type Model = GeneralizedResidual; + + let mut model: Model = Default::default(); + *model.0 .0.weight.mut_data() = W0; + *model.0 .0.bias.mut_data() = B0; + *model.0 .2.weight.mut_data() = W2; + *model.0 .2.bias.mut_data() = B2; + *model.1 .0.weight.mut_data() = W0; + *model.1 .0.bias.mut_data() = B0; + *model.1 .2.weight.mut_data() = W2; + *model.1 .2.bias.mut_data() = B2; + + let mut model2: SubModel = Default::default(); + *model2.0.weight.mut_data() = W0; + *model2.0.bias.mut_data() = B0; + // The submodel s(x) = l(x) with l(x) = ax + b and is the last linear layer + // model2 has to be model + model = 2 * model => s2(x) = 2 * s(x) => s2(x) = 2ax + 2b + // => a' = 2a; b' = 2b + *model2.2.weight.mut_data() = W2; + model2.2.weight = model2.2.weight * 2.0; + *model2.2.bias.mut_data() = B2; + model2.2.bias = model2.2.bias * 2.0; + + let x = Tensor2D::new(X); + let y = model.forward(x.traced()); + let x2 = Tensor2D::new(X); + let y2 = model2.forward(x2.traced()); + assert_close(y.data(), y2.data()); + + let gradients = y.mean().backward(); + let gradients2 = y2.mean().backward(); + + assert_close(gradients.ref_gradient(&model.0 .0.weight), &W0G); + assert_close(gradients.ref_gradient(&model.0 .0.bias), &B0G); + assert_close(gradients.ref_gradient(&model.0 .2.weight), &W2G); + assert_close(gradients.ref_gradient(&model.0 .2.bias), &B2G); + assert_close(gradients.ref_gradient(&model.1 .0.weight), &W0G); + assert_close(gradients.ref_gradient(&model.1 .0.bias), &B0G); + assert_close(gradients.ref_gradient(&model.1 .2.weight), &W2G); + assert_close(gradients.ref_gradient(&model.1 .2.bias), &B2G); + assert_close( + gradients2.ref_gradient(&model2.0.weight), + (Tensor2D::new(W0G) * 2.0).data(), + ); + assert_close( + gradients2.ref_gradient(&model2.0.bias), + (Tensor1D::new(B0G) * 2.0).data(), + ); + // no multiplication with 2 here since f'(x) = h'(x) * g'(h(j(x))) with f(x) = g(h(j(x))) + // In this case, it's f(x) = g(h(j(2x))) => f'(x) = h'(j(2x)) * g'(h(j(x))), + // while g(x) = h(j(2x)) => g'(x) = 2 * j'(x) * h'(j(x)) + assert_close(gradients2.ref_gradient(&model2.2.weight), &W2G); + assert_close(gradients2.ref_gradient(&model2.2.bias), &B2G); + + // with lr = 1, w* = w - w' + let sgd_config = SgdConfig { + lr: 1.0, + momentum: None, + }; + Sgd::new(sgd_config).update(&mut model, gradients); + Sgd::new(sgd_config).update(&mut model2, gradients2); + + assert_close( + model.0 .0.weight.data(), + sub(Tensor2D::new(W0), &Tensor2D::new(W0G)).data(), + ); + assert_close( + model.0 .0.bias.data(), + sub(Tensor1D::new(B0), &Tensor1D::new(B0G)).data(), + ); + assert_close( + model.0 .2.weight.data(), + sub(Tensor2D::new(W2), &Tensor2D::new(W2G)).data(), + ); + assert_close( + model.0 .2.bias.data(), + sub(Tensor1D::new(B2), &Tensor1D::new(B2G)).data(), + ); + assert_close( + model.1 .0.weight.data(), + sub(Tensor2D::new(W0), &Tensor2D::new(W0G)).data(), + ); + assert_close( + model.1 .0.bias.data(), + sub(Tensor1D::new(B0), &Tensor1D::new(B0G)).data(), + ); + assert_close( + model.1 .2.weight.data(), + sub(Tensor2D::new(W2), &Tensor2D::new(W2G)).data(), + ); + assert_close( + model.1 .2.bias.data(), + sub(Tensor1D::new(B2), &Tensor1D::new(B2G)).data(), + ); + } + + // gradients have to be summed, r(x) = g(x) + h(x) => r'(x) = g'(x) + h'(x) + #[test] + fn test_residual_gradients_correctly_added() { + type Model = (Linear<1, 1>, GeneralizedResidual); + // Linear<1, 2>-layer has weights with one and bias zeroed + let mut model: Model = Default::default(); + *model.0.weight.mut_data() = [[1.0]]; + + let x = Tensor2D::new([[-1.0], [1.0]]); + let y = model.forward(x.traced()); + + assert_close(y.data(), &[[0.0], [2.0]]); + + let grads = y.mean().backward(); + + // m(x) = r(x) + r(x) = 2r(x); m'(x) = 2r'(x) + // y_mean(x_1, x_2) = (m(x_1) + m(x_2)) / 2 = (2r(x_1) + 2r(x_2)) / 2 = r(x_1) + r(x_2) + // x_1 = -1; x_2 = 1 => r(-1) + r(1) = 0 + 1 = 1 + // y_mean'(x_1, x_2) = r'(x_1) + r'(x_2) = r'(-1) + r'(1) = 0 + 1 = 1 + // x_1 and x_2 are from the Linear layer, so x_1 = ax + b with a = 1 and b = 0 + // => derivates for a and b are the same + assert_close(grads.ref_gradient(&model.0.weight), &[[1.0]]); + assert_close(grads.ref_gradient(&model.0.bias), &[1.0]); + } + + #[test] + fn test_save_residual() { + let model: GeneralizedResidual, Linear<5, 3>> = Default::default(); + let file = NamedTempFile::new().expect("failed to create tempfile"); + model + .save(file.path().to_str().unwrap()) + .expect("failed to save model"); + let f = File::open(file.path()).expect("failed to open resulting file"); + let mut zip = ZipArchive::new(f).expect("failed to create zip archive from file"); + { + let weight_file = zip + .by_name("_mainweight.npy") + .expect("failed to find _mainweight.npy file"); + assert!(weight_file.size() > 0); + } + { + let bias_file = zip + .by_name("_mainbias.npy") + .expect("failed to find _mainbias.npy file"); + assert!(bias_file.size() > 0); + } + { + let weight_file = zip + .by_name("_residualweight.npy") + .expect("failed to find _residualweight.npy file"); + assert!(weight_file.size() > 0); + } + { + let bias_file = zip + .by_name("_residualbias.npy") + .expect("failed to find _residualbias.npy file"); + assert!(bias_file.size() > 0); + } + } + + #[test] + fn test_load_residual() { + let mut rng = StdRng::seed_from_u64(0); + let mut saved_model: GeneralizedResidual, Linear<5, 3>> = Default::default(); + saved_model.reset_params(&mut rng); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + assert!(saved_model.save(file.path().to_str().unwrap()).is_ok()); + + let mut loaded_model: GeneralizedResidual, Linear<5, 3>> = Default::default(); + assert_ne!(loaded_model.0.weight.data(), saved_model.0.weight.data()); + assert_ne!(loaded_model.0.bias.data(), saved_model.0.bias.data()); + + assert_ne!(loaded_model.1.weight.data(), saved_model.1.weight.data()); + assert_ne!(loaded_model.1.bias.data(), saved_model.1.bias.data()); + + assert!(loaded_model.load(file.path().to_str().unwrap()).is_ok()); + //assert_eq!(loaded_model.0.weight.data(), saved_model.0.weight.data()); + //assert_eq!(loaded_model.0.bias.data(), saved_model.0.bias.data()); + + assert_eq!(loaded_model.1.weight.data(), saved_model.1.weight.data()); + assert_eq!(loaded_model.1.bias.data(), saved_model.1.bias.data()); + } +} diff --git a/src/nn/mod.rs b/src/nn/mod.rs index b0afff676..e8b12fdee 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -56,6 +56,7 @@ mod activations; mod dropout; +mod generalized_residual; mod impl_module_for_tuples; mod layer_norm; mod linear; @@ -67,6 +68,7 @@ mod split_into; pub use activations::*; pub use dropout::*; +pub use generalized_residual::*; pub use impl_module_for_tuples::*; pub use layer_norm::*; pub use linear::*; diff --git a/src/nn/residual.rs b/src/nn/residual.rs index c6269e648..3d8a4e51f 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -198,8 +198,8 @@ mod tests { assert!(saved_model.save(file.path().to_str().unwrap()).is_ok()); let mut loaded_model: Residual> = Default::default(); - assert!(loaded_model.0.weight.data() != saved_model.0.weight.data()); - assert!(loaded_model.0.bias.data() != saved_model.0.bias.data()); + assert_ne!(loaded_model.0.weight.data(), saved_model.0.weight.data()); + assert_ne!(loaded_model.0.bias.data(), saved_model.0.bias.data()); assert!(loaded_model.load(file.path().to_str().unwrap()).is_ok()); assert_eq!(loaded_model.0.weight.data(), saved_model.0.weight.data());