diff --git a/examples/07-custom-module.rs b/examples/07-custom-module.rs index 456f18f52..0831aff98 100644 --- a/examples/07-custom-module.rs +++ b/examples/07-custom-module.rs @@ -16,6 +16,8 @@ type Device = dfdx::tensor::Cpu; #[cfg(feature = "cuda")] type Device = dfdx::tensor::Cuda; +type Err = ::Err; + /// Custom model struct /// This case is trivial and should be done with a tuple of linears and relus, /// but it demonstrates how to build models with custom behavior @@ -29,7 +31,7 @@ struct Mlp { impl BuildModule for Mlp { - fn try_build(device: &Device) -> Result::Err> { + fn try_build(device: &Device) -> Result { Ok(Self { l1: BuildModule::try_build(device)?, l2: BuildModule::try_build(device)?, @@ -43,11 +45,12 @@ impl Module { type Output = Tensor, f32, Device>; + type Error = Err; - fn forward(&self, x: Tensor, f32, Device>) -> Self::Output { - let x = self.l1.forward(x); - let x = self.relu.forward(x); - self.l2.forward(x) + fn try_forward(&self, x: Tensor, f32, Device>) -> Result { + let x = self.l1.try_forward(x)?; + let x = self.relu.try_forward(x)?; + self.l2.try_forward(x) } } @@ -61,11 +64,15 @@ impl< > Module, f32, Device, T>> for Mlp { type Output = Tensor, f32, Device, T>; + type Error = Err; - fn forward(&self, x: Tensor, f32, Device, T>) -> Self::Output { - let x = self.l1.forward(x); - let x = self.relu.forward(x); - self.l2.forward(x) + fn try_forward( + &self, + x: Tensor, f32, Device, T>, + ) -> Result { + let x = self.l1.try_forward(x)?; + let x = self.relu.try_forward(x)?; + self.l2.try_forward(x) } } diff --git a/src/nn/activations.rs b/src/nn/activations.rs index 92c0a97d5..d84702933 100644 --- a/src/nn/activations.rs +++ b/src/nn/activations.rs @@ -15,24 +15,26 @@ macro_rules! activation_impls { for $struct_name { type Output = Tensor; - fn forward(&self, input: Tensor) -> Self::Output { - $func_name(input) + type Error = D::Err; + + fn try_forward(&self, input: Tensor) -> Result { + input.$func_name() } } }; } -activation_impls!(ReLU, relu, #[doc="Unit struct that impls [Module] as calling [relu()] on `input`."]); -activation_impls!(GeLU, gelu, #[doc="Unit struct that impls [Module] as calling [gelu()] on `input`."]); -activation_impls!(Sin, sin, #[doc="Unit struct that impls [Module] as calling [sin()] on `input`."]); -activation_impls!(Cos, cos, #[doc="Unit struct that impls [Module] as calling [cos()] on `input`."]); -activation_impls!(Ln, ln, #[doc="Unit struct that impls [Module] as calling [ln()] on `input`."]); -activation_impls!(Exp, exp, #[doc="Unit struct that impls [Module] as calling [exp()] on `input`."]); -activation_impls!(Sigmoid, sigmoid, #[doc="Unit struct that impls [Module] as calling [sigmoid()] on `input`."]); -activation_impls!(Tanh, tanh, #[doc="Unit struct that impls [Module] as calling [tanh()] on `input`."]); -activation_impls!(Square, square, #[doc="Unit struct that impls [Module] as calling [square()] on `input`."]); -activation_impls!(Sqrt, sqrt, #[doc="Unit struct that impls [Module] as calling [sqrt()] on `input`."]); -activation_impls!(Abs, abs, #[doc="Unit struct that impls [Module] as calling [abs()] on `input`."]); +activation_impls!(ReLU, try_relu, #[doc="Unit struct that impls [Module] as calling [relu()] on `input`."]); +activation_impls!(GeLU, try_gelu, #[doc="Unit struct that impls [Module] as calling [gelu()] on `input`."]); +activation_impls!(Sin, try_sin, #[doc="Unit struct that impls [Module] as calling [sin()] on `input`."]); +activation_impls!(Cos, try_cos, #[doc="Unit struct that impls [Module] as calling [cos()] on `input`."]); +activation_impls!(Ln, try_ln, #[doc="Unit struct that impls [Module] as calling [ln()] on `input`."]); +activation_impls!(Exp, try_exp, #[doc="Unit struct that impls [Module] as calling [exp()] on `input`."]); +activation_impls!(Sigmoid, try_sigmoid, #[doc="Unit struct that impls [Module] as calling [sigmoid()] on `input`."]); +activation_impls!(Tanh, try_tanh, #[doc="Unit struct that impls [Module] as calling [tanh()] on `input`."]); +activation_impls!(Square, try_square, #[doc="Unit struct that impls [Module] as calling [square()] on `input`."]); +activation_impls!(Sqrt, try_sqrt, #[doc="Unit struct that impls [Module] as calling [sqrt()] on `input`."]); +activation_impls!(Abs, try_abs, #[doc="Unit struct that impls [Module] as calling [abs()] on `input`."]); /// Unit struct that impls [Module] as calling [softmax()] on `input`." #[derive(Default, Debug, Clone, Copy)] @@ -45,8 +47,10 @@ impl + ReduceShape, E: Dtype, D: Device Module> for Softmax { type Output = Tensor; - fn forward(&self, input: Tensor) -> Self::Output { - input.softmax::() + type Error = D::Err; + + fn try_forward(&self, input: Tensor) -> Result { + input.try_softmax::() } } diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index 00d2034d8..b4066430b 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -55,38 +55,48 @@ macro_rules! add_into_impls { ($([$Mod:tt $ModVar:tt $Inp:tt $InpVar:tt]),+) => { impl< Out: std::ops::Add, - $($Inp, )+ - $($Mod: Module<$Inp, Output = Out>, )+ - > Module<($($Inp, )+)> for AddInto<($($Mod, )+)> { + Ai, $($Inp, )+ + A: Module, + $($Mod: Module<$Inp, Output = Out, Error = A::Error>, )+ + > Module<(Ai, $($Inp, )+)> for AddInto<(A, $($Mod, )+)> + { type Output = Out; - fn forward(&self, x: ($($Inp, )+)) -> Self::Output { - let ($($ModVar, )+) = &self.0; - let ($($InpVar, )+) = x; - $(let $InpVar = $ModVar.forward($InpVar);)+ - sum!($($InpVar),*) + type Error = A::Error; + + fn try_forward(&self, x: (Ai, $($Inp, )+)) -> Result { + let (a, $($ModVar, )+) = &self.0; + let (a_i, $($InpVar, )+) = x; + let a_i = a.try_forward(a_i)?; + $(let $InpVar = $ModVar.try_forward($InpVar)?;)+ + Ok(sum!(a_i, $($InpVar),*)) } } impl< Out: std::ops::Add, - $($Inp, )+ - $($Mod: ModuleMut<$Inp, Output = Out>, )+ - > ModuleMut<($($Inp, )+)> for AddInto<($($Mod, )+)> { + Ai, $($Inp, )+ + A: ModuleMut, + $($Mod: ModuleMut<$Inp, Output = Out, Error = A::Error>, )+ + > ModuleMut<(Ai, $($Inp, )+)> for AddInto<(A, $($Mod, )+)> + { type Output = Out; - fn forward_mut(&mut self, x: ($($Inp, )+)) -> Self::Output { - let ($($ModVar, )+) = &mut self.0; - let ($($InpVar, )+) = x; - $(let $InpVar = $ModVar.forward_mut($InpVar);)+ - sum!($($InpVar),*) + type Error = A::Error; + + fn try_forward_mut(&mut self, x: (Ai, $($Inp, )+)) -> Result { + let (a, $($ModVar, )+) = &mut self.0; + let (a_i, $($InpVar, )+) = x; + let a_i = a.try_forward_mut(a_i)?; + $(let $InpVar = $ModVar.try_forward_mut($InpVar)?;)+ + Ok(sum!(a_i, $($InpVar),*)) } } }; } -add_into_impls!([A a Ai a_i], [B b Bi b_i]); -add_into_impls!([A a Ai a_i], [B b Bi b_i], [C c Ci c_i]); -add_into_impls!([A a Ai a_i], [B b Bi b_i], [C c Ci c_i], [D d Di d_i]); -add_into_impls!([A a Ai a_i], [B b Bi b_i], [C c Ci c_i], [D d Di d_i], [E e Ei e_i]); -add_into_impls!([A a Ai a_i], [B b Bi b_i], [C c Ci c_i], [D d Di d_i], [E e Ei e_i], [F f Fi f_i]); +add_into_impls!([B b Bi b_i]); +add_into_impls!([B b Bi b_i], [C c Ci c_i]); +add_into_impls!([B b Bi b_i], [C c Ci c_i], [D d Di d_i]); +add_into_impls!([B b Bi b_i], [C c Ci c_i], [D d Di d_i], [E e Ei e_i]); +add_into_impls!([B b Bi b_i], [C c Ci c_i], [D d Di d_i], [E e Ei e_i], [F f Fi f_i]); #[cfg(test)] mod tests { diff --git a/src/nn/batchnorm2d.rs b/src/nn/batchnorm2d.rs index 058bc18d1..f7adbabb5 100644 --- a/src/nn/batchnorm2d.rs +++ b/src/nn/batchnorm2d.rs @@ -71,27 +71,27 @@ pub struct BatchNorm2D { impl> BatchNorm2D { /// generic forward for inference - fn infer_fwd(&self, x: Tensor) -> Tensor + fn infer_fwd(&self, x: Tensor) -> Result, D::Err> where Rank1: BroadcastShapeTo, { let shape = *x.shape(); // statistics for normalizing - let std = (self.running_var.clone() + self.epsilon).sqrt(); + let std = (self.running_var.clone() + self.epsilon).try_sqrt()?; let mean = self.running_mean.clone(); // normalize & affine - let x = sub(x, mean.broadcast_like(&shape)); - let x = div(x, std.broadcast_like(&shape)); - let x = mul(x, self.scale.clone().broadcast_like(&shape)); - add(x, self.bias.clone().broadcast_like(&shape)) + let x = x.try_sub(mean.try_broadcast_like(&shape)?)?; + let x = x.try_div(std.try_broadcast_like(&shape)?)?; + let x = x.try_mul(self.scale.clone().try_broadcast_like(&shape)?)?; + x.try_add(self.bias.clone().try_broadcast_like(&shape)?) } fn train_fwd, Ax: Axes>( &mut self, x: Tensor, - ) -> Tensor + ) -> Result, D::Err> where S: HasAxes + ReduceShapeTo, Ax>, { @@ -99,29 +99,39 @@ impl> BatchNorm2D { let shape = *x.shape(); // compute statistics for updating running stats later - on tape - let mean_chan = x.retaped::().mean::, _>(); + let mean_chan = x.retaped::().try_mean::, _>()?; // update statistics since we are training - off tape - self.running_mean = self.running_mean.clone() * (E::ONE - self.momentum) - + mean_chan.retaped::() * self.momentum; + self.running_mean = self + .running_mean + .clone() + .try_mul(E::ONE - self.momentum)? + .try_add(mean_chan.retaped::().try_mul(self.momentum)?)?; - let centered = x - mean_chan.broadcast_like(&shape); + let centered = x - mean_chan.try_broadcast_like(&shape)?; let var_chan = centered.retaped::().square().mean::, _>(); // NOTE: uses unbiased variance in running estimate - self.running_var = self.running_var.clone() * (E::ONE - self.momentum) - + var_chan.retaped::() * (self.momentum * n / (n - E::ONE)); + self.running_var = self + .running_var + .clone() + .try_mul(E::ONE - self.momentum)? + .try_add( + var_chan + .retaped::() + .try_mul(self.momentum * n / (n - E::ONE))?, + )?; // statistics for normalizing - on tape - let std = (var_chan + self.epsilon).sqrt(); + let std = (var_chan + self.epsilon).try_sqrt()?; // record broadcast of scale & bias - on tape - let scale = (self.scale.retaped::() / std).broadcast_like(&shape); - let bias = self.bias.retaped::().broadcast_like(&shape); + let scale = (self.scale.retaped::() / std).try_broadcast_like(&shape)?; + let bias = self.bias.retaped::().try_broadcast_like(&shape)?; // normalize & affine - on tape - centered * scale + bias + centered.try_mul(scale)?.try_add(bias) } } @@ -129,9 +139,13 @@ impl> Module, H, W), E, D, NoneTape>> for BatchNorm2D { type Output = Tensor<(Const, H, W), E, D, NoneTape>; + type Error = D::Err; /// Inference 3d forward - does **not** update [Self::running_mean] and [Self::running_var] - fn forward(&self, x: Tensor<(Const, H, W), E, D, NoneTape>) -> Self::Output { + fn try_forward( + &self, + x: Tensor<(Const, H, W), E, D, NoneTape>, + ) -> Result { self.infer_fwd(x) } } @@ -140,9 +154,13 @@ impl> Module, H, W), E, D, NoneTape>> for BatchNorm2D { type Output = Tensor<(B, Const, H, W), E, D, NoneTape>; + type Error = D::Err; /// Inference 4d forward - does **not** update [Self::running_mean] and [Self::running_var] - fn forward(&self, x: Tensor<(B, Const, H, W), E, D, NoneTape>) -> Self::Output { + fn try_forward( + &self, + x: Tensor<(B, Const, H, W), E, D, NoneTape>, + ) -> Result { self.infer_fwd(x) } } @@ -151,9 +169,13 @@ impl> ModuleMut, H, W), E, D, OwnedTape>> for BatchNorm2D { type Output = Tensor<(Const, H, W), E, D, OwnedTape>; + type Error = D::Err; /// Training 3d forward - updates [Self::running_mean] and [Self::running_var] - fn forward_mut(&mut self, x: Tensor<(Const, H, W), E, D, OwnedTape>) -> Self::Output { + fn try_forward_mut( + &mut self, + x: Tensor<(Const, H, W), E, D, OwnedTape>, + ) -> Result { self.train_fwd(x) } } @@ -162,9 +184,13 @@ impl> ModuleMut, H, W), E, D, OwnedTape>> for BatchNorm2D { type Output = Tensor<(B, Const, H, W), E, D, OwnedTape>; + type Error = D::Err; /// Training 4d forward - updates [Self::running_mean] and [Self::running_var] - fn forward_mut(&mut self, x: Tensor<(B, Const, H, W), E, D, OwnedTape>) -> Self::Output { + fn try_forward_mut( + &mut self, + x: Tensor<(B, Const, H, W), E, D, OwnedTape>, + ) -> Result { self.train_fwd(x) } } @@ -300,7 +326,7 @@ mod tests { } #[test] - fn test_batchform2d_3d_repeated_forward_mut() { + fn test_batchnorm2d_3d_repeated_forward_mut() { let dev = TestDevice::seed_from_u64(12); let x1: Tensor, TestDtype, _> = dev.sample_normal(); diff --git a/src/nn/conv.rs b/src/nn/conv.rs index 06924daf8..b2b52aed8 100644 --- a/src/nn/conv.rs +++ b/src/nn/conv.rs @@ -120,12 +120,14 @@ impl, - Img: TryConv2DTo, E, D>, S, P>, - for<'a> Bias2D<'a, O, E, D>: Module, + Img: TryConv2DTo, E, D>, S, P> + HasErr, + for<'a> Bias2D<'a, O, E, D>: Module, { type Output = Img::Output; - fn forward(&self, x: Img) -> Self::Output { - Bias2D { beta: &self.bias }.forward(x.conv2d_to(self.weight.clone())) + type Error = D::Err; + + fn try_forward(&self, x: Img) -> Result { + Bias2D { beta: &self.bias }.try_forward(x.try_conv2d_to(self.weight.clone())?) } } @@ -134,11 +136,13 @@ impl, - Self: Module, + Self: Module, { type Output = >::Output; - fn forward_mut(&mut self, input: Img) -> Self::Output { - self.forward(input) + type Error = D::Err; + + fn try_forward_mut(&mut self, input: Img) -> Result { + self.try_forward(input) } } @@ -151,8 +155,16 @@ impl<'a, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device, T: Tape> Module, H, W), E, D, T>> for Bias2D<'a, C, E, D> { type Output = Tensor<(Const, H, W), E, D, T>; - fn forward(&self, input: Tensor<(Const, H, W), E, D, T>) -> Self::Output { - self.beta.retaped::().broadcast_like(input.shape()) + input + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor<(Const, H, W), E, D, T>, + ) -> Result { + self.beta + .retaped::() + .try_broadcast_like(input.shape())? + .try_add(input) } } @@ -160,8 +172,16 @@ impl<'a, B: Dim, const C: usize, H: Dim, W: Dim, E: Dtype, D: Device, T: Tape Module, H, W), E, D, T>> for Bias2D<'a, C, E, D> { type Output = Tensor<(B, Const, H, W), E, D, T>; - fn forward(&self, input: Tensor<(B, Const, H, W), E, D, T>) -> Self::Output { - self.beta.retaped::().broadcast_like(input.shape()) + input + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor<(B, Const, H, W), E, D, T>, + ) -> Result { + self.beta + .retaped::() + .try_broadcast_like(input.shape())? + .try_add(input) } } diff --git a/src/nn/dropout.rs b/src/nn/dropout.rs index 591e6d5eb..a9e4ce318 100644 --- a/src/nn/dropout.rs +++ b/src/nn/dropout.rs @@ -50,8 +50,10 @@ impl> Module; /// Does nothing - fn forward(&self, input: Tensor) -> Self::Output { - input + type Error = D::Err; + + fn try_forward(&self, input: Tensor) -> Result { + Ok(input) } } @@ -59,9 +61,14 @@ impl> ModuleMut { type Output = Tensor>; + type Error = D::Err; + /// Calls [dropout()] with `p=1/N` using `self.rng`. - fn forward_mut(&mut self, input: Tensor>) -> Self::Output { - dropout(input, E::ONE / E::from_usize(N).unwrap()) + fn try_forward_mut( + &mut self, + input: Tensor>, + ) -> Result { + input.try_dropout(E::ONE / E::from_usize(N).unwrap()) } } @@ -119,17 +126,24 @@ impl ZeroSizedModule for Dropout {} impl> Module> for Dropout { type Output = Tensor; + type Error = D::Err; + /// Does nothing. - fn forward(&self, input: Tensor) -> Self::Output { - input + fn try_forward(&self, input: Tensor) -> Result { + Ok(input) } } impl> ModuleMut>> for Dropout { type Output = Tensor>; + type Error = D::Err; + /// Calls [dropout()] - fn forward_mut(&mut self, input: Tensor>) -> Self::Output { - dropout(input, E::from_f32(self.p).unwrap()) + fn try_forward_mut( + &mut self, + input: Tensor>, + ) -> Result { + input.try_dropout(E::from_f32(self.p).unwrap()) } } diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index fd372b5fb..27bcaaa3c 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -87,9 +87,11 @@ impl, T: Module, usize, D, T>> for Embedding { type Output = Tensor, E, D, T>; - fn forward(&self, input: Tensor, usize, D, T>) -> Self::Output { + type Error = D::Err; + + fn try_forward(&self, input: Tensor, usize, D, T>) -> Result { let (input, tape) = input.split_tape(); - self.weight.clone().put_tape(tape).gather(input) + self.weight.clone().put_tape(tape).try_gather(input) } } @@ -104,9 +106,14 @@ impl< > Module, usize, D, T>> for Embedding { type Output = Tensor, E, D, T>; - fn forward(&self, input: Tensor, usize, D, T>) -> Self::Output { + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor, usize, D, T>, + ) -> Result { let (input, tape) = input.split_tape(); - self.weight.clone().put_tape(tape).gather(input) + self.weight.clone().put_tape(tape).try_gather(input) } } diff --git a/src/nn/flatten.rs b/src/nn/flatten.rs index 2659a50c6..6b4d5b23f 100644 --- a/src/nn/flatten.rs +++ b/src/nn/flatten.rs @@ -17,8 +17,10 @@ where Rank3: HasSameNumelAs>, { type Output = Tensor, E, D, T>; - fn forward(&self, input: Tensor, E, D, T>) -> Self::Output { - input.reshape() + type Error = D::Err; + + fn try_forward(&self, input: Tensor, E, D, T>) -> Result { + input.try_reshape() } } @@ -30,8 +32,13 @@ where Rank4: HasSameNumelAs>, { type Output = Tensor, E, D, T>; - fn forward(&self, input: Tensor, E, D, T>) -> Self::Output { - input.reshape() + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor, E, D, T>, + ) -> Result { + input.try_reshape() } } diff --git a/src/nn/generalized_residual.rs b/src/nn/generalized_residual.rs index 1177b5100..d0203fe56 100644 --- a/src/nn/generalized_residual.rs +++ b/src/nn/generalized_residual.rs @@ -1,4 +1,4 @@ -use crate::{shapes::*, tensor::*}; +use crate::{shapes::*, tensor::*, tensor_ops::TryAdd}; use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; @@ -61,25 +61,33 @@ impl, R: ToDevice> ToDevice for GeneralizedResidual, R: Module> Module +impl, R: Module> Module for GeneralizedResidual where - F::Output: std::ops::Add, + F::Output: TryAdd + HasErr, { - type Output = >::Output; - fn forward(&self, x: T) -> Self::Output { - self.f.forward(x.with_empty_tape()) + self.r.forward(x) + type Output = F::Output; + type Error = F::Error; + + fn try_forward(&self, x: T) -> Result { + self.f + .try_forward(x.with_empty_tape())? + .try_add(self.r.try_forward(x)?) } } -impl, R: ModuleMut> ModuleMut - for GeneralizedResidual +impl, R: ModuleMut> + ModuleMut for GeneralizedResidual where - F::Output: std::ops::Add, + F::Output: TryAdd + HasErr, { - type Output = >::Output; - fn forward_mut(&mut self, x: T) -> Self::Output { - self.f.forward_mut(x.with_empty_tape()) + self.r.forward_mut(x) + type Output = F::Output; + type Error = F::Error; + + fn try_forward_mut(&mut self, x: T) -> Result { + self.f + .try_forward_mut(x.with_empty_tape())? + .try_add(self.r.try_forward_mut(x)?) } } diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 4f03842fe..794889cd9 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -58,30 +58,32 @@ macro_rules! tuple_impls { impl< Input, $last: - $(Module::<$rev_tail ::Output>, $rev_tail: )+ + $(Module::<$rev_tail ::Output, Error=$rev_tail::Error>, $rev_tail: )+ Module > Module for ($($name,)+) { type Output = $last ::Output; + type Error = $last ::Error; /// Calls forward sequentially on each module in the tuple. - fn forward(&self, x: Input) -> Self::Output { - $(let x = self.$idx.forward(x);)+ - x + fn try_forward(&self, x: Input) -> Result { + $(let x = self.$idx.try_forward(x)?;)+ + Ok(x) } } impl< Input, $last: - $(ModuleMut::<$rev_tail ::Output>, $rev_tail: )+ + $(ModuleMut::<$rev_tail ::Output, Error=$rev_tail::Error>, $rev_tail: )+ ModuleMut > ModuleMut for ($($name,)+) { type Output = $last ::Output; + type Error = $last ::Error; /// Calls forward sequentially on each module in the tuple. - fn forward_mut(&mut self, x: Input) -> Self::Output { - $(let x = self.$idx.forward_mut(x);)+ - x + fn try_forward_mut(&mut self, x: Input) -> Result { + $(let x = self.$idx.try_forward_mut(x)?;)+ + Ok(x) } } }; @@ -154,11 +156,17 @@ mod tests { #[derive(Debug, Default, Clone)] struct SetTo1; impl ZeroSizedModule for SetTo1 {} + impl Module, f32, Cpu>> for SetTo1 { type Output = Tensor, f32, Cpu>; - fn forward(&self, mut input: Tensor, f32, Cpu>) -> Self::Output { + type Error = ::Err; + + fn try_forward( + &self, + mut input: Tensor, f32, Cpu>, + ) -> Result { std::sync::Arc::make_mut(&mut input.storage.data)[I] = 1.0; - input + Ok(input) } } diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index f049745d7..cc5a20f0d 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -90,8 +90,12 @@ impl, T: Tape> Module, for LayerNorm1D { type Output = Tensor, E, D, T>; - fn forward(&self, x: Tensor, E, D, T>) -> Self::Output { - x.normalize(self.epsilon) * self.gamma.clone() + self.beta.clone() + type Error = D::Err; + + fn try_forward(&self, x: Tensor, E, D, T>) -> Result { + x.try_normalize(self.epsilon)? + .try_mul(self.gamma.clone())? + .try_add(self.beta.clone()) } } @@ -99,10 +103,13 @@ impl, T: Tape> Module), E, D, T>> for LayerNorm1D { type Output = Tensor<(B, Const), E, D, T>; - fn forward(&self, x: Tensor<(B, Const), E, D, T>) -> Self::Output { + type Error = D::Err; + + fn try_forward(&self, x: Tensor<(B, Const), E, D, T>) -> Result { let shape = *x.shape(); - x.normalize::>(self.epsilon) * self.gamma.retaped::().broadcast_like(&shape) - + self.beta.retaped::().broadcast_like(&shape) + x.try_normalize::>(self.epsilon)? + .try_mul(self.gamma.retaped::().try_broadcast_like(&shape)?)? + .try_add(self.beta.retaped::().try_broadcast_like(&shape)?) } } @@ -110,10 +117,13 @@ impl, T: Tape> Module), E, D, T>> for LayerNorm1D { type Output = Tensor<(B, S, Const), E, D, T>; - fn forward(&self, x: Tensor<(B, S, Const), E, D, T>) -> Self::Output { + type Error = D::Err; + + fn try_forward(&self, x: Tensor<(B, S, Const), E, D, T>) -> Result { let shape = *x.shape(); - x.normalize::>(self.epsilon) * self.gamma.retaped::().broadcast_like(&shape) - + self.beta.retaped::().broadcast_like(&shape) + x.try_normalize::>(self.epsilon)? + .try_mul(self.gamma.retaped::().try_broadcast_like(&shape)?)? + .try_add(self.beta.retaped::().try_broadcast_like(&shape)?) } } diff --git a/src/nn/linear.rs b/src/nn/linear.rs index df5c58aaf..4601c5a21 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -107,16 +107,17 @@ impl, D2: Device> ToD impl, T> Module for Linear where - T: SplitTape + TryMatMul, E, D, T::Tape>>, + T: SplitTape + TryMatMul, E, D, T::Tape>> + HasErr, T::Tape: Tape, - for<'a> Bias1D<'a, O, E, D>: Module, + for<'a> Bias1D<'a, O, E, D>: Module, { type Output = T::Output; + type Error = D::Err; /// 1d forward using [matmul()] and [add()]. - fn forward(&self, x: T) -> Self::Output { - let o = x.matmul(self.weight.retaped::().permute()); - Bias1D { beta: &self.bias }.forward(o) + fn try_forward(&self, x: T) -> Result { + let o = x.try_matmul(self.weight.retaped::().try_permute()?)?; + Bias1D { beta: &self.bias }.try_forward(o) } } @@ -129,8 +130,10 @@ impl<'a, const M: usize, E: Dtype, D: Device, T: Tape> Module { type Output = Tensor, E, D, T>; - fn forward(&self, input: Tensor, E, D, T>) -> Self::Output { - input + self.beta.clone() + type Error = D::Err; + + fn try_forward(&self, input: Tensor, E, D, T>) -> Result { + input.try_add(self.beta.clone()) } } @@ -138,8 +141,13 @@ impl<'a, B: Dim, const M: usize, E: Dtype, D: Device, T: Tape> Module), E, D, T>> for Bias1D<'a, M, E, D> { type Output = Tensor<(B, Const), E, D, T>; - fn forward(&self, input: Tensor<(B, Const), E, D, T>) -> Self::Output { - self.beta.retaped::().broadcast_like(input.shape()) + input + type Error = D::Err; + + fn try_forward(&self, input: Tensor<(B, Const), E, D, T>) -> Result { + self.beta + .retaped::() + .try_broadcast_like(input.shape())? + .try_add(input) } } @@ -147,8 +155,16 @@ impl<'a, B: Dim, S: Dim, const M: usize, E: Dtype, D: Device, T: Tape> Module), E, D, T>> for Bias1D<'a, M, E, D> { type Output = Tensor<(B, S, Const), E, D, T>; - fn forward(&self, input: Tensor<(B, S, Const), E, D, T>) -> Self::Output { - self.beta.retaped::().broadcast_like(input.shape()) + input + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor<(B, S, Const), E, D, T>, + ) -> Result { + self.beta + .retaped::() + .try_broadcast_like(input.shape())? + .try_add(input) } } diff --git a/src/nn/module.rs b/src/nn/module.rs index 8b1f0f8e5..a21acd5f0 100644 --- a/src/nn/module.rs +++ b/src/nn/module.rs @@ -10,11 +10,16 @@ use super::tensor_collection::{ModuleVisitor, TensorCollection}; pub trait Module { /// The type that this unit produces given `Input`. type Output; + type Error: core::fmt::Debug; + + fn try_forward(&self, input: Input) -> Result; /// Forward `Input` through the module and produce [Module::Output]. /// /// **See [ModuleMut::forward_mut()] for version that can mutate `self`.** - fn forward(&self, input: Input) -> Self::Output; + fn forward(&self, input: Input) -> Self::Output { + self.try_forward(input).unwrap() + } } /// Mutable forward of `Input` that produces [ModuleMut::Output]. @@ -22,11 +27,16 @@ pub trait Module { pub trait ModuleMut { /// The type that this unit produces given `Input`. type Output; + type Error: core::fmt::Debug; + + fn try_forward_mut(&mut self, input: Input) -> Result; /// Forward `Input` through the module and produce [ModuleMut::Output]. /// /// **See [Module::forward()] for immutable version** - fn forward_mut(&mut self, input: Input) -> Self::Output; + fn forward_mut(&mut self, input: Input) -> Self::Output { + self.try_forward_mut(input).unwrap() + } } /// Something that can be built. Related to [BuildOnDevice] @@ -101,7 +111,9 @@ where Self: Module, { type Output = >::Output; - fn forward_mut(&mut self, input: T) -> Self::Output { - self.forward(input) + type Error = >::Error; + + fn try_forward_mut(&mut self, input: T) -> Result { + self.try_forward(input) } } diff --git a/src/nn/pool2d.rs b/src/nn/pool2d.rs index 48dd20891..ab8d5ea51 100644 --- a/src/nn/pool2d.rs +++ b/src/nn/pool2d.rs @@ -44,8 +44,10 @@ macro_rules! impl_pools { for $PoolTy { type Output = Img::Output; - fn forward(&self, x: Img) -> Self::Output { - x.try_pool2d().unwrap() + type Error = Img::Err; + + fn try_forward(&self, x: Img) -> Result { + x.try_pool2d() } } }; diff --git a/src/nn/pool_global.rs b/src/nn/pool_global.rs index 6a3cf77c0..7aa77392b 100644 --- a/src/nn/pool_global.rs +++ b/src/nn/pool_global.rs @@ -65,8 +65,13 @@ macro_rules! impl_pools { Module> for $PoolTy { type Output = Tensor<(C,), E, D, T>; - fn forward(&self, input: Tensor<(C, H, W), E, D, T>) -> Self::Output { - input.min() + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor<(C, H, W), E, D, T>, + ) -> Result { + input.try_min() } } @@ -74,13 +79,18 @@ macro_rules! impl_pools { Module> for $PoolTy { type Output = Tensor<(B, C), E, D, T>; - fn forward(&self, input: Tensor<(B, C, H, W), E, D, T>) -> Self::Output { + type Error = D::Err; + + fn try_forward( + &self, + input: Tensor<(B, C, H, W), E, D, T>, + ) -> Result { input.$Method() } } }; } -impl_pools!(AvgPoolGlobal, mean); -impl_pools!(MaxPoolGlobal, max); -impl_pools!(MinPoolGlobal, min); +impl_pools!(AvgPoolGlobal, try_mean); +impl_pools!(MaxPoolGlobal, try_max); +impl_pools!(MinPoolGlobal, try_min); diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index d02c1ee15..bc71adc37 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -76,11 +76,13 @@ impl std::ops::Index for Repeated { impl, const N: usize> Module for Repeated { type Output = T::Output; - fn forward(&self, mut x: Input) -> Self::Output { + type Error = T::Error; + + fn try_forward(&self, mut x: Input) -> Result { for i in 0..N { - x = self.modules[i].forward(x); + x = self.modules[i].try_forward(x)?; } - x + Ok(x) } } @@ -88,11 +90,13 @@ impl, const N: usize> ModuleMut { type Output = T::Output; - fn forward_mut(&mut self, mut x: Input) -> Self::Output { + type Error = T::Error; + + fn try_forward_mut(&mut self, mut x: Input) -> Result { for i in 0..N { - x = self.modules[i].forward_mut(x); + x = self.modules[i].try_forward_mut(x)?; } - x + Ok(x) } } diff --git a/src/nn/residual.rs b/src/nn/residual.rs index a6de6810a..094df4b06 100644 --- a/src/nn/residual.rs +++ b/src/nn/residual.rs @@ -1,9 +1,7 @@ -use crate::{shapes::*, tensor::*}; +use crate::{shapes::*, tensor::*, tensor_ops::TryAdd}; use super::{tensor_collection::*, BuildModule, BuildOnDevice, Module, ModuleMut, ToDevice}; -use std::ops::Add; - /// A residual connection around `F`: `F(x) + x`, /// as introduced in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). /// @@ -46,17 +44,23 @@ impl, D> ToDevice for Residual { } } -impl, F: Module> Module for Residual { +impl, F: Module> Module for Residual { type Output = T; - fn forward(&self, x: T) -> Self::Output { - self.0.forward(x.with_empty_tape()) + x + type Error = F::Error; + + fn try_forward(&self, x: T) -> Result { + self.0.try_forward(x.with_empty_tape())?.try_add(x) } } -impl, F: ModuleMut> ModuleMut for Residual { +impl, F: ModuleMut> ModuleMut + for Residual +{ type Output = T; - fn forward_mut(&mut self, x: T) -> Self::Output { - self.0.forward_mut(x.with_empty_tape()) + x + type Error = F::Error; + + fn try_forward_mut(&mut self, x: T) -> Result { + self.0.try_forward_mut(x.with_empty_tape())?.try_add(x) } } diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 5a1847a6c..1404f94ee 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -52,7 +52,7 @@ macro_rules! tuple_impls { ([$($heads:ident),+] $tail:ident) => { impl< Input: SplitTape, - $($heads : Module,)+ + $($heads : Module,)+ $tail: Module > Module for SplitInto<($($heads,)+ $tail)> where @@ -62,20 +62,21 @@ where $(<$heads::Output as SplitTape>::NoTape, )+ $tail::Output ); + type Error = $tail::Error; #[allow(non_snake_case)] - fn forward(&self, x: Input) -> Self::Output { + fn try_forward(&self, x: Input) -> Result { let (x, tape) = x.split_tape(); let ($($heads, )+ $tail) = &self.0; - $(let ($heads, tape) = $heads.forward(x.clone().put_tape(tape)).split_tape();)+ - let $tail = $tail.forward(x.put_tape(tape)); - ($($heads,)+ $tail) + $(let ($heads, tape) = $heads.try_forward(x.clone().put_tape(tape))?.split_tape();)+ + let $tail = $tail.try_forward(x.put_tape(tape))?; + Ok(($($heads,)+ $tail)) } } impl< Input: SplitTape, - $($heads : ModuleMut,)+ + $($heads : ModuleMut,)+ $tail: ModuleMut > ModuleMut for SplitInto<($($heads,)+ $tail)> where @@ -85,14 +86,15 @@ where $(<$heads::Output as SplitTape>::NoTape, )+ $tail::Output ); + type Error = $tail::Error; #[allow(non_snake_case)] - fn forward_mut(&mut self, x: Input) -> Self::Output { + fn try_forward_mut(&mut self, x: Input) -> Result { let (x, tape) = x.split_tape(); let ($($heads, )+ $tail) = &mut self.0; - $(let ($heads, tape) = $heads.forward_mut(x.clone().put_tape(tape)).split_tape();)+ - let $tail = $tail.forward_mut(x.put_tape(tape)); - ($($heads,)+ $tail) + $(let ($heads, tape) = $heads.try_forward_mut(x.clone().put_tape(tape))?.split_tape();)+ + let $tail = $tail.try_forward_mut(x.put_tape(tape))?; + Ok(($($heads,)+ $tail)) } } } diff --git a/src/nn/transformer/decoder.rs b/src/nn/transformer/decoder.rs index c5faa15c6..87dc3bc47 100644 --- a/src/nn/transformer/decoder.rs +++ b/src/nn/transformer/decoder.rs @@ -3,9 +3,10 @@ use rand_distr::uniform::SampleUniform; use crate::{ nn::{modules::*, tensor_collection::*, *}, + prelude::storage_traits::HasErr, shapes::Dtype, tensor::{PutTape, SplitTape}, - tensor_ops::Device, + tensor_ops::{Device, TryAdd}, }; use super::mha::MultiHeadAttention; @@ -107,26 +108,30 @@ impl, - TransformerDecoderBlock: Module<(Tgt, Mem), Output = Tgt>, + TransformerDecoderBlock: Module<(Tgt, Mem), Output = Tgt, Error = D::Err>, { type Output = Tgt; - fn forward(&self, (mut tgt, mem): (Tgt, Mem)) -> Self::Output { + type Error = D::Err; + + fn try_forward(&self, (mut tgt, mem): (Tgt, Mem)) -> Result { for block in self.0.modules.iter() { - tgt = block.forward((tgt, mem.clone())); + tgt = block.try_forward((tgt, mem.clone()))?; } - tgt + Ok(tgt) } } impl, T> ModuleMut for TransformerDecoder where - Self: Module, + Self: Module, { type Output = >::Output; - fn forward_mut(&mut self, t: T) -> Self::Output { - self.forward(t) + type Error = D::Err; + + fn try_forward_mut(&mut self, t: T) -> Result { + self.try_forward(t) } } @@ -216,28 +221,31 @@ impl, D2 impl, Tgt, Mem> Module<(Tgt, Mem)> for TransformerDecoderBlock where - Tgt: SplitTape + std::ops::Add, + Tgt: SplitTape + TryAdd + HasErr, Mem: Clone, - MultiHeadAttention: - Module + Module<(Tgt, Mem, Mem), Output = Tgt>, - LayerNorm1D: Module, - FF: Module, + MultiHeadAttention: Module + + Module<(Tgt, Mem, Mem), Output = Tgt, Error = D::Err>, + LayerNorm1D: Module, + FF: Module, { type Output = Tgt; + type Error = D::Err; - fn forward(&self, (tgt, mem): (Tgt, Mem)) -> Self::Output { + fn try_forward(&self, (tgt, mem): (Tgt, Mem)) -> Result { let (tgt, tape) = tgt.split_tape(); - let x = self.self_attn.forward(tgt.clone().put_tape(tape)); - let x = x + tgt; - let x = self.norm1.forward(x); + let x = self.self_attn.try_forward(tgt.clone().put_tape(tape))?; + let x = x.try_add(tgt)?; + let x = self.norm1.try_forward(x)?; let (x, tape) = x.split_tape(); let x_residual = x.clone(); - let x = self.mh_attn.forward((x.put_tape(tape), mem.clone(), mem)); - let x = x + x_residual; - let x = self.norm2.forward(x); - let x = self.ff.forward(x); - self.norm3.forward(x) + let x = self + .mh_attn + .try_forward((x.put_tape(tape), mem.clone(), mem))?; + let x = x.try_add(x_residual)?; + let x = self.norm2.try_forward(x)?; + let x = self.ff.try_forward(x)?; + self.norm3.try_forward(x) } } diff --git a/src/nn/transformer/encoder.rs b/src/nn/transformer/encoder.rs index 6d5cc77c4..a90b58adf 100644 --- a/src/nn/transformer/encoder.rs +++ b/src/nn/transformer/encoder.rs @@ -144,31 +144,33 @@ impl, Src for TransformerEncoderBlock where Src: SplitTape + std::ops::Add, - MultiHeadAttention: Module, - LayerNorm1D: Module, - FF: Module, + MultiHeadAttention: Module, + LayerNorm1D: Module, + FF: Module, { type Output = Src; + type Error = D::Err; - fn forward(&self, src: Src) -> Self::Output { + fn try_forward(&self, src: Src) -> Result { let (src, tape) = src.split_tape(); - let x = self.self_attn.forward(src.clone().put_tape(tape)); + let x = self.self_attn.try_forward(src.clone().put_tape(tape))?; let x = x + src; - let x = self.norm1.forward(x); - let x = self.ff.forward(x); - self.norm2.forward(x) + let x = self.norm1.try_forward(x)?; + let x = self.ff.try_forward(x)?; + self.norm2.try_forward(x) } } impl, T> ModuleMut for TransformerEncoderBlock where - Self: Module, + Self: Module, { type Output = >::Output; + type Error = D::Err; - fn forward_mut(&mut self, t: T) -> Self::Output { - self.forward(t) + fn try_forward_mut(&mut self, t: T) -> Result { + self.try_forward(t) } } diff --git a/src/nn/transformer/mha.rs b/src/nn/transformer/mha.rs index 59ed289fb..2ccebfbda 100644 --- a/src/nn/transformer/mha.rs +++ b/src/nn/transformer/mha.rs @@ -133,39 +133,40 @@ where Assert<{ S1 * H * (V / H) == S1 * V }>: ConstTrue, { type Output = Tensor, E, D, T>; + type Error = D::Err; /// Encoder-Decoder style self attention where one set of tensors is used for values and keys, and another is used for queries - fn forward( + fn try_forward( &self, (q, k, v): ( Tensor, E, D, T>, Tensor, E, D>, Tensor, E, D>, ), - ) -> Self::Output { - let v: Tensor, _, _, _> = self.w_v.forward(v.retaped::()); - let v = v.reshape::>(); - let v = v.permute::, _>(); + ) -> Result { + let v: Tensor, _, _, _> = self.w_v.try_forward(v.retaped::())?; + let v = v.try_reshape::>()?; + let v = v.try_permute::, _>()?; - let k: Tensor, _, _, _> = self.w_k.forward(k.retaped::()); - let k = k.reshape::>(); - let k = k.permute::, _>(); + let k: Tensor, _, _, _> = self.w_k.try_forward(k.retaped::())?; + let k = k.try_reshape::>()?; + let k = k.try_permute::, _>()?; - let q: Tensor, _, _, _> = self.w_q.forward(q); - let q = q.reshape::>(); - let q = q.permute::, _>(); + let q: Tensor, _, _, _> = self.w_q.try_forward(q)?; + let q = q.try_reshape::>()?; + let q = q.try_permute::, _>()?; // Get weights let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt(); - let weights: Tensor, _, _, _> = q.matmul(k) * scalar; - let weights = weights.softmax::>(); + let weights: Tensor, _, _, _> = q.try_matmul(k)?.try_mul(scalar)?; + let weights = weights.try_softmax::>()?; // Get new tokens - let tokens: Tensor, _, _, _> = weights.matmul(v); - let tokens = tokens.permute::, _>(); - let tokens = tokens.reshape::>(); + let tokens: Tensor, _, _, _> = weights.try_matmul(v)?; + let tokens = tokens.try_permute::, _>()?; + let tokens = tokens.try_reshape::>()?; - self.w_o.forward(tokens) + self.w_o.try_forward(tokens) } } @@ -194,39 +195,40 @@ where Assert<{ B * S1 * H * (V / H) == B * S1 * V }>: ConstTrue, { type Output = Tensor, E, D, T>; + type Error = D::Err; /// Batched Encoder-Decoder style self attention where one set of tensors is used for values and keys, and another is used for queries - fn forward( + fn try_forward( &self, (q, k, v): ( Tensor, E, D, T>, Tensor, E, D>, Tensor, E, D>, ), - ) -> Self::Output { - let v: Tensor, _, _, _> = self.w_v.forward(v.retaped::()); - let v = v.reshape::>(); - let v = v.permute::, _>(); + ) -> Result { + let v: Tensor, _, _, _> = self.w_v.try_forward(v.retaped::())?; + let v = v.try_reshape::>()?; + let v = v.try_permute::, _>()?; - let k: Tensor, _, _, _> = self.w_k.forward(k.retaped::()); - let k = k.reshape::>(); - let k = k.permute::, _>(); + let k: Tensor, _, _, _> = self.w_k.try_forward(k.retaped::())?; + let k = k.try_reshape::>()?; + let k = k.try_permute::, _>()?; - let q: Tensor, _, _, _> = self.w_q.forward(q); - let q = q.reshape::>(); - let q = q.permute::, _>(); + let q: Tensor, _, _, _> = self.w_q.try_forward(q)?; + let q = q.try_reshape::>()?; + let q = q.try_permute::, _>()?; // Get weights let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt(); - let weights: Tensor, _, _, _> = q.matmul(k) * scalar; - let weights = weights.softmax::>(); + let weights: Tensor, _, _, _> = q.try_matmul(k)?.try_mul(scalar)?; + let weights = weights.try_softmax::>()?; // Get new tokens - let tokens: Tensor, _, _, _> = weights.matmul(v); - let tokens = tokens.permute::, _>(); - let tokens = tokens.reshape::>(); + let tokens: Tensor, _, _, _> = weights.try_matmul(v)?; + let tokens = tokens.try_permute::, _>()?; + let tokens = tokens.try_reshape::>()?; - self.w_o.forward(tokens) + self.w_o.try_forward(tokens) } } @@ -236,24 +238,27 @@ where E: Dtype, D: Device, Src: SplitTape, - Self: Module<(Src, Src::NoTape, Src::NoTape), Output = Src>, + Self: Module<(Src, Src::NoTape, Src::NoTape), Output = Src, Error = D::Err>, { type Output = Src; - fn forward(&self, src: Src) -> Self::Output { + type Error = D::Err; + + fn try_forward(&self, src: Src) -> Result { let (src, tape) = src.split_tape(); - self.forward((src.clone().put_tape(tape), src.clone(), src)) + self.try_forward((src.clone().put_tape(tape), src.clone(), src)) } } impl, T> ModuleMut for MultiHeadAttention where - Self: Module, + Self: Module, { type Output = >::Output; + type Error = D::Err; - fn forward_mut(&mut self, t: T) -> Self::Output { - self.forward(t) + fn try_forward_mut(&mut self, t: T) -> Result { + self.try_forward(t) } } diff --git a/src/nn/transformer/mod.rs b/src/nn/transformer/mod.rs index b9865dbe0..5702ea50f 100644 --- a/src/nn/transformer/mod.rs +++ b/src/nn/transformer/mod.rs @@ -133,17 +133,19 @@ impl< Tgt: PutTape, > Module<(Src, Tgt)> for Transformer where - TransformerEncoder: Module, + TransformerEncoder: Module, TransformerDecoder: Module< (>::Output, Src::NoTape), Output = >::Output, + Error = D::Err, >, { type Output = >::Output; + type Error = D::Err; - fn forward(&self, (src, tgt): (Src, Tgt)) -> Self::Output { - let (mem, tape) = self.encoder.forward(src).split_tape(); - self.decoder.forward((tgt.put_tape(tape), mem)) + fn try_forward(&self, (src, tgt): (Src, Tgt)) -> Result { + let (mem, tape) = self.encoder.try_forward(src)?.split_tape(); + self.decoder.try_forward((tgt.put_tape(tape), mem)) } } @@ -152,11 +154,13 @@ impl, - Self: Module, + Self: Module, { type Output = >::Output; - fn forward_mut(&mut self, t: T) -> Self::Output { - self.forward(t) + type Error = D::Err; + + fn try_forward_mut(&mut self, t: T) -> Result { + self.try_forward(t) } }