From 513470e08285b73277e73e8e4304e73b39110acd Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 26 Jan 2023 10:46:10 -0500 Subject: [PATCH] Removing Device generic from Gradients & optimizers (#402) * Removing Device generic from Gradients & optimizers * Formatting --- examples/04-gradients.rs | 2 +- examples/06-mnist.rs | 2 +- src/gradients.rs | 95 ++++++++++++---------------- src/lib.rs | 4 +- src/nn/add_into.rs | 2 +- src/nn/impl_module_for_tuples.rs | 2 +- src/nn/layer_norm.rs | 2 +- src/nn/linear.rs | 2 +- src/nn/mod.rs | 4 +- src/nn/repeated.rs | 2 +- src/nn/split_into.rs | 2 +- src/optim/adam/mod.rs | 20 +++--- src/optim/mod.rs | 2 +- src/optim/optimizer.rs | 44 +++++++------ src/optim/rmsprop/mod.rs | 22 +++---- src/optim/sgd/mod.rs | 18 +++--- src/tensor/storage_traits.rs | 12 ++-- src/tensor_ops/utilities/backward.rs | 12 ++-- 18 files changed, 121 insertions(+), 128 deletions(-) diff --git a/examples/04-gradients.rs b/examples/04-gradients.rs index 8a30830e0..ad3536b19 100644 --- a/examples/04-gradients.rs +++ b/examples/04-gradients.rs @@ -31,7 +31,7 @@ fn main() { // finally you can use .backward() to extract the gradients! // NOTE: that this method is only available on tensors that **own** // the tape! - let gradients: Gradients = e.backward(); + let gradients: Gradients = e.backward(); // now you can extract gradients for specific tensors // by querying with them diff --git a/examples/06-mnist.rs b/examples/06-mnist.rs index 81a94d28a..a924550ca 100644 --- a/examples/06-mnist.rs +++ b/examples/06-mnist.rs @@ -99,7 +99,7 @@ fn main() { // initialize model and optimizer let mut model: Mlp = dev.build_module(); - let mut opt: Adam = Default::default(); + let mut opt: Adam = Default::default(); // initialize dataset let dataset = MnistDataset::train(&mnist_path); diff --git a/src/gradients.rs b/src/gradients.rs index 63d33611d..36b8efc71 100644 --- a/src/gradients.rs +++ b/src/gradients.rs @@ -1,11 +1,9 @@ //! Implementations of [GradientTape] and generic Nd array containers via [Gradients]. #![allow(clippy::type_complexity)] -use core::marker::PhantomData; use std::collections::HashMap; use std::{boxed::Box, vec::Vec}; -use crate::shapes::{HasDtype, HasShape}; use crate::tensor::storage_traits::{AllocGrad, DeviceStorage}; use crate::unique_id::{HasUniqueId, UniqueId}; @@ -24,28 +22,24 @@ use crate::unique_id::{HasUniqueId, UniqueId}; /// important part of key's implementing [HasShape], and [HasDtype] is that the associated type /// of that trait is used to downcast the box to the expected value. #[derive(Debug, Default)] -pub struct Gradients { +pub struct Gradients { gradient_by_id: HashMap>, - device: PhantomData<*const D>, } -impl Gradients { +impl Gradients { /// Retrieves mutable gradient for `t`, allocating one if it isn't present. - pub(crate) fn get_or_alloc_mut( - &mut self, - t: &T, - ) -> Result<&mut D::Storage, D::Err> + pub(crate) fn get_or_alloc_mut(&mut self, t: &T) -> Result<&mut T::Gradient, T::Err> where - T: HasUniqueId + AllocGrad, + T: HasUniqueId + AllocGrad, { self.try_alloc_for(t)?; Ok(self.get_mut(t)) } /// Inserts a gradient for `t` - pub(crate) fn try_alloc_for(&mut self, t: &T) -> Result<(), D::Err> + pub(crate) fn try_alloc_for(&mut self, t: &T) -> Result<(), T::Err> where - T: HasUniqueId + AllocGrad, + T: HasUniqueId + AllocGrad, { if !self.gradient_by_id.contains_key(t.id()) { let grad = t.try_alloc_grad()?; @@ -57,10 +51,10 @@ impl Gradients { /// Removes and returns the data associated with `t.id()`. /// /// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug. - pub(crate) fn remove( - &mut self, - t: &T, - ) -> Option> { + pub(crate) fn remove(&mut self, t: &T) -> Option + where + T: HasUniqueId + AllocGrad, + { self.gradient_by_id .remove_entry(t.id()) .map(|e| *e.1.downcast().unwrap()) @@ -69,9 +63,9 @@ impl Gradients { /// Returns a mutable reference to the data associated with `t`. /// /// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug. - pub(crate) fn get_mut(&mut self, t: &T) -> &mut D::Storage + pub(crate) fn get_mut(&mut self, t: &T) -> &mut T::Gradient where - T: HasUniqueId + HasDtype + HasShape, + T: HasUniqueId + AllocGrad, { self.gradient_by_id .get_mut(t.id()) @@ -86,10 +80,10 @@ impl Gradients { /// /// If no data is associated with `t` yet, this will panic due to an unwrap() /// on a .get() to the underlying hashmap. - pub fn get( - &self, - t: &T, - ) -> &D::Storage { + pub fn get(&self, t: &T) -> &T::Gradient + where + T: HasUniqueId + AllocGrad, + { self.gradient_by_id .get(t.id()) .unwrap() @@ -102,17 +96,10 @@ impl Gradients { /// `l` is the gradient to update, and `r` is the gradient to backprop. /// /// **Panics** if `l` and `r` have the same id. - pub(crate) fn mut_and_ref( - &mut self, - l: &L, - r: &R, - ) -> ( - &mut D::Storage, - &D::Storage, - ) + pub(crate) fn mut_and_ref(&mut self, l: &L, r: &R) -> (&mut L::Gradient, &R::Gradient) where - L: HasUniqueId + HasShape + HasDtype, - R: HasUniqueId + HasShape + HasDtype, + L: HasUniqueId + AllocGrad, + R: HasUniqueId + AllocGrad, { assert_ne!(l.id(), r.id()); let l_ptr = self.get_mut(l) as *mut _; @@ -128,15 +115,11 @@ impl Gradients { l1: &L1, l2: &L2, r: &R, - ) -> ( - &mut D::Storage, - &mut D::Storage, - &D::Storage, - ) + ) -> (&mut L1::Gradient, &mut L2::Gradient, &R::Gradient) where - L1: HasUniqueId + HasShape + HasDtype, - L2: HasUniqueId + HasShape + HasDtype, - R: HasUniqueId + HasShape + HasDtype, + L1: HasUniqueId + AllocGrad, + L2: HasUniqueId + AllocGrad, + R: HasUniqueId + AllocGrad, { assert_ne!(l1.id(), l2.id()); assert_ne!(l1.id(), r.id()); @@ -183,8 +166,8 @@ impl Gradients { /// This would not be possible if these chain rule operations were inside of GradientTape! #[allow(clippy::type_complexity)] pub struct GradientTape { - operations: Vec) -> Result<(), D::Err>>>, - gradients: Gradients, + operations: Vec Result<(), D::Err>>>, + gradients: Gradients, } impl Default for GradientTape { @@ -212,7 +195,7 @@ impl GradientTape { /// * `operation` - A FnOnce that acts on [Gradients]. /// /// See src/tensor_ops for implementation examples. - pub(crate) fn add_backward_op) -> Result<(), D::Err>>( + pub(crate) fn add_backward_op Result<(), D::Err>>( &mut self, operation: F, ) { @@ -222,7 +205,7 @@ impl GradientTape { /// Compute the [Gradients]! This just runs all the operations on a new [Gradients] struct. /// /// Note that this method takes ownership of self, so it can't be called twice! - pub(crate) fn execute(mut self) -> Result, D::Err> { + pub(crate) fn execute(mut self) -> Result { for operation in self.operations.drain(..).rev() { (operation)(&mut self.gradients)?; } @@ -251,34 +234,40 @@ pub struct NoneTape; pub trait Tape: Default + Merge + Merge { /// Whether this object currently owns the [GradientTape]. This is known at compile time. const OWNS_TAPE: bool; - fn add_backward_op) -> Result<(), D::Err>>( + fn add_backward_op Result<(), D::Err>>( &mut self, operation: F, ); - fn try_alloc_grad>(&mut self, t: &T) -> Result<(), D::Err>; + fn try_alloc_grad>( + &mut self, + t: &T, + ) -> Result<(), D::Err>; } impl Tape for OwnedTape { const OWNS_TAPE: bool = true; - fn add_backward_op) -> Result<(), D::Err>>( + fn add_backward_op Result<(), D::Err>>( &mut self, operation: F, ) { self.0.add_backward_op(operation) } - fn try_alloc_grad>(&mut self, t: &T) -> Result<(), D::Err> { + fn try_alloc_grad>( + &mut self, + t: &T, + ) -> Result<(), D::Err> { self.0.gradients.try_alloc_for(t) } } impl Tape for NoneTape { const OWNS_TAPE: bool = false; - fn add_backward_op) -> Result<(), D::Err>>( - &mut self, - _: F, - ) { + fn add_backward_op Result<(), D::Err>>(&mut self, _: F) { } - fn try_alloc_grad>(&mut self, _: &T) -> Result<(), D::Err> { + fn try_alloc_grad>( + &mut self, + _: &T, + ) -> Result<(), D::Err> { Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index c4192fd1a..544243aa1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ //! let loss = cross_entropy_with_logits_loss(y, y_true); //! //! // call `backward()` to compute gradients. The tensor *must* have `OwnedTape`! -//! let gradients: Gradients = loss.backward(); +//! let gradients: Gradients = loss.backward(); //! ``` //! 7. Use an optimizer from [crate::optim] to optimize your network! //! ```rust @@ -84,7 +84,7 @@ //! # let y_true = dev.sample_normal::>().softmax(); //! # let y = model.forward(dev.zeros::>().trace()); //! # let loss = cross_entropy_with_logits_loss(y, y_true); -//! # let gradients: Gradients = loss.backward(); +//! # let gradients: Gradients = loss.backward(); //! // Use stochastic gradient descent (Sgd), with a learning rate of 1e-2, and 0.9 momentum. //! let mut opt = Sgd::new(SgdConfig { //! lr: 1e-2, diff --git a/src/nn/add_into.rs b/src/nn/add_into.rs index c5f7092ce..788ea018e 100644 --- a/src/nn/add_into.rs +++ b/src/nn/add_into.rs @@ -206,7 +206,7 @@ mod tests { fn test_missing_gradients() { let dev: TestDevice = Default::default(); let mut model: AddInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = dev.build_module(); - let mut g: SimpleUpdater<_> = Default::default(); + let mut g: SimpleUpdater = Default::default(); // no gradients present let mut unused = Default::default(); diff --git a/src/nn/impl_module_for_tuples.rs b/src/nn/impl_module_for_tuples.rs index 1f9d1fb38..f10a66061 100644 --- a/src/nn/impl_module_for_tuples.rs +++ b/src/nn/impl_module_for_tuples.rs @@ -239,7 +239,7 @@ mod tests { fn test_tuple_missing_gradients() { let dev: TestDevice = Default::default(); let mut model: (Linear<5, 3, _>, Linear<5, 3, _>, Linear<5, 3, _>) = dev.build_module(); - let mut g: SimpleUpdater<_> = Default::default(); + let mut g: SimpleUpdater = Default::default(); // no gradients present let mut unused: UnusedTensors = Default::default(); diff --git a/src/nn/layer_norm.rs b/src/nn/layer_norm.rs index 96e00bfd2..842cb99a3 100644 --- a/src/nn/layer_norm.rs +++ b/src/nn/layer_norm.rs @@ -167,7 +167,7 @@ mod tests { let dev: TestDevice = Default::default(); let mut model: LayerNorm1D<5, _> = dev.build_module(); - let mut g: SimpleUpdater<_> = Default::default(); + let mut g: SimpleUpdater = Default::default(); // no gradients present let mut unused = Default::default(); diff --git a/src/nn/linear.rs b/src/nn/linear.rs index d3b33295f..66aa8bbb0 100644 --- a/src/nn/linear.rs +++ b/src/nn/linear.rs @@ -250,7 +250,7 @@ mod tests { let dev: TestDevice = Default::default(); let mut model: Linear<5, 3, _> = dev.build_module(); - let mut g: SimpleUpdater<_> = Default::default(); + let mut g: SimpleUpdater = Default::default(); // no gradients present let mut unused = Default::default(); diff --git a/src/nn/mod.rs b/src/nn/mod.rs index 58335aa1c..c18e45c63 100644 --- a/src/nn/mod.rs +++ b/src/nn/mod.rs @@ -133,9 +133,9 @@ mod tests { use crate::{gradients::Gradients, optim::ParamUpdater, shapes::Dtype, tensor::DeviceStorage}; #[derive(Default)] - pub struct SimpleUpdater(pub Gradients); + pub struct SimpleUpdater(pub Gradients); - impl ParamUpdater for SimpleUpdater { + impl ParamUpdater for SimpleUpdater { fn update_param( &mut self, p: &mut crate::tensor::Tensor, diff --git a/src/nn/repeated.rs b/src/nn/repeated.rs index fd22aac15..c7d94552b 100644 --- a/src/nn/repeated.rs +++ b/src/nn/repeated.rs @@ -121,7 +121,7 @@ mod tests { let dev: TestDevice = Default::default(); let mut model: Repeated, 3> = dev.build_module(); - let mut g: SimpleUpdater<_> = Default::default(); + let mut g: SimpleUpdater = Default::default(); // no gradients present let mut unused = Default::default(); diff --git a/src/nn/split_into.rs b/src/nn/split_into.rs index 4097b38ec..05a81583d 100644 --- a/src/nn/split_into.rs +++ b/src/nn/split_into.rs @@ -228,7 +228,7 @@ mod tests { fn test_missing_gradients() { let dev: TestDevice = Default::default(); let mut model: SplitInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = dev.build_module(); - let mut g: SimpleUpdater<_> = Default::default(); + let mut g: SimpleUpdater = Default::default(); // no gradients present let mut unused = Default::default(); diff --git a/src/optim/adam/mod.rs b/src/optim/adam/mod.rs index 8d93cc9be..4148a8a4e 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam/mod.rs @@ -8,7 +8,7 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, shapes::{Dtype, Shape}, - tensor::{Cpu, DeviceStorage}, + tensor::DeviceStorage, }; use super::{GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, WeightDecay}; @@ -77,19 +77,19 @@ impl Default for AdamConfig { /// /// See module level documentation at [crate::optim] for examples of how to actually use an optimizer. #[derive(Debug)] -pub struct Adam { +pub struct Adam { /// Hyperparameter configuration pub cfg: AdamConfig, t: i32, - gradients: Gradients, - moment1: Gradients, - moment2: Gradients, + gradients: Gradients, + moment1: Gradients, + moment2: Gradients, marker: PhantomData<*const M>, } -impl Default for Adam +impl Default for Adam where AdamConfig: Default, { @@ -99,7 +99,7 @@ where } } -impl Adam { +impl Adam { /// Constructs using hyperparameters from `cfg`. pub fn new(cfg: AdamConfig) -> Self { Self { @@ -125,7 +125,7 @@ pub(super) trait AdamKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, E: Dtype> ParamUpdater for Adam { +impl, E: Dtype> ParamUpdater for Adam { fn update_param( &mut self, p: &mut crate::tensor::Tensor, @@ -145,14 +145,14 @@ impl, E: Dtype> ParamUpdater for Adam< } } -impl> Optimizer for Adam +impl, D: AdamKernel, E: Dtype> Optimizer for Adam where Self: ParamUpdater, { fn update( &mut self, module: &mut M, - gradients: Gradients, + gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.t = self.t.checked_add(1).unwrap(); self.gradients = gradients; diff --git a/src/optim/mod.rs b/src/optim/mod.rs index 199ca890b..f8787a779 100644 --- a/src/optim/mod.rs +++ b/src/optim/mod.rs @@ -23,7 +23,7 @@ //! # let loss = losses::mse_loss(y, dev.zeros()); //! // -- snip loss computation -- //! -//! let gradients: Gradients = loss.backward(); +//! let gradients: Gradients = loss.backward(); //! opt.update(&mut model, gradients); //! ``` diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index 2c31f19a2..6d8fe4ef9 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -1,7 +1,7 @@ use crate::{ gradients::Gradients, shapes::{Dtype, Shape}, - tensor::{DeviceStorage, Tensor}, + tensor::{DeviceStorage, HasErr, Tensor}, unique_id::{HasUniqueId, UniqueId}, }; @@ -79,7 +79,7 @@ pub(super) fn momentum_to_cuda(wd: Option>) -> (Momentum /// /// 3. Optimizer itself is generic over M, not the update method. This means a single optimizer object /// can only work on objects of type `M`. This also requires you to specify the model up front for the optimizer. -pub trait Optimizer, D: DeviceStorage, E: Dtype> { +pub trait Optimizer { /// Updates all of `module`'s parameters using `gradients`. /// /// Requires a `&mut self` because the optimizer may change some internally @@ -87,10 +87,30 @@ pub trait Optimizer, D: DeviceStorage, E: Dtype> { fn update( &mut self, module: &mut M, - gradients: Gradients, + gradients: Gradients, ) -> Result<(), OptimizerUpdateError>; } +/// Represents something that can be updated with a [ParamUpdater]. +pub trait GradientUpdate { + /// Updates self given the [ParamUpdater]. + fn update>( + &mut self, + updater: &mut U, + unused: &mut UnusedTensors, + ) -> Result<(), D::Err>; +} + +impl GradientUpdate for Tensor { + fn update>( + &mut self, + updater: &mut U, + unused: &mut UnusedTensors, + ) -> Result<(), ::Err> { + updater.update_param(self, unused) + } +} + /// Represents something that can update a tensor. /// /// See [crate::optim::Sgd] and [crate::optim::Adam] for examples on implementing this. @@ -124,24 +144,6 @@ impl UnusedTensors { } } -/// Represents something that can be updated with a [ParamUpdater]. -pub trait GradientUpdate: Sized { - /// Updates self given the [ParamUpdater]. - fn update(&mut self, updater: &mut U, unused: &mut UnusedTensors) -> Result<(), D::Err> - where - U: ParamUpdater; -} - -impl GradientUpdate for Tensor { - fn update>( - &mut self, - opt: &mut U, - unused: &mut UnusedTensors, - ) -> Result<(), D::Err> { - opt.update_param(self, unused) - } -} - /// An error indicating that a parameter was not used in gradient /// computation, and was therefore not present in [Gradients] /// while a [GradientUpdate] was trying to update it. diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop/mod.rs index a989fa866..4cc1f3f90 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop/mod.rs @@ -8,7 +8,7 @@ use std::marker::PhantomData; use crate::{ gradients::Gradients, shapes::{Dtype, Shape}, - tensor::{Cpu, DeviceStorage, OneFillStorage, Tensor}, + tensor::{DeviceStorage, OneFillStorage, Tensor}, }; use super::{ @@ -87,20 +87,20 @@ impl Default for RMSpropConfig { /// /// See module level documentation at [crate::optim] for examples of how to actually use an optimizer. #[derive(Debug)] -pub struct RMSprop { +pub struct RMSprop { /// Hyperparameter configuration pub cfg: RMSpropConfig, step: usize, - momentums: Gradients, - square_avg: Gradients, - grad_avg: Gradients, - gradients: Gradients, + momentums: Gradients, + square_avg: Gradients, + grad_avg: Gradients, + gradients: Gradients, marker: PhantomData<*const M>, } -impl Default for RMSprop +impl Default for RMSprop where RMSpropConfig: Default, { @@ -110,7 +110,7 @@ where } } -impl RMSprop { +impl RMSprop { /// Constructs using hyperparameters from `cfg`. pub fn new(cfg: RMSpropConfig) -> Self { Self { @@ -137,7 +137,7 @@ pub(super) trait RMSpropKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl + OneFillStorage> ParamUpdater for RMSprop { +impl + OneFillStorage> ParamUpdater for RMSprop { fn update_param( &mut self, p: &mut Tensor, @@ -162,14 +162,14 @@ impl + OneFillStorage> ParamUpdater for RM } } -impl> Optimizer for RMSprop +impl, D: RMSpropKernel, E: Dtype> Optimizer for RMSprop where Self: ParamUpdater, { fn update( &mut self, module: &mut M, - gradients: Gradients, + gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.gradients = gradients; let mut unused = Default::default(); diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd/mod.rs index b68d063b7..d5613f3f7 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd/mod.rs @@ -7,7 +7,7 @@ use std::marker::PhantomData; use crate::gradients::Gradients; use crate::shapes::{Dtype, Shape}; -use crate::tensor::{Cpu, DeviceStorage, Tensor}; +use crate::tensor::{DeviceStorage, Tensor}; use super::optimizer::*; @@ -115,17 +115,17 @@ impl Default for SgdConfig { /// /// See module level documentation at [crate::optim] for examples of how to actually use an optimizer. #[derive(Debug)] -pub struct Sgd { +pub struct Sgd { /// Hyperparameter configuration pub cfg: SgdConfig, - velocity: Gradients, - gradients: Gradients, + velocity: Gradients, + gradients: Gradients, marker: PhantomData<*const M>, } -impl Default for Sgd +impl Default for Sgd where SgdConfig: Default, { @@ -135,7 +135,7 @@ where } } -impl Sgd { +impl Sgd { /// Constructs using hyperparameters from `cfg` pub fn new(cfg: SgdConfig) -> Self { Self { @@ -157,7 +157,7 @@ pub(super) trait SgdKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -impl, E: Dtype> ParamUpdater for Sgd { +impl, E: Dtype> ParamUpdater for Sgd { fn update_param( &mut self, p: &mut Tensor, @@ -175,14 +175,14 @@ impl, E: Dtype> ParamUpdater for Sgd { } } -impl> Optimizer for Sgd +impl, D: SgdKernel, E: Dtype> Optimizer for Sgd where Self: ParamUpdater, { fn update( &mut self, module: &mut M, - gradients: Gradients, + gradients: Gradients, ) -> Result<(), OptimizerUpdateError> { self.gradients = gradients; let mut unused = Default::default(); diff --git a/src/tensor/storage_traits.rs b/src/tensor/storage_traits.rs index 5857c2121..b5ab11e25 100644 --- a/src/tensor/storage_traits.rs +++ b/src/tensor/storage_traits.rs @@ -2,7 +2,7 @@ use rand::distributions::Distribution; use rand_distr::{Standard, StandardNormal}; use crate::{ - shapes::{ConstShape, Dtype, HasDtype, HasShape, HasUnitType, Shape, Unit}, + shapes::{ConstShape, Dtype, HasShape, HasUnitType, Shape, Unit}, unique_id::unique_id, }; @@ -44,12 +44,14 @@ pub trait DeviceStorage: 'static + Default + Clone + HasErr { } /// Internal trait - Represents something that can allocate its own gradient. -pub trait AllocGrad: HasShape + HasDtype { - fn try_alloc_grad(&self) -> Result, D::Err>; +pub trait AllocGrad: HasErr { + type Gradient: 'static; + fn try_alloc_grad(&self) -> Result; } -impl AllocGrad for Tensor { - fn try_alloc_grad(&self) -> Result, D::Err> { +impl AllocGrad for Tensor { + type Gradient = D::Storage; + fn try_alloc_grad(&self) -> Result { self.device.try_alloc_grad(&self.storage) } } diff --git a/src/tensor_ops/utilities/backward.rs b/src/tensor_ops/utilities/backward.rs index c99a1690c..bcc6343ab 100644 --- a/src/tensor_ops/utilities/backward.rs +++ b/src/tensor_ops/utilities/backward.rs @@ -1,21 +1,21 @@ use crate::gradients::{Gradients, OwnedTape, Tape}; use crate::shapes::{Dtype, Rank0}; -use crate::tensor::{DeviceStorage, OneFillStorage, SplitTape, Tensor}; +use crate::tensor::{HasErr, OneFillStorage, SplitTape, Tensor}; /// Runs backprop algorithm with all operations contained in the tape that `t` has. /// /// This function takes ownership of `self` and returns [Gradients]. -pub trait Backward: Sized { +pub trait Backward: HasErr { /// Runs backprop - fn backward(self) -> Gradients { + fn backward(self) -> Gradients { self.try_backward().unwrap() } /// Fallible version of [Backward::backward] - fn try_backward(self) -> Result, D::Err>; + fn try_backward(self) -> Result; } -impl> Backward for Tensor> { - fn try_backward(self) -> Result, D::Err> { +impl> Backward for Tensor> { + fn try_backward(self) -> Result { let (t, mut tape) = self.split_tape(); tape.add_backward_op(move |grads| t.device.try_fill_with_ones(grads.get_mut(&t))); tape.0.execute()