Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #96 #118

Merged
merged 14 commits into from
Aug 3, 2022
379 changes: 379 additions & 0 deletions src/nn/generalized_residual.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
use crate::prelude::*;

/// A residual connection `R` around `F`: `F(x) + R(x)`,
/// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385).
///
/// # Generics
/// - `F`: The underlying module to do a skip connection around.
/// - `R`: The underlying residual module
///
/// # Examples
/// ```rust
/// # use dfdx::prelude::*;
/// let module: Residual<ReLU, Square> = Default::default();
/// let x = Tensor1D::new([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = module.forward(x);
/// assert_eq!(y.data(), &[4.0, 1.0, 0.0, 2.0, 6.0]);
/// ```
#[derive(Debug, Clone, Default)]
pub struct GeneralizedResidual<F, R>(F, R);

impl<F: CanUpdateWithGradients, R: CanUpdateWithGradients> CanUpdateWithGradients
for GeneralizedResidual<F, R>
{
/// Pass through to `F`'s [CanUpdateWithGradients].
fn update<G: GradientProvider>(&mut self, grads: &mut G) {
self.0.update(grads);
self.1.update(grads);
}
}

impl<F: ResetParams, R: ResetParams> ResetParams for GeneralizedResidual<F, R> {
/// Pass through to `F`'s [ResetParams].
fn reset_params<RNG: rand::Rng>(&mut self, rng: &mut RNG) {
self.0.reset_params(rng);
self.1.reset_params(rng);
}
}

impl<F, R, T, O> Module<T> for GeneralizedResidual<F, R>
where
T: Tensor<Dtype = f32>,
O: Tensor<Dtype = T::Dtype, Tape = T::Tape>,
F: Module<T, Output = O>,
R: Module<T, Output = O>,
{
type Output = O;

/// Calls forward on `F` and `R` and then sums their result: `F(x) + R(x)`
fn forward(&self, x: T) -> Self::Output {
let (x, tape) = x.split_tape();

// do R(x) on the tape
let (r_x, tape) = self.1.forward(x.duplicate().put_tape(tape)).split_tape();

// do F(x) on the tape
let f_x = self.0.forward(x.put_tape(tape));

add(f_x, &r_x)
}
}

impl<F: SaveToNpz, R: SaveToNpz> SaveToNpz for GeneralizedResidual<F, R> {
/// Pass through to `F`/`R`'s [SaveToNpz].
fn write<W>(
&self,
filename_prefix: &str,
w: &mut zip::ZipWriter<W>,
) -> zip::result::ZipResult<()>
where
W: std::io::Write + std::io::Seek,
{
// I have no idea what to put here
self.0.write(filename_prefix, w)?;
self.1.write(filename_prefix, w)?;
coreylowman marked this conversation as resolved.
Show resolved Hide resolved
Ok(())
}
}

impl<F: LoadFromNpz, R: LoadFromNpz> LoadFromNpz for GeneralizedResidual<F, R> {
/// Pass through to `F`/`R`'s [LoadFromNpz].
fn read<READ>(
&mut self,
filename_prefix: &str,
r: &mut zip::ZipArchive<READ>,
) -> Result<(), NpzError>
where
READ: std::io::Read + std::io::Seek,
{
// I have no idea what to put here, reverse order since it's like a stack?
self.1.read(filename_prefix, r)?;
self.0.read(filename_prefix, r)?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::tests::assert_close;
use rand::{prelude::StdRng, SeedableRng};
use std::fs::File;
use tempfile::NamedTempFile;
use zip::ZipArchive;

#[test]
fn test_reset() {
let mut rng = StdRng::seed_from_u64(0);
let mut model: GeneralizedResidual<Linear<2, 5>, Linear<2, 5>> = Default::default();
assert_eq!(model.0.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.0.bias.data(), &[0.0; 5]);
assert_eq!(model.1.weight.data(), &[[0.0; 2]; 5]);
assert_eq!(model.1.bias.data(), &[0.0; 5]);

model.reset_params(&mut rng);
assert_ne!(model.0.weight.data(), &[[0.0; 2]; 5]);
assert_ne!(model.0.bias.data(), &[0.0; 5]);
assert_ne!(model.1.weight.data(), &[[0.0; 2]; 5]);
assert_ne!(model.1.bias.data(), &[0.0; 5]);
}

const W0: [[f32; 2]; 5] = [
[0.63315326, 0.3361526],
[0.60201937, 0.30927354],
[0.39831632, 0.29526848],
[-0.4730785, -0.10664469],
[0.5074884, -0.08458644],
];
const B0: [f32; 5] = [-0.7014593, 0.01725882, 0.67181975, -0.61593556, 0.27809095];

const W2: [[f32; 5]; 2] = [
[0.37967658, -0.30938417, -0.4046409, 0.34131002, -0.36532],
[0.01010674, 0.2922417, -0.28791183, 0.09316397, 0.00722069],
];
const B2: [f32; 2] = [-0.01353309, 0.19437504];
const X: [[f32; 2]; 10] = [
[0.9706649, -0.50246257],
[0.36609784, 0.22519696],
[-0.26957038, -2.4395447],
[0.729607, 0.06136635],
[1.0758572, -0.6158074],
[1.844528, -0.7769507],
[-0.83232504, 0.26263165],
[-0.18690403, 0.5396985],
[-1.0891576, 0.09805013],
[-0.63034505, 2.4173584],
];
const Y: [[f32; 2]; 10] = [
[0.15374291, -0.43383744],
[-0.26277426, 0.25803787],
[-0.41010314, -2.2426596],
[-0.062764645, 0.117026225],
[0.2237711, -0.54089284],
[0.69048953, -0.6508272],
[-1.0149324, 0.33670622],
[-0.57907265, 0.53813595],
[-1.2107061, 0.21556953],
[-1.2221863, 2.3977249],
];

const W0G: [[f32; 2]; 5] = [
[0.035948314, -0.015142122],
[-0.0035737813, -0.001155745],
[-0.07784372, -0.059181444],
[0.0, 0.0],
[-0.081114516, 0.06281963],
];
const B0G: [f32; 5] = [0.019489167, -0.005999865, -0.3116488, 0.0, -0.12533475];
const W2G: [[f32; 5]; 2] = [[0.010261777, 0.15239798, 0.37232202, 0.0, 0.22712366]; 2];
const B2G: [f32; 2] = [0.50000006; 2];

#[test]
fn test_residual_forward_backward_resadd_as_main() {
type SubModel = (Linear<2, 5>, ReLU, Linear<5, 2>);
type Model = GeneralizedResidual<SubModel, ReLU>;

let mut model: Model = Default::default();
*model.0 .0.weight.mut_data() = W0;
*model.0 .0.bias.mut_data() = B0;
*model.0 .2.weight.mut_data() = W2;
*model.0 .2.bias.mut_data() = B2;

let x = Tensor2D::new(X);
let y = model.forward(x.traced());
// Y = s(x) + x, including negative x
// Y2 would be s(x) + r(x) [r == ReLU]
// Y2 = s(x) + r(x) = Y - x + r(x)
// Y2 = Y - (x - r(x))
// x - r(x) = {0, if x >= 0, because r(x) = x
// x - r(x) = {x, if x < 0, because r(x) = 0 => x - 0 = x
// this is r(-x), since this returns x if x < 0 and 0 elsewhere
// => Y2 = Y - r(-x)
assert_close(
y.data(),
add(Tensor2D::new(Y), &(-Tensor2D::new(X)).relu()).data(),
);

let gradients = y.mean().backward();

assert_close(gradients.ref_gradient(&model.0 .0.weight), &W0G);
assert_close(gradients.ref_gradient(&model.0 .0.bias), &B0G);
assert_close(gradients.ref_gradient(&model.0 .2.weight), &W2G);
assert_close(gradients.ref_gradient(&model.0 .2.bias), &B2G);
}

#[test]
fn test_residual_forward_backward_with_update() {
type SubModel = (Linear<2, 5>, ReLU, Linear<5, 2>);
type Model = GeneralizedResidual<SubModel, SubModel>;

let mut model: Model = Default::default();
*model.0 .0.weight.mut_data() = W0;
*model.0 .0.bias.mut_data() = B0;
*model.0 .2.weight.mut_data() = W2;
*model.0 .2.bias.mut_data() = B2;
*model.1 .0.weight.mut_data() = W0;
*model.1 .0.bias.mut_data() = B0;
*model.1 .2.weight.mut_data() = W2;
*model.1 .2.bias.mut_data() = B2;

let mut model2: SubModel = Default::default();
*model2.0.weight.mut_data() = W0;
*model2.0.bias.mut_data() = B0;
// The submodel s(x) = l(x) with l(x) = ax + b and is the last linear layer
// model2 has to be model + model = 2 * model => s2(x) = 2 * s(x) => s2(x) = 2ax + 2b
// => a' = 2a; b' = 2b
*model2.2.weight.mut_data() = W2;
model2.2.weight = model2.2.weight * 2.0;
*model2.2.bias.mut_data() = B2;
model2.2.bias = model2.2.bias * 2.0;

let x = Tensor2D::new(X);
let y = model.forward(x.traced());
let x2 = Tensor2D::new(X);
let y2 = model2.forward(x2.traced());
assert_close(y.data(), y2.data());

let gradients = y.mean().backward();
let gradients2 = y2.mean().backward();

assert_close(gradients.ref_gradient(&model.0 .0.weight), &W0G);
assert_close(gradients.ref_gradient(&model.0 .0.bias), &B0G);
assert_close(gradients.ref_gradient(&model.0 .2.weight), &W2G);
assert_close(gradients.ref_gradient(&model.0 .2.bias), &B2G);
assert_close(gradients.ref_gradient(&model.1 .0.weight), &W0G);
assert_close(gradients.ref_gradient(&model.1 .0.bias), &B0G);
assert_close(gradients.ref_gradient(&model.1 .2.weight), &W2G);
assert_close(gradients.ref_gradient(&model.1 .2.bias), &B2G);
assert_close(
gradients2.ref_gradient(&model2.0.weight),
(Tensor2D::new(W0G) * 2.0).data(),
);
assert_close(
gradients2.ref_gradient(&model2.0.bias),
(Tensor1D::new(B0G) * 2.0).data(),
);
// no multiplication with 2 here since f'(x) = h'(x) * g'(h(j(x))) with f(x) = g(h(j(x)))
// In this case, it's f(x) = g(h(j(2x))) => f'(x) = h'(j(2x)) * g'(h(j(x))),
// while g(x) = h(j(2x)) => g'(x) = 2 * j'(x) * h'(j(x))
assert_close(gradients2.ref_gradient(&model2.2.weight), &W2G);
assert_close(gradients2.ref_gradient(&model2.2.bias), &B2G);

// with lr = 1, w* = w - w'
let sgd_config = SgdConfig {
lr: 1.0,
momentum: None,
};
Sgd::new(sgd_config).update(&mut model, gradients);
Sgd::new(sgd_config).update(&mut model2, gradients2);

assert_close(
model.0 .0.weight.data(),
sub(Tensor2D::new(W0), &Tensor2D::new(W0G)).data(),
);
assert_close(
model.0 .0.bias.data(),
sub(Tensor1D::new(B0), &Tensor1D::new(B0G)).data(),
);
assert_close(
model.0 .2.weight.data(),
sub(Tensor2D::new(W2), &Tensor2D::new(W2G)).data(),
);
assert_close(
model.0 .2.bias.data(),
sub(Tensor1D::new(B2), &Tensor1D::new(B2G)).data(),
);
assert_close(
model.1 .0.weight.data(),
sub(Tensor2D::new(W0), &Tensor2D::new(W0G)).data(),
);
assert_close(
model.1 .0.bias.data(),
sub(Tensor1D::new(B0), &Tensor1D::new(B0G)).data(),
);
assert_close(
model.1 .2.weight.data(),
sub(Tensor2D::new(W2), &Tensor2D::new(W2G)).data(),
);
assert_close(
model.1 .2.bias.data(),
sub(Tensor1D::new(B2), &Tensor1D::new(B2G)).data(),
);
}

// gradients have to be summed, r(x) = g(x) + h(x) => r'(x) = g'(x) + h'(x)
#[test]
fn test_residual_gradients_correctly_added() {
type Model = (Linear<1, 1>, GeneralizedResidual<ReLU, ReLU>);
// Linear<1, 2>-layer has weights with one and bias zeroed
let mut model: Model = Default::default();
*model.0.weight.mut_data() = [[1.0]];

let x = Tensor2D::new([[-1.0], [1.0]]);
let y = model.forward(x.traced());

assert_close(y.data(), &[[0.0], [2.0]]);

let grads = y.mean().backward();

// m(x) = r(x) + r(x) = 2r(x); m'(x) = 2r'(x)
// y_mean(x_1, x_2) = (m(x_1) + m(x_2)) / 2 = (2r(x_1) + 2r(x_2)) / 2 = r(x_1) + r(x_2)
// x_1 = -1; x_2 = 1 => r(-1) + r(1) = 0 + 1 = 1
// y_mean'(x_1, x_2) = r'(x_1) + r'(x_2) = r'(-1) + r'(1) = 0 + 1 = 1
// x_1 and x_2 are from the Linear layer, so x_1 = ax + b with a = 1 and b = 0
// => derivates for a and b are the same
assert_close(grads.ref_gradient(&model.0.weight), &[[1.0]]);
assert_close(grads.ref_gradient(&model.0.bias), &[1.0]);
}

// test for different input and output shapes?

#[test]
fn test_save_residual() {
let model: GeneralizedResidual<Linear<5, 3>, Linear<5, 3>> = Default::default();
let file = NamedTempFile::new().expect("failed to create tempfile");
model
.save(file.path().to_str().unwrap())
.expect("failed to save model");
let f = File::open(file.path()).expect("failed to open resulting file");
let mut zip = ZipArchive::new(f).expect("failed to create zip archive from file");
{
let weight_file = zip
.by_name("weight.npy")
.expect("failed to find weight.npy file");
assert!(weight_file.size() > 0);
}
{
let bias_file = zip
.by_name("bias.npy")
.expect("failed to find bias.npy file");
assert!(bias_file.size() > 0);
}
}

/// TODO test fails
#[test]
fn test_load_residual() {
let mut rng = StdRng::seed_from_u64(0);
let mut saved_model: GeneralizedResidual<Linear<5, 3>, Linear<5, 3>> = Default::default();
saved_model.reset_params(&mut rng);

let file = NamedTempFile::new().expect("failed to create tempfile");
assert!(saved_model.save(file.path().to_str().unwrap()).is_ok());

let mut loaded_model: GeneralizedResidual<Linear<5, 3>, Linear<5, 3>> = Default::default();
assert_ne!(loaded_model.0.weight.data(), saved_model.0.weight.data());
assert_ne!(loaded_model.0.bias.data(), saved_model.0.bias.data());

assert_ne!(loaded_model.1.weight.data(), saved_model.1.weight.data());
assert_ne!(loaded_model.1.bias.data(), saved_model.1.bias.data());

assert!(loaded_model.load(file.path().to_str().unwrap()).is_ok());
// only the next 2 lines are 'failing
assert_eq!(loaded_model.0.weight.data(), saved_model.0.weight.data());
assert_eq!(loaded_model.0.bias.data(), saved_model.0.bias.data());

assert_eq!(loaded_model.1.weight.data(), saved_model.1.weight.data());
assert_eq!(loaded_model.1.bias.data(), saved_model.1.bias.data());
}
}
2 changes: 2 additions & 0 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

mod activations;
mod dropout;
mod generalized_residual;
mod impl_module_for_tuples;
mod layer_norm;
mod linear;
Expand All @@ -67,6 +68,7 @@ mod split_into;

pub use activations::*;
pub use dropout::*;
pub use generalized_residual::*;
pub use impl_module_for_tuples::*;
pub use layer_norm::*;
pub use linear::*;
Expand Down
Loading