From 5a59b473aa8765bbe658a6187d2a3da41debc433 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 09:00:12 -0500 Subject: [PATCH 01/20] Temp commit --- examples/02-ops.rs | 4 +- examples/03-nn.rs | 4 +- src/nn/add_into.rs | 45 ++---- src/nn/batchnorm2d.rs | 44 +++--- src/nn/conv.rs | 50 +++--- src/nn/embedding.rs | 96 +++++------ src/nn/generalized_residual.rs | 34 ++-- src/nn/impl_module_for_tuples.rs | 65 +++----- src/nn/layer_norm.rs | 74 ++++----- src/nn/linear.rs | 74 ++++----- src/nn/mod.rs | 29 +++- src/nn/module.rs | 32 +--- src/nn/repeated.rs | 60 +++---- src/nn/residual.rs | 25 +-- src/nn/split_into.rs | 47 +++--- src/nn/transformer/decoder.rs | 53 ++----- src/nn/transformer/encoder.rs | 30 +--- src/nn/transformer/mha.rs | 45 ++---- src/nn/transformer/mod.rs | 40 ++--- src/optim/adam/mod.rs | 27 ++-- src/optim/mod.rs | 4 +- src/optim/optimizer.rs | 46 ++---- src/optim/rmsprop/mod.rs | 31 ++-- src/optim/sgd/mod.rs | 23 +-- src/tensor/mod.rs | 8 +- src/tensor/visitors/base.rs | 236 ++++++++++++++++++++++++++++ src/tensor/visitors/mod.rs | 7 + src/tensor/visitors/num_params.rs | 33 ++++ src/tensor/visitors/reset_params.rs | 30 ++++ 29 files changed, 700 insertions(+), 596 deletions(-) create mode 100644 src/tensor/visitors/base.rs create mode 100644 src/tensor/visitors/mod.rs create mode 100644 src/tensor/visitors/num_params.rs create mode 100644 src/tensor/visitors/reset_params.rs diff --git a/examples/02-ops.rs b/examples/02-ops.rs index a6d3ea23b..bd7859131 100644 --- a/examples/02-ops.rs +++ b/examples/02-ops.rs @@ -2,7 +2,7 @@ use dfdx::{ shapes::{Rank0, Rank1, Rank2}, - tensor::{AsArray, Cpu, SampleTensor, Tensor, ToDevice}, + tensor::{AsArray, Cpu, SampleTensor, Tensor}, tensor_ops::{MeanTo, TryMatMul}, }; @@ -63,6 +63,8 @@ fn main() { // these operations are equal across devices #[cfg(feature = "cuda")] { + use dfdx::tensor::ToDevice; + let cpu = Cpu::default(); let a: Tensor, f32, _> = dev.sample_normal(); diff --git a/examples/03-nn.rs b/examples/03-nn.rs index c14836f20..e1c2e10ad 100644 --- a/examples/03-nn.rs +++ b/examples/03-nn.rs @@ -1,9 +1,9 @@ //! Intro to dfdx::nn use dfdx::{ - nn::{builders::*, BuildOnDevice, DeviceBuildExt, Module, ModuleMut, ResetParams}, + nn::{builders::*, BuildOnDevice, DeviceBuildExt, Module, ModuleMut}, shapes::{Const, Rank1, Rank2}, - tensor::{AsArray, SampleTensor, Tensor, ZerosTensor}, + tensor::{AsArray, ResetParams, SampleTensor, Tensor, ZerosTensor}, }; #[cfg(not(feature = "cuda"))] diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index c3674177b..d12d0fd69 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -1,6 +1,6 @@ -use crate::{optim::*, shapes::Dtype, tensor_ops::Device}; +use crate::{shapes::Dtype, tensor::visitors::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Add inputs together into a single tensor. `T` should be a tuple //// where every element of the tuple has the same output type @@ -23,28 +23,19 @@ use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice #[derive(Debug, Default, Clone)] pub struct AddInto(pub T); -impl, D: Device, E: Dtype> GradientUpdate for AddInto { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.0.update(updater, unused) - } -} - -impl, D: Device, E: Dtype> BuildOnDevice for AddInto { +impl, D: DeviceStorage, E: Dtype> BuildOnDevice for AddInto { type Built = AddInto; } -impl, D: Device, E: Dtype> BuildModule for AddInto { +impl, D: DeviceStorage, E: Dtype> BuildModule for AddInto { fn try_build(device: &D) -> Result::Err> { Ok(Self(BuildModule::try_build(device)?)) } } -impl, D: Device, E: Dtype> ResetParams for AddInto { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.0.try_reset_params() +impl> TensorCollection for AddInto { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -103,8 +94,8 @@ mod tests { use crate::{ gradients::OwnedTape, nn::{builders::*, tests::SimpleUpdater, DeviceBuildExt}, + optim::*, shapes::*, - tensor::*, tests::{TestDevice, TestDtype}, unique_id::HasUniqueId, }; @@ -229,10 +220,9 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); + model.update(&mut g).unwrap(); assert_eq!( - &unused.ids, + &g.unused.ids, &[ *model.0 .0.weight.id(), *model.0 .0.bias.id(), @@ -242,14 +232,13 @@ mod tests { ); // weight gradient is present - g.0.try_alloc_for(&model.0 .0.weight).unwrap(); - g.0.try_alloc_for(&model.0 .0.bias).unwrap(); - g.0.try_alloc_for(&model.0 .1.weight).unwrap(); - g.0.try_alloc_for(&model.0 .1.bias).unwrap(); - - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.grads.try_alloc_for(&model.0 .0.weight).unwrap(); + g.grads.try_alloc_for(&model.0 .0.bias).unwrap(); + g.grads.try_alloc_for(&model.0 .1.weight).unwrap(); + g.grads.try_alloc_for(&model.0 .1.bias).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } #[test] diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index a3d99f532..3acd5c7b7 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -1,6 +1,6 @@ -use crate::{gradients::*, optim::*, shapes::*, tensor::*, tensor_ops::*}; +use crate::{gradients::*, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug, Copy, Clone, Eq, PartialEq)] @@ -183,13 +183,28 @@ impl> BuildModule for BatchNorm2D> ResetParams for BatchNorm2D { - fn try_reset_params(&mut self) -> Result<(), D::Err> { - self.scale.try_fill_with_ones()?; - self.bias.try_fill_with_zeros()?; - self.running_mean.try_fill_with_zeros()?; - self.running_var.try_fill_with_ones()?; - Ok(()) +impl> TensorCollection for BatchNorm2D { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_tensor( + |s| &s.scale, + |s| &mut s.scale, + TensorOptions::named("scale", |t| t.try_fill_with_ones()), + )?; + visitor.visit_tensor( + |s| &s.bias, + |s| &mut s.bias, + TensorOptions::named("bias", |t| t.try_fill_with_zeros()), + )?; + visitor.visit_tensor( + |s| &s.running_mean, + |s| &mut s.running_mean, + TensorOptions::no_grad("running_mean", |t| t.try_fill_with_zeros()), + )?; + visitor.visit_tensor( + |s| &s.running_var, + |s| &mut s.running_var, + TensorOptions::no_grad("running_var", |t| t.try_fill_with_ones()), + ) } } @@ -209,17 +224,6 @@ impl, D2: Device> ToDevice } } -impl> GradientUpdate for BatchNorm2D { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.scale.update(updater, unused)?; - self.bias.update(updater, unused)?; - Ok(()) - } -} - #[cfg(test)] mod tests { use super::builder::BatchNorm2D; diff --git a/src/nn/conv.rs b/src/nn/conv.rs index a035be606..3f7a25381 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -1,9 +1,9 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; -use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug)] @@ -54,18 +54,28 @@ pub struct Conv2D< } impl - GradientUpdate for Conv2D + TensorCollection for Conv2D where - E: Dtype, + E: Dtype + Float + SampleUniform, D: Device, { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.weight.update(updater, unused)?; - self.bias.update(updater, unused)?; - Ok(()) + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_tensor( + |s| &s.weight, + |s| &mut s.weight, + TensorOptions::named("weight", |t| { + let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); + t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) + }), + )?; + visitor.visit_tensor( + |s| &s.bias, + |s| &mut s.bias, + TensorOptions::named("bias", |t| { + let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); + t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) + }), + ) } } @@ -85,23 +95,6 @@ where } } -impl - ResetParams for Conv2D -where - E: Dtype + Float + SampleUniform, - D: Device, -{ - fn try_reset_params(&mut self) -> Result<(), ::Err> { - let k = E::from_usize(I * K * K).unwrap(); - let bound = E::ONE / k.sqrt(); - self.weight - .try_fill_with_distr(rand_distr::Uniform::new(-bound, bound))?; - self.bias - .try_fill_with_distr(rand_distr::Uniform::new(-bound, bound))?; - Ok(()) - } -} - impl ToDevice for Conv2D where @@ -175,6 +168,7 @@ impl<'a, B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device, T: Tape mod tests { use crate::{ nn::DeviceBuildExt, + optim::*, tensor::{AsArray, SampleTensor, ZerosTensor}, tests::*, }; diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index 09196c4ac..3703cba3c 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -1,9 +1,9 @@ use num_traits::Float; -use rand_distr::uniform::SampleUniform; +use rand_distr::{uniform::SampleUniform, Uniform}; -use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; -use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::module::{BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; pub mod builder { #[derive(Debug)] @@ -52,6 +52,36 @@ pub struct Embedding, E, D>, } +impl NonMutableModule + for Embedding +{ +} + +impl> + BuildModule for Embedding +{ + fn try_build(device: &D) -> Result { + let bound = E::ONE / E::from_usize(V).unwrap().sqrt(); + let weight = device.try_sample(Uniform::new(-bound, bound))?; + Ok(Self { weight }) + } +} + +impl> + TensorCollection for Embedding +{ + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_tensor( + |s| &s.weight, + |s| &mut s.weight, + TensorOptions::named("weight", |t| { + let b: E = E::ONE / E::from_usize(C).unwrap().sqrt(); + t.try_fill_with_distr(Uniform::new(-b, b)) + }), + ) + } +} + impl, T: Tape> Module, usize, D, T>> for Embedding { @@ -79,51 +109,6 @@ impl< } } -impl> ModuleMut - for Embedding -where - Self: Module, -{ - type Output = >::Output; - fn forward_mut(&mut self, input: T) -> Self::Output { - self.forward(input) - } -} - -impl> GradientUpdate - for Embedding -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err> - where - U: ParamUpdater, - { - self.weight.update(updater, unused)?; - Ok(()) - } -} - -impl> - BuildModule for Embedding -{ - fn try_build(device: &D) -> Result { - let bound = E::ONE / E::from_usize(V).unwrap().sqrt(); - let distr = rand_distr::Uniform::new(-bound, bound); - let weight = device.try_sample(distr)?; - Ok(Self { weight }) - } -} - -impl> - ResetParams for Embedding -{ - fn try_reset_params(&mut self) -> Result<(), D::Err> { - let bound = E::ONE / E::from_usize(VOCAB).unwrap().sqrt(); - let distr = rand_distr::Uniform::new(-bound, bound); - self.weight.try_fill_with_distr(distr)?; - Ok(()) - } -} - impl, D2: Device> ToDevice for Embedding { @@ -140,6 +125,7 @@ mod tests { use super::*; use crate::{ nn::{tests::SimpleUpdater, DeviceBuildExt}, + optim::*, tests::*, unique_id::HasUniqueId, }; @@ -254,15 +240,13 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert_eq!(&unused.ids, &[*model.weight.id()]); - - g.0.try_alloc_for(&model.weight).unwrap(); + model.update(&mut g).unwrap(); + assert_eq!(&g.unused.ids, &[*model.weight.id()]); // weight gradient is present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.grads.try_alloc_for(&model.weight).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index c55d92e61..8a1e35322 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -1,6 +1,6 @@ -use crate::{optim::*, shapes::*, tensor::*, tensor_ops::*}; +use crate::{shapes::*, tensor::visitors::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// A residual connection `R` around `F`: `F(x) + R(x)`, /// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). @@ -25,26 +25,13 @@ pub struct GeneralizedResidual { pub r: R, } -impl, E: Dtype, F: GradientUpdate, R: GradientUpdate> GradientUpdate - for GeneralizedResidual -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.f.update(updater, unused)?; - self.r.update(updater, unused)?; - Ok(()) - } -} - -impl, E: Dtype, F: BuildOnDevice, R: BuildOnDevice> BuildOnDevice +impl, R: BuildOnDevice> BuildOnDevice for GeneralizedResidual { type Built = GeneralizedResidual; } -impl, E: Dtype, F: BuildModule, R: BuildModule> BuildModule +impl, R: BuildModule> BuildModule for GeneralizedResidual { fn try_build(device: &D) -> Result::Err> { @@ -55,13 +42,12 @@ impl, E: Dtype, F: BuildModule, R: BuildModule> BuildMo } } -impl, E: Dtype, F: ResetParams, R: ResetParams> ResetParams - for GeneralizedResidual +impl, R: TensorCollection> + TensorCollection for GeneralizedResidual { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.f.try_reset_params()?; - self.r.try_reset_params()?; - Ok(()) + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_module(|s| &s.f, |s| &mut s.f, "f")?; + visitor.visit_module(|s| &s.r, |s| &mut s.r, "r") } } @@ -102,7 +88,7 @@ mod tests { use super::*; use crate::nn::builders::Linear; use crate::nn::DeviceBuildExt; - use crate::tests::*; + use crate::{tensor_ops::*, tests::*}; #[test] fn test_reset_generalized_residual() { diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 510b4ae1e..68fea1b29 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -1,17 +1,12 @@ -use crate::{optim::*, shapes::*, tensor_ops::*}; +use crate::{shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; -use super::module::{ - BuildModule, BuildOnDevice, Module, ModuleMut, OnDevice, ResetParams, ToDevice, -}; +use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, OnDevice, ToDevice}; macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { - impl, E: Dtype, $($name: GradientUpdate),+> GradientUpdate for ($($name,)+) { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err> - where - U: ParamUpdater - { - $(self.$idx.update(updater, unused)?;)+ + impl),+> TensorCollection for ($($name,)+) { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + $(visitor.visit_module(|s| &s.$idx, |s| &mut s.$idx, &std::format!("{}", $idx))?;)+ Ok(()) } } @@ -28,14 +23,6 @@ macro_rules! tuple_impls { } } - impl, E: Dtype, $($name: ResetParams),+> ResetParams for ($($name,)+) { - #[allow(non_snake_case)] - fn try_reset_params(&mut self) -> Result<(), D::Err> { - $(self.$idx.try_reset_params()?;)+ - Ok(()) - } - } - impl<$($name: ToDevice,)+ D> ToDevice for ($($name,)+) { type Output = ($(OnDevice<$name, D>,)+); fn to_device(&self, device: &D) -> Self::Output { @@ -110,12 +97,11 @@ tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5], M6, [M5, M4, M3, M2, M mod tests { use super::*; use crate::nn::tests::SimpleUpdater; - use crate::tests::TestDtype; use crate::unique_id::HasUniqueId; use crate::{ nn::{builders::*, *}, - tensor::*, - tests::TestDevice, + optim::*, + tests::*, }; #[test] @@ -268,10 +254,9 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused: UnusedTensors = Default::default(); - model.update(&mut g, &mut unused).unwrap(); + model.update(&mut g).unwrap(); assert_eq!( - &unused.ids, + &g.unused.ids, &[ *model.0.weight.id(), *model.0.bias.id(), @@ -283,26 +268,24 @@ mod tests { ); // weight gradient is present - g.0.try_alloc_for(&model.0.weight).unwrap(); - g.0.try_alloc_for(&model.1.weight).unwrap(); - g.0.try_alloc_for(&model.2.weight).unwrap(); - - let mut unused: UnusedTensors = Default::default(); - model.update(&mut g, &mut unused).unwrap(); + g.grads.try_alloc_for(&model.0.weight).unwrap(); + g.grads.try_alloc_for(&model.1.weight).unwrap(); + g.grads.try_alloc_for(&model.2.weight).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); assert_eq!( - &unused.ids, + &g.unused.ids, &[*model.0.bias.id(), *model.1.bias.id(), *model.2.bias.id(),] ); - g.0.try_alloc_for(&model.0.weight).unwrap(); - g.0.try_alloc_for(&model.0.bias).unwrap(); - g.0.try_alloc_for(&model.1.weight).unwrap(); - g.0.try_alloc_for(&model.1.bias).unwrap(); - g.0.try_alloc_for(&model.2.weight).unwrap(); - g.0.try_alloc_for(&model.2.bias).unwrap(); - - let mut unused: UnusedTensors = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.grads.try_alloc_for(&model.0.weight).unwrap(); + g.grads.try_alloc_for(&model.0.bias).unwrap(); + g.grads.try_alloc_for(&model.1.weight).unwrap(); + g.grads.try_alloc_for(&model.1.bias).unwrap(); + g.grads.try_alloc_for(&model.2.weight).unwrap(); + g.grads.try_alloc_for(&model.2.bias).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index 0dabf14d4..8dd088859 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -1,6 +1,6 @@ -use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; pub mod builder { #[derive(Debug)] @@ -42,6 +42,8 @@ pub struct LayerNorm1D { pub epsilon: E, } +impl NonMutableModule for LayerNorm1D {} + impl> BuildModule for LayerNorm1D { /// Fills [Self::gamma] with 1s and [Self::beta] with 0s and sets [Self::epsilon] to `1e-5`. fn try_build(device: &D) -> Result { @@ -53,11 +55,18 @@ impl> BuildModule for LayerNorm1D> ResetParams for LayerNorm1D { - fn try_reset_params(&mut self) -> Result<(), D::Err> { - self.gamma.try_fill_with_ones()?; - self.beta.try_fill_with_zeros()?; - Ok(()) +impl> TensorCollection for LayerNorm1D { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_tensor( + |s| &s.gamma, + |s| &mut s.gamma, + TensorOptions::named("gamma", |t| t.try_fill_with_ones()), + )?; + visitor.visit_tensor( + |s| &s.beta, + |s| &mut s.beta, + TensorOptions::named("beta", |t| t.try_fill_with_zeros()), + ) } } @@ -75,17 +84,6 @@ impl, D2: Device> ToDevice } } -impl> GradientUpdate for LayerNorm1D { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.gamma.update(updater, unused)?; - self.beta.update(updater, unused)?; - Ok(()) - } -} - impl, T: Tape> Module, E, D, T>> for LayerNorm1D { @@ -117,23 +115,12 @@ impl, T: Tape> } } -impl> ModuleMut for LayerNorm1D -where - Self: Module, -{ - type Output = >::Output; - fn forward_mut(&mut self, input: T) -> Self::Output { - self.forward(input) - } -} - #[cfg(test)] mod tests { use super::*; use crate::nn::tests::SimpleUpdater; - use crate::nn::DeviceBuildExt; - use crate::tests::{assert_close, TestDevice, TestDtype}; - use crate::unique_id::HasUniqueId; + use crate::nn::{DeviceBuildExt, ModuleMut}; + use crate::{optim::*, tests::*, unique_id::HasUniqueId}; #[test] fn test_layer_norm_reset() { @@ -203,23 +190,20 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert_eq!(&unused.ids, &[*model.gamma.id(), *model.beta.id()]); - - g.0.try_alloc_for(&model.gamma).unwrap(); + model.update(&mut g).unwrap(); + assert_eq!(&g.unused.ids, &[*model.gamma.id(), *model.beta.id()]); // weight gradient is present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert_eq!(&unused.ids, &[*model.beta.id()]); - - g.0.try_alloc_for(&model.gamma).unwrap(); - g.0.try_alloc_for(&model.beta).unwrap(); + g.grads.try_alloc_for(&model.gamma).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert_eq!(&g.unused.ids, &[*model.beta.id()]); // all gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.grads.try_alloc_for(&model.gamma).unwrap(); + g.grads.try_alloc_for(&model.beta).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 495df1faf..a4ba5faa8 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -1,6 +1,6 @@ -use crate::{gradients::Tape, optim::*, shapes::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; -use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::module::{BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; use num_traits::Float; use rand_distr::{uniform::SampleUniform, Uniform}; @@ -52,17 +52,9 @@ pub struct Linear { pub bias: Tensor, E, D>, } -impl GradientUpdate +impl NonMutableModule for Linear { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err> - where - U: ParamUpdater, - { - self.weight.update(updater, unused)?; - self.bias.update(updater, unused)?; - Ok(()) - } } impl> @@ -77,13 +69,25 @@ impl> - ResetParams for Linear + TensorCollection for Linear { - fn try_reset_params(&mut self) -> Result<(), D::Err> { - let b: E = E::ONE / E::from_usize(I).unwrap().sqrt(); - self.weight.try_fill_with_distr(Uniform::new(-b, b))?; - self.bias.try_fill_with_distr(Uniform::new(-b, b))?; - Ok(()) + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_tensor( + |s| &s.weight, + |s| &mut s.weight, + TensorOptions::named("weight", |t| { + let b: E = E::ONE / E::from_usize(I).unwrap().sqrt(); + t.try_fill_with_distr(Uniform::new(-b, b)) + }), + )?; + visitor.visit_tensor( + |s| &s.bias, + |s| &mut s.bias, + TensorOptions::named("bias", |t| { + let b: E = E::ONE / E::from_usize(I).unwrap().sqrt(); + t.try_fill_with_distr(Uniform::new(-b, b)) + }), + ) } } @@ -114,16 +118,6 @@ where } } -impl> ModuleMut for Linear -where - Self: Module, -{ - type Output = >::Output; - fn forward_mut(&mut self, input: T) -> Self::Output { - self.forward(input) - } -} - #[derive(Clone, Debug)] struct Bias1D<'a, const M: usize, E: Dtype, D: DeviceStorage> { beta: &'a Tensor, E, D>, @@ -161,6 +155,7 @@ mod tests { use super::*; use crate::{ nn::{tests::SimpleUpdater, DeviceBuildExt}, + optim::*, tests::*, unique_id::HasUniqueId, }; @@ -301,23 +296,20 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert_eq!(&unused.ids, &[*model.weight.id(), *model.bias.id()]); - - g.0.try_alloc_for(&model.weight).unwrap(); + model.update(&mut g).unwrap(); + assert_eq!(&g.unused.ids, &[*model.weight.id(), *model.bias.id()]); // weight gradient is present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert_eq!(&unused.ids, &[*model.bias.id()]); - - g.0.try_alloc_for(&model.weight).unwrap(); - g.0.try_alloc_for(&model.bias).unwrap(); + g.grads.try_alloc_for(&model.weight).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert_eq!(&g.unused.ids, &[*model.bias.id()]); // both gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.grads.try_alloc_for(&model.weight).unwrap(); + g.grads.try_alloc_for(&model.bias).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 10a14019b..993a62548 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -185,19 +185,32 @@ pub mod builders { #[cfg(test)] mod tests { - use crate::{gradients::Gradients, optim::ParamUpdater, shapes::Dtype, tensor::DeviceStorage}; + use crate::{ + gradients::Gradients, optim::UnusedTensors, shapes::Dtype, tensor::visitors::*, + tensor::DeviceStorage, + }; #[derive(Default)] - pub struct SimpleUpdater(pub Gradients); + pub struct SimpleUpdater { + pub grads: Gradients, + pub unused: UnusedTensors, + } + + impl SimpleUpdater { + pub(crate) fn clear_unused(&mut self) { + self.unused.clear(); + } + } - impl ParamUpdater for SimpleUpdater { - fn update_param( + impl VisitTensorMut for SimpleUpdater { + fn visit( &mut self, - p: &mut crate::tensor::Tensor, - unused: &mut crate::optim::UnusedTensors, + _: alloc::string::String, + _: TensorOptions, + p: &mut crate::prelude::Tensor, ) -> Result<(), ::Err> { - if self.0.remove(p).is_none() { - unused.add(p); + if self.grads.remove(p).is_none() { + self.unused.add(p); } Ok(()) } diff --git a/src/nn/module.rs b/src/nn/module.rs index 7c3b69ce9..6f85e297b 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -1,8 +1,10 @@ -use crate::{optim::GradientUpdate, shapes::Dtype}; - #[cfg(feature = "cuda")] pub use crate::tensor::OnCuda; pub use crate::tensor::{DeviceStorage, OnCpu, OnDevice, ToDevice}; +use crate::{ + shapes::Dtype, + tensor::visitors::{ModuleWalker, TensorCollection}, +}; /// Immutable forward of `Input` that produces [Module::Output]. /// See [ModuleMut] for mutable forward. @@ -64,19 +66,6 @@ pub trait DeviceBuildExt: DeviceStorage { } impl DeviceBuildExt for D {} -/// Something that can reset it's parameters. -pub trait ResetParams { - /// Mutates parameters. Each implementor - /// of this trait decides how the parameters are initialized. In - /// fact, some impls may not even use randomness. - fn reset_params(&mut self) { - self.try_reset_params().unwrap(); - } - - /// Fallible version of [ResetParams::reset_params]. - fn try_reset_params(&mut self) -> Result<(), D::Err>; -} - /// Marker trait for modules with no updatable parameters. These have /// blanket impls for [ResetParams], [GradientUpdate], and [ModuleMut] pub trait ZeroSizedModule: Default {} @@ -85,8 +74,8 @@ impl, D: DeviceStorage, E: Dtype> BuildOn type Built = T; } -impl ResetParams for T { - fn try_reset_params(&mut self) -> Result<(), ::Err> { +impl TensorCollection for T { + fn iter_tensors>(_: &mut V) -> Result<(), ::Err> { Ok(()) } } @@ -98,15 +87,6 @@ impl ToDevice for T { } } -impl GradientUpdate for T { - fn update(&mut self, _: &mut U, _: &mut crate::optim::UnusedTensors) -> Result<(), ::Err> - where - U: crate::optim::ParamUpdater, - { - Ok(()) - } -} - /// Marker trait for modules that don't have different behavior between /// mutable forwards and non-mutable forwards pub trait NonMutableModule {} diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index 315d8b3b7..a23927338 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -1,6 +1,6 @@ -use crate::{optim::*, shapes::Dtype, tensor_ops::Device}; +use crate::{shapes::Dtype, tensor::visitors::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Repeats `T` `N` times. This requires that `T`'s input is the same as it's output. /// @@ -21,13 +21,13 @@ pub struct Repeated { pub modules: std::vec::Vec, } -impl, E: Dtype, T: BuildOnDevice, const N: usize> BuildOnDevice +impl, const N: usize> BuildOnDevice for Repeated { type Built = Repeated; } -impl, E: Dtype, T: BuildModule, const N: usize> BuildModule +impl, const N: usize> BuildModule for Repeated { fn try_build(device: &D) -> Result::Err> { @@ -39,12 +39,16 @@ impl, E: Dtype, T: BuildModule, const N: usize> BuildModule, E: Dtype, T: ResetParams, const N: usize> ResetParams +impl, const N: usize> TensorCollection for Repeated { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - for m in self.modules.iter_mut() { - m.try_reset_params()?; + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + for i in 0..N { + visitor.visit_module( + |s| &s.modules[i], + |s| &mut s.modules[i], + &std::format!("{i}"), + )?; } Ok(()) } @@ -70,20 +74,6 @@ impl std::ops::Index for Repeated { } } -impl, E: Dtype, T: GradientUpdate, const N: usize> GradientUpdate - for Repeated -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - for m in self.modules.iter_mut() { - m.update(updater, unused)?; - } - Ok(()) - } -} - impl, const N: usize> Module for Repeated { type Output = T::Output; fn forward(&self, mut x: Input) -> Self::Output { @@ -111,7 +101,7 @@ mod tests { use super::*; use crate::nn::DeviceBuildExt; use crate::tests::TestDtype; - use crate::{nn::builders::*, shapes::*, tensor::*, unique_id::HasUniqueId}; + use crate::{nn::builders::*, optim::*, shapes::*, unique_id::HasUniqueId}; use crate::{nn::tests::SimpleUpdater, tests::TestDevice}; #[test] @@ -153,10 +143,9 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); + model.update(&mut g).unwrap(); assert_eq!( - &unused.ids, + &g.unused.ids, &[ *model[0].weight.id(), *model[0].bias.id(), @@ -169,13 +158,13 @@ mod tests { // weight gradient is present for i in 0..3 { - g.0.try_alloc_for(&model[i].weight).unwrap(); + g.grads.try_alloc_for(&model[i].weight).unwrap(); } - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); assert_eq!( - &unused.ids, + &g.unused.ids, &[ *model[0].bias.id(), *model[1].bias.id(), @@ -185,12 +174,11 @@ mod tests { // all gradients present for i in 0..3 { - g.0.try_alloc_for(&model[i].weight).unwrap(); - g.0.try_alloc_for(&model[i].bias).unwrap(); + g.grads.try_alloc_for(&model[i].weight).unwrap(); + g.grads.try_alloc_for(&model[i].bias).unwrap(); } - - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/residual.rs b/src/nn/residual.rs index f3e07764a..abdda0c06 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -1,6 +1,6 @@ -use crate::{optim::*, shapes::*, tensor::SplitTape, tensor_ops::Device}; +use crate::{shapes::*, tensor::visitors::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; use std::ops::Add; @@ -23,28 +23,19 @@ use std::ops::Add; #[derive(Debug, Clone, Default)] pub struct Residual(pub F); -impl, E: Dtype, F: GradientUpdate> GradientUpdate for Residual { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err> - where - U: ParamUpdater, - { - self.0.update(updater, unused) - } -} - -impl, E: Dtype, F: BuildOnDevice> BuildOnDevice for Residual { +impl> BuildOnDevice for Residual { type Built = Residual; } -impl, E: Dtype, F: BuildModule> BuildModule for Residual { +impl> BuildModule for Residual { fn try_build(device: &D) -> Result::Err> { Ok(Self(BuildModule::try_build(device)?)) } } -impl, E: Dtype, F: ResetParams> ResetParams for Residual { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.0.try_reset_params() +impl> TensorCollection for Residual { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -74,7 +65,7 @@ mod tests { use super::*; use crate::nn::DeviceBuildExt; use crate::tests::*; - use crate::{nn::builders::Linear, tensor::*, tensor_ops::*}; + use crate::{nn::builders::Linear, tensor_ops::*}; #[test] fn test_residual_reset() { diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 89dbccef1..a0ed81bea 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -1,6 +1,6 @@ -use crate::{optim::*, shapes::Dtype, tensor::*, tensor_ops::Device}; +use crate::{shapes::Dtype, tensor::visitors::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Splits input into multiple heads. `T` should be a tuple, /// where every element of the tuple accepts the same input type. @@ -22,28 +22,21 @@ use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice #[derive(Debug, Default, Clone)] pub struct SplitInto(pub T); -impl, D: Device, E: Dtype> GradientUpdate for SplitInto { - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.0.update(updater, unused) - } -} - -impl, D: Device, E: Dtype> BuildOnDevice for SplitInto { +impl, D: DeviceStorage, E: Dtype> BuildOnDevice for SplitInto { type Built = SplitInto; } -impl, D: Device, E: Dtype> BuildModule for SplitInto { +impl, D: DeviceStorage, E: Dtype> BuildModule for SplitInto { fn try_build(device: &D) -> Result::Err> { Ok(Self(BuildModule::try_build(device)?)) } } -impl, D: Device, E: Dtype> ResetParams for SplitInto { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.0.try_reset_params() +impl> TensorCollection + for SplitInto +{ + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -117,7 +110,7 @@ mod tests { use super::*; use crate::nn::DeviceBuildExt; - use crate::{gradients::*, shapes::*, tensor_ops::*}; + use crate::{gradients::*, optim::*, shapes::*, tensor_ops::*}; use crate::{ nn::{builders::Linear, tests::SimpleUpdater}, tests::*, @@ -250,10 +243,9 @@ mod tests { let mut g: SimpleUpdater = Default::default(); // no gradients present - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); + model.update(&mut g).unwrap(); assert_eq!( - &unused.ids, + &g.unused.ids, &[ *model.0 .0.weight.id(), *model.0 .0.bias.id(), @@ -263,13 +255,12 @@ mod tests { ); // weight gradient is present - g.0.try_alloc_for(&model.0 .0.weight).unwrap(); - g.0.try_alloc_for(&model.0 .0.bias).unwrap(); - g.0.try_alloc_for(&model.0 .1.weight).unwrap(); - g.0.try_alloc_for(&model.0 .1.bias).unwrap(); - - let mut unused = Default::default(); - model.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + g.grads.try_alloc_for(&model.0 .0.weight).unwrap(); + g.grads.try_alloc_for(&model.0 .0.bias).unwrap(); + g.grads.try_alloc_for(&model.0 .1.weight).unwrap(); + g.grads.try_alloc_for(&model.0 .1.bias).unwrap(); + g.clear_unused(); + model.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index 8e6a0ef46..3fafc92c9 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -3,8 +3,8 @@ use rand_distr::uniform::SampleUniform; use crate::{ nn::{modules::*, *}, - optim::{GradientUpdate, ParamUpdater, UnusedTensors}, shapes::Dtype, + tensor::visitors::*, tensor::{PutTape, SplitTape}, tensor_ops::Device, }; @@ -80,23 +80,12 @@ where } impl> - ResetParams for TransformerDecoder + TensorCollection for TransformerDecoder where E: Dtype + Float + SampleUniform, { - fn try_reset_params(&mut self) -> Result<(), D::Err> { - self.0.try_reset_params() - } -} - -impl> - GradientUpdate for TransformerDecoder -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err> - where - U: ParamUpdater, - { - self.0.update(updater, unused) + fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -193,36 +182,18 @@ where } } -impl> ResetParams +impl> TensorCollection for TransformerDecoderBlock where E: Dtype + Float + SampleUniform, { - fn try_reset_params(&mut self) -> Result<(), D::Err> { - self.self_attn.try_reset_params()?; - self.norm1.try_reset_params()?; - self.mh_attn.try_reset_params()?; - self.norm2.try_reset_params()?; - self.ff.try_reset_params()?; - self.norm3.try_reset_params()?; - Ok(()) - } -} - -impl> GradientUpdate - for TransformerDecoderBlock -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.self_attn.update(updater, unused)?; - self.norm1.update(updater, unused)?; - self.mh_attn.update(updater, unused)?; - self.norm2.update(updater, unused)?; - self.ff.update(updater, unused)?; - self.norm3.update(updater, unused)?; - Ok(()) + fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; + visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; + visitor.visit_module(|s| &s.mh_attn, |s| &mut s.mh_attn, "mh_attn")?; + visitor.visit_module(|s| &s.norm2, |s| &mut s.norm2, "norm2")?; + visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; + visitor.visit_module(|s| &s.norm3, |s| &mut s.norm3, "norm3") } } diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 407c7e371..2c2b683f7 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -3,8 +3,8 @@ use rand_distr::uniform::SampleUniform; use crate::{ nn::{modules::*, *}, - optim::{GradientUpdate, ParamUpdater, UnusedTensors}, shapes::Dtype, + tensor::visitors::*, tensor::{PutTape, SplitTape}, tensor_ops::Device, }; @@ -114,32 +114,16 @@ where } } -impl> ResetParams +impl> TensorCollection for TransformerEncoderBlock where E: Dtype + Float + SampleUniform, { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.self_attn.try_reset_params()?; - self.norm1.try_reset_params()?; - self.ff.try_reset_params()?; - self.norm2.try_reset_params()?; - Ok(()) - } -} - -impl> GradientUpdate - for TransformerEncoderBlock -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.self_attn.update(updater, unused)?; - self.norm1.update(updater, unused)?; - self.ff.update(updater, unused)?; - self.norm2.update(updater, unused)?; - Ok(()) + fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; + visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; + visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; + visitor.visit_module(|s| &s.norm2, |s| &mut s.norm2, "norm2") } } diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 204df2740..7f5af2e52 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -3,8 +3,8 @@ use rand_distr::uniform::SampleUniform; use crate::{ nn::{modules::*, *}, - optim::*, shapes::Dtype, + tensor::visitors::*, tensor::*, tensor_ops::*, }; @@ -79,34 +79,15 @@ where } impl> - ResetParams for MultiHeadAttention + TensorCollection for MultiHeadAttention where E: Dtype + Float + SampleUniform, { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.w_q.try_reset_params()?; - self.w_k.try_reset_params()?; - self.w_v.try_reset_params()?; - self.w_o.try_reset_params()?; - Ok(()) - } -} - -impl GradientUpdate - for MultiHeadAttention -where - E: Dtype, - D: DeviceStorage, -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.w_q.update(updater, unused)?; - self.w_k.update(updater, unused)?; - self.w_v.update(updater, unused)?; - self.w_o.update(updater, unused)?; - Ok(()) + fn iter_tensors>(visitor: &mut W) -> Result<(), ::Err> { + visitor.visit_module(|s| &s.w_q, |s| &mut s.w_q, "w_q")?; + visitor.visit_module(|s| &s.w_k, |s| &mut s.w_k, "w_k")?; + visitor.visit_module(|s| &s.w_v, |s| &mut s.w_v, "w_v")?; + visitor.visit_module(|s| &s.w_o, |s| &mut s.w_o, "w_o") } } @@ -281,7 +262,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{nn::tests::SimpleUpdater, tests::*}; + use crate::{nn::tests::SimpleUpdater, optim::*, tests::*}; #[test] fn test_mha_unbatched() { @@ -387,9 +368,11 @@ mod tests { let v: Tensor, TestDtype, _> = dev.sample_normal(); let y = mha.forward((q.trace(), k, v)); - let mut g = SimpleUpdater(y.mean().backward()); - let mut unused = Default::default(); - mha.update(&mut g, &mut unused).unwrap(); - assert!(unused.is_empty()); + let mut g = SimpleUpdater { + grads: y.mean().backward(), + unused: Default::default(), + }; + mha.update(&mut g).unwrap(); + assert!(g.unused.is_empty()); } } diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 32decdc08..059087d7e 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -9,13 +9,13 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - optim::{GradientUpdate, ParamUpdater, UnusedTensors}, shapes::Dtype, + tensor::visitors::*, tensor::{DeviceStorage, PutTape, SplitTape}, tensor_ops::Device, }; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ResetParams, ToDevice}; +use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug, Clone)] @@ -97,31 +97,14 @@ where } impl - ResetParams for Transformer + TensorCollection for Transformer where E: Dtype + Float + SampleUniform, D: Device, { - fn try_reset_params(&mut self) -> Result<(), ::Err> { - self.encoder.try_reset_params()?; - self.decoder.try_reset_params()?; - Ok(()) - } -} - -impl - GradientUpdate for Transformer -where - E: Dtype, - D: Device, -{ - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), ::Err> - where - U: ParamUpdater, - { - self.encoder.update(updater, unused)?; - self.decoder.update(updater, unused)?; - Ok(()) + fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + visitor.visit_module(|s| &s.encoder, |s| &mut s.encoder, "encoder")?; + visitor.visit_module(|s| &s.decoder, |s| &mut s.decoder, "decoder") } } @@ -187,6 +170,7 @@ mod tests { use super::*; use crate::{ nn::{tests::SimpleUpdater, DeviceBuildExt}, + optim::*, shapes::*, tensor::*, tensor_ops::*, @@ -221,9 +205,11 @@ mod tests { let out: Tensor, _, _, _> = t.forward_mut((src.trace(), tgt)); let g = out.mean().backward(); - let mut gs = SimpleUpdater(g); - let mut unused: UnusedTensors = Default::default(); - t.update(&mut gs, &mut unused).unwrap(); - assert!(unused.is_empty()); + let mut gs = SimpleUpdater { + grads: g, + unused: Default::default(), + }; + t.update(&mut gs).unwrap(); + assert!(gs.unused.is_empty()); } } diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 50fd2ee2b..5fc302749 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -8,10 +8,11 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, shapes::{Dtype, Shape}, + tensor::visitors::*, tensor::DeviceStorage, }; -use super::{GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, WeightDecay}; +use super::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; /// Configuration of hyperparameters for [Adam]. /// @@ -79,6 +80,8 @@ pub struct Adam { moment1: Gradients, moment2: Gradients, + unused: UnusedTensors, + marker: PhantomData<*const M>, } @@ -91,6 +94,7 @@ impl Adam { gradients: Default::default(), moment1: Default::default(), moment2: Default::default(), + unused: Default::default(), marker: PhantomData, } } @@ -108,15 +112,16 @@ pub(super) trait AdamKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, E: Dtype> ParamUpdater for Adam { - fn update_param( +impl, E: Dtype> VisitTensorMut for Adam { + fn visit( &mut self, - p: &mut crate::tensor::Tensor, - unused: &mut super::UnusedTensors, + _: alloc::string::String, + _: TensorOptions, + p: &mut crate::prelude::Tensor, ) -> Result<(), ::Err> { let g = self.gradients.remove(p); match g { - None => unused.add(p), + None => self.unused.add(p), Some(g) => { let m_t = self.moment1.get_or_alloc_mut(p)?; let v_t = self.moment2.get_or_alloc_mut(p)?; @@ -128,10 +133,7 @@ impl, E: Dtype> ParamUpdater for Adam< } } -impl, D: AdamKernel, E: Dtype> Optimizer for Adam -where - Self: ParamUpdater, -{ +impl, D: AdamKernel, E: Dtype> Optimizer for Adam { fn update( &mut self, module: &mut M, @@ -139,8 +141,9 @@ where ) -> Result<(), OptimizerUpdateError> { self.t = self.t.checked_add(1).unwrap(); self.gradients = gradients; - let mut unused = Default::default(); - match module.update(self, &mut unused) { + let result = module.update(self); + let unused = std::mem::take(&mut self.unused); + match result { Ok(_) => unused.into(), Err(e) => Err(OptimizerUpdateError::DeviceError(e)), } diff --git a/src/optim/mod.rs b/src/optim/mod.rs index 2c9a1f0cc..fe6daeb13 100644 --- a/src/optim/mod.rs +++ b/src/optim/mod.rs @@ -34,11 +34,11 @@ mod rmsprop; mod sgd; pub use adam::{Adam, AdamConfig}; -pub use optimizer::{GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, UnusedTensors}; +pub use optimizer::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors}; pub use optimizer::{Momentum, WeightDecay}; pub use rmsprop::{RMSprop, RMSpropConfig}; pub use sgd::{Sgd, SgdConfig}; pub mod prelude { - pub use super::{GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, UnusedTensors}; + pub use super::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors}; } diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index 6d8fe4ef9..f056c2c8d 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -1,7 +1,8 @@ use crate::{ gradients::Gradients, - shapes::{Dtype, Shape}, - tensor::{DeviceStorage, HasErr, Tensor}, + shapes::Dtype, + tensor::visitors::{RecursiveWalker, TensorCollection, VisitTensorMut}, + tensor::DeviceStorage, unique_id::{HasUniqueId, UniqueId}, }; @@ -92,38 +93,17 @@ pub trait Optimizer { } /// Represents something that can be updated with a [ParamUpdater]. -pub trait GradientUpdate { +pub trait GradientUpdate: TensorCollection { /// Updates self given the [ParamUpdater]. - fn update>( - &mut self, - updater: &mut U, - unused: &mut UnusedTensors, - ) -> Result<(), D::Err>; -} - -impl GradientUpdate for Tensor { - fn update>( - &mut self, - updater: &mut U, - unused: &mut UnusedTensors, - ) -> Result<(), ::Err> { - updater.update_param(self, unused) + fn update>(&mut self, updater: &mut V) -> Result<(), D::Err> { + Self::iter_tensors(&mut RecursiveWalker { + m: self, + f: updater, + path: &mut std::vec::Vec::new(), + }) } } - -/// Represents something that can update a tensor. -/// -/// See [crate::optim::Sgd] and [crate::optim::Adam] for examples on implementing this. -pub trait ParamUpdater { - /// Retrieves the data associated with `p` if there is any. - /// This can modify `self`, for instance if velocities are calculated - /// based on the associated data! - fn update_param( - &mut self, - p: &mut Tensor, - unused: &mut UnusedTensors, - ) -> Result<(), D::Err>; -} +impl> GradientUpdate for M {} /// Holds [UniqueId] of tensors that were missing gradients during /// [GradientUpdate::update()], and therefore are unused @@ -142,6 +122,10 @@ impl UnusedTensors { pub fn is_empty(&self) -> bool { self.ids.is_empty() } + + pub fn clear(&mut self) { + self.ids.clear(); + } } /// An error indicating that a parameter was not used in gradient diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index aba54c956..af1c09d64 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -8,12 +8,11 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, shapes::{Dtype, Shape}, - tensor::{DeviceStorage, OneFillStorage, Tensor}, + tensor::visitors::*, + tensor::*, }; -use super::{ - GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, UnusedTensors, WeightDecay, -}; +use super::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; /// Configuration of hyperparameters for [RMSprop]. #[derive(Debug, Clone, Copy)] @@ -89,6 +88,8 @@ pub struct RMSprop { grad_avg: Gradients, gradients: Gradients, + unused: UnusedTensors, + marker: PhantomData<*const M>, } @@ -102,6 +103,7 @@ impl RMSprop { square_avg: Default::default(), grad_avg: Default::default(), gradients: Default::default(), + unused: Default::default(), marker: PhantomData, } } @@ -119,15 +121,16 @@ pub(super) trait RMSpropKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl + OneFillStorage> ParamUpdater for RMSprop { - fn update_param( +impl + OneFillStorage> VisitTensorMut for RMSprop { + fn visit( &mut self, + _: alloc::string::String, + _: TensorOptions, p: &mut Tensor, - unused: &mut UnusedTensors, ) -> Result<(), ::Err> { let g = self.gradients.remove(p); match g { - None => unused.add(p), + None => self.unused.add(p), Some(g) => { let m = self.momentums.get_or_alloc_mut(p)?; let sa = self.square_avg.get_or_alloc_mut(p)?; @@ -144,9 +147,8 @@ impl + OneFillStorage> ParamUpdater fo } } -impl, D: RMSpropKernel, E: Dtype> Optimizer for RMSprop -where - Self: ParamUpdater, +impl, D: RMSpropKernel + OneFillStorage, E: Dtype> + Optimizer for RMSprop { fn update( &mut self, @@ -154,8 +156,9 @@ where gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.gradients = gradients; - let mut unused = Default::default(); - let r = match module.update(self, &mut unused) { + let result = module.update(self); + let unused = std::mem::take(&mut self.unused); + let r = match result { Ok(_) => unused.into(), Err(e) => Err(OptimizerUpdateError::DeviceError(e)), }; @@ -167,7 +170,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{shapes::*, tensor::*, tensor_ops::*, tests::*}; + use crate::{shapes::*, tensor_ops::*, tests::*}; fn test_matches_expected(cfg: RMSpropConfig, expected: [[TestDtype; 5]; 5]) { let dev: TestDevice = Default::default(); diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 91d391533..607b7aa8e 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -7,6 +7,7 @@ use std::marker::PhantomData; use crate::gradients::Gradients; use crate::shapes::{Dtype, Shape}; +use crate::tensor::visitors::*; use crate::tensor::{DeviceStorage, Tensor}; use super::optimizer::*; @@ -116,6 +117,8 @@ pub struct Sgd { velocity: Gradients, gradients: Gradients, + unused: UnusedTensors, + marker: PhantomData<*const M>, } @@ -126,6 +129,7 @@ impl Sgd { cfg, velocity: Default::default(), gradients: Default::default(), + unused: Default::default(), marker: PhantomData, } } @@ -141,15 +145,16 @@ pub(super) trait SgdKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, E: Dtype> ParamUpdater for Sgd { - fn update_param( +impl, M> VisitTensorMut for Sgd { + fn visit( &mut self, + _: alloc::string::String, + _: TensorOptions, p: &mut Tensor, - unused: &mut UnusedTensors, ) -> Result<(), D::Err> { let g = self.gradients.remove(p); match g { - None => unused.add(p), + None => self.unused.add(p), Some(g) => { let v = self.velocity.get_or_alloc_mut(p)?; p.device.update(&self.cfg, &mut p.storage, v, g)?; @@ -159,18 +164,16 @@ impl, E: Dtype> ParamUpdater for Sgd { } } -impl, D: SgdKernel, E: Dtype> Optimizer for Sgd -where - Self: ParamUpdater, -{ +impl, D: SgdKernel, E: Dtype> Optimizer for Sgd { fn update( &mut self, module: &mut M, gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.gradients = gradients; - let mut unused = Default::default(); - match module.update(self, &mut unused) { + let result = module.update(self); + let unused = std::mem::take(&mut self.unused); + match result { Ok(_) => unused.into(), Err(e) => Err(OptimizerUpdateError::DeviceError(e)), } diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 341e6126f..7567b34f4 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -111,15 +111,13 @@ //! zip archives. pub(crate) mod cpu; -mod tensor_impls; - #[cfg(feature = "cuda")] pub(crate) mod cuda; - #[cfg(feature = "numpy")] pub(crate) mod numpy; - pub(crate) mod storage_traits; +mod tensor_impls; +pub(crate) mod visitors; // TODO pub? pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage}; @@ -137,6 +135,8 @@ pub use tensor_impls::OnCuda; pub use tensor_impls::{OnCpu, OnDevice, PutTape, SplitTape, Tensor, ToDevice}; pub use tensor_impls::{Tensor0D, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D}; +pub use visitors::{NumParams, ResetParams}; + #[cfg(test)] mod tests { use super::*; diff --git a/src/tensor/visitors/base.rs b/src/tensor/visitors/base.rs new file mode 100644 index 000000000..9d741078f --- /dev/null +++ b/src/tensor/visitors/base.rs @@ -0,0 +1,236 @@ +use crate::{ + shapes::{Dtype, Shape}, + tensor::{DeviceStorage, Tensor}, +}; + +use std::{string::String, vec::Vec}; + +pub struct TensorOptions { + pub name: &'static str, + pub update: bool, + + #[allow(clippy::type_complexity)] + pub reset: fn(&mut Tensor) -> Result<(), D::Err>, +} + +impl TensorOptions { + pub fn named( + name: &'static str, + reset: fn(&mut Tensor) -> Result<(), D::Err>, + ) -> Self { + Self { + name, + update: true, + reset, + } + } + + pub fn no_grad( + name: &'static str, + reset: fn(&mut Tensor) -> Result<(), D::Err>, + ) -> Self { + Self { + name, + update: false, + reset, + } + } +} + +pub trait VisitTensorRef { + fn visit( + &mut self, + full_path: String, + opts: TensorOptions, + t: &Tensor, + ) -> Result<(), D::Err>; +} + +pub trait VisitTensorMut { + fn visit( + &mut self, + full_path: String, + opts: TensorOptions, + t: &mut Tensor, + ) -> Result<(), D::Err>; +} + +pub trait VisitTensorMutRef { + fn visit( + &mut self, + full_path: String, + opts: TensorOptions, + ts: (&mut Tensor, &Tensor), + ) -> Result<(), D::Err>; +} + +pub trait TensorCollection: Sized { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err>; +} + +impl TensorCollection for Tensor { + fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + visitor.visit_tensor( + |s| s, + |s| s, + TensorOptions { + name: "", + update: true, + reset: |_| Ok(()), + }, + ) + } +} + +pub trait ModuleWalker: Sized { + fn visit_module( + &mut self, + get_refs: GetRef, + get_muts: GetMut, + name: &str, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&T) -> &Field, + GetMut: FnMut(&mut T) -> &mut Field, + Field: TensorCollection; + + fn visit_tensor( + &mut self, + get_refs: GetRef, + get_muts: GetMut, + opts: TensorOptions, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&T) -> &Tensor, + GetMut: FnMut(&mut T) -> &mut Tensor; +} + +pub(crate) struct RecursiveWalker<'a, M, F> { + pub(crate) m: M, + pub(crate) f: &'a mut F, + pub(crate) path: &'a mut Vec, +} + +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> ModuleWalker + for RecursiveWalker<'a, &'a M, F> +{ + fn visit_module( + &mut self, + mut get_refs: GetRef, + _: GetMut, + name: &str, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&M) -> &Field, + GetMut: FnMut(&mut M) -> &mut Field, + Field: TensorCollection, + { + self.path.push(name.into()); + let mut walker = RecursiveWalker { + m: get_refs(self.m), + f: self.f, + path: self.path, + }; + Field::iter_tensors(&mut walker)?; + self.path.pop(); + Ok(()) + } + fn visit_tensor( + &mut self, + mut get_refs: GetRef, + _: GetMut, + opts: TensorOptions, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&M) -> &Tensor, + GetMut: FnMut(&mut M) -> &mut Tensor, + { + self.path.push(opts.name.into()); + self.f.visit(self.path.join("."), opts, get_refs(self.m))?; + self.path.pop(); + Ok(()) + } +} + +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> ModuleWalker + for RecursiveWalker<'a, &'a mut M, F> +{ + fn visit_module( + &mut self, + _: GetRef, + mut get_muts: GetMut, + name: &str, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&M) -> &Field, + GetMut: FnMut(&mut M) -> &mut Field, + Field: TensorCollection, + { + self.path.push(name.into()); + let mut walker = RecursiveWalker { + m: get_muts(self.m), + f: self.f, + path: self.path, + }; + Field::iter_tensors(&mut walker)?; + self.path.pop(); + Ok(()) + } + fn visit_tensor( + &mut self, + _: GetRef, + mut get_muts: GetMut, + opts: TensorOptions, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&M) -> &Tensor, + GetMut: FnMut(&mut M) -> &mut Tensor, + { + self.path.push(opts.name.into()); + self.f.visit(self.path.join("."), opts, get_muts(self.m))?; + self.path.pop(); + Ok(()) + } +} + +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef> ModuleWalker + for RecursiveWalker<'a, (&'a mut M, &'a M), F> +{ + fn visit_module( + &mut self, + mut get_refs: GetRef, + mut get_muts: GetMut, + name: &str, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&M) -> &Field, + GetMut: FnMut(&mut M) -> &mut Field, + Field: TensorCollection, + { + self.path.push(name.into()); + let mut walker = RecursiveWalker { + m: (get_muts(self.m.0), get_refs(self.m.1)), + f: self.f, + path: self.path, + }; + Field::iter_tensors(&mut walker)?; + self.path.pop(); + Ok(()) + } + fn visit_tensor( + &mut self, + mut get_refs: GetRef, + mut get_muts: GetMut, + opts: TensorOptions, + ) -> Result<(), D::Err> + where + GetRef: FnMut(&M) -> &Tensor, + GetMut: FnMut(&mut M) -> &mut Tensor, + { + self.path.push(opts.name.into()); + let tensors = (get_muts(self.m.0), get_refs(self.m.1)); + self.f.visit(self.path.join("."), opts, tensors)?; + self.path.pop(); + Ok(()) + } +} diff --git a/src/tensor/visitors/mod.rs b/src/tensor/visitors/mod.rs new file mode 100644 index 000000000..8fed4ab31 --- /dev/null +++ b/src/tensor/visitors/mod.rs @@ -0,0 +1,7 @@ +mod base; +mod num_params; +mod reset_params; + +pub use base::*; +pub use num_params::NumParams; +pub use reset_params::ResetParams; diff --git a/src/tensor/visitors/num_params.rs b/src/tensor/visitors/num_params.rs new file mode 100644 index 000000000..3861118bf --- /dev/null +++ b/src/tensor/visitors/num_params.rs @@ -0,0 +1,33 @@ +use super::base::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorRef}; + +use crate::{shapes::*, tensor::*}; + +use std::{string::String, vec::Vec}; + +struct Counter(usize); +impl VisitTensorRef for Counter { + fn visit( + &mut self, + _: String, + opts: TensorOptions, + t: &Tensor, + ) -> Result<(), D::Err> { + if opts.update { + self.0 += t.shape().num_elements(); + } + Ok(()) + } +} +pub trait NumParams: TensorCollection { + fn num_trainable_params(&self) -> usize { + let mut op = Counter(0); + Self::iter_tensors(&mut RecursiveWalker { + m: self, + f: &mut op, + path: &mut Vec::new(), + }) + .unwrap(); + op.0 + } +} +impl> NumParams for M {} diff --git a/src/tensor/visitors/reset_params.rs b/src/tensor/visitors/reset_params.rs new file mode 100644 index 000000000..41ac063db --- /dev/null +++ b/src/tensor/visitors/reset_params.rs @@ -0,0 +1,30 @@ +use super::base::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorMut}; + +use crate::{shapes::*, tensor::*}; + +use std::{string::String, vec::Vec}; + +struct Resetter; +impl VisitTensorMut for Resetter { + fn visit( + &mut self, + _: String, + opts: TensorOptions, + t: &mut Tensor, + ) -> Result<(), D::Err> { + (opts.reset)(t) + } +} +pub trait ResetParams: TensorCollection { + fn reset_params(&mut self) { + self.try_reset_params().unwrap(); + } + fn try_reset_params(&mut self) -> Result<(), D::Err> { + Self::iter_tensors(&mut RecursiveWalker { + m: self, + f: &mut Resetter, + path: &mut Vec::new(), + }) + } +} +impl> ResetParams for M {} From a5ea12b52570a4d791ee74ee870bc9f31ae23398 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 09:25:41 -0500 Subject: [PATCH 02/20] Using visitors for npz --- src/nn/add_into.rs | 2 +- src/nn/batchnorm2d.rs | 2 +- src/nn/conv.rs | 2 +- src/nn/embedding.rs | 2 +- src/nn/generalized_residual.rs | 2 +- src/nn/impl_module_for_tuples.rs | 2 +- src/nn/layer_norm.rs | 2 +- src/nn/linear.rs | 2 +- src/nn/mod.rs | 3 +- src/nn/module.rs | 2 +- src/nn/npz.rs | 256 ++++++++++++- src/nn/npz_impls.rs | 533 ---------------------------- src/nn/repeated.rs | 2 +- src/nn/residual.rs | 2 +- src/nn/split_into.rs | 2 +- src/nn/transformer/decoder.rs | 4 +- src/nn/transformer/encoder.rs | 2 +- src/nn/transformer/mha.rs | 2 +- src/nn/transformer/mod.rs | 2 +- src/optim/adam/mod.rs | 1 + src/optim/optimizer.rs | 2 +- src/optim/rmsprop/mod.rs | 1 + src/optim/sgd/mod.rs | 1 + src/tensor/visitors/base.rs | 33 +- src/tensor/visitors/num_params.rs | 1 + src/tensor/visitors/reset_params.rs | 1 + 26 files changed, 290 insertions(+), 576 deletions(-) delete mode 100644 src/nn/npz_impls.rs diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index d12d0fd69..ed825170b 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -34,7 +34,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Add } impl> TensorCollection for AddInto { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 3acd5c7b7..2084bea0d 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -184,7 +184,7 @@ impl> BuildModule for BatchNorm2D> TensorCollection for BatchNorm2D { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.scale, |s| &mut s.scale, diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 3f7a25381..be3b1af0f 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -59,7 +59,7 @@ where E: Dtype + Float + SampleUniform, D: Device, { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index 3703cba3c..5b01ec047 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -70,7 +70,7 @@ impl> TensorCollection for Embedding { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index 8a1e35322..8129eb095 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -45,7 +45,7 @@ impl, R: BuildModule> Bui impl, R: TensorCollection> TensorCollection for GeneralizedResidual { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.f, |s| &mut s.f, "f")?; visitor.visit_module(|s| &s.r, |s| &mut s.r, "r") } diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 68fea1b29..dbb4b8d0e 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -5,7 +5,7 @@ use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, OnDevice, ToD macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { impl),+> TensorCollection for ($($name,)+) { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { $(visitor.visit_module(|s| &s.$idx, |s| &mut s.$idx, &std::format!("{}", $idx))?;)+ Ok(()) } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index 8dd088859..9716b3997 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -56,7 +56,7 @@ impl> BuildModule for LayerNorm1D> TensorCollection for LayerNorm1D { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.gamma, |s| &mut s.gamma, diff --git a/src/nn/linear.rs b/src/nn/linear.rs index a4ba5faa8..4e7f4ee56 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -71,7 +71,7 @@ impl> TensorCollection for Linear { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 993a62548..1c3367d52 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -118,8 +118,6 @@ mod linear; mod module; #[cfg(feature = "numpy")] mod npz; -#[cfg(feature = "numpy")] -mod npz_impls; mod pool2d; mod pool_global; mod repeated; @@ -203,6 +201,7 @@ mod tests { } impl VisitTensorMut for SimpleUpdater { + type Err = D::Err; fn visit( &mut self, _: alloc::string::String, diff --git a/src/nn/module.rs b/src/nn/module.rs index 6f85e297b..6ab0d6fc5 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -75,7 +75,7 @@ impl, D: DeviceStorage, E: Dtype> BuildOn } impl TensorCollection for T { - fn iter_tensors>(_: &mut V) -> Result<(), ::Err> { + fn iter_tensors>(_: &mut V) -> Result<(), V::Err> { Ok(()) } } diff --git a/src/nn/npz.rs b/src/nn/npz.rs index 7896e516d..34d32513b 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -1,14 +1,25 @@ -use crate::tensor::numpy::NpzError; +use crate::{ + shapes::{Dtype, Shape}, + tensor::visitors::*, + tensor::{ + numpy::{NpzError, NumpyDtype}, + CopySlice, Tensor, + }, +}; use std::{ io::{BufReader, BufWriter, Read, Seek, Write}, path::Path, + string::String, +}; +use zip::{ + result::{ZipError, ZipResult}, + ZipArchive, ZipWriter, }; -use zip::{result::ZipResult, ZipArchive, ZipWriter}; /// Something that can be saved to a `.npz` (which is a `.zip`). /// /// All [super::Module]s in nn implement SaveToNpz, and the zips are formatted in a `.npz` fashion. -pub trait SaveToNpz { +pub trait SaveToNpz>: TensorCollection { /// Save this object into the `.npz` file determined located at `path`. /// /// Example: @@ -21,7 +32,7 @@ pub trait SaveToNpz { let f = std::fs::File::create(path)?; let f = BufWriter::new(f); let mut zip = ZipWriter::new(f); - self.write("", &mut zip)?; + self.write(&mut zip)?; zip.finish()?; Ok(()) } @@ -41,18 +52,23 @@ pub trait SaveToNpz { /// - `0.bias.npy` /// - `1.weight.npy` /// - `1.bias.npy` - fn write(&self, _filename_prefix: &str, _w: &mut ZipWriter) -> ZipResult<()> + fn write(&self, w: &mut ZipWriter) -> ZipResult<()> where W: Write + Seek, { - Ok(()) + Self::iter_tensors(&mut RecursiveWalker { + m: self, + f: w, + path: &mut std::vec::Vec::new(), + }) } } +impl, T: TensorCollection> SaveToNpz for T {} /// Something that can be loaded from a `.npz` file (which is a `zip` file). /// /// All [super::Module]s in nn implement LoadFromNpz, and the zips are formatted in a `.npz` fashion. -pub trait LoadFromNpz { +pub trait LoadFromNpz>: TensorCollection { /// Loads data from a `.npz` zip archive at the specified `path`. /// /// Example: @@ -65,7 +81,7 @@ pub trait LoadFromNpz { let f = std::fs::File::open(path)?; let f = BufReader::new(f); let mut zip = ZipArchive::new(f)?; - self.read("", &mut zip)?; + self.read(&mut zip)?; Ok(()) } @@ -81,10 +97,230 @@ pub trait LoadFromNpz { /// Will try to read data from the following files: /// - `0.weight.npy` /// - `0.bias.npy` - fn read(&mut self, _filename_prefix: &str, _r: &mut ZipArchive) -> Result<(), NpzError> + fn read(&mut self, r: &mut ZipArchive) -> Result<(), NpzError> where R: Read + Seek, { - Ok(()) + Self::iter_tensors(&mut RecursiveWalker { + m: self, + f: r, + path: &mut std::vec::Vec::new(), + }) + } +} +impl, T: TensorCollection> LoadFromNpz for T {} + +impl> VisitTensorRef + for zip::ZipWriter +{ + type Err = ZipError; + fn visit( + &mut self, + full_path: String, + _: TensorOptions, + t: &Tensor, + ) -> Result<(), Self::Err> { + t.write_to_npz(self, full_path) + } +} + +impl> VisitTensorMut + for zip::ZipArchive +{ + type Err = NpzError; + fn visit( + &mut self, + full_path: String, + _: TensorOptions, + t: &mut Tensor, + ) -> Result<(), Self::Err> { + t.read_from_npz(self, full_path) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + nn::{builders::*, *}, + shapes::*, + tensor::{numpy::NumpyDtype, AsArray, SampleTensor, Tensor}, + tensor_ops::Device, + tests::{TestDevice, TestDtype}, + }; + use rand_distr::{Distribution, Standard, StandardNormal}; + use tempfile::NamedTempFile; + + fn test_save_load, M: BuildOnDevice>( + dev: &D, + ) where + M::Built: Module> + SaveToNpz + LoadFromNpz, + >>::Output: AsArray, + StandardNormal: Distribution, + { + let x = dev.sample_normal(); + let file = NamedTempFile::new().expect("failed to create tempfile"); + + let saved: M::Built = M::build_on_device(dev); + let mut loaded: M::Built = M::build_on_device(dev); + + let y = saved.forward(x.clone()); + + assert_ne!(loaded.forward(x.clone()).array(), y.array()); + + saved.save(file.path()).expect(""); + loaded.load(file.path()).expect(""); + + assert_eq!(loaded.forward(x).array(), y.array()); + } + + #[test] + fn test_batchnorm2d_save_load() { + let dev: TestDevice = Default::default(); + type Model = BatchNorm2D<3>; + + let x: Tensor, TestDtype, _> = dev.sample_normal(); + let file = NamedTempFile::new().expect("failed to create tempfile"); + + let mut saved = Model::build_on_device(&dev); + let mut loaded = Model::build_on_device(&dev); + + saved.running_mean.fill_with_distr(Standard); + saved.running_var.fill_with_distr(Standard); + saved.scale.fill_with_distr(Standard); + saved.bias.fill_with_distr(Standard); + let y = saved.forward(x.clone()); + + assert_ne!(loaded.forward(x.clone()).array(), y.array()); + + saved.save(file.path()).expect(""); + loaded.load(file.path()).expect(""); + + assert_eq!(loaded.forward(x).array(), y.array()); + } + + #[cfg(feature = "nightly")] + #[test] + fn test_save_load_conv() { + type T = Conv2D<2, 4, 3>; + let dev: TestDevice = Default::default(); + test_save_load::, TestDtype, TestDevice, T>(&dev); + } + + #[test] + fn test_save_load_generalized_residual() { + let dev: TestDevice = Default::default(); + type T = GeneralizedResidual, Linear<5, 5>>; + test_save_load::, TestDtype, TestDevice, T>(&dev); + test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); + } + + #[test] + fn test_save_load_linear() { + let dev: TestDevice = Default::default(); + type T = Linear<5, 5>; + test_save_load::, TestDtype, TestDevice, T>(&dev); + test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); + } + + #[test] + fn test_save_load_tuple() { + let dev: TestDevice = Default::default(); + type T = ( + (Linear<1, 2>, ReLU, Linear<2, 3>), + (Dropout, Linear<3, 3>, Linear<3, 4>), + ); + test_save_load::, TestDtype, TestDevice, T>(&dev); + } + + #[test] + fn test_save_load_layer_norm() { + type M = LayerNorm1D<3>; + let dev: TestDevice = Default::default(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + + let mut saved = M::build_on_device(&dev); + let mut loaded = M::build_on_device(&dev); + + saved.gamma.fill_with_distr(Standard); + saved.beta.fill_with_distr(Standard); + let y = saved.forward(x.clone()); + + assert_ne!(loaded.forward(x.clone()).array(), y.array()); + + saved.save(file.path()).expect(""); + loaded.load(file.path()).expect(""); + + assert_eq!(loaded.forward(x).array(), y.array()); + } + + #[test] + fn test_save_load_repeated() { + type T = Repeated, 4>; + let dev: TestDevice = Default::default(); + test_save_load::, TestDtype, TestDevice, T>(&dev); + test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); + } + + #[test] + fn test_save_load_residual() { + type T = Residual>; + let dev: TestDevice = Default::default(); + test_save_load::, TestDtype, TestDevice, T>(&dev); + test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); + } + + #[cfg(feature = "nightly")] + #[test] + fn test_save_load_mha() { + let dev: TestDevice = Default::default(); + type Model = MultiHeadAttention<12, 4>; + + let saved = Model::build_on_device(&dev); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + saved.save(file.path()).expect(""); + + let mut loaded = Model::build_on_device(&dev); + + let q: Tensor, TestDtype, _> = dev.sample_normal(); + let k: Tensor, TestDtype, _> = dev.sample_normal(); + let v: Tensor, TestDtype, _> = dev.sample_normal(); + let y1 = saved.forward((q.clone(), k.clone(), v.clone())); + + let y2 = loaded.forward((q.clone(), k.clone(), v.clone())); + assert_ne!(y1.array(), y2.array()); + + loaded.load(file.path()).expect(""); + + let y2 = loaded.forward((q.clone(), k.clone(), v.clone())); + assert_eq!(y1.array(), y2.array()); + } + + #[cfg(feature = "nightly")] + #[test] + fn test_save_load_transformer() { + let dev: TestDevice = Default::default(); + type Model = Transformer<16, 4, 3, 4, 8>; + + let mut saved = Model::build_on_device(&dev); + + let file = NamedTempFile::new().expect("failed to create tempfile"); + saved.save(file.path()).expect(""); + + let mut loaded = Model::build_on_device(&dev); + + let src: Tensor, TestDtype, _> = dev.sample_normal(); + let tgt: Tensor, TestDtype, _> = dev.sample_normal(); + let y1 = saved.forward_mut((src.clone(), tgt.clone())); + + let y2 = loaded.forward_mut((src.clone(), tgt.clone())); + assert_ne!(y1.array(), y2.array()); + + loaded.load(file.path()).expect(""); + + let y2 = loaded.forward_mut((src.clone(), tgt.clone())); + assert_eq!(y1.array(), y2.array()); } } diff --git a/src/nn/npz_impls.rs b/src/nn/npz_impls.rs deleted file mode 100644 index 6d48b6594..000000000 --- a/src/nn/npz_impls.rs +++ /dev/null @@ -1,533 +0,0 @@ -use super::{ - modules::*, - npz::{LoadFromNpz, SaveToNpz}, - *, -}; -use crate::{ - shapes::Dtype, - tensor::{ - numpy::{NpzError, NumpyDtype}, - CopySlice, - }, -}; -use std::format; -use std::io::{Read, Seek, Write}; -use zip::{result::ZipResult, ZipArchive, ZipWriter}; - -impl SaveToNpz for T {} -impl LoadFromNpz for T {} - -impl> SaveToNpz for BatchNorm2D { - fn write(&self, p: &str, w: &mut zip::ZipWriter) -> ZipResult<()> { - self.scale.write_to_npz(w, format!("{p}scale.npy"))?; - self.bias.write_to_npz(w, format!("{p}bias.npy"))?; - self.running_mean - .write_to_npz(w, format!("{p}running_mean.npy"))?; - self.running_var - .write_to_npz(w, format!("{p}running_var.npy"))?; - Ok(()) - } -} - -impl> LoadFromNpz for BatchNorm2D { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.scale.read_from_npz(r, format!("{p}scale.npy"))?; - self.bias.read_from_npz(r, format!("{p}bias.npy"))?; - self.running_mean - .read_from_npz(r, format!("{p}running_mean.npy"))?; - self.running_var - .read_from_npz(r, format!("{p}running_var.npy"))?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl< - const I: usize, - const O: usize, - const K: usize, - const S: usize, - const P: usize, - E: Dtype + NumpyDtype, - D: CopySlice, - > SaveToNpz for Conv2D -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.weight.write_to_npz(w, format!("{p}weight.npy"))?; - self.bias.write_to_npz(w, format!("{p}bias.npy"))?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl< - const I: usize, - const O: usize, - const K: usize, - const S: usize, - const P: usize, - E: Dtype + NumpyDtype, - D: CopySlice, - > LoadFromNpz for Conv2D -{ - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.weight.read_from_npz(r, format!("{p}weight.npy"))?; - self.bias.read_from_npz(r, format!("{p}bias.npy"))?; - Ok(()) - } -} - -impl SaveToNpz for GeneralizedResidual { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.f.write(&format!("{p}.f"), w)?; - self.r.write(&format!("{p}.r"), w) - } -} - -impl LoadFromNpz for GeneralizedResidual { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.f.read(&format!("{p}.f"), r)?; - self.r.read(&format!("{p}.r"), r) - } -} - -impl> SaveToNpz for LayerNorm1D { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.gamma.write_to_npz(w, format!("{p}gamma.npy"))?; - self.beta.write_to_npz(w, format!("{p}beta.npy"))?; - Ok(()) - } -} - -impl> LoadFromNpz for LayerNorm1D { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.gamma.read_from_npz(r, format!("{p}gamma.npy"))?; - self.beta.read_from_npz(r, format!("{p}beta.npy"))?; - Ok(()) - } -} - -impl> SaveToNpz - for Linear -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.weight.write_to_npz(w, format!("{p}weight.npy"))?; - self.bias.write_to_npz(w, format!("{p}bias.npy"))?; - Ok(()) - } -} - -impl> LoadFromNpz - for Linear -{ - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.weight.read_from_npz(r, format!("{p}weight.npy"))?; - self.bias.read_from_npz(r, format!("{p}bias.npy"))?; - Ok(()) - } -} - -macro_rules! tuple_npz_impl { - ([$($name:ident),+], [$($idx:tt),+]) => { -impl<$($name: SaveToNpz),+> SaveToNpz for ($($name,)+) { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - $(self.$idx.write(&format!("{p}{}.", $idx), w)?;)+ - Ok(()) - } -} - -impl<$($name: LoadFromNpz),+> LoadFromNpz for ($($name,)+) { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - $(self.$idx.read(&format!("{p}{}.", $idx), r)?;)+ - Ok(()) - } -} - }; -} - -tuple_npz_impl!([A, B], [0, 1]); -tuple_npz_impl!([A, B, C], [0, 1, 2]); -tuple_npz_impl!([A, B, C, D], [0, 1, 2, 3]); -tuple_npz_impl!([A, B, C, D, E], [0, 1, 2, 3, 4]); -tuple_npz_impl!([A, B, C, D, E, F], [0, 1, 2, 3, 4, 5]); - -impl SaveToNpz for Repeated { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - for i in 0..N { - self.modules[i].write(&format!("{p}{i}."), w)?; - } - Ok(()) - } -} - -impl LoadFromNpz for Repeated { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - for i in 0..N { - self.modules[i].read(&format!("{p}{i}."), r)?; - } - Ok(()) - } -} - -impl SaveToNpz for Residual { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.0.write(&format!("{p}.0"), w) - } -} - -impl LoadFromNpz for Residual { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.0.read(&format!("{p}.0"), r) - } -} - -impl SaveToNpz for SplitInto { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.0.write(&format!("{p}.0"), w) - } -} - -impl LoadFromNpz for SplitInto { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.0.read(&format!("{p}.0"), r) - } -} - -impl SaveToNpz for AddInto { - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.0.write(&format!("{p}.0"), w) - } -} - -impl LoadFromNpz for AddInto { - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.0.read(&format!("{p}.0"), r) - } -} - -#[cfg(feature = "nightly")] -impl> SaveToNpz - for TransformerDecoder -where - E: Dtype + NumpyDtype, -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.0.write(&format!("{p}.0"), w) - } -} - -#[cfg(feature = "nightly")] -impl> - SaveToNpz for TransformerDecoderBlock -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.self_attn.write(&format!("{p}self_attn."), w)?; - self.norm1.write(&format!("{p}norm1."), w)?; - self.mh_attn.write(&format!("{p}mh_attn."), w)?; - self.norm2.write(&format!("{p}norm2."), w)?; - self.ff.0 .0.write(&format!("{p}linear1."), w)?; - self.ff.0 .2.write(&format!("{p}linear2."), w)?; - self.norm3.write(&format!("{p}norm3."), w)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl> - LoadFromNpz for TransformerDecoderBlock -{ - fn read(&mut self, pre: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.self_attn.read(&format!("{pre}self_attn."), r)?; - self.norm1.read(&format!("{pre}norm1."), r)?; - self.mh_attn.read(&format!("{pre}mh_attn."), r)?; - self.norm2.read(&format!("{pre}norm2."), r)?; - self.ff.0 .0.read(&format!("{pre}linear1."), r)?; - self.ff.0 .2.read(&format!("{pre}linear2."), r)?; - self.norm3.read(&format!("{pre}norm3."), r)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl> LoadFromNpz - for TransformerDecoder -where - E: Dtype + NumpyDtype, -{ - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.0.read(&format!("{p}.0"), r) - } -} - -#[cfg(feature = "nightly")] -impl> - SaveToNpz for TransformerEncoderBlock -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.self_attn.write(&format!("{p}self_attn."), w)?; - self.norm1.write(&format!("{p}norm1."), w)?; - self.norm2.write(&format!("{p}norm2."), w)?; - self.ff.0 .0.write(&format!("{p}linear1."), w)?; - self.ff.0 .2.write(&format!("{p}linear2."), w)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl> - LoadFromNpz for TransformerEncoderBlock -{ - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.self_attn.read(&format!("{p}self_attn."), r)?; - self.norm1.read(&format!("{p}norm1."), r)?; - self.norm2.read(&format!("{p}norm2."), r)?; - self.ff.0 .0.read(&format!("{p}linear1."), r)?; - self.ff.0 .2.read(&format!("{p}linear2."), r)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl> SaveToNpz - for MultiHeadAttention -where - E: Dtype + NumpyDtype, -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.w_q.write(&format!("{p}w_q."), w)?; - self.w_k.write(&format!("{p}w_k."), w)?; - self.w_v.write(&format!("{p}w_v."), w)?; - self.w_o.write(&format!("{p}w_o."), w)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl> LoadFromNpz - for MultiHeadAttention -where - E: Dtype + NumpyDtype, -{ - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.w_q.read(&format!("{p}w_q."), r)?; - self.w_k.read(&format!("{p}w_k."), r)?; - self.w_v.read(&format!("{p}w_v."), r)?; - self.w_o.read(&format!("{p}w_o."), r)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl SaveToNpz - for Transformer -where - E: Dtype + NumpyDtype, - D: CopySlice, -{ - fn write(&self, p: &str, w: &mut ZipWriter) -> ZipResult<()> { - self.encoder.write(&format!("{p}encoder."), w)?; - self.decoder.write(&format!("{p}decoder."), w)?; - Ok(()) - } -} - -#[cfg(feature = "nightly")] -impl - LoadFromNpz for Transformer -where - E: Dtype + NumpyDtype, - D: CopySlice, -{ - fn read(&mut self, p: &str, r: &mut ZipArchive) -> Result<(), NpzError> { - self.encoder.read(&format!("{p}encoder."), r)?; - self.decoder.read(&format!("{p}decoder."), r)?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - nn::{builders::*, *}, - shapes::*, - tensor::{numpy::NumpyDtype, AsArray, SampleTensor, Tensor}, - tensor_ops::Device, - tests::{TestDevice, TestDtype}, - }; - use rand_distr::{Distribution, Standard, StandardNormal}; - use tempfile::NamedTempFile; - - fn test_save_load, M: BuildOnDevice>( - dev: &D, - ) where - M::Built: Module> + SaveToNpz + LoadFromNpz, - >>::Output: AsArray, - StandardNormal: Distribution, - { - let x = dev.sample_normal(); - let file = NamedTempFile::new().expect("failed to create tempfile"); - - let saved: M::Built = M::build_on_device(dev); - let mut loaded: M::Built = M::build_on_device(dev); - - let y = saved.forward(x.clone()); - - assert_ne!(loaded.forward(x.clone()).array(), y.array()); - - saved.save(file.path()).expect(""); - loaded.load(file.path()).expect(""); - - assert_eq!(loaded.forward(x).array(), y.array()); - } - - #[test] - fn test_batchnorm2d_save_load() { - let dev: TestDevice = Default::default(); - type Model = BatchNorm2D<3>; - - let x: Tensor, TestDtype, _> = dev.sample_normal(); - let file = NamedTempFile::new().expect("failed to create tempfile"); - - let mut saved = Model::build_on_device(&dev); - let mut loaded = Model::build_on_device(&dev); - - saved.running_mean.fill_with_distr(Standard); - saved.running_var.fill_with_distr(Standard); - saved.scale.fill_with_distr(Standard); - saved.bias.fill_with_distr(Standard); - let y = saved.forward(x.clone()); - - assert_ne!(loaded.forward(x.clone()).array(), y.array()); - - saved.save(file.path()).expect(""); - loaded.load(file.path()).expect(""); - - assert_eq!(loaded.forward(x).array(), y.array()); - } - - #[cfg(feature = "nightly")] - #[test] - fn test_save_load_conv() { - type T = Conv2D<2, 4, 3>; - let dev: TestDevice = Default::default(); - test_save_load::, TestDtype, TestDevice, T>(&dev); - } - - #[test] - fn test_save_load_generalized_residual() { - let dev: TestDevice = Default::default(); - type T = GeneralizedResidual, Linear<5, 5>>; - test_save_load::, TestDtype, TestDevice, T>(&dev); - test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); - } - - #[test] - fn test_save_load_linear() { - let dev: TestDevice = Default::default(); - type T = Linear<5, 5>; - test_save_load::, TestDtype, TestDevice, T>(&dev); - test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); - } - - #[test] - fn test_save_load_tuple() { - let dev: TestDevice = Default::default(); - type T = ( - (Linear<1, 2>, ReLU, Linear<2, 3>), - (Dropout, Linear<3, 3>, Linear<3, 4>), - ); - test_save_load::, TestDtype, TestDevice, T>(&dev); - } - - #[test] - fn test_save_load_layer_norm() { - type M = LayerNorm1D<3>; - let dev: TestDevice = Default::default(); - let x: Tensor, TestDtype, _> = dev.sample_normal(); - - let file = NamedTempFile::new().expect("failed to create tempfile"); - - let mut saved = M::build_on_device(&dev); - let mut loaded = M::build_on_device(&dev); - - saved.gamma.fill_with_distr(Standard); - saved.beta.fill_with_distr(Standard); - let y = saved.forward(x.clone()); - - assert_ne!(loaded.forward(x.clone()).array(), y.array()); - - saved.save(file.path()).expect(""); - loaded.load(file.path()).expect(""); - - assert_eq!(loaded.forward(x).array(), y.array()); - } - - #[test] - fn test_save_load_repeated() { - type T = Repeated, 4>; - let dev: TestDevice = Default::default(); - test_save_load::, TestDtype, TestDevice, T>(&dev); - test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); - } - - #[test] - fn test_save_load_residual() { - type T = Residual>; - let dev: TestDevice = Default::default(); - test_save_load::, TestDtype, TestDevice, T>(&dev); - test_save_load::, TestDtype, TestDevice, (T, T)>(&dev); - } - - #[cfg(feature = "nightly")] - #[test] - fn test_save_load_mha() { - let dev: TestDevice = Default::default(); - type Model = MultiHeadAttention<12, 4>; - - let saved = Model::build_on_device(&dev); - - let file = NamedTempFile::new().expect("failed to create tempfile"); - saved.save(file.path()).expect(""); - - let mut loaded = Model::build_on_device(&dev); - - let q: Tensor, TestDtype, _> = dev.sample_normal(); - let k: Tensor, TestDtype, _> = dev.sample_normal(); - let v: Tensor, TestDtype, _> = dev.sample_normal(); - let y1 = saved.forward((q.clone(), k.clone(), v.clone())); - - let y2 = loaded.forward((q.clone(), k.clone(), v.clone())); - assert_ne!(y1.array(), y2.array()); - - loaded.load(file.path()).expect(""); - - let y2 = loaded.forward((q.clone(), k.clone(), v.clone())); - assert_eq!(y1.array(), y2.array()); - } - - #[cfg(feature = "nightly")] - #[test] - fn test_save_load_transformer() { - let dev: TestDevice = Default::default(); - type Model = Transformer<16, 4, 3, 4, 8>; - - let mut saved = Model::build_on_device(&dev); - - let file = NamedTempFile::new().expect("failed to create tempfile"); - saved.save(file.path()).expect(""); - - let mut loaded = Model::build_on_device(&dev); - - let src: Tensor, TestDtype, _> = dev.sample_normal(); - let tgt: Tensor, TestDtype, _> = dev.sample_normal(); - let y1 = saved.forward_mut((src.clone(), tgt.clone())); - - let y2 = loaded.forward_mut((src.clone(), tgt.clone())); - assert_ne!(y1.array(), y2.array()); - - loaded.load(file.path()).expect(""); - - let y2 = loaded.forward_mut((src.clone(), tgt.clone())); - assert_eq!(y1.array(), y2.array()); - } -} diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index a23927338..e7a1d8133 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -42,7 +42,7 @@ impl, const N: usize> BuildModu impl, const N: usize> TensorCollection for Repeated { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { for i in 0..N { visitor.visit_module( |s| &s.modules[i], diff --git a/src/nn/residual.rs b/src/nn/residual.rs index abdda0c06..a6ef128cd 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -34,7 +34,7 @@ impl> BuildModule for Res } impl> TensorCollection for Residual { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index a0ed81bea..521be82c8 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -35,7 +35,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Spl impl> TensorCollection for SplitInto { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index 3fafc92c9..5adc736d5 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -84,7 +84,7 @@ impl>(visitor: &mut V) -> Result<(), ::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -187,7 +187,7 @@ impl> TensorColl where E: Dtype + Float + SampleUniform, { - fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; visitor.visit_module(|s| &s.mh_attn, |s| &mut s.mh_attn, "mh_attn")?; diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 2c2b683f7..520a09bed 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -119,7 +119,7 @@ impl> TensorColl where E: Dtype + Float + SampleUniform, { - fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 7f5af2e52..6a82e3dba 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -83,7 +83,7 @@ impl>(visitor: &mut W) -> Result<(), ::Err> { + fn iter_tensors>(visitor: &mut W) -> Result<(), W::Err> { visitor.visit_module(|s| &s.w_q, |s| &mut s.w_q, "w_q")?; visitor.visit_module(|s| &s.w_k, |s| &mut s.w_k, "w_k")?; visitor.visit_module(|s| &s.w_v, |s| &mut s.w_v, "w_v")?; diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 059087d7e..046a6a867 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -102,7 +102,7 @@ where E: Dtype + Float + SampleUniform, D: Device, { - fn iter_tensors>(visitor: &mut V) -> Result<(), ::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.encoder, |s| &mut s.encoder, "encoder")?; visitor.visit_module(|s| &s.decoder, |s| &mut s.decoder, "decoder") } diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 5fc302749..91431d6ff 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -113,6 +113,7 @@ pub(super) trait AdamKernel: DeviceStorage { } impl, E: Dtype> VisitTensorMut for Adam { + type Err = D::Err; fn visit( &mut self, _: alloc::string::String, diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index f056c2c8d..fb2c723ab 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -95,7 +95,7 @@ pub trait Optimizer { /// Represents something that can be updated with a [ParamUpdater]. pub trait GradientUpdate: TensorCollection { /// Updates self given the [ParamUpdater]. - fn update>(&mut self, updater: &mut V) -> Result<(), D::Err> { + fn update>(&mut self, updater: &mut V) -> Result<(), V::Err> { Self::iter_tensors(&mut RecursiveWalker { m: self, f: updater, diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index af1c09d64..0d0868029 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -122,6 +122,7 @@ pub(super) trait RMSpropKernel: DeviceStorage { } impl + OneFillStorage> VisitTensorMut for RMSprop { + type Err = D::Err; fn visit( &mut self, _: alloc::string::String, diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 607b7aa8e..9903b7562 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -146,6 +146,7 @@ pub(super) trait SgdKernel: DeviceStorage { } impl, M> VisitTensorMut for Sgd { + type Err = D::Err; fn visit( &mut self, _: alloc::string::String, diff --git a/src/tensor/visitors/base.rs b/src/tensor/visitors/base.rs index 9d741078f..417afaeff 100644 --- a/src/tensor/visitors/base.rs +++ b/src/tensor/visitors/base.rs @@ -38,38 +38,41 @@ impl TensorOptions { } pub trait VisitTensorRef { + type Err; fn visit( &mut self, full_path: String, opts: TensorOptions, t: &Tensor, - ) -> Result<(), D::Err>; + ) -> Result<(), Self::Err>; } pub trait VisitTensorMut { + type Err; fn visit( &mut self, full_path: String, opts: TensorOptions, t: &mut Tensor, - ) -> Result<(), D::Err>; + ) -> Result<(), Self::Err>; } pub trait VisitTensorMutRef { + type Err; fn visit( &mut self, full_path: String, opts: TensorOptions, ts: (&mut Tensor, &Tensor), - ) -> Result<(), D::Err>; + ) -> Result<(), Self::Err>; } pub trait TensorCollection: Sized { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err>; + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; } impl TensorCollection for Tensor { - fn iter_tensors>(visitor: &mut V) -> Result<(), D::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| s, |s| s, @@ -83,12 +86,13 @@ impl TensorCollection for Tensor: Sized { + type Err; fn visit_module( &mut self, get_refs: GetRef, get_muts: GetMut, name: &str, - ) -> Result<(), D::Err> + ) -> Result<(), Self::Err> where GetRef: FnMut(&T) -> &Field, GetMut: FnMut(&mut T) -> &mut Field, @@ -99,7 +103,7 @@ pub trait ModuleWalker: Sized { get_refs: GetRef, get_muts: GetMut, opts: TensorOptions, - ) -> Result<(), D::Err> + ) -> Result<(), Self::Err> where GetRef: FnMut(&T) -> &Tensor, GetMut: FnMut(&mut T) -> &mut Tensor; @@ -114,12 +118,13 @@ pub(crate) struct RecursiveWalker<'a, M, F> { impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> ModuleWalker for RecursiveWalker<'a, &'a M, F> { + type Err = F::Err; fn visit_module( &mut self, mut get_refs: GetRef, _: GetMut, name: &str, - ) -> Result<(), D::Err> + ) -> Result<(), Self::Err> where GetRef: FnMut(&M) -> &Field, GetMut: FnMut(&mut M) -> &mut Field, @@ -140,7 +145,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> ModuleWalker, - ) -> Result<(), D::Err> + ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, @@ -155,12 +160,13 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> ModuleWalker> ModuleWalker for RecursiveWalker<'a, &'a mut M, F> { + type Err = F::Err; fn visit_module( &mut self, _: GetRef, mut get_muts: GetMut, name: &str, - ) -> Result<(), D::Err> + ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Field, GetMut: FnMut(&mut M) -> &mut Field, @@ -181,7 +187,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> ModuleWalker, - ) -> Result<(), D::Err> + ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, @@ -196,12 +202,13 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> ModuleWalker> ModuleWalker for RecursiveWalker<'a, (&'a mut M, &'a M), F> { + type Err = F::Err; fn visit_module( &mut self, mut get_refs: GetRef, mut get_muts: GetMut, name: &str, - ) -> Result<(), D::Err> + ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Field, GetMut: FnMut(&mut M) -> &mut Field, @@ -222,7 +229,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef> ModuleWalker mut get_refs: GetRef, mut get_muts: GetMut, opts: TensorOptions, - ) -> Result<(), D::Err> + ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, diff --git a/src/tensor/visitors/num_params.rs b/src/tensor/visitors/num_params.rs index 3861118bf..397b80ffb 100644 --- a/src/tensor/visitors/num_params.rs +++ b/src/tensor/visitors/num_params.rs @@ -6,6 +6,7 @@ use std::{string::String, vec::Vec}; struct Counter(usize); impl VisitTensorRef for Counter { + type Err = D::Err; fn visit( &mut self, _: String, diff --git a/src/tensor/visitors/reset_params.rs b/src/tensor/visitors/reset_params.rs index 41ac063db..209fde007 100644 --- a/src/tensor/visitors/reset_params.rs +++ b/src/tensor/visitors/reset_params.rs @@ -6,6 +6,7 @@ use std::{string::String, vec::Vec}; struct Resetter; impl VisitTensorMut for Resetter { + type Err = D::Err; fn visit( &mut self, _: String, From c32c2eb1b737a989b21d54ba0647e48a5444c9af Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 09:46:47 -0500 Subject: [PATCH 03/20] Moving missing grads test to optimizers --- src/nn/add_into.rs | 33 +------------------- src/nn/embedding.rs | 25 +-------------- src/nn/impl_module_for_tuples.rs | 45 --------------------------- src/nn/layer_norm.rs | 28 +---------------- src/nn/linear.rs | 32 +------------------- src/nn/mod.rs | 35 --------------------- src/nn/repeated.rs | 52 ++------------------------------ src/nn/split_into.rs | 37 ++--------------------- src/nn/transformer/mha.rs | 11 +++---- src/nn/transformer/mod.rs | 10 ++---- src/optim/adam/mod.rs | 16 ++++++++-- src/optim/mod.rs | 4 +-- src/optim/optimizer.rs | 14 --------- src/optim/rmsprop/mod.rs | 16 ++++++++-- src/optim/sgd/mod.rs | 14 ++++++++- src/tensor/visitors/base.rs | 4 +-- 16 files changed, 60 insertions(+), 316 deletions(-) diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index ed825170b..0c9a08c59 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -93,11 +93,9 @@ mod tests { use super::*; use crate::{ gradients::OwnedTape, - nn::{builders::*, tests::SimpleUpdater, DeviceBuildExt}, - optim::*, + nn::{builders::*, DeviceBuildExt}, shapes::*, tests::{TestDevice, TestDtype}, - unique_id::HasUniqueId, }; type TestAddIntoCpu = AddInto<(Linear<2, 5>, Linear<3, 5>)>; @@ -212,35 +210,6 @@ mod tests { )); } - #[test] - fn test_missing_gradients() { - let dev: TestDevice = Default::default(); - type Model = AddInto<(Linear<5, 3>, Linear<5, 3>)>; - let mut model = dev.build_module::(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!( - &g.unused.ids, - &[ - *model.0 .0.weight.id(), - *model.0 .0.bias.id(), - *model.0 .1.weight.id(), - *model.0 .1.bias.id() - ] - ); - - // weight gradient is present - g.grads.try_alloc_for(&model.0 .0.weight).unwrap(); - g.grads.try_alloc_for(&model.0 .0.bias).unwrap(); - g.grads.try_alloc_for(&model.0 .1.weight).unwrap(); - g.grads.try_alloc_for(&model.0 .1.bias).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } - #[test] fn longer_network() { let dev: TestDevice = Default::default(); diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index 5b01ec047..e288254a5 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -123,12 +123,7 @@ impl, D2: Device, TestDtype>(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!(&g.unused.ids, &[*model.weight.id()]); - - // weight gradient is present - g.grads.try_alloc_for(&model.weight).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } } diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index dbb4b8d0e..c9d1c9500 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -96,8 +96,6 @@ tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5], M6, [M5, M4, M3, M2, M #[cfg(test)] mod tests { use super::*; - use crate::nn::tests::SimpleUpdater; - use crate::unique_id::HasUniqueId; use crate::{ nn::{builders::*, *}, optim::*, @@ -245,47 +243,4 @@ mod tests { let y = model.forward(dev.zeros()); assert_eq!(y.array(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); } - - #[test] - fn test_tuple_missing_gradients() { - let dev: TestDevice = Default::default(); - type Model = (Linear<5, 3>, Linear<5, 3>, Linear<5, 3>); - let mut model = dev.build_module::(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!( - &g.unused.ids, - &[ - *model.0.weight.id(), - *model.0.bias.id(), - *model.1.weight.id(), - *model.1.bias.id(), - *model.2.weight.id(), - *model.2.bias.id(), - ] - ); - - // weight gradient is present - g.grads.try_alloc_for(&model.0.weight).unwrap(); - g.grads.try_alloc_for(&model.1.weight).unwrap(); - g.grads.try_alloc_for(&model.2.weight).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert_eq!( - &g.unused.ids, - &[*model.0.bias.id(), *model.1.bias.id(), *model.2.bias.id(),] - ); - - g.grads.try_alloc_for(&model.0.weight).unwrap(); - g.grads.try_alloc_for(&model.0.bias).unwrap(); - g.grads.try_alloc_for(&model.1.weight).unwrap(); - g.grads.try_alloc_for(&model.1.bias).unwrap(); - g.grads.try_alloc_for(&model.2.weight).unwrap(); - g.grads.try_alloc_for(&model.2.bias).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index 9716b3997..f0966deb6 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -118,9 +118,8 @@ impl, T: Tape> #[cfg(test)] mod tests { use super::*; - use crate::nn::tests::SimpleUpdater; use crate::nn::{DeviceBuildExt, ModuleMut}; - use crate::{optim::*, tests::*, unique_id::HasUniqueId}; + use crate::tests::*; #[test] fn test_layer_norm_reset() { @@ -181,29 +180,4 @@ mod tests { ); assert_close(&g.get(&m.beta).array(), &[0.2; 5]); } - - #[test] - fn test_layer_norm_missing_gradients() { - let dev: TestDevice = Default::default(); - - let mut model = dev.build_module::, TestDtype>(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!(&g.unused.ids, &[*model.gamma.id(), *model.beta.id()]); - - // weight gradient is present - g.grads.try_alloc_for(&model.gamma).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert_eq!(&g.unused.ids, &[*model.beta.id()]); - - // all gradients present - g.grads.try_alloc_for(&model.gamma).unwrap(); - g.grads.try_alloc_for(&model.beta).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } } diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 4e7f4ee56..0c84ec8bc 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -153,12 +153,7 @@ impl<'a, B: Dim, S: Dim, const M: usize, E: Dtype, D: Device, T: Tape> #[cfg(test)] mod tests { use super::*; - use crate::{ - nn::{tests::SimpleUpdater, DeviceBuildExt}, - optim::*, - tests::*, - unique_id::HasUniqueId, - }; + use crate::{nn::DeviceBuildExt, tests::*}; const W: [[TestDtype; 5]; 2] = [ [-0.3458893, -0.30371523, -0.3712057, 0.14303583, -0.0268966], @@ -287,29 +282,4 @@ mod tests { ); assert_close(&g.get(&model.bias).array(), &[0.40265593, -0.2874091]); } - - #[test] - fn test_linear_missing_gradients() { - let dev: TestDevice = Default::default(); - - let mut model = dev.build_module::, TestDtype>(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!(&g.unused.ids, &[*model.weight.id(), *model.bias.id()]); - - // weight gradient is present - g.grads.try_alloc_for(&model.weight).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert_eq!(&g.unused.ids, &[*model.bias.id()]); - - // both gradients present - g.grads.try_alloc_for(&model.weight).unwrap(); - g.grads.try_alloc_for(&model.bias).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } } diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 1c3367d52..922ed69c4 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -180,38 +180,3 @@ pub mod builders { #[cfg(feature = "nightly")] pub use super::transformer::builder::*; } - -#[cfg(test)] -mod tests { - use crate::{ - gradients::Gradients, optim::UnusedTensors, shapes::Dtype, tensor::visitors::*, - tensor::DeviceStorage, - }; - - #[derive(Default)] - pub struct SimpleUpdater { - pub grads: Gradients, - pub unused: UnusedTensors, - } - - impl SimpleUpdater { - pub(crate) fn clear_unused(&mut self) { - self.unused.clear(); - } - } - - impl VisitTensorMut for SimpleUpdater { - type Err = D::Err; - fn visit( - &mut self, - _: alloc::string::String, - _: TensorOptions, - p: &mut crate::prelude::Tensor, - ) -> Result<(), ::Err> { - if self.grads.remove(p).is_none() { - self.unused.add(p); - } - Ok(()) - } - } -} diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index e7a1d8133..1acd3ef21 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -100,9 +100,9 @@ impl, const N: usize> ModuleMut>()).array()); } - - #[test] - fn test_repeated_missing_gradients() { - let dev: TestDevice = Default::default(); - - type Model = Repeated, 3>; - let mut model = dev.build_module::(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!( - &g.unused.ids, - &[ - *model[0].weight.id(), - *model[0].bias.id(), - *model[1].weight.id(), - *model[1].bias.id(), - *model[2].weight.id(), - *model[2].bias.id(), - ] - ); - - // weight gradient is present - for i in 0..3 { - g.grads.try_alloc_for(&model[i].weight).unwrap(); - } - - g.clear_unused(); - model.update(&mut g).unwrap(); - assert_eq!( - &g.unused.ids, - &[ - *model[0].bias.id(), - *model[1].bias.id(), - *model[2].bias.id() - ] - ); - - // all gradients present - for i in 0..3 { - g.grads.try_alloc_for(&model[i].weight).unwrap(); - g.grads.try_alloc_for(&model[i].bias).unwrap(); - } - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } } diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 521be82c8..61e58e3e8 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -110,12 +110,8 @@ mod tests { use super::*; use crate::nn::DeviceBuildExt; - use crate::{gradients::*, optim::*, shapes::*, tensor_ops::*}; - use crate::{ - nn::{builders::Linear, tests::SimpleUpdater}, - tests::*, - unique_id::HasUniqueId, - }; + use crate::{gradients::*, shapes::*, tensor_ops::*}; + use crate::{nn::builders::Linear, tests::*}; #[test] fn test_unused() { @@ -234,33 +230,4 @@ mod tests { Tensor, _, _, OwnedTape<_>>, ) = m.forward(dev.zeros::>().traced()); } - - #[test] - fn test_missing_gradients() { - let dev: TestDevice = Default::default(); - type Model = SplitInto<(Linear<5, 3>, Linear<5, 3>)>; - let mut model = dev.build_module::(); - let mut g: SimpleUpdater = Default::default(); - - // no gradients present - model.update(&mut g).unwrap(); - assert_eq!( - &g.unused.ids, - &[ - *model.0 .0.weight.id(), - *model.0 .0.bias.id(), - *model.0 .1.weight.id(), - *model.0 .1.bias.id() - ] - ); - - // weight gradient is present - g.grads.try_alloc_for(&model.0 .0.weight).unwrap(); - g.grads.try_alloc_for(&model.0 .0.bias).unwrap(); - g.grads.try_alloc_for(&model.0 .1.weight).unwrap(); - g.grads.try_alloc_for(&model.0 .1.bias).unwrap(); - g.clear_unused(); - model.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); - } } diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 6a82e3dba..2bb3bcab2 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -262,7 +262,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{nn::tests::SimpleUpdater, optim::*, tests::*}; + use crate::{optim::*, tests::*}; #[test] fn test_mha_unbatched() { @@ -367,12 +367,9 @@ mod tests { let k: Tensor, TestDtype, _> = dev.sample_normal(); let v: Tensor, TestDtype, _> = dev.sample_normal(); let y = mha.forward((q.trace(), k, v)); + let g = y.square().mean().backward(); - let mut g = SimpleUpdater { - grads: y.mean().backward(), - unused: Default::default(), - }; - mha.update(&mut g).unwrap(); - assert!(g.unused.is_empty()); + let mut opt = Sgd::new(&mha, Default::default()); + opt.update(&mut mha, g).expect(""); } } diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 046a6a867..5b92d9256 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -169,7 +169,7 @@ where mod tests { use super::*; use crate::{ - nn::{tests::SimpleUpdater, DeviceBuildExt}, + nn::{DeviceBuildExt}, optim::*, shapes::*, tensor::*, @@ -205,11 +205,7 @@ mod tests { let out: Tensor, _, _, _> = t.forward_mut((src.trace(), tgt)); let g = out.mean().backward(); - let mut gs = SimpleUpdater { - grads: g, - unused: Default::default(), - }; - t.update(&mut gs).unwrap(); - assert!(gs.unused.is_empty()); + let mut opt = Sgd::new(&t, Default::default()); + opt.update(&mut t, g).expect(""); } } diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 91431d6ff..7a204d59f 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -12,7 +12,7 @@ use crate::{ tensor::DeviceStorage, }; -use super::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; +use super::{Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; /// Configuration of hyperparameters for [Adam]. /// @@ -142,7 +142,11 @@ impl, D: AdamKernel, E: Dtype> Optimizer f ) -> Result<(), OptimizerUpdateError> { self.t = self.t.checked_add(1).unwrap(); self.gradients = gradients; - let result = module.update(self); + let result = M::iter_tensors(&mut RecursiveWalker { + m: module, + f: self, + path: &mut std::vec::Vec::new(), + }); let unused = std::mem::take(&mut self.unused); match result { Ok(_) => unused.into(), @@ -281,4 +285,12 @@ mod tests { assert_close(&t.array(), e); } } + + #[test] + fn test_unused_tensors() { + let dev: TestDevice = Default::default(); + let mut t: Tensor, TestDtype, _> = dev.sample_normal(); + let mut opt = Adam::new(&t, Default::default()); + opt.update(&mut t, Default::default()).expect_err(""); + } } diff --git a/src/optim/mod.rs b/src/optim/mod.rs index fe6daeb13..8f4930260 100644 --- a/src/optim/mod.rs +++ b/src/optim/mod.rs @@ -34,11 +34,11 @@ mod rmsprop; mod sgd; pub use adam::{Adam, AdamConfig}; -pub use optimizer::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors}; pub use optimizer::{Momentum, WeightDecay}; +pub use optimizer::{Optimizer, OptimizerUpdateError, UnusedTensors}; pub use rmsprop::{RMSprop, RMSpropConfig}; pub use sgd::{Sgd, SgdConfig}; pub mod prelude { - pub use super::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors}; + pub use super::{Optimizer, OptimizerUpdateError, UnusedTensors}; } diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index fb2c723ab..1d2787c04 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -1,7 +1,6 @@ use crate::{ gradients::Gradients, shapes::Dtype, - tensor::visitors::{RecursiveWalker, TensorCollection, VisitTensorMut}, tensor::DeviceStorage, unique_id::{HasUniqueId, UniqueId}, }; @@ -92,19 +91,6 @@ pub trait Optimizer { ) -> Result<(), OptimizerUpdateError>; } -/// Represents something that can be updated with a [ParamUpdater]. -pub trait GradientUpdate: TensorCollection { - /// Updates self given the [ParamUpdater]. - fn update>(&mut self, updater: &mut V) -> Result<(), V::Err> { - Self::iter_tensors(&mut RecursiveWalker { - m: self, - f: updater, - path: &mut std::vec::Vec::new(), - }) - } -} -impl> GradientUpdate for M {} - /// Holds [UniqueId] of tensors that were missing gradients during /// [GradientUpdate::update()], and therefore are unused #[derive(Debug, Default)] diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index 0d0868029..d20aa3623 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -12,7 +12,7 @@ use crate::{ tensor::*, }; -use super::{GradientUpdate, Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; +use super::{Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; /// Configuration of hyperparameters for [RMSprop]. #[derive(Debug, Clone, Copy)] @@ -157,7 +157,11 @@ impl, D: RMSpropKernel + OneFillStorage, E: Dtyp gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.gradients = gradients; - let result = module.update(self); + let result = M::iter_tensors(&mut RecursiveWalker { + m: module, + f: self, + path: &mut std::vec::Vec::new(), + }); let unused = std::mem::take(&mut self.unused); let r = match result { Ok(_) => unused.into(), @@ -317,4 +321,12 @@ mod tests { ]; test_matches_expected(cfg, EXPECTED); } + + #[test] + fn test_unused_tensors() { + let dev: TestDevice = Default::default(); + let mut t: Tensor, TestDtype, _> = dev.sample_normal(); + let mut opt = RMSprop::new(&t, Default::default()); + opt.update(&mut t, Default::default()).expect_err(""); + } } diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 9903b7562..672ef9408 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -172,7 +172,11 @@ impl, D: SgdKernel, E: Dtype> Optimizer fo gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.gradients = gradients; - let result = module.update(self); + let result = M::iter_tensors(&mut RecursiveWalker { + m: module, + f: self, + path: &mut std::vec::Vec::new(), + }); let unused = std::mem::take(&mut self.unused); match result { Ok(_) => unused.into(), @@ -414,4 +418,12 @@ mod tests { assert_close(&t.array(), e); } } + + #[test] + fn test_unused_tensors() { + let dev: TestDevice = Default::default(); + let mut t: Tensor, TestDtype, _> = dev.sample_normal(); + let mut opt = Sgd::new(&t, Default::default()); + opt.update(&mut t, Default::default()).expect_err(""); + } } diff --git a/src/tensor/visitors/base.rs b/src/tensor/visitors/base.rs index 417afaeff..5316f14b6 100644 --- a/src/tensor/visitors/base.rs +++ b/src/tensor/visitors/base.rs @@ -1,3 +1,5 @@ +#![allow(clippy::type_complexity)] + use crate::{ shapes::{Dtype, Shape}, tensor::{DeviceStorage, Tensor}, @@ -8,8 +10,6 @@ use std::{string::String, vec::Vec}; pub struct TensorOptions { pub name: &'static str, pub update: bool, - - #[allow(clippy::type_complexity)] pub reset: fn(&mut Tensor) -> Result<(), D::Err>, } From dd129a130acfd5068153074341e0d06a975a82a4 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 09:54:22 -0500 Subject: [PATCH 04/20] Format --- src/nn/transformer/mod.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 5b92d9256..ed76a4452 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -168,14 +168,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{ - nn::{DeviceBuildExt}, - optim::*, - shapes::*, - tensor::*, - tensor_ops::*, - tests::*, - }; + use crate::{nn::DeviceBuildExt, optim::*, shapes::*, tensor::*, tensor_ops::*, tests::*}; #[test] fn test_forward() { From 8c37bf054863b618d5df07ce32fa3dfd9acb0592 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 10:36:04 -0500 Subject: [PATCH 05/20] Rename ModuleWalker to TensorVisitor --- src/nn/add_into.rs | 2 +- src/nn/batchnorm2d.rs | 2 +- src/nn/conv.rs | 2 +- src/nn/embedding.rs | 2 +- src/nn/generalized_residual.rs | 2 +- src/nn/layer_norm.rs | 2 +- src/nn/linear.rs | 2 +- src/nn/module.rs | 4 ++-- src/nn/repeated.rs | 2 +- src/nn/residual.rs | 2 +- src/nn/split_into.rs | 2 +- src/nn/transformer/decoder.rs | 4 ++-- src/nn/transformer/encoder.rs | 2 +- src/nn/transformer/mha.rs | 2 +- src/nn/transformer/mod.rs | 2 +- src/tensor/visitors/base.rs | 12 ++++++------ 16 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index 0c9a08c59..96cc05c3d 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -34,7 +34,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Add } impl> TensorCollection for AddInto { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 2084bea0d..662dc43d5 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -184,7 +184,7 @@ impl> BuildModule for BatchNorm2D> TensorCollection for BatchNorm2D { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.scale, |s| &mut s.scale, diff --git a/src/nn/conv.rs b/src/nn/conv.rs index be3b1af0f..25c39efe9 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -59,7 +59,7 @@ where E: Dtype + Float + SampleUniform, D: Device, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index e288254a5..c38a2dada 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -70,7 +70,7 @@ impl> TensorCollection for Embedding { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index 8129eb095..24fd3560e 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -45,7 +45,7 @@ impl, R: BuildModule> Bui impl, R: TensorCollection> TensorCollection for GeneralizedResidual { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.f, |s| &mut s.f, "f")?; visitor.visit_module(|s| &s.r, |s| &mut s.r, "r") } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index f0966deb6..dd7cc80e4 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -56,7 +56,7 @@ impl> BuildModule for LayerNorm1D> TensorCollection for LayerNorm1D { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.gamma, |s| &mut s.gamma, diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 0c84ec8bc..1f2dfab46 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -71,7 +71,7 @@ impl> TensorCollection for Linear { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/module.rs b/src/nn/module.rs index 6ab0d6fc5..519b3957b 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -3,7 +3,7 @@ pub use crate::tensor::OnCuda; pub use crate::tensor::{DeviceStorage, OnCpu, OnDevice, ToDevice}; use crate::{ shapes::Dtype, - tensor::visitors::{ModuleWalker, TensorCollection}, + tensor::visitors::{TensorCollection, TensorVisitor}, }; /// Immutable forward of `Input` that produces [Module::Output]. @@ -75,7 +75,7 @@ impl, D: DeviceStorage, E: Dtype> BuildOn } impl TensorCollection for T { - fn iter_tensors>(_: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(_: &mut V) -> Result<(), V::Err> { Ok(()) } } diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index 1acd3ef21..1ac7b0fa2 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -42,7 +42,7 @@ impl, const N: usize> BuildModu impl, const N: usize> TensorCollection for Repeated { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { for i in 0..N { visitor.visit_module( |s| &s.modules[i], diff --git a/src/nn/residual.rs b/src/nn/residual.rs index a6ef128cd..8731da42b 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -34,7 +34,7 @@ impl> BuildModule for Res } impl> TensorCollection for Residual { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 61e58e3e8..b152d8efa 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -35,7 +35,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Spl impl> TensorCollection for SplitInto { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index 5adc736d5..036ed0a24 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -84,7 +84,7 @@ impl>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -187,7 +187,7 @@ impl> TensorColl where E: Dtype + Float + SampleUniform, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; visitor.visit_module(|s| &s.mh_attn, |s| &mut s.mh_attn, "mh_attn")?; diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 520a09bed..51c1eac50 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -119,7 +119,7 @@ impl> TensorColl where E: Dtype + Float + SampleUniform, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 2bb3bcab2..88de75232 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -83,7 +83,7 @@ impl>(visitor: &mut W) -> Result<(), W::Err> { + fn iter_tensors>(visitor: &mut W) -> Result<(), W::Err> { visitor.visit_module(|s| &s.w_q, |s| &mut s.w_q, "w_q")?; visitor.visit_module(|s| &s.w_k, |s| &mut s.w_k, "w_k")?; visitor.visit_module(|s| &s.w_v, |s| &mut s.w_v, "w_v")?; diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index ed76a4452..63a4b00c5 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -102,7 +102,7 @@ where E: Dtype + Float + SampleUniform, D: Device, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.encoder, |s| &mut s.encoder, "encoder")?; visitor.visit_module(|s| &s.decoder, |s| &mut s.decoder, "decoder") } diff --git a/src/tensor/visitors/base.rs b/src/tensor/visitors/base.rs index 5316f14b6..5dd1de6f9 100644 --- a/src/tensor/visitors/base.rs +++ b/src/tensor/visitors/base.rs @@ -68,11 +68,11 @@ pub trait VisitTensorMutRef { } pub trait TensorCollection: Sized { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; } impl TensorCollection for Tensor { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| s, |s| s, @@ -85,7 +85,7 @@ impl TensorCollection for Tensor: Sized { +pub trait TensorVisitor: Sized { type Err; fn visit_module( &mut self, @@ -115,7 +115,7 @@ pub(crate) struct RecursiveWalker<'a, M, F> { pub(crate) path: &'a mut Vec, } -impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> ModuleWalker +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> TensorVisitor for RecursiveWalker<'a, &'a M, F> { type Err = F::Err; @@ -157,7 +157,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> ModuleWalker> ModuleWalker +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> TensorVisitor for RecursiveWalker<'a, &'a mut M, F> { type Err = F::Err; @@ -199,7 +199,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> ModuleWalker> ModuleWalker +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef> TensorVisitor for RecursiveWalker<'a, (&'a mut M, &'a M), F> { type Err = F::Err; From 338da57e52a9a9babe2156827cc674b3b5d37026 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 10:37:39 -0500 Subject: [PATCH 06/20] Fix missing update of ModuleWalker --- src/nn/impl_module_for_tuples.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index c9d1c9500..14fb3706a 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -5,7 +5,7 @@ use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, OnDevice, ToD macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { impl),+> TensorCollection for ($($name,)+) { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { $(visitor.visit_module(|s| &s.$idx, |s| &mut s.$idx, &std::format!("{}", $idx))?;)+ Ok(()) } From f3100627851ff26242a6c15ac504b4bf81253159 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 11:43:07 -0500 Subject: [PATCH 07/20] Fixing old docs --- src/nn/module.rs | 2 +- src/optim/optimizer.rs | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/nn/module.rs b/src/nn/module.rs index 519b3957b..365f71cf9 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -67,7 +67,7 @@ pub trait DeviceBuildExt: DeviceStorage { impl DeviceBuildExt for D {} /// Marker trait for modules with no updatable parameters. These have -/// blanket impls for [ResetParams], [GradientUpdate], and [ModuleMut] +/// blanket impls for, and [ModuleMut] pub trait ZeroSizedModule: Default {} impl, D: DeviceStorage, E: Dtype> BuildOnDevice for T { diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index 1d2787c04..cd45d3826 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -64,16 +64,11 @@ pub(super) fn momentum_to_cuda(wd: Option>) -> (Momentum } } -/// All optimizers must implement the update function, which takes an object -/// that implements [GradientUpdate], and calls [GradientUpdate::update]. +/// All optimizers must implement the update function, which takes a `M` +/// and updates all of its parameters. /// /// # Notes /// -/// 1. [GradientUpdate] requires an object that implements [crate::optim::ParamUpdater]. -/// A common implementation involves implementing both [Optimizer] and [crate::optim::ParamUpdater] -/// on one struct, and passing self to [GradientUpdate::update]. See [super::Sgd] for an example -/// of implementing this trait. -/// /// 2. Update takes ownership of [Gradients], so update cannot be called /// with the same gradients object. /// @@ -92,7 +87,7 @@ pub trait Optimizer { } /// Holds [UniqueId] of tensors that were missing gradients during -/// [GradientUpdate::update()], and therefore are unused +/// update, and therefore are unused #[derive(Debug, Default)] pub struct UnusedTensors { pub ids: std::vec::Vec, @@ -116,7 +111,7 @@ impl UnusedTensors { /// An error indicating that a parameter was not used in gradient /// computation, and was therefore not present in [Gradients] -/// while a [GradientUpdate] was trying to update it. +/// during an update. #[derive(Debug)] pub enum OptimizerUpdateError { UnusedParams(UnusedTensors), From 4b0df4bacdf087b9ea2bad29f3a31a56de9fe321 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 12:18:55 -0500 Subject: [PATCH 08/20] Moves name out of TensorOptions --- src/nn/batchnorm2d.rs | 30 +++++++++++++++------- src/nn/conv.rs | 6 +++-- src/nn/embedding.rs | 3 ++- src/nn/layer_norm.rs | 9 +++---- src/nn/linear.rs | 6 +++-- src/optim/adam/mod.rs | 5 +++- src/optim/rmsprop/mod.rs | 5 +++- src/optim/sgd/mod.rs | 5 +++- src/tensor/visitors/base.rs | 50 +++++++++++++++++++++++-------------- 9 files changed, 77 insertions(+), 42 deletions(-) diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 662dc43d5..9e6f39ef4 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -188,22 +188,21 @@ impl> TensorCollection for BatchNor visitor.visit_tensor( |s| &s.scale, |s| &mut s.scale, - TensorOptions::named("scale", |t| t.try_fill_with_ones()), - )?; - visitor.visit_tensor( - |s| &s.bias, - |s| &mut s.bias, - TensorOptions::named("bias", |t| t.try_fill_with_zeros()), + "scale", + TensorOptions::ones(), )?; + visitor.visit_tensor(|s| &s.bias, |s| &mut s.bias, "bias", TensorOptions::zeros())?; visitor.visit_tensor( |s| &s.running_mean, |s| &mut s.running_mean, - TensorOptions::no_grad("running_mean", |t| t.try_fill_with_zeros()), + "running_mean", + TensorOptions::no_grad(|t| t.try_fill_with_zeros()), )?; visitor.visit_tensor( |s| &s.running_var, |s| &mut s.running_var, - TensorOptions::no_grad("running_var", |t| t.try_fill_with_ones()), + "running_var", + TensorOptions::no_grad(|t| t.try_fill_with_ones()), ) } } @@ -227,7 +226,7 @@ impl, D2: Device> ToDevice #[cfg(test)] mod tests { use super::builder::BatchNorm2D; - use crate::{nn::*, shapes::*, tensor::*, tensor_ops::*, tests::*}; + use crate::{nn::*, optim::*, shapes::*, tensor::*, tensor_ops::*, tests::*}; #[test] fn test_batchnorm2d_3d_forward_mut() { @@ -348,4 +347,17 @@ mod tests { ], ); } + + #[test] + fn test_batchnorm2d_update() { + let dev: TestDevice = Default::default(); + + let x1: Tensor, TestDtype, _> = dev.sample_normal(); + let mut bn = dev.build_module::, TestDtype>(); + let y = bn.forward_mut(x1.trace()); + let g = y.square().mean().backward(); + + let mut opt = Sgd::new(&bn, Default::default()); + opt.update(&mut bn, g).expect(""); + } } diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 25c39efe9..4c170251c 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -63,7 +63,8 @@ where visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, - TensorOptions::named("weight", |t| { + "weight", + TensorOptions::requires_grad(|t| { let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), @@ -71,7 +72,8 @@ where visitor.visit_tensor( |s| &s.bias, |s| &mut s.bias, - TensorOptions::named("bias", |t| { + "bias", + TensorOptions::requires_grad(|t| { let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index c38a2dada..2e92168a9 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -74,7 +74,8 @@ impl> TensorCollection for LayerNor visitor.visit_tensor( |s| &s.gamma, |s| &mut s.gamma, - TensorOptions::named("gamma", |t| t.try_fill_with_ones()), + "gamma", + TensorOptions::ones(), )?; - visitor.visit_tensor( - |s| &s.beta, - |s| &mut s.beta, - TensorOptions::named("beta", |t| t.try_fill_with_zeros()), - ) + visitor.visit_tensor(|s| &s.beta, |s| &mut s.beta, "beta", TensorOptions::zeros()) } } diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 1f2dfab46..097f2243a 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -75,7 +75,8 @@ impl, E: Dtype> VisitTensorMut for Adam { fn visit( &mut self, _: alloc::string::String, - _: TensorOptions, + opts: TensorOptions, p: &mut crate::prelude::Tensor, ) -> Result<(), ::Err> { + if !opts.update { + return Ok(()); + } let g = self.gradients.remove(p); match g { None => self.unused.add(p), diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index d20aa3623..2859d2b04 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -126,9 +126,12 @@ impl + OneFillStorage> VisitTensorMut fn visit( &mut self, _: alloc::string::String, - _: TensorOptions, + opts: TensorOptions, p: &mut Tensor, ) -> Result<(), ::Err> { + if !opts.update { + return Ok(()); + } let g = self.gradients.remove(p); match g { None => self.unused.add(p), diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 672ef9408..bd61f47ee 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -150,9 +150,12 @@ impl, M> VisitTensorMut for Sgd { fn visit( &mut self, _: alloc::string::String, - _: TensorOptions, + opts: TensorOptions, p: &mut Tensor, ) -> Result<(), D::Err> { + if !opts.update { + return Ok(()); + } let g = self.gradients.remove(p); match g { None => self.unused.add(p), diff --git a/src/tensor/visitors/base.rs b/src/tensor/visitors/base.rs index 5dd1de6f9..1c7ca9733 100644 --- a/src/tensor/visitors/base.rs +++ b/src/tensor/visitors/base.rs @@ -2,35 +2,43 @@ use crate::{ shapes::{Dtype, Shape}, - tensor::{DeviceStorage, Tensor}, + tensor::{DeviceStorage, OneFillStorage, Tensor, ZeroFillStorage}, }; use std::{string::String, vec::Vec}; pub struct TensorOptions { - pub name: &'static str, pub update: bool, pub reset: fn(&mut Tensor) -> Result<(), D::Err>, } impl TensorOptions { - pub fn named( - name: &'static str, - reset: fn(&mut Tensor) -> Result<(), D::Err>, - ) -> Self { - Self { - name, + pub fn zeros() -> Self + where + D: ZeroFillStorage, + { + TensorOptions { + update: true, + reset: |t| t.try_fill_with_zeros(), + } + } + pub fn ones() -> Self + where + D: OneFillStorage, + { + TensorOptions { + update: true, + reset: |t| t.try_fill_with_ones(), + } + } + pub fn requires_grad(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { + TensorOptions { update: true, reset, } } - - pub fn no_grad( - name: &'static str, - reset: fn(&mut Tensor) -> Result<(), D::Err>, - ) -> Self { - Self { - name, + pub fn no_grad(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { + TensorOptions { update: false, reset, } @@ -76,8 +84,8 @@ impl TensorCollection for Tensor: Sized { &mut self, get_refs: GetRef, get_muts: GetMut, + name: &str, opts: TensorOptions, ) -> Result<(), Self::Err> where @@ -144,13 +153,14 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> TensorVisitor, ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, { - self.path.push(opts.name.into()); + self.path.push(name.into()); self.f.visit(self.path.join("."), opts, get_refs(self.m))?; self.path.pop(); Ok(()) @@ -186,13 +196,14 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> TensorVisitor, ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, { - self.path.push(opts.name.into()); + self.path.push(name.into()); self.f.visit(self.path.join("."), opts, get_muts(self.m))?; self.path.pop(); Ok(()) @@ -228,13 +239,14 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef> TensorVisito &mut self, mut get_refs: GetRef, mut get_muts: GetMut, + name: &str, opts: TensorOptions, ) -> Result<(), F::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, { - self.path.push(opts.name.into()); + self.path.push(name.into()); let tensors = (get_muts(self.m.0), get_refs(self.m.1)); self.f.visit(self.path.join("."), opts, tensors)?; self.path.pop(); From c4f0481a040e6c15636888c9fe7d586bebd63cf3 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 12:20:45 -0500 Subject: [PATCH 09/20] Rename TensorOptions helper methods --- src/nn/batchnorm2d.rs | 13 +++++++++---- src/nn/conv.rs | 4 ++-- src/nn/embedding.rs | 2 +- src/nn/layer_norm.rs | 9 +++++++-- src/nn/linear.rs | 4 ++-- src/tensor/visitors/base.rs | 8 ++++---- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 9e6f39ef4..7dc09c31a 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -189,20 +189,25 @@ impl> TensorCollection for BatchNor |s| &s.scale, |s| &mut s.scale, "scale", - TensorOptions::ones(), + TensorOptions::reset_to_ones(), + )?; + visitor.visit_tensor( + |s| &s.bias, + |s| &mut s.bias, + "bias", + TensorOptions::reset_to_zeros(), )?; - visitor.visit_tensor(|s| &s.bias, |s| &mut s.bias, "bias", TensorOptions::zeros())?; visitor.visit_tensor( |s| &s.running_mean, |s| &mut s.running_mean, "running_mean", - TensorOptions::no_grad(|t| t.try_fill_with_zeros()), + TensorOptions::detached(|t| t.try_fill_with_zeros()), )?; visitor.visit_tensor( |s| &s.running_var, |s| &mut s.running_var, "running_var", - TensorOptions::no_grad(|t| t.try_fill_with_ones()), + TensorOptions::detached(|t| t.try_fill_with_ones()), ) } } diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 4c170251c..1ae2de4bd 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -64,7 +64,7 @@ where |s| &s.weight, |s| &mut s.weight, "weight", - TensorOptions::requires_grad(|t| { + TensorOptions::reset_with(|t| { let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), @@ -73,7 +73,7 @@ where |s| &s.bias, |s| &mut s.bias, "bias", - TensorOptions::requires_grad(|t| { + TensorOptions::reset_with(|t| { let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index 2e92168a9..6a4fc412a 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -75,7 +75,7 @@ impl> TensorCollection for LayerNor |s| &s.gamma, |s| &mut s.gamma, "gamma", - TensorOptions::ones(), + TensorOptions::reset_to_ones(), )?; - visitor.visit_tensor(|s| &s.beta, |s| &mut s.beta, "beta", TensorOptions::zeros()) + visitor.visit_tensor( + |s| &s.beta, + |s| &mut s.beta, + "beta", + TensorOptions::reset_to_zeros(), + ) } } diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 097f2243a..7351b6eff 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -76,7 +76,7 @@ impl { } impl TensorOptions { - pub fn zeros() -> Self + pub fn reset_to_zeros() -> Self where D: ZeroFillStorage, { @@ -22,7 +22,7 @@ impl TensorOptions { reset: |t| t.try_fill_with_zeros(), } } - pub fn ones() -> Self + pub fn reset_to_ones() -> Self where D: OneFillStorage, { @@ -31,13 +31,13 @@ impl TensorOptions { reset: |t| t.try_fill_with_ones(), } } - pub fn requires_grad(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { + pub fn reset_with(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { TensorOptions { update: true, reset, } } - pub fn no_grad(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { + pub fn detached(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { TensorOptions { update: false, reset, From c3e442d6bb80ce8d1ece8c3def808b020e3ca626 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 17:02:20 -0500 Subject: [PATCH 10/20] move visitors into nn --- src/nn/add_into.rs | 4 ++-- src/nn/batchnorm2d.rs | 4 ++-- src/nn/conv.rs | 4 ++-- src/nn/embedding.rs | 4 ++-- src/nn/generalized_residual.rs | 4 ++-- src/nn/impl_module_for_tuples.rs | 6 +++--- src/nn/layer_norm.rs | 6 +++--- src/nn/linear.rs | 4 ++-- src/nn/mod.rs | 8 +++++++- src/nn/module.rs | 7 +++---- src/nn/npz.rs | 4 +++- src/{tensor/visitors => nn}/num_params.rs | 2 +- src/nn/repeated.rs | 4 ++-- src/{tensor/visitors => nn}/reset_params.rs | 2 +- src/nn/residual.rs | 4 ++-- src/nn/split_into.rs | 4 ++-- src/nn/transformer/decoder.rs | 3 +-- src/nn/transformer/encoder.rs | 3 +-- src/nn/transformer/mha.rs | 3 +-- src/nn/transformer/mod.rs | 12 ++++-------- src/{tensor/visitors/base.rs => nn/visitors.rs} | 0 src/optim/adam/mod.rs | 2 +- src/optim/rmsprop/mod.rs | 2 +- src/optim/sgd/mod.rs | 10 ++++++---- src/tensor/mod.rs | 3 --- src/tensor/visitors/mod.rs | 7 ------- 26 files changed, 54 insertions(+), 62 deletions(-) rename src/{tensor/visitors => nn}/num_params.rs (90%) rename src/{tensor/visitors => nn}/reset_params.rs (90%) rename src/{tensor/visitors/base.rs => nn/visitors.rs} (100%) delete mode 100644 src/tensor/visitors/mod.rs diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index 96cc05c3d..00a99c401 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -1,6 +1,6 @@ -use crate::{shapes::Dtype, tensor::visitors::*, tensor::*}; +use crate::{shapes::Dtype, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Add inputs together into a single tensor. `T` should be a tuple //// where every element of the tuple has the same output type diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 7dc09c31a..ac855223e 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -1,6 +1,6 @@ -use crate::{gradients::*, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; +use crate::{gradients::*, shapes::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug, Copy, Clone, Eq, PartialEq)] diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 1ae2de4bd..1291b61f5 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -1,9 +1,9 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; -use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug)] diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index 6a4fc412a..c48349ce9 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -1,9 +1,9 @@ use num_traits::Float; use rand_distr::{uniform::SampleUniform, Uniform}; -use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::module::{BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; pub mod builder { #[derive(Debug)] diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index 24fd3560e..0a1503341 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -1,6 +1,6 @@ -use crate::{shapes::*, tensor::visitors::*, tensor::*}; +use crate::{shapes::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// A residual connection `R` around `F`: `F(x) + R(x)`, /// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 14fb3706a..7f7447f06 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -1,6 +1,6 @@ -use crate::{shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; +use crate::{shapes::*, tensor::*, tensor_ops::*}; -use super::module::{BuildModule, BuildOnDevice, Module, ModuleMut, OnDevice, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { @@ -24,7 +24,7 @@ macro_rules! tuple_impls { } impl<$($name: ToDevice,)+ D> ToDevice for ($($name,)+) { - type Output = ($(OnDevice<$name, D>,)+); + type Output = ($(<$name as ToDevice>::Output,)+); fn to_device(&self, device: &D) -> Self::Output { ($(self.$idx.to_device(device)),+) } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index fa219d49b..b6013586a 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -1,6 +1,6 @@ -use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; pub mod builder { #[derive(Debug)] @@ -120,7 +120,7 @@ impl, T: Tape> #[cfg(test)] mod tests { use super::*; - use crate::nn::{DeviceBuildExt, ModuleMut}; + use crate::nn::{DeviceBuildExt, ModuleMut, ResetParams}; use crate::tests::*; #[test] diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 7351b6eff..490b84d49 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -1,6 +1,6 @@ -use crate::{gradients::Tape, shapes::*, tensor::visitors::*, tensor::*, tensor_ops::*}; +use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::module::{BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; use num_traits::Float; use rand_distr::{uniform::SampleUniform, Uniform}; diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 922ed69c4..3dbd44250 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -104,6 +104,10 @@ //! mlp.load_state_dict(state_dict) //! ``` +mod num_params; +mod reset_params; +pub mod visitors; + mod activations; mod add_into; mod batchnorm2d; @@ -128,7 +132,9 @@ mod transformer; pub use module::*; #[cfg(feature = "numpy")] -pub use npz::*; +pub use npz::{LoadFromNpz, SaveToNpz}; +pub use num_params::NumParams; +pub use reset_params::ResetParams; pub mod modules { /// Structs containing initialized Tensors & impls for [super::Module]. See diff --git a/src/nn/module.rs b/src/nn/module.rs index 8b006854b..f8dcc4c3c 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -1,10 +1,9 @@ +use crate::shapes::Dtype; #[cfg(feature = "cuda")] pub use crate::tensor::OnCuda; pub use crate::tensor::{DeviceStorage, OnCpu, OnDevice, ToDevice}; -use crate::{ - shapes::Dtype, - tensor::visitors::{TensorCollection, TensorVisitor}, -}; + +use super::visitors::{TensorCollection, TensorVisitor}; /// Immutable forward of `Input` that produces [Module::Output]. /// See [ModuleMut] for mutable forward. diff --git a/src/nn/npz.rs b/src/nn/npz.rs index 34d32513b..2d28d3509 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -1,11 +1,13 @@ use crate::{ shapes::{Dtype, Shape}, - tensor::visitors::*, tensor::{ numpy::{NpzError, NumpyDtype}, CopySlice, Tensor, }, }; + +use super::visitors::*; + use std::{ io::{BufReader, BufWriter, Read, Seek, Write}, path::Path, diff --git a/src/tensor/visitors/num_params.rs b/src/nn/num_params.rs similarity index 90% rename from src/tensor/visitors/num_params.rs rename to src/nn/num_params.rs index 397b80ffb..bfc76b565 100644 --- a/src/tensor/visitors/num_params.rs +++ b/src/nn/num_params.rs @@ -1,4 +1,4 @@ -use super::base::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorRef}; +use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorRef}; use crate::{shapes::*, tensor::*}; diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index 1ac7b0fa2..a7ff28c54 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -1,6 +1,6 @@ -use crate::{shapes::Dtype, tensor::visitors::*, tensor::*}; +use crate::{shapes::Dtype, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Repeats `T` `N` times. This requires that `T`'s input is the same as it's output. /// diff --git a/src/tensor/visitors/reset_params.rs b/src/nn/reset_params.rs similarity index 90% rename from src/tensor/visitors/reset_params.rs rename to src/nn/reset_params.rs index 209fde007..2d4f059ad 100644 --- a/src/tensor/visitors/reset_params.rs +++ b/src/nn/reset_params.rs @@ -1,4 +1,4 @@ -use super::base::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorMut}; +use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorMut}; use crate::{shapes::*, tensor::*}; diff --git a/src/nn/residual.rs b/src/nn/residual.rs index 8731da42b..7b502ddcd 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -1,6 +1,6 @@ -use crate::{shapes::*, tensor::visitors::*, tensor::*}; +use crate::{shapes::*, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; use std::ops::Add; diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index b152d8efa..1ffda2a7e 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -1,6 +1,6 @@ -use crate::{shapes::Dtype, tensor::visitors::*, tensor::*}; +use crate::{shapes::Dtype, tensor::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Splits input into multiple heads. `T` should be a tuple, /// where every element of the tuple accepts the same input type. diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index 036ed0a24..99bd629ce 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -2,9 +2,8 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - nn::{modules::*, *}, + nn::{modules::*, visitors::*, *}, shapes::Dtype, - tensor::visitors::*, tensor::{PutTape, SplitTape}, tensor_ops::Device, }; diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 51c1eac50..65bb0d5fa 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -2,9 +2,8 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - nn::{modules::*, *}, + nn::{modules::*, visitors::*, *}, shapes::Dtype, - tensor::visitors::*, tensor::{PutTape, SplitTape}, tensor_ops::Device, }; diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 88de75232..c2a5d03dc 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -2,9 +2,8 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - nn::{modules::*, *}, + nn::{modules::*, visitors::*, *}, shapes::Dtype, - tensor::visitors::*, tensor::*, tensor_ops::*, }; diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 63a4b00c5..9da5ea198 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -5,17 +5,13 @@ mod mha; pub use decoder::*; pub use encoder::*; pub use mha::*; + use num_traits::Float; use rand_distr::uniform::SampleUniform; -use crate::{ - shapes::Dtype, - tensor::visitors::*, - tensor::{DeviceStorage, PutTape, SplitTape}, - tensor_ops::Device, -}; +use crate::{shapes::*, tensor::*, tensor_ops::*}; -use super::{BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug, Clone)] @@ -168,7 +164,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{nn::DeviceBuildExt, optim::*, shapes::*, tensor::*, tensor_ops::*, tests::*}; + use crate::{nn::DeviceBuildExt, optim::*, tests::*}; #[test] fn test_forward() { diff --git a/src/tensor/visitors/base.rs b/src/nn/visitors.rs similarity index 100% rename from src/tensor/visitors/base.rs rename to src/nn/visitors.rs diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 1a49be23d..c4f41e924 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -7,8 +7,8 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, + nn::visitors::*, shapes::{Dtype, Shape}, - tensor::visitors::*, tensor::DeviceStorage, }; diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index 2859d2b04..be9f9a02c 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -7,8 +7,8 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, + nn::visitors::*, shapes::{Dtype, Shape}, - tensor::visitors::*, tensor::*, }; diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index bd61f47ee..ef2e98927 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -5,10 +5,12 @@ mod cuda_kernel; use std::marker::PhantomData; -use crate::gradients::Gradients; -use crate::shapes::{Dtype, Shape}; -use crate::tensor::visitors::*; -use crate::tensor::{DeviceStorage, Tensor}; +use crate::{ + gradients::Gradients, + nn::visitors::*, + shapes::{Dtype, Shape}, + tensor::{DeviceStorage, Tensor}, +}; use super::optimizer::*; diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 7567b34f4..58a17ee62 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -117,7 +117,6 @@ pub(crate) mod cuda; pub(crate) mod numpy; pub(crate) mod storage_traits; mod tensor_impls; -pub(crate) mod visitors; // TODO pub? pub(crate) use storage_traits::{OneFillStorage, ZeroFillStorage}; @@ -135,8 +134,6 @@ pub use tensor_impls::OnCuda; pub use tensor_impls::{OnCpu, OnDevice, PutTape, SplitTape, Tensor, ToDevice}; pub use tensor_impls::{Tensor0D, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D}; -pub use visitors::{NumParams, ResetParams}; - #[cfg(test)] mod tests { use super::*; diff --git a/src/tensor/visitors/mod.rs b/src/tensor/visitors/mod.rs deleted file mode 100644 index 8fed4ab31..000000000 --- a/src/tensor/visitors/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod base; -mod num_params; -mod reset_params; - -pub use base::*; -pub use num_params::NumParams; -pub use reset_params::ResetParams; From c7c1577a96da2d50f2ae6a0c64b6d2d29671e3df Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Tue, 21 Feb 2023 17:03:22 -0500 Subject: [PATCH 11/20] Fixing example --- examples/03-nn.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/03-nn.rs b/examples/03-nn.rs index e1c2e10ad..c14836f20 100644 --- a/examples/03-nn.rs +++ b/examples/03-nn.rs @@ -1,9 +1,9 @@ //! Intro to dfdx::nn use dfdx::{ - nn::{builders::*, BuildOnDevice, DeviceBuildExt, Module, ModuleMut}, + nn::{builders::*, BuildOnDevice, DeviceBuildExt, Module, ModuleMut, ResetParams}, shapes::{Const, Rank1, Rank2}, - tensor::{AsArray, ResetParams, SampleTensor, Tensor, ZerosTensor}, + tensor::{AsArray, SampleTensor, Tensor, ZerosTensor}, }; #[cfg(not(feature = "cuda"))] From dd7977e357e4d6c19014ed13f402eec32c21538e Mon Sep 17 00:00:00 2001 From: nkoppel Date: Wed, 22 Feb 2023 07:44:23 -0600 Subject: [PATCH 12/20] Add TensorContainer trait to allow more argument types for TensorVisitors in #469 (#472) * Add TensorContainer trait; deduplicate code in visitors/base.rs * Implement TensorContainer for tuples, Option, and Vec * run cargo fmt --- src/nn/impl_tensor_container.rs | 95 +++++++++++++++++++++ src/nn/mod.rs | 1 + src/nn/npz.rs | 8 +- src/nn/num_params.rs | 6 +- src/nn/reset_params.rs | 6 +- src/nn/visitors.rs | 141 +++++++------------------------- src/optim/adam/mod.rs | 4 +- src/optim/rmsprop/mod.rs | 4 +- src/optim/sgd/mod.rs | 4 +- 9 files changed, 149 insertions(+), 120 deletions(-) create mode 100644 src/nn/impl_tensor_container.rs diff --git a/src/nn/impl_tensor_container.rs b/src/nn/impl_tensor_container.rs new file mode 100644 index 000000000..4e0e5f436 --- /dev/null +++ b/src/nn/impl_tensor_container.rs @@ -0,0 +1,95 @@ +use super::visitors::TensorContainer; + +impl TensorContainer for &'static () { + type WithModule<'a, Mod: 'a> = &'a Mod; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + _get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + get_ref(*module) + } +} + +impl TensorContainer for &'static mut () { + type WithModule<'a, Mod: 'a> = &'a mut Mod; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + _get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + get_mut(*module) + } +} + +macro_rules! tuple_impls { + ([$($name:ident),+] [$($idx:tt),+]) => { + impl<$($name: TensorContainer),+> TensorContainer for ($($name,)+) { + type WithModule<'a, Mod: 'a> = ($($name::WithModule<'a, Mod>,)+); + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + ($($name::get_field(&mut module.$idx, get_ref, get_mut),)+) + } + } + } +} + +tuple_impls!([M1][0]); +tuple_impls!([M1, M2] [0, 1]); +tuple_impls!([M1, M2, M3] [0, 1, 2]); +tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3]); +tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4]); +tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]); + +impl TensorContainer for std::vec::Vec { + type WithModule<'a, Mod: 'a> = std::vec::Vec>; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + module + .iter_mut() + .map(|x| T::get_field(x, get_ref, get_mut)) + .collect() + } +} + +impl TensorContainer for Option { + type WithModule<'a, Mod: 'a> = Option>; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + module.as_mut().map(|x| T::get_field(x, get_ref, get_mut)) + } +} diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 3dbd44250..e8b78da1b 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -117,6 +117,7 @@ mod embedding; mod flatten; mod generalized_residual; mod impl_module_for_tuples; +mod impl_tensor_container; mod layer_norm; mod linear; mod module; diff --git a/src/nn/npz.rs b/src/nn/npz.rs index 2d28d3509..f3062a7c5 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -112,10 +112,12 @@ pub trait LoadFromNpz>: TensorCollection< } impl, T: TensorCollection> LoadFromNpz for T {} -impl> VisitTensorRef +impl> VisitTensors for zip::ZipWriter { + type Container = &'static (); type Err = ZipError; + fn visit( &mut self, full_path: String, @@ -126,10 +128,12 @@ impl> VisitTensorRef> VisitTensorMut +impl> VisitTensors for zip::ZipArchive { + type Container = &'static mut (); type Err = NpzError; + fn visit( &mut self, full_path: String, diff --git a/src/nn/num_params.rs b/src/nn/num_params.rs index bfc76b565..95610873f 100644 --- a/src/nn/num_params.rs +++ b/src/nn/num_params.rs @@ -1,12 +1,14 @@ -use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorRef}; +use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensors}; use crate::{shapes::*, tensor::*}; use std::{string::String, vec::Vec}; struct Counter(usize); -impl VisitTensorRef for Counter { +impl VisitTensors for Counter { + type Container = &'static (); type Err = D::Err; + fn visit( &mut self, _: String, diff --git a/src/nn/reset_params.rs b/src/nn/reset_params.rs index 2d4f059ad..db8277194 100644 --- a/src/nn/reset_params.rs +++ b/src/nn/reset_params.rs @@ -1,12 +1,14 @@ -use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensorMut}; +use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensors}; use crate::{shapes::*, tensor::*}; use std::{string::String, vec::Vec}; struct Resetter; -impl VisitTensorMut for Resetter { +impl VisitTensors for Resetter { + type Container = &'static mut (); type Err = D::Err; + fn visit( &mut self, _: String, diff --git a/src/nn/visitors.rs b/src/nn/visitors.rs index 2fb417b85..c4f80619f 100644 --- a/src/nn/visitors.rs +++ b/src/nn/visitors.rs @@ -45,34 +45,33 @@ impl TensorOptions { } } -pub trait VisitTensorRef { +pub trait VisitTensors { + type Container: TensorContainer; type Err; - fn visit( - &mut self, - full_path: String, - opts: TensorOptions, - t: &Tensor, - ) -> Result<(), Self::Err>; -} -pub trait VisitTensorMut { - type Err; fn visit( &mut self, full_path: String, opts: TensorOptions, - t: &mut Tensor, + t: ::WithModule<'_, Tensor>, ) -> Result<(), Self::Err>; } -pub trait VisitTensorMutRef { - type Err; - fn visit( - &mut self, - full_path: String, - opts: TensorOptions, - ts: (&mut Tensor, &Tensor), - ) -> Result<(), Self::Err>; +type ContainerWithModule<'a, C, M> = ::WithModule<'a, M>; + +pub trait TensorContainer: 'static { + type WithModule<'a, Mod: 'a> + where + Self: 'a; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field; } pub trait TensorCollection: Sized { @@ -124,102 +123,17 @@ pub(crate) struct RecursiveWalker<'a, M, F> { pub(crate) path: &'a mut Vec, } -impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorRef> TensorVisitor - for RecursiveWalker<'a, &'a M, F> -{ - type Err = F::Err; - fn visit_module( - &mut self, - mut get_refs: GetRef, - _: GetMut, - name: &str, - ) -> Result<(), Self::Err> - where - GetRef: FnMut(&M) -> &Field, - GetMut: FnMut(&mut M) -> &mut Field, - Field: TensorCollection, - { - self.path.push(name.into()); - let mut walker = RecursiveWalker { - m: get_refs(self.m), - f: self.f, - path: self.path, - }; - Field::iter_tensors(&mut walker)?; - self.path.pop(); - Ok(()) - } - fn visit_tensor( - &mut self, - mut get_refs: GetRef, - _: GetMut, - name: &str, - opts: TensorOptions, - ) -> Result<(), F::Err> - where - GetRef: FnMut(&M) -> &Tensor, - GetMut: FnMut(&mut M) -> &mut Tensor, - { - self.path.push(name.into()); - self.f.visit(self.path.join("."), opts, get_refs(self.m))?; - self.path.pop(); - Ok(()) - } -} - -impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMut> TensorVisitor - for RecursiveWalker<'a, &'a mut M, F> +impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors> TensorVisitor + for RecursiveWalker<'a, ContainerWithModule<'a, F::Container, M>, F> { type Err = F::Err; - fn visit_module( - &mut self, - _: GetRef, - mut get_muts: GetMut, - name: &str, - ) -> Result<(), F::Err> - where - GetRef: FnMut(&M) -> &Field, - GetMut: FnMut(&mut M) -> &mut Field, - Field: TensorCollection, - { - self.path.push(name.into()); - let mut walker = RecursiveWalker { - m: get_muts(self.m), - f: self.f, - path: self.path, - }; - Field::iter_tensors(&mut walker)?; - self.path.pop(); - Ok(()) - } - fn visit_tensor( - &mut self, - _: GetRef, - mut get_muts: GetMut, - name: &str, - opts: TensorOptions, - ) -> Result<(), F::Err> - where - GetRef: FnMut(&M) -> &Tensor, - GetMut: FnMut(&mut M) -> &mut Tensor, - { - self.path.push(name.into()); - self.f.visit(self.path.join("."), opts, get_muts(self.m))?; - self.path.pop(); - Ok(()) - } -} -impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef> TensorVisitor - for RecursiveWalker<'a, (&'a mut M, &'a M), F> -{ - type Err = F::Err; fn visit_module( &mut self, mut get_refs: GetRef, mut get_muts: GetMut, name: &str, - ) -> Result<(), F::Err> + ) -> Result<(), Self::Err> where GetRef: FnMut(&M) -> &Field, GetMut: FnMut(&mut M) -> &mut Field, @@ -227,28 +141,33 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensorMutRef> TensorVisito { self.path.push(name.into()); let mut walker = RecursiveWalker { - m: (get_muts(self.m.0), get_refs(self.m.1)), + m: F::Container::get_field(&mut self.m, &mut get_refs, &mut get_muts), f: self.f, path: self.path, }; Field::iter_tensors(&mut walker)?; + std::mem::drop(walker); self.path.pop(); Ok(()) } + fn visit_tensor( &mut self, mut get_refs: GetRef, mut get_muts: GetMut, name: &str, opts: TensorOptions, - ) -> Result<(), F::Err> + ) -> Result<(), Self::Err> where GetRef: FnMut(&M) -> &Tensor, GetMut: FnMut(&mut M) -> &mut Tensor, { self.path.push(name.into()); - let tensors = (get_muts(self.m.0), get_refs(self.m.1)); - self.f.visit(self.path.join("."), opts, tensors)?; + self.f.visit( + self.path.join("."), + opts, + F::Container::get_field(&mut self.m, &mut get_refs, &mut get_muts), + )?; self.path.pop(); Ok(()) } diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index c4f41e924..ca88e6dd7 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -112,8 +112,10 @@ pub(super) trait AdamKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, E: Dtype> VisitTensorMut for Adam { +impl, E: Dtype> VisitTensors for Adam { + type Container = &'static mut (); type Err = D::Err; + fn visit( &mut self, _: alloc::string::String, diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index be9f9a02c..6128738e5 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -121,8 +121,10 @@ pub(super) trait RMSpropKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl + OneFillStorage> VisitTensorMut for RMSprop { +impl + OneFillStorage> VisitTensors for RMSprop { + type Container = &'static mut (); type Err = D::Err; + fn visit( &mut self, _: alloc::string::String, diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index ef2e98927..9fd369da1 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -147,8 +147,10 @@ pub(super) trait SgdKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, M> VisitTensorMut for Sgd { +impl, M> VisitTensors for Sgd { + type Container = &'static mut (); type Err = D::Err; + fn visit( &mut self, _: alloc::string::String, From 9cc732283bcd0cc3e926c7312368873f1c225099 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 08:47:41 -0500 Subject: [PATCH 13/20] Add TensorMut and TensorRef tensor contains --- src/nn/impl_tensor_container.rs | 95 -------------------------------- src/nn/mod.rs | 1 - src/nn/npz.rs | 4 +- src/nn/num_params.rs | 4 +- src/nn/reset_params.rs | 4 +- src/nn/visitors.rs | 97 +++++++++++++++++++++++++++++++++ src/optim/adam/mod.rs | 2 +- src/optim/rmsprop/mod.rs | 2 +- src/optim/sgd/mod.rs | 2 +- 9 files changed, 106 insertions(+), 105 deletions(-) delete mode 100644 src/nn/impl_tensor_container.rs diff --git a/src/nn/impl_tensor_container.rs b/src/nn/impl_tensor_container.rs deleted file mode 100644 index 4e0e5f436..000000000 --- a/src/nn/impl_tensor_container.rs +++ /dev/null @@ -1,95 +0,0 @@ -use super::visitors::TensorContainer; - -impl TensorContainer for &'static () { - type WithModule<'a, Mod: 'a> = &'a Mod; - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - get_ref: &mut GetRef, - _get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field, - { - get_ref(*module) - } -} - -impl TensorContainer for &'static mut () { - type WithModule<'a, Mod: 'a> = &'a mut Mod; - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - _get_ref: &mut GetRef, - get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field, - { - get_mut(*module) - } -} - -macro_rules! tuple_impls { - ([$($name:ident),+] [$($idx:tt),+]) => { - impl<$($name: TensorContainer),+> TensorContainer for ($($name,)+) { - type WithModule<'a, Mod: 'a> = ($($name::WithModule<'a, Mod>,)+); - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - get_ref: &mut GetRef, - get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field, - { - ($($name::get_field(&mut module.$idx, get_ref, get_mut),)+) - } - } - } -} - -tuple_impls!([M1][0]); -tuple_impls!([M1, M2] [0, 1]); -tuple_impls!([M1, M2, M3] [0, 1, 2]); -tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3]); -tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4]); -tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]); - -impl TensorContainer for std::vec::Vec { - type WithModule<'a, Mod: 'a> = std::vec::Vec>; - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - get_ref: &mut GetRef, - get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field, - { - module - .iter_mut() - .map(|x| T::get_field(x, get_ref, get_mut)) - .collect() - } -} - -impl TensorContainer for Option { - type WithModule<'a, Mod: 'a> = Option>; - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - get_ref: &mut GetRef, - get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field, - { - module.as_mut().map(|x| T::get_field(x, get_ref, get_mut)) - } -} diff --git a/src/nn/mod.rs b/src/nn/mod.rs index e8b78da1b..3dbd44250 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -117,7 +117,6 @@ mod embedding; mod flatten; mod generalized_residual; mod impl_module_for_tuples; -mod impl_tensor_container; mod layer_norm; mod linear; mod module; diff --git a/src/nn/npz.rs b/src/nn/npz.rs index f3062a7c5..6e433809d 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -115,7 +115,7 @@ impl, T: TensorCollection> LoadFrom impl> VisitTensors for zip::ZipWriter { - type Container = &'static (); + type Container = TensorRef; type Err = ZipError; fn visit( @@ -131,7 +131,7 @@ impl> VisitTensors impl> VisitTensors for zip::ZipArchive { - type Container = &'static mut (); + type Container = TensorMut; type Err = NpzError; fn visit( diff --git a/src/nn/num_params.rs b/src/nn/num_params.rs index 95610873f..766c3bf72 100644 --- a/src/nn/num_params.rs +++ b/src/nn/num_params.rs @@ -1,4 +1,4 @@ -use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensors}; +use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, TensorRef, VisitTensors}; use crate::{shapes::*, tensor::*}; @@ -6,7 +6,7 @@ use std::{string::String, vec::Vec}; struct Counter(usize); impl VisitTensors for Counter { - type Container = &'static (); + type Container = TensorRef; type Err = D::Err; fn visit( diff --git a/src/nn/reset_params.rs b/src/nn/reset_params.rs index db8277194..a47770452 100644 --- a/src/nn/reset_params.rs +++ b/src/nn/reset_params.rs @@ -1,4 +1,4 @@ -use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, VisitTensors}; +use super::visitors::{RecursiveWalker, TensorCollection, TensorMut, TensorOptions, VisitTensors}; use crate::{shapes::*, tensor::*}; @@ -6,7 +6,7 @@ use std::{string::String, vec::Vec}; struct Resetter; impl VisitTensors for Resetter { - type Container = &'static mut (); + type Container = TensorMut; type Err = D::Err; fn visit( diff --git a/src/nn/visitors.rs b/src/nn/visitors.rs index c4f80619f..8efffe2de 100644 --- a/src/nn/visitors.rs +++ b/src/nn/visitors.rs @@ -172,3 +172,100 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors> TensorVisitor = &'a Mod; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + _get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + get_ref(*module) + } +} + +impl TensorContainer for TensorMut { + type WithModule<'a, Mod: 'a> = &'a mut Mod; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + _get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + get_mut(*module) + } +} + +macro_rules! tuple_impls { + ([$($name:ident),+] [$($idx:tt),+]) => { + impl<$($name: TensorContainer),+> TensorContainer for ($($name,)+) { + type WithModule<'a, Mod: 'a> = ($($name::WithModule<'a, Mod>,)+); + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + ($($name::get_field(&mut module.$idx, get_ref, get_mut),)+) + } + } + } +} + +tuple_impls!([M1][0]); +tuple_impls!([M1, M2] [0, 1]); +tuple_impls!([M1, M2, M3] [0, 1, 2]); +tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3]); +tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4]); +tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]); + +impl TensorContainer for std::vec::Vec { + type WithModule<'a, Mod: 'a> = std::vec::Vec>; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + module + .iter_mut() + .map(|x| T::get_field(x, get_ref, get_mut)) + .collect() + } +} + +impl TensorContainer for Option { + type WithModule<'a, Mod: 'a> = Option>; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field, + { + module.as_mut().map(|x| T::get_field(x, get_ref, get_mut)) + } +} diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index ca88e6dd7..392078669 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -113,7 +113,7 @@ pub(super) trait AdamKernel: DeviceStorage { } impl, E: Dtype> VisitTensors for Adam { - type Container = &'static mut (); + type Container = TensorMut; type Err = D::Err; fn visit( diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index 6128738e5..360c378d7 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -122,7 +122,7 @@ pub(super) trait RMSpropKernel: DeviceStorage { } impl + OneFillStorage> VisitTensors for RMSprop { - type Container = &'static mut (); + type Container = TensorMut; type Err = D::Err; fn visit( diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 9fd369da1..0947bd208 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -148,7 +148,7 @@ pub(super) trait SgdKernel: DeviceStorage { } impl, M> VisitTensors for Sgd { - type Container = &'static mut (); + type Container = TensorMut; type Err = D::Err; fn visit( From 901ac92f52ecf84cdb4b88aa7f771e20b3f21cbc Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 08:50:04 -0500 Subject: [PATCH 14/20] Renamign visitors -> tensor_collection --- src/nn/add_into.rs | 2 +- src/nn/batchnorm2d.rs | 2 +- src/nn/conv.rs | 2 +- src/nn/embedding.rs | 2 +- src/nn/generalized_residual.rs | 2 +- src/nn/impl_module_for_tuples.rs | 2 +- src/nn/layer_norm.rs | 2 +- src/nn/linear.rs | 2 +- src/nn/mod.rs | 2 +- src/nn/module.rs | 2 +- src/nn/npz.rs | 2 +- src/nn/num_params.rs | 4 +++- src/nn/repeated.rs | 2 +- src/nn/reset_params.rs | 4 +++- src/nn/residual.rs | 2 +- src/nn/split_into.rs | 2 +- src/nn/{visitors.rs => tensor_collection.rs} | 0 src/nn/transformer/decoder.rs | 2 +- src/nn/transformer/encoder.rs | 2 +- src/nn/transformer/mha.rs | 2 +- src/nn/transformer/mod.rs | 2 +- src/optim/adam/mod.rs | 2 +- src/optim/rmsprop/mod.rs | 2 +- src/optim/sgd/mod.rs | 2 +- 24 files changed, 27 insertions(+), 23 deletions(-) rename src/nn/{visitors.rs => tensor_collection.rs} (100%) diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index 00a99c401..842a6f55c 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -1,6 +1,6 @@ use crate::{shapes::Dtype, tensor::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Add inputs together into a single tensor. `T` should be a tuple //// where every element of the tuple has the same output type diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index ac855223e..e9b52af74 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -1,6 +1,6 @@ use crate::{gradients::*, shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug, Copy, Clone, Eq, PartialEq)] diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 1291b61f5..92b703e3c 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -3,7 +3,7 @@ use rand_distr::uniform::SampleUniform; use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug)] diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index c48349ce9..e5edd8739 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -3,7 +3,7 @@ use rand_distr::{uniform::SampleUniform, Uniform}; use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; pub mod builder { #[derive(Debug)] diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index 0a1503341..4c398d8ac 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -1,6 +1,6 @@ use crate::{shapes::*, tensor::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// A residual connection `R` around `F`: `F(x) + R(x)`, /// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 7f7447f06..0191e0d14 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -1,6 +1,6 @@ use crate::{shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index b6013586a..6482d772b 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -1,6 +1,6 @@ use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; pub mod builder { #[derive(Debug)] diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 490b84d49..fbd9005b0 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -1,6 +1,6 @@ use crate::{gradients::Tape, shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, NonMutableModule, ToDevice}; use num_traits::Float; use rand_distr::{uniform::SampleUniform, Uniform}; diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 3dbd44250..5ac4b2ac4 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -106,7 +106,7 @@ mod num_params; mod reset_params; -pub mod visitors; +pub mod tensor_collection; mod activations; mod add_into; diff --git a/src/nn/module.rs b/src/nn/module.rs index f8dcc4c3c..1c30bf447 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -3,7 +3,7 @@ use crate::shapes::Dtype; pub use crate::tensor::OnCuda; pub use crate::tensor::{DeviceStorage, OnCpu, OnDevice, ToDevice}; -use super::visitors::{TensorCollection, TensorVisitor}; +use super::tensor_collection::{TensorCollection, TensorVisitor}; /// Immutable forward of `Input` that produces [Module::Output]. /// See [ModuleMut] for mutable forward. diff --git a/src/nn/npz.rs b/src/nn/npz.rs index 6e433809d..312b8c72a 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -6,7 +6,7 @@ use crate::{ }, }; -use super::visitors::*; +use super::tensor_collection::*; use std::{ io::{BufReader, BufWriter, Read, Seek, Write}, diff --git a/src/nn/num_params.rs b/src/nn/num_params.rs index 766c3bf72..7a5ac6a1e 100644 --- a/src/nn/num_params.rs +++ b/src/nn/num_params.rs @@ -1,4 +1,6 @@ -use super::visitors::{RecursiveWalker, TensorCollection, TensorOptions, TensorRef, VisitTensors}; +use super::tensor_collection::{ + RecursiveWalker, TensorCollection, TensorOptions, TensorRef, VisitTensors, +}; use crate::{shapes::*, tensor::*}; diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index a7ff28c54..885f96189 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -1,6 +1,6 @@ use crate::{shapes::Dtype, tensor::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Repeats `T` `N` times. This requires that `T`'s input is the same as it's output. /// diff --git a/src/nn/reset_params.rs b/src/nn/reset_params.rs index a47770452..94de9104f 100644 --- a/src/nn/reset_params.rs +++ b/src/nn/reset_params.rs @@ -1,4 +1,6 @@ -use super::visitors::{RecursiveWalker, TensorCollection, TensorMut, TensorOptions, VisitTensors}; +use super::tensor_collection::{ + RecursiveWalker, TensorCollection, TensorMut, TensorOptions, VisitTensors, +}; use crate::{shapes::*, tensor::*}; diff --git a/src/nn/residual.rs b/src/nn/residual.rs index 7b502ddcd..62f157307 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -1,6 +1,6 @@ use crate::{shapes::*, tensor::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; use std::ops::Add; diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 1ffda2a7e..26b297d39 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -1,6 +1,6 @@ use crate::{shapes::Dtype, tensor::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; /// Splits input into multiple heads. `T` should be a tuple, /// where every element of the tuple accepts the same input type. diff --git a/src/nn/visitors.rs b/src/nn/tensor_collection.rs similarity index 100% rename from src/nn/visitors.rs rename to src/nn/tensor_collection.rs diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index 99bd629ce..43a65e2e1 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -2,7 +2,7 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - nn::{modules::*, visitors::*, *}, + nn::{modules::*, tensor_collection::*, *}, shapes::Dtype, tensor::{PutTape, SplitTape}, tensor_ops::Device, diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 65bb0d5fa..1227e2be7 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -2,7 +2,7 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - nn::{modules::*, visitors::*, *}, + nn::{modules::*, tensor_collection::*, *}, shapes::Dtype, tensor::{PutTape, SplitTape}, tensor_ops::Device, diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index c2a5d03dc..c0451cf8f 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -2,7 +2,7 @@ use num_traits::Float; use rand_distr::uniform::SampleUniform; use crate::{ - nn::{modules::*, visitors::*, *}, + nn::{modules::*, tensor_collection::*, *}, shapes::Dtype, tensor::*, tensor_ops::*, diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 9da5ea198..a5b4eaa27 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -11,7 +11,7 @@ use rand_distr::uniform::SampleUniform; use crate::{shapes::*, tensor::*, tensor_ops::*}; -use super::{visitors::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; +use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; pub mod builder { #[derive(Debug, Clone)] diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 392078669..7ddd01eaf 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, - nn::visitors::*, + nn::tensor_collection::*, shapes::{Dtype, Shape}, tensor::DeviceStorage, }; diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index 360c378d7..46ee3526e 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, - nn::visitors::*, + nn::tensor_collection::*, shapes::{Dtype, Shape}, tensor::*, }; diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 0947bd208..18b92ff85 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, - nn::visitors::*, + nn::tensor_collection::*, shapes::{Dtype, Shape}, tensor::{DeviceStorage, Tensor}, }; From 1deb94480fb516df7b9056333840cac85365ebd7 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 09:01:16 -0500 Subject: [PATCH 15/20] Reorg tensor collection --- src/nn/tensor_collection/collection.rs | 85 ++++++++++++ src/nn/tensor_collection/mod.rs | 5 + .../visitor.rs} | 125 +++--------------- 3 files changed, 112 insertions(+), 103 deletions(-) create mode 100644 src/nn/tensor_collection/collection.rs create mode 100644 src/nn/tensor_collection/mod.rs rename src/nn/{tensor_collection.rs => tensor_collection/visitor.rs} (68%) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs new file mode 100644 index 000000000..3f3e0e7cc --- /dev/null +++ b/src/nn/tensor_collection/collection.rs @@ -0,0 +1,85 @@ +use crate::{ + shapes::{Dtype, Shape}, + tensor::{DeviceStorage, OneFillStorage, Tensor, ZeroFillStorage}, +}; + +pub trait TensorCollection: Sized { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; +} + +pub trait TensorVisitor: Sized { + type Err; + fn visit_module( + &mut self, + get_refs: GetRef, + get_muts: GetMut, + name: &str, + ) -> Result<(), Self::Err> + where + GetRef: FnMut(&T) -> &Field, + GetMut: FnMut(&mut T) -> &mut Field, + Field: TensorCollection; + + fn visit_tensor( + &mut self, + get_refs: GetRef, + get_muts: GetMut, + name: &str, + opts: TensorOptions, + ) -> Result<(), Self::Err> + where + GetRef: FnMut(&T) -> &Tensor, + GetMut: FnMut(&mut T) -> &mut Tensor; +} + +impl TensorCollection for Tensor { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + visitor.visit_tensor( + |s| s, + |s| s, + "", + TensorOptions { + update: true, + reset: |_| Ok(()), + }, + ) + } +} + +pub struct TensorOptions { + pub update: bool, + pub reset: fn(&mut Tensor) -> Result<(), D::Err>, +} + +impl TensorOptions { + pub fn reset_to_zeros() -> Self + where + D: ZeroFillStorage, + { + TensorOptions { + update: true, + reset: |t| t.try_fill_with_zeros(), + } + } + pub fn reset_to_ones() -> Self + where + D: OneFillStorage, + { + TensorOptions { + update: true, + reset: |t| t.try_fill_with_ones(), + } + } + pub fn reset_with(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { + TensorOptions { + update: true, + reset, + } + } + pub fn detached(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { + TensorOptions { + update: false, + reset, + } + } +} diff --git a/src/nn/tensor_collection/mod.rs b/src/nn/tensor_collection/mod.rs new file mode 100644 index 000000000..8a427f8f4 --- /dev/null +++ b/src/nn/tensor_collection/mod.rs @@ -0,0 +1,5 @@ +mod collection; +mod visitor; + +pub use collection::{TensorCollection, TensorOptions, TensorVisitor}; +pub use visitor::{RecursiveWalker, TensorContainer, TensorMut, TensorRef, VisitTensors}; diff --git a/src/nn/tensor_collection.rs b/src/nn/tensor_collection/visitor.rs similarity index 68% rename from src/nn/tensor_collection.rs rename to src/nn/tensor_collection/visitor.rs index 8efffe2de..957e89fa3 100644 --- a/src/nn/tensor_collection.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -1,48 +1,16 @@ -#![allow(clippy::type_complexity)] - use crate::{ shapes::{Dtype, Shape}, - tensor::{DeviceStorage, OneFillStorage, Tensor, ZeroFillStorage}, + tensor::{DeviceStorage, Tensor}, }; +use super::collection::{TensorCollection, TensorOptions, TensorVisitor}; + use std::{string::String, vec::Vec}; -pub struct TensorOptions { - pub update: bool, - pub reset: fn(&mut Tensor) -> Result<(), D::Err>, -} - -impl TensorOptions { - pub fn reset_to_zeros() -> Self - where - D: ZeroFillStorage, - { - TensorOptions { - update: true, - reset: |t| t.try_fill_with_zeros(), - } - } - pub fn reset_to_ones() -> Self - where - D: OneFillStorage, - { - TensorOptions { - update: true, - reset: |t| t.try_fill_with_ones(), - } - } - pub fn reset_with(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { - TensorOptions { - update: true, - reset, - } - } - pub fn detached(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { - TensorOptions { - update: false, - reset, - } - } +pub struct RecursiveWalker<'a, M, F> { + pub m: M, + pub f: &'a mut F, + pub path: &'a mut Vec, } pub trait VisitTensors { @@ -59,70 +27,6 @@ pub trait VisitTensors { type ContainerWithModule<'a, C, M> = ::WithModule<'a, M>; -pub trait TensorContainer: 'static { - type WithModule<'a, Mod: 'a> - where - Self: 'a; - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - get_ref: &mut GetRef, - get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field; -} - -pub trait TensorCollection: Sized { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; -} - -impl TensorCollection for Tensor { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_tensor( - |s| s, - |s| s, - "", - TensorOptions { - update: true, - reset: |_| Ok(()), - }, - ) - } -} - -pub trait TensorVisitor: Sized { - type Err; - fn visit_module( - &mut self, - get_refs: GetRef, - get_muts: GetMut, - name: &str, - ) -> Result<(), Self::Err> - where - GetRef: FnMut(&T) -> &Field, - GetMut: FnMut(&mut T) -> &mut Field, - Field: TensorCollection; - - fn visit_tensor( - &mut self, - get_refs: GetRef, - get_muts: GetMut, - name: &str, - opts: TensorOptions, - ) -> Result<(), Self::Err> - where - GetRef: FnMut(&T) -> &Tensor, - GetMut: FnMut(&mut T) -> &mut Tensor; -} - -pub(crate) struct RecursiveWalker<'a, M, F> { - pub(crate) m: M, - pub(crate) f: &'a mut F, - pub(crate) path: &'a mut Vec, -} - impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors> TensorVisitor for RecursiveWalker<'a, ContainerWithModule<'a, F::Container, M>, F> { @@ -173,6 +77,21 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors> TensorVisitor + where + Self: 'a; + + fn get_field<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::WithModule<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::WithModule<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field; +} + pub enum TensorRef {} pub enum TensorMut {} From 18a56d432d58892e0f8a911e791dbfff06a20295 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 09:36:39 -0500 Subject: [PATCH 16/20] Renaming --- src/nn/add_into.rs | 2 +- src/nn/batchnorm2d.rs | 2 +- src/nn/conv.rs | 2 +- src/nn/embedding.rs | 2 +- src/nn/generalized_residual.rs | 2 +- src/nn/impl_module_for_tuples.rs | 2 +- src/nn/layer_norm.rs | 2 +- src/nn/linear.rs | 2 +- src/nn/module.rs | 4 +- src/nn/npz.rs | 8 +- src/nn/num_params.rs | 8 +- src/nn/repeated.rs | 2 +- src/nn/reset_params.rs | 6 +- src/nn/residual.rs | 2 +- src/nn/split_into.rs | 2 +- src/nn/tensor_collection/collection.rs | 26 ++++-- src/nn/tensor_collection/mod.rs | 7 +- src/nn/tensor_collection/visitor.rs | 122 +++++++++++++------------ src/nn/transformer/decoder.rs | 4 +- src/nn/transformer/encoder.rs | 2 +- src/nn/transformer/mha.rs | 2 +- src/nn/transformer/mod.rs | 2 +- src/optim/adam/mod.rs | 6 +- src/optim/rmsprop/mod.rs | 6 +- src/optim/sgd/mod.rs | 6 +- 25 files changed, 125 insertions(+), 106 deletions(-) diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index 842a6f55c..63a01a1fc 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -34,7 +34,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Add } impl> TensorCollection for AddInto { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index e9b52af74..dc2dda7d3 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -184,7 +184,7 @@ impl> BuildModule for BatchNorm2D> TensorCollection for BatchNorm2D { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.scale, |s| &mut s.scale, diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 92b703e3c..295e154b9 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -59,7 +59,7 @@ where E: Dtype + Float + SampleUniform, D: Device, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index e5edd8739..cf0045ef5 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -70,7 +70,7 @@ impl> TensorCollection for Embedding { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index 4c398d8ac..df6ef62b7 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -45,7 +45,7 @@ impl, R: BuildModule> Bui impl, R: TensorCollection> TensorCollection for GeneralizedResidual { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.f, |s| &mut s.f, "f")?; visitor.visit_module(|s| &s.r, |s| &mut s.r, "r") } diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 0191e0d14..44e47605e 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -5,7 +5,7 @@ use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { impl),+> TensorCollection for ($($name,)+) { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { $(visitor.visit_module(|s| &s.$idx, |s| &mut s.$idx, &std::format!("{}", $idx))?;)+ Ok(()) } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index 6482d772b..a615970b9 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -56,7 +56,7 @@ impl> BuildModule for LayerNorm1D> TensorCollection for LayerNorm1D { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.gamma, |s| &mut s.gamma, diff --git a/src/nn/linear.rs b/src/nn/linear.rs index fbd9005b0..69798f7af 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -71,7 +71,7 @@ impl> TensorCollection for Linear { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| &s.weight, |s| &mut s.weight, diff --git a/src/nn/module.rs b/src/nn/module.rs index 1c30bf447..8b1f0f8e5 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -3,7 +3,7 @@ use crate::shapes::Dtype; pub use crate::tensor::OnCuda; pub use crate::tensor::{DeviceStorage, OnCpu, OnDevice, ToDevice}; -use super::tensor_collection::{TensorCollection, TensorVisitor}; +use super::tensor_collection::{ModuleVisitor, TensorCollection}; /// Immutable forward of `Input` that produces [Module::Output]. /// See [ModuleMut] for mutable forward. @@ -80,7 +80,7 @@ impl BuildModule for T { } impl TensorCollection for T { - fn iter_tensors>(_: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(_: &mut V) -> Result<(), V::Err> { Ok(()) } } diff --git a/src/nn/npz.rs b/src/nn/npz.rs index 312b8c72a..8114da3f8 100644 --- a/src/nn/npz.rs +++ b/src/nn/npz.rs @@ -112,10 +112,10 @@ pub trait LoadFromNpz>: TensorCollection< } impl, T: TensorCollection> LoadFromNpz for T {} -impl> VisitTensors +impl> TensorVisitor for zip::ZipWriter { - type Container = TensorRef; + type Viewer = ViewTensorRef; type Err = ZipError; fn visit( @@ -128,10 +128,10 @@ impl> VisitTensors } } -impl> VisitTensors +impl> TensorVisitor for zip::ZipArchive { - type Container = TensorMut; + type Viewer = ViewTensorMut; type Err = NpzError; fn visit( diff --git a/src/nn/num_params.rs b/src/nn/num_params.rs index 7a5ac6a1e..7dc2df403 100644 --- a/src/nn/num_params.rs +++ b/src/nn/num_params.rs @@ -1,5 +1,5 @@ use super::tensor_collection::{ - RecursiveWalker, TensorCollection, TensorOptions, TensorRef, VisitTensors, + RecursiveWalker, TensorCollection, TensorOptions, TensorVisitor, ViewTensorRef, }; use crate::{shapes::*, tensor::*}; @@ -7,8 +7,8 @@ use crate::{shapes::*, tensor::*}; use std::{string::String, vec::Vec}; struct Counter(usize); -impl VisitTensors for Counter { - type Container = TensorRef; +impl TensorVisitor for Counter { + type Viewer = ViewTensorRef; type Err = D::Err; fn visit( @@ -17,7 +17,7 @@ impl VisitTensors for Counter { opts: TensorOptions, t: &Tensor, ) -> Result<(), D::Err> { - if opts.update { + if opts.do_gradient_update { self.0 += t.shape().num_elements(); } Ok(()) diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index 885f96189..88f1123fb 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -42,7 +42,7 @@ impl, const N: usize> BuildModu impl, const N: usize> TensorCollection for Repeated { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { for i in 0..N { visitor.visit_module( |s| &s.modules[i], diff --git a/src/nn/reset_params.rs b/src/nn/reset_params.rs index 94de9104f..d4ae3a628 100644 --- a/src/nn/reset_params.rs +++ b/src/nn/reset_params.rs @@ -1,5 +1,5 @@ use super::tensor_collection::{ - RecursiveWalker, TensorCollection, TensorMut, TensorOptions, VisitTensors, + RecursiveWalker, TensorCollection, TensorOptions, TensorVisitor, ViewTensorMut, }; use crate::{shapes::*, tensor::*}; @@ -7,8 +7,8 @@ use crate::{shapes::*, tensor::*}; use std::{string::String, vec::Vec}; struct Resetter; -impl VisitTensors for Resetter { - type Container = TensorMut; +impl TensorVisitor for Resetter { + type Viewer = ViewTensorMut; type Err = D::Err; fn visit( diff --git a/src/nn/residual.rs b/src/nn/residual.rs index 62f157307..df48fdb49 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -34,7 +34,7 @@ impl> BuildModule for Res } impl> TensorCollection for Residual { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 26b297d39..070ee18a8 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -35,7 +35,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Spl impl> TensorCollection for SplitInto { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index 3f3e0e7cc..08661ade3 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -3,12 +3,17 @@ use crate::{ tensor::{DeviceStorage, OneFillStorage, Tensor, ZeroFillStorage}, }; +/// A collection of named tensors. Implementing this trait will enable anything +/// that operates pub trait TensorCollection: Sized { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; } -pub trait TensorVisitor: Sized { +/// An object that can visit [TensorCollection]s and [Tensor]s recursively. +pub trait ModuleVisitor: Sized { type Err; + + /// Visit a [TensorCollection] fn visit_module( &mut self, get_refs: GetRef, @@ -20,6 +25,7 @@ pub trait TensorVisitor: Sized { GetMut: FnMut(&mut T) -> &mut Field, Field: TensorCollection; + /// Visits an actual named [Tensor] fn visit_tensor( &mut self, get_refs: GetRef, @@ -33,21 +39,23 @@ pub trait TensorVisitor: Sized { } impl TensorCollection for Tensor { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( |s| s, |s| s, "", TensorOptions { - update: true, + do_gradient_update: true, reset: |_| Ok(()), }, ) } } +/// Options to change behavior of [TensorVisitor] pub struct TensorOptions { - pub update: bool, + /// Whether the tensor should be updated with gradients + pub do_gradient_update: bool, pub reset: fn(&mut Tensor) -> Result<(), D::Err>, } @@ -57,7 +65,7 @@ impl TensorOptions { D: ZeroFillStorage, { TensorOptions { - update: true, + do_gradient_update: true, reset: |t| t.try_fill_with_zeros(), } } @@ -66,19 +74,19 @@ impl TensorOptions { D: OneFillStorage, { TensorOptions { - update: true, + do_gradient_update: true, reset: |t| t.try_fill_with_ones(), } } pub fn reset_with(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { TensorOptions { - update: true, + do_gradient_update: true, reset, } } pub fn detached(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { TensorOptions { - update: false, + do_gradient_update: false, reset, } } diff --git a/src/nn/tensor_collection/mod.rs b/src/nn/tensor_collection/mod.rs index 8a427f8f4..0db081fda 100644 --- a/src/nn/tensor_collection/mod.rs +++ b/src/nn/tensor_collection/mod.rs @@ -1,5 +1,8 @@ +//! Traits to define a [TensorCollection] and how to iterate them using [ModuleVisitor]. +//! Use [RecursiveWalker] to do the iteration. + mod collection; mod visitor; -pub use collection::{TensorCollection, TensorOptions, TensorVisitor}; -pub use visitor::{RecursiveWalker, TensorContainer, TensorMut, TensorRef, VisitTensors}; +pub use collection::{ModuleVisitor, TensorCollection, TensorOptions}; +pub use visitor::{RecursiveWalker, TensorViewer, TensorVisitor, ViewTensorMut, ViewTensorRef}; diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 957e89fa3..3d6da38e6 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -3,32 +3,58 @@ use crate::{ tensor::{DeviceStorage, Tensor}, }; -use super::collection::{TensorCollection, TensorOptions, TensorVisitor}; +use super::collection::{ModuleVisitor, TensorCollection, TensorOptions}; use std::{string::String, vec::Vec}; +/// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered. +/// `F` must implement [TensorVisitor] pub struct RecursiveWalker<'a, M, F> { pub m: M, pub f: &'a mut F, pub path: &'a mut Vec, } -pub trait VisitTensors { - type Container: TensorContainer; +/// Something that can visit [Tensor]s. Used in conjunction with [RecursiveWalker]. +pub trait TensorVisitor { + /// The type of tensor this struct uses. E.g. [TensorMut], or [TensorRef] + type Viewer: TensorViewer; type Err; fn visit( &mut self, full_path: String, opts: TensorOptions, - t: ::WithModule<'_, Tensor>, + t: ::View<'_, Tensor>, ) -> Result<(), Self::Err>; } -type ContainerWithModule<'a, C, M> = ::WithModule<'a, M>; +/// Something that can view [Tensor]s in different ways. For example +/// [ViewTensorRef] can view `&Tensor`, and [ViewTensorMut] can view `&mut Tensor. +pub trait TensorViewer: 'static { + type View<'a, Mod: 'a> + where + Self: 'a; + + /// Return the view of the tensor + fn view<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::View<'_, Mod>, + get_ref: &mut GetRef, + get_mut: &mut GetMut, + ) -> Self::View<'a, Field> + where + GetRef: FnMut(&Mod) -> &Field, + GetMut: FnMut(&mut Mod) -> &mut Field; +} + +/// A [TensorViewer] that represents a `&Tensor` +pub enum ViewTensorRef {} + +/// A [TensorViewer] that represents a `&mut Tensor` +pub enum ViewTensorMut {} -impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors> TensorVisitor - for RecursiveWalker<'a, ContainerWithModule<'a, F::Container, M>, F> +impl<'a, M, E: Dtype, D: DeviceStorage, F: TensorVisitor> ModuleVisitor + for RecursiveWalker<'a, ::View<'a, M>, F> { type Err = F::Err; @@ -45,7 +71,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: VisitTensors> TensorVisitor> TensorVisitor - where - Self: 'a; - - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, - get_ref: &mut GetRef, - get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> - where - GetRef: FnMut(&Mod) -> &Field, - GetMut: FnMut(&mut Mod) -> &mut Field; -} - -pub enum TensorRef {} -pub enum TensorMut {} - -impl TensorContainer for TensorRef { - type WithModule<'a, Mod: 'a> = &'a Mod; +impl TensorViewer for ViewTensorRef { + type View<'a, Mod: 'a> = &'a Mod; - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, + fn view<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, _get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> + ) -> Self::View<'a, Field> where GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - get_ref(*module) + get_ref(module) } } -impl TensorContainer for TensorMut { - type WithModule<'a, Mod: 'a> = &'a mut Mod; +impl TensorViewer for ViewTensorMut { + type View<'a, Mod: 'a> = &'a mut Mod; - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, + fn view<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::View<'_, Mod>, _get_ref: &mut GetRef, get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> + ) -> Self::View<'a, Field> where GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - get_mut(*module) + get_mut(module) } } macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+]) => { - impl<$($name: TensorContainer),+> TensorContainer for ($($name,)+) { - type WithModule<'a, Mod: 'a> = ($($name::WithModule<'a, Mod>,)+); + impl<$($name: TensorViewer),+> TensorViewer for ($($name,)+) { + type View<'a, Mod: 'a> = ($($name::View<'a, Mod>,)+); - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, + fn view<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> + ) -> Self::View<'a, Field> where GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - ($($name::get_field(&mut module.$idx, get_ref, get_mut),)+) + ($($name::view(&mut module.$idx, get_ref, get_mut),)+) } } } @@ -154,37 +162,37 @@ tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3]); tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4]); tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]); -impl TensorContainer for std::vec::Vec { - type WithModule<'a, Mod: 'a> = std::vec::Vec>; +impl TensorViewer for std::vec::Vec { + type View<'a, Mod: 'a> = std::vec::Vec>; - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, + fn view<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> + ) -> Self::View<'a, Field> where GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { module .iter_mut() - .map(|x| T::get_field(x, get_ref, get_mut)) + .map(|x| T::view(x, get_ref, get_mut)) .collect() } } -impl TensorContainer for Option { - type WithModule<'a, Mod: 'a> = Option>; +impl TensorViewer for Option { + type View<'a, Mod: 'a> = Option>; - fn get_field<'a, Mod, Field, GetRef, GetMut>( - module: &'a mut Self::WithModule<'_, Mod>, + fn view<'a, Mod, Field, GetRef, GetMut>( + module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, - ) -> Self::WithModule<'a, Field> + ) -> Self::View<'a, Field> where GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - module.as_mut().map(|x| T::get_field(x, get_ref, get_mut)) + module.as_mut().map(|x| T::view(x, get_ref, get_mut)) } } diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index 43a65e2e1..e3affa725 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -83,7 +83,7 @@ impl>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") } } @@ -186,7 +186,7 @@ impl> TensorColl where E: Dtype + Float + SampleUniform, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; visitor.visit_module(|s| &s.mh_attn, |s| &mut s.mh_attn, "mh_attn")?; diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 1227e2be7..b6ad3ee1a 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -118,7 +118,7 @@ impl> TensorColl where E: Dtype + Float + SampleUniform, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index c0451cf8f..d554845f8 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -82,7 +82,7 @@ impl>(visitor: &mut W) -> Result<(), W::Err> { + fn iter_tensors>(visitor: &mut W) -> Result<(), W::Err> { visitor.visit_module(|s| &s.w_q, |s| &mut s.w_q, "w_q")?; visitor.visit_module(|s| &s.w_k, |s| &mut s.w_k, "w_k")?; visitor.visit_module(|s| &s.w_v, |s| &mut s.w_v, "w_v")?; diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index a5b4eaa27..462d80f62 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -98,7 +98,7 @@ where E: Dtype + Float + SampleUniform, D: Device, { - fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { + fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_module(|s| &s.encoder, |s| &mut s.encoder, "encoder")?; visitor.visit_module(|s| &s.decoder, |s| &mut s.decoder, "decoder") } diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 7ddd01eaf..04e69a107 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -112,8 +112,8 @@ pub(super) trait AdamKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, E: Dtype> VisitTensors for Adam { - type Container = TensorMut; +impl, E: Dtype> TensorVisitor for Adam { + type Viewer = ViewTensorMut; type Err = D::Err; fn visit( @@ -122,7 +122,7 @@ impl, E: Dtype> VisitTensors for Adam { opts: TensorOptions, p: &mut crate::prelude::Tensor, ) -> Result<(), ::Err> { - if !opts.update { + if !opts.do_gradient_update { return Ok(()); } let g = self.gradients.remove(p); diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index 46ee3526e..8c56b9009 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -121,8 +121,8 @@ pub(super) trait RMSpropKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl + OneFillStorage> VisitTensors for RMSprop { - type Container = TensorMut; +impl + OneFillStorage> TensorVisitor for RMSprop { + type Viewer = ViewTensorMut; type Err = D::Err; fn visit( @@ -131,7 +131,7 @@ impl + OneFillStorage> VisitTensors fo opts: TensorOptions, p: &mut Tensor, ) -> Result<(), ::Err> { - if !opts.update { + if !opts.do_gradient_update { return Ok(()); } let g = self.gradients.remove(p); diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index 18b92ff85..eb8a1acbd 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -147,8 +147,8 @@ pub(super) trait SgdKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, M> VisitTensors for Sgd { - type Container = TensorMut; +impl, M> TensorVisitor for Sgd { + type Viewer = ViewTensorMut; type Err = D::Err; fn visit( @@ -157,7 +157,7 @@ impl, M> VisitTensors for Sgd { opts: TensorOptions, p: &mut Tensor, ) -> Result<(), D::Err> { - if !opts.update { + if !opts.do_gradient_update { return Ok(()); } let g = self.gradients.remove(p); From 06920985d8636e258fde4e3a265fc7d066a0377d Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 09:45:13 -0500 Subject: [PATCH 17/20] Moving name to first arg --- src/nn/add_into.rs | 2 +- src/nn/batchnorm2d.rs | 8 ++++---- src/nn/conv.rs | 4 ++-- src/nn/embedding.rs | 2 +- src/nn/generalized_residual.rs | 4 ++-- src/nn/impl_module_for_tuples.rs | 2 +- src/nn/layer_norm.rs | 4 ++-- src/nn/linear.rs | 4 ++-- src/nn/repeated.rs | 2 +- src/nn/residual.rs | 2 +- src/nn/split_into.rs | 2 +- src/nn/tensor_collection/collection.rs | 9 +++++---- src/nn/tensor_collection/visitor.rs | 4 ++-- src/nn/transformer/decoder.rs | 14 +++++++------- src/nn/transformer/encoder.rs | 8 ++++---- src/nn/transformer/mha.rs | 8 ++++---- src/nn/transformer/mod.rs | 4 ++-- 17 files changed, 42 insertions(+), 41 deletions(-) diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index 63a01a1fc..00d2034d8 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -35,7 +35,7 @@ impl, D: DeviceStorage, E: Dtype> BuildModule for Add impl> TensorCollection for AddInto { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") + visitor.visit_module("0", |s| &s.0, |s| &mut s.0) } } diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index dc2dda7d3..3456513cf 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -186,27 +186,27 @@ impl> BuildModule for BatchNorm2D> TensorCollection for BatchNorm2D { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( + "scale", |s| &s.scale, |s| &mut s.scale, - "scale", TensorOptions::reset_to_ones(), )?; visitor.visit_tensor( + "bias", |s| &s.bias, |s| &mut s.bias, - "bias", TensorOptions::reset_to_zeros(), )?; visitor.visit_tensor( + "running_mean", |s| &s.running_mean, |s| &mut s.running_mean, - "running_mean", TensorOptions::detached(|t| t.try_fill_with_zeros()), )?; visitor.visit_tensor( + "running_var", |s| &s.running_var, |s| &mut s.running_var, - "running_var", TensorOptions::detached(|t| t.try_fill_with_ones()), ) } diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 295e154b9..06924daf8 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -61,18 +61,18 @@ where { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( + "weight", |s| &s.weight, |s| &mut s.weight, - "weight", TensorOptions::reset_with(|t| { let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) }), )?; visitor.visit_tensor( + "bias", |s| &s.bias, |s| &mut s.bias, - "bias", TensorOptions::reset_with(|t| { let b = E::ONE / E::from_usize(I * K * K).unwrap().sqrt(); t.try_fill_with_distr(rand_distr::Uniform::new(-b, b)) diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index cf0045ef5..fd372b5fb 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -72,9 +72,9 @@ impl>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( + "weight", |s| &s.weight, |s| &mut s.weight, - "weight", TensorOptions::reset_with(|t| { let b: E = E::ONE / E::from_usize(C).unwrap().sqrt(); t.try_fill_with_distr(Uniform::new(-b, b)) diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index df6ef62b7..1177b5100 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -46,8 +46,8 @@ impl, R: TensorCollection< TensorCollection for GeneralizedResidual { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.f, |s| &mut s.f, "f")?; - visitor.visit_module(|s| &s.r, |s| &mut s.r, "r") + visitor.visit_module("f", |s| &s.f, |s| &mut s.f)?; + visitor.visit_module("r", |s| &s.r, |s| &mut s.r) } } diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 44e47605e..4f03842fe 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -6,7 +6,7 @@ macro_rules! tuple_impls { ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { impl),+> TensorCollection for ($($name,)+) { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - $(visitor.visit_module(|s| &s.$idx, |s| &mut s.$idx, &std::format!("{}", $idx))?;)+ + $(visitor.visit_module(&std::format!("{}", $idx), |s| &s.$idx, |s| &mut s.$idx)?;)+ Ok(()) } } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index a615970b9..f049745d7 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -58,15 +58,15 @@ impl> BuildModule for LayerNorm1D> TensorCollection for LayerNorm1D { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( + "gamma", |s| &s.gamma, |s| &mut s.gamma, - "gamma", TensorOptions::reset_to_ones(), )?; visitor.visit_tensor( + "beta", |s| &s.beta, |s| &mut s.beta, - "beta", TensorOptions::reset_to_zeros(), ) } diff --git a/src/nn/linear.rs b/src/nn/linear.rs index 69798f7af..df5c58aaf 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -73,18 +73,18 @@ impl>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( + "weight", |s| &s.weight, |s| &mut s.weight, - "weight", TensorOptions::reset_with(|t| { let b: E = E::ONE / E::from_usize(I).unwrap().sqrt(); t.try_fill_with_distr(Uniform::new(-b, b)) }), )?; visitor.visit_tensor( + "bias", |s| &s.bias, |s| &mut s.bias, - "bias", TensorOptions::reset_with(|t| { let b: E = E::ONE / E::from_usize(I).unwrap().sqrt(); t.try_fill_with_distr(Uniform::new(-b, b)) diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index 88f1123fb..d02c1ee15 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -45,9 +45,9 @@ impl, const N: usize> Tens fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { for i in 0..N { visitor.visit_module( + &std::format!("{i}"), |s| &s.modules[i], |s| &mut s.modules[i], - &std::format!("{i}"), )?; } Ok(()) diff --git a/src/nn/residual.rs b/src/nn/residual.rs index df48fdb49..a6de6810a 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -35,7 +35,7 @@ impl> BuildModule for Res impl> TensorCollection for Residual { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") + visitor.visit_module("0", |s| &s.0, |s| &mut s.0) } } diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 070ee18a8..5a1847a6c 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -36,7 +36,7 @@ impl> TensorCollection { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") + visitor.visit_module("0", |s| &s.0, |s| &mut s.0) } } diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index 08661ade3..81b57e06e 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -4,7 +4,8 @@ use crate::{ }; /// A collection of named tensors. Implementing this trait will enable anything -/// that operates +/// that operates on tensors, like resetting, EMA, counting number of params, +/// gradient updates, etc. pub trait TensorCollection: Sized { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err>; } @@ -16,9 +17,9 @@ pub trait ModuleVisitor: Sized { /// Visit a [TensorCollection] fn visit_module( &mut self, + name: &str, get_refs: GetRef, get_muts: GetMut, - name: &str, ) -> Result<(), Self::Err> where GetRef: FnMut(&T) -> &Field, @@ -28,9 +29,9 @@ pub trait ModuleVisitor: Sized { /// Visits an actual named [Tensor] fn visit_tensor( &mut self, + name: &str, get_refs: GetRef, get_muts: GetMut, - name: &str, opts: TensorOptions, ) -> Result<(), Self::Err> where @@ -41,9 +42,9 @@ pub trait ModuleVisitor: Sized { impl TensorCollection for Tensor { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { visitor.visit_tensor( + "", |s| s, |s| s, - "", TensorOptions { do_gradient_update: true, reset: |_| Ok(()), diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 3d6da38e6..0f48a2cdb 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -60,9 +60,9 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: TensorVisitor> ModuleVisitor( &mut self, + name: &str, mut get_refs: GetRef, mut get_muts: GetMut, - name: &str, ) -> Result<(), Self::Err> where GetRef: FnMut(&M) -> &Field, @@ -83,9 +83,9 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: TensorVisitor> ModuleVisitor( &mut self, + name: &str, mut get_refs: GetRef, mut get_muts: GetMut, - name: &str, opts: TensorOptions, ) -> Result<(), Self::Err> where diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index e3affa725..c5faa15c6 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -84,7 +84,7 @@ where E: Dtype + Float + SampleUniform, { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.0, |s| &mut s.0, "0") + visitor.visit_module("0", |s| &s.0, |s| &mut s.0) } } @@ -187,12 +187,12 @@ where E: Dtype + Float + SampleUniform, { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; - visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; - visitor.visit_module(|s| &s.mh_attn, |s| &mut s.mh_attn, "mh_attn")?; - visitor.visit_module(|s| &s.norm2, |s| &mut s.norm2, "norm2")?; - visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; - visitor.visit_module(|s| &s.norm3, |s| &mut s.norm3, "norm3") + visitor.visit_module("self_attn", |s| &s.self_attn, |s| &mut s.self_attn)?; + visitor.visit_module("norm1", |s| &s.norm1, |s| &mut s.norm1)?; + visitor.visit_module("mh_attn", |s| &s.mh_attn, |s| &mut s.mh_attn)?; + visitor.visit_module("norm2", |s| &s.norm2, |s| &mut s.norm2)?; + visitor.visit_module("ff", |s| &s.ff, |s| &mut s.ff)?; + visitor.visit_module("norm", |s| &s.norm3, |s| &mut s.norm3) } } diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index b6ad3ee1a..6d5cc77c4 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -119,10 +119,10 @@ where E: Dtype + Float + SampleUniform, { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.self_attn, |s| &mut s.self_attn, "self_attn")?; - visitor.visit_module(|s| &s.norm1, |s| &mut s.norm1, "norm1")?; - visitor.visit_module(|s| &s.ff, |s| &mut s.ff, "ff")?; - visitor.visit_module(|s| &s.norm2, |s| &mut s.norm2, "norm2") + visitor.visit_module("self_attn", |s| &s.self_attn, |s| &mut s.self_attn)?; + visitor.visit_module("norm1", |s| &s.norm1, |s| &mut s.norm1)?; + visitor.visit_module("ff", |s| &s.ff, |s| &mut s.ff)?; + visitor.visit_module("norm2", |s| &s.norm2, |s| &mut s.norm2) } } diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index d554845f8..59ed289fb 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -83,10 +83,10 @@ where E: Dtype + Float + SampleUniform, { fn iter_tensors>(visitor: &mut W) -> Result<(), W::Err> { - visitor.visit_module(|s| &s.w_q, |s| &mut s.w_q, "w_q")?; - visitor.visit_module(|s| &s.w_k, |s| &mut s.w_k, "w_k")?; - visitor.visit_module(|s| &s.w_v, |s| &mut s.w_v, "w_v")?; - visitor.visit_module(|s| &s.w_o, |s| &mut s.w_o, "w_o") + visitor.visit_module("w_q", |s| &s.w_q, |s| &mut s.w_q)?; + visitor.visit_module("w_k", |s| &s.w_k, |s| &mut s.w_k)?; + visitor.visit_module("w_v", |s| &s.w_v, |s| &mut s.w_v)?; + visitor.visit_module("w_o", |s| &s.w_o, |s| &mut s.w_o) } } diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index 462d80f62..b9865dbe0 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -99,8 +99,8 @@ where D: Device, { fn iter_tensors>(visitor: &mut V) -> Result<(), V::Err> { - visitor.visit_module(|s| &s.encoder, |s| &mut s.encoder, "encoder")?; - visitor.visit_module(|s| &s.decoder, |s| &mut s.decoder, "decoder") + visitor.visit_module("encoder", |s| &s.encoder, |s| &mut s.encoder)?; + visitor.visit_module("decoder", |s| &s.decoder, |s| &mut s.decoder) } } From ae5d2742045c57bba5f5dac2f34753ba61772b10 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 09:53:01 -0500 Subject: [PATCH 18/20] Fixing clippy warnings --- src/nn/tensor_collection/collection.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index 81b57e06e..442a71e24 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -1,3 +1,5 @@ +#![allow(clippy::type_complexity)] + use crate::{ shapes::{Dtype, Shape}, tensor::{DeviceStorage, OneFillStorage, Tensor, ZeroFillStorage}, From 8e712c6ed9bf31161ef436dfa12e95afde7e1df3 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 11:52:23 -0500 Subject: [PATCH 19/20] Update src/nn/tensor_collection/visitor.rs Co-authored-by: nkoppel --- src/nn/tensor_collection/visitor.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 0f48a2cdb..50ec631c2 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -36,8 +36,8 @@ pub trait TensorViewer: 'static { where Self: 'a; - /// Return the view of the tensor - fn view<'a, Mod, Field, GetRef, GetMut>( + /// Given a view of a module, returns a view of one of that module's fields + fn view_field<'a, Mod, Field, GetRef, GetMut>( module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, From 9e4e738a2d9e60b043168c47946b6e8653718cc1 Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Wed, 22 Feb 2023 12:02:31 -0500 Subject: [PATCH 20/20] Adding #non_exhausting to TensorOptions --- src/nn/tensor_collection/collection.rs | 12 +++++++++++- src/nn/tensor_collection/visitor.rs | 23 +++++++++++++---------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/nn/tensor_collection/collection.rs b/src/nn/tensor_collection/collection.rs index 442a71e24..6cbf584e4 100644 --- a/src/nn/tensor_collection/collection.rs +++ b/src/nn/tensor_collection/collection.rs @@ -56,13 +56,17 @@ impl TensorCollection for Tensor { /// Whether the tensor should be updated with gradients pub do_gradient_update: bool, - pub reset: fn(&mut Tensor) -> Result<(), D::Err>, + + /// How to reset the tensor in the future. + pub reset: fn(&'_ mut Tensor) -> Result<(), D::Err>, } impl TensorOptions { + /// A tensor that should be updated with gradients & reset to 0 pub fn reset_to_zeros() -> Self where D: ZeroFillStorage, @@ -72,6 +76,8 @@ impl TensorOptions { reset: |t| t.try_fill_with_zeros(), } } + + /// A tensor that should be updated with gradients & reset to 1 pub fn reset_to_ones() -> Self where D: OneFillStorage, @@ -81,12 +87,16 @@ impl TensorOptions { reset: |t| t.try_fill_with_ones(), } } + + /// A tensor that should be updated with gradients & reset with the fn passed in pub fn reset_with(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { TensorOptions { do_gradient_update: true, reset, } } + + /// A tensor that should **NOT** be updated with gradients & reset with the fn passed in pub fn detached(reset: fn(&mut Tensor) -> Result<(), D::Err>) -> Self { TensorOptions { do_gradient_update: false, diff --git a/src/nn/tensor_collection/visitor.rs b/src/nn/tensor_collection/visitor.rs index 50ec631c2..ad5670feb 100644 --- a/src/nn/tensor_collection/visitor.rs +++ b/src/nn/tensor_collection/visitor.rs @@ -9,6 +9,7 @@ use std::{string::String, vec::Vec}; /// A standard [ModuleVisitor] that executes `F` on every [Tensor] encountered. /// `F` must implement [TensorVisitor] +#[derive(Debug)] pub struct RecursiveWalker<'a, M, F> { pub m: M, pub f: &'a mut F, @@ -48,9 +49,11 @@ pub trait TensorViewer: 'static { } /// A [TensorViewer] that represents a `&Tensor` +#[derive(Debug)] pub enum ViewTensorRef {} /// A [TensorViewer] that represents a `&mut Tensor` +#[derive(Debug)] pub enum ViewTensorMut {} impl<'a, M, E: Dtype, D: DeviceStorage, F: TensorVisitor> ModuleVisitor @@ -71,7 +74,7 @@ impl<'a, M, E: Dtype, D: DeviceStorage, F: TensorVisitor> ModuleVisitor> ModuleVisitor> ModuleVisitor = &'a Mod; - fn view<'a, Mod, Field, GetRef, GetMut>( + fn view_field<'a, Mod, Field, GetRef, GetMut>( module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, _get_mut: &mut GetMut, @@ -122,7 +125,7 @@ impl TensorViewer for ViewTensorRef { impl TensorViewer for ViewTensorMut { type View<'a, Mod: 'a> = &'a mut Mod; - fn view<'a, Mod, Field, GetRef, GetMut>( + fn view_field<'a, Mod, Field, GetRef, GetMut>( module: &'a mut Self::View<'_, Mod>, _get_ref: &mut GetRef, get_mut: &mut GetMut, @@ -140,7 +143,7 @@ macro_rules! tuple_impls { impl<$($name: TensorViewer),+> TensorViewer for ($($name,)+) { type View<'a, Mod: 'a> = ($($name::View<'a, Mod>,)+); - fn view<'a, Mod, Field, GetRef, GetMut>( + fn view_field<'a, Mod, Field, GetRef, GetMut>( module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, @@ -149,7 +152,7 @@ macro_rules! tuple_impls { GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - ($($name::view(&mut module.$idx, get_ref, get_mut),)+) + ($($name::view_field(&mut module.$idx, get_ref, get_mut),)+) } } } @@ -165,7 +168,7 @@ tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5]); impl TensorViewer for std::vec::Vec { type View<'a, Mod: 'a> = std::vec::Vec>; - fn view<'a, Mod, Field, GetRef, GetMut>( + fn view_field<'a, Mod, Field, GetRef, GetMut>( module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, @@ -176,7 +179,7 @@ impl TensorViewer for std::vec::Vec { { module .iter_mut() - .map(|x| T::view(x, get_ref, get_mut)) + .map(|x| T::view_field(x, get_ref, get_mut)) .collect() } } @@ -184,7 +187,7 @@ impl TensorViewer for std::vec::Vec { impl TensorViewer for Option { type View<'a, Mod: 'a> = Option>; - fn view<'a, Mod, Field, GetRef, GetMut>( + fn view_field<'a, Mod, Field, GetRef, GetMut>( module: &'a mut Self::View<'_, Mod>, get_ref: &mut GetRef, get_mut: &mut GetMut, @@ -193,6 +196,6 @@ impl TensorViewer for Option { GetRef: FnMut(&Mod) -> &Field, GetMut: FnMut(&mut Mod) -> &mut Field, { - module.as_mut().map(|x| T::view(x, get_ref, get_mut)) + module.as_mut().map(|x| T::view_field(x, get_ref, get_mut)) } }