Skip to content

Commit

Permalink
Removing Device generic from Gradients & optimizers (#402)
Browse files Browse the repository at this point in the history
* Removing Device generic from Gradients & optimizers

* Formatting
  • Loading branch information
coreylowman authored Jan 26, 2023
1 parent 1211cd3 commit 513470e
Show file tree
Hide file tree
Showing 18 changed files with 121 additions and 128 deletions.
2 changes: 1 addition & 1 deletion examples/04-gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() {
// finally you can use .backward() to extract the gradients!
// NOTE: that this method is only available on tensors that **own**
// the tape!
let gradients: Gradients<Cpu> = e.backward();
let gradients: Gradients = e.backward();

// now you can extract gradients for specific tensors
// by querying with them
Expand Down
2 changes: 1 addition & 1 deletion examples/06-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ fn main() {

// initialize model and optimizer
let mut model: Mlp = dev.build_module();
let mut opt: Adam<Mlp, Dev> = Default::default();
let mut opt: Adam<Mlp> = Default::default();

// initialize dataset
let dataset = MnistDataset::train(&mnist_path);
Expand Down
95 changes: 42 additions & 53 deletions src/gradients.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
//! Implementations of [GradientTape] and generic Nd array containers via [Gradients].
#![allow(clippy::type_complexity)]

use core::marker::PhantomData;
use std::collections::HashMap;
use std::{boxed::Box, vec::Vec};

use crate::shapes::{HasDtype, HasShape};
use crate::tensor::storage_traits::{AllocGrad, DeviceStorage};
use crate::unique_id::{HasUniqueId, UniqueId};

Expand All @@ -24,28 +22,24 @@ use crate::unique_id::{HasUniqueId, UniqueId};
/// important part of key's implementing [HasShape], and [HasDtype] is that the associated type
/// of that trait is used to downcast the box to the expected value.
#[derive(Debug, Default)]
pub struct Gradients<D: DeviceStorage> {
pub struct Gradients {
gradient_by_id: HashMap<UniqueId, Box<dyn std::any::Any>>,
device: PhantomData<*const D>,
}

impl<D: DeviceStorage> Gradients<D> {
impl Gradients {
/// Retrieves mutable gradient for `t`, allocating one if it isn't present.
pub(crate) fn get_or_alloc_mut<T>(
&mut self,
t: &T,
) -> Result<&mut D::Storage<T::Shape, T::Dtype>, D::Err>
pub(crate) fn get_or_alloc_mut<T>(&mut self, t: &T) -> Result<&mut T::Gradient, T::Err>
where
T: HasUniqueId + AllocGrad<D>,
T: HasUniqueId + AllocGrad,
{
self.try_alloc_for(t)?;
Ok(self.get_mut(t))
}

/// Inserts a gradient for `t`
pub(crate) fn try_alloc_for<T>(&mut self, t: &T) -> Result<(), D::Err>
pub(crate) fn try_alloc_for<T>(&mut self, t: &T) -> Result<(), T::Err>
where
T: HasUniqueId + AllocGrad<D>,
T: HasUniqueId + AllocGrad,
{
if !self.gradient_by_id.contains_key(t.id()) {
let grad = t.try_alloc_grad()?;
Expand All @@ -57,10 +51,10 @@ impl<D: DeviceStorage> Gradients<D> {
/// Removes and returns the data associated with `t.id()`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn remove<T: HasUniqueId + HasShape + HasDtype>(
&mut self,
t: &T,
) -> Option<D::Storage<T::Shape, T::Dtype>> {
pub(crate) fn remove<T>(&mut self, t: &T) -> Option<T::Gradient>
where
T: HasUniqueId + AllocGrad,
{
self.gradient_by_id
.remove_entry(t.id())
.map(|e| *e.1.downcast().unwrap())
Expand All @@ -69,9 +63,9 @@ impl<D: DeviceStorage> Gradients<D> {
/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_mut<T>(&mut self, t: &T) -> &mut D::Storage<T::Shape, T::Dtype>
pub(crate) fn get_mut<T>(&mut self, t: &T) -> &mut T::Gradient
where
T: HasUniqueId + HasDtype + HasShape,
T: HasUniqueId + AllocGrad,
{
self.gradient_by_id
.get_mut(t.id())
Expand All @@ -86,10 +80,10 @@ impl<D: DeviceStorage> Gradients<D> {
///
/// If no data is associated with `t` yet, this will panic due to an unwrap()
/// on a .get() to the underlying hashmap.
pub fn get<T: HasUniqueId + HasDtype + HasShape>(
&self,
t: &T,
) -> &D::Storage<T::Shape, T::Dtype> {
pub fn get<T>(&self, t: &T) -> &T::Gradient
where
T: HasUniqueId + AllocGrad,
{
self.gradient_by_id
.get(t.id())
.unwrap()
Expand All @@ -102,17 +96,10 @@ impl<D: DeviceStorage> Gradients<D> {
/// `l` is the gradient to update, and `r` is the gradient to backprop.
///
/// **Panics** if `l` and `r` have the same id.
pub(crate) fn mut_and_ref<L, R>(
&mut self,
l: &L,
r: &R,
) -> (
&mut D::Storage<L::Shape, L::Dtype>,
&D::Storage<R::Shape, R::Dtype>,
)
pub(crate) fn mut_and_ref<L, R>(&mut self, l: &L, r: &R) -> (&mut L::Gradient, &R::Gradient)
where
L: HasUniqueId + HasShape + HasDtype,
R: HasUniqueId + HasShape + HasDtype,
L: HasUniqueId + AllocGrad,
R: HasUniqueId + AllocGrad,
{
assert_ne!(l.id(), r.id());
let l_ptr = self.get_mut(l) as *mut _;
Expand All @@ -128,15 +115,11 @@ impl<D: DeviceStorage> Gradients<D> {
l1: &L1,
l2: &L2,
r: &R,
) -> (
&mut D::Storage<L1::Shape, L1::Dtype>,
&mut D::Storage<L2::Shape, L2::Dtype>,
&D::Storage<R::Shape, R::Dtype>,
)
) -> (&mut L1::Gradient, &mut L2::Gradient, &R::Gradient)
where
L1: HasUniqueId + HasShape + HasDtype,
L2: HasUniqueId + HasShape + HasDtype,
R: HasUniqueId + HasShape + HasDtype,
L1: HasUniqueId + AllocGrad,
L2: HasUniqueId + AllocGrad,
R: HasUniqueId + AllocGrad,
{
assert_ne!(l1.id(), l2.id());
assert_ne!(l1.id(), r.id());
Expand Down Expand Up @@ -183,8 +166,8 @@ impl<D: DeviceStorage> Gradients<D> {
/// This would not be possible if these chain rule operations were inside of GradientTape!
#[allow(clippy::type_complexity)]
pub struct GradientTape<D: DeviceStorage> {
operations: Vec<Box<dyn FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>>,
gradients: Gradients<D>,
operations: Vec<Box<dyn FnOnce(&mut Gradients) -> Result<(), D::Err>>>,
gradients: Gradients,
}

impl<D: DeviceStorage> Default for GradientTape<D> {
Expand Down Expand Up @@ -212,7 +195,7 @@ impl<D: DeviceStorage> GradientTape<D> {
/// * `operation` - A FnOnce that acts on [Gradients].
///
/// See src/tensor_ops for implementation examples.
pub(crate) fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
pub(crate) fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(
&mut self,
operation: F,
) {
Expand All @@ -222,7 +205,7 @@ impl<D: DeviceStorage> GradientTape<D> {
/// Compute the [Gradients]! This just runs all the operations on a new [Gradients] struct.
///
/// Note that this method takes ownership of self, so it can't be called twice!
pub(crate) fn execute(mut self) -> Result<Gradients<D>, D::Err> {
pub(crate) fn execute(mut self) -> Result<Gradients, D::Err> {
for operation in self.operations.drain(..).rev() {
(operation)(&mut self.gradients)?;
}
Expand Down Expand Up @@ -251,34 +234,40 @@ pub struct NoneTape;
pub trait Tape<D: DeviceStorage>: Default + Merge<Self> + Merge<NoneTape> {
/// Whether this object currently owns the [GradientTape]. This is known at compile time.
const OWNS_TAPE: bool;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(
&mut self,
operation: F,
);
fn try_alloc_grad<T: HasUniqueId + AllocGrad<D>>(&mut self, t: &T) -> Result<(), D::Err>;
fn try_alloc_grad<T: HasUniqueId + AllocGrad<Err = D::Err>>(
&mut self,
t: &T,
) -> Result<(), D::Err>;
}

impl<D: DeviceStorage> Tape<D> for OwnedTape<D> {
const OWNS_TAPE: bool = true;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(
&mut self,
operation: F,
) {
self.0.add_backward_op(operation)
}
fn try_alloc_grad<T: HasUniqueId + AllocGrad<D>>(&mut self, t: &T) -> Result<(), D::Err> {
fn try_alloc_grad<T: HasUniqueId + AllocGrad<Err = D::Err>>(
&mut self,
t: &T,
) -> Result<(), D::Err> {
self.0.gradients.try_alloc_for(t)
}
}

impl<D: DeviceStorage> Tape<D> for NoneTape {
const OWNS_TAPE: bool = false;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
&mut self,
_: F,
) {
fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(&mut self, _: F) {
}
fn try_alloc_grad<T: HasUniqueId + AllocGrad<D>>(&mut self, _: &T) -> Result<(), D::Err> {
fn try_alloc_grad<T: HasUniqueId + AllocGrad<Err = D::Err>>(
&mut self,
_: &T,
) -> Result<(), D::Err> {
Ok(())
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
//! let loss = cross_entropy_with_logits_loss(y, y_true);
//!
//! // call `backward()` to compute gradients. The tensor *must* have `OwnedTape`!
//! let gradients: Gradients<Cpu> = loss.backward();
//! let gradients: Gradients = loss.backward();
//! ```
//! 7. Use an optimizer from [crate::optim] to optimize your network!
//! ```rust
Expand All @@ -84,7 +84,7 @@
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
//! # let gradients: Gradients<Cpu> = loss.backward();
//! # let gradients: Gradients = loss.backward();
//! // Use stochastic gradient descent (Sgd), with a learning rate of 1e-2, and 0.9 momentum.
//! let mut opt = Sgd::new(SgdConfig {
//! lr: 1e-2,
Expand Down
2 changes: 1 addition & 1 deletion src/nn/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ mod tests {
fn test_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: AddInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/impl_module_for_tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ mod tests {
fn test_tuple_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: (Linear<5, 3, _>, Linear<5, 3, _>, Linear<5, 3, _>) = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused: UnusedTensors = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ mod tests {
let dev: TestDevice = Default::default();

let mut model: LayerNorm1D<5, _> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ mod tests {
let dev: TestDevice = Default::default();

let mut model: Linear<5, 3, _> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
4 changes: 2 additions & 2 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ mod tests {
use crate::{gradients::Gradients, optim::ParamUpdater, shapes::Dtype, tensor::DeviceStorage};

#[derive(Default)]
pub struct SimpleUpdater<D: DeviceStorage>(pub Gradients<D>);
pub struct SimpleUpdater(pub Gradients);

impl<D: DeviceStorage, E: Dtype> ParamUpdater<D, E> for SimpleUpdater<D> {
impl<D: DeviceStorage, E: Dtype> ParamUpdater<D, E> for SimpleUpdater {
fn update_param<S: crate::shapes::Shape>(
&mut self,
p: &mut crate::tensor::Tensor<S, E, D>,
Expand Down
2 changes: 1 addition & 1 deletion src/nn/repeated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ mod tests {
let dev: TestDevice = Default::default();

let mut model: Repeated<Linear<5, 5, _>, 3> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/split_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ mod tests {
fn test_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: SplitInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
20 changes: 10 additions & 10 deletions src/optim/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::marker::PhantomData;
use crate::{
gradients::Gradients,
shapes::{Dtype, Shape},
tensor::{Cpu, DeviceStorage},
tensor::DeviceStorage,
};

use super::{GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, WeightDecay};
Expand Down Expand Up @@ -77,19 +77,19 @@ impl Default for AdamConfig<f32> {
///
/// See module level documentation at [crate::optim] for examples of how to actually use an optimizer.
#[derive(Debug)]
pub struct Adam<M, D: DeviceStorage = Cpu, E: Dtype = f32> {
pub struct Adam<M, E: Dtype = f32> {
/// Hyperparameter configuration
pub cfg: AdamConfig<E>,

t: i32,
gradients: Gradients<D>,
moment1: Gradients<D>,
moment2: Gradients<D>,
gradients: Gradients,
moment1: Gradients,
moment2: Gradients,

marker: PhantomData<*const M>,
}

impl<M, D: DeviceStorage, E: Dtype> Default for Adam<M, D, E>
impl<M, E: Dtype> Default for Adam<M, E>
where
AdamConfig<E>: Default,
{
Expand All @@ -99,7 +99,7 @@ where
}
}

impl<M, D: DeviceStorage, E: Dtype> Adam<M, D, E> {
impl<M, E: Dtype> Adam<M, E> {
/// Constructs using hyperparameters from `cfg`.
pub fn new(cfg: AdamConfig<E>) -> Self {
Self {
Expand All @@ -125,7 +125,7 @@ pub(super) trait AdamKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

impl<M, D: DeviceStorage + AdamKernel<E>, E: Dtype> ParamUpdater<D, E> for Adam<M, D, E> {
impl<M, D: DeviceStorage + AdamKernel<E>, E: Dtype> ParamUpdater<D, E> for Adam<M, E> {
fn update_param<S: Shape>(
&mut self,
p: &mut crate::tensor::Tensor<S, E, D>,
Expand All @@ -145,14 +145,14 @@ impl<M, D: DeviceStorage + AdamKernel<E>, E: Dtype> ParamUpdater<D, E> for Adam<
}
}

impl<E: Dtype, D: DeviceStorage, M: GradientUpdate<D, E>> Optimizer<M, D, E> for Adam<M, D, E>
impl<M: GradientUpdate<D, E>, D: AdamKernel<E>, E: Dtype> Optimizer<M, D, E> for Adam<M, E>
where
Self: ParamUpdater<D, E>,
{
fn update(
&mut self,
module: &mut M,
gradients: Gradients<D>,
gradients: Gradients,
) -> Result<(), OptimizerUpdateError<D>> {
self.t = self.t.checked_add(1).unwrap();
self.gradients = gradients;
Expand Down
2 changes: 1 addition & 1 deletion src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
//! # let loss = losses::mse_loss(y, dev.zeros());
//! // -- snip loss computation --
//!
//! let gradients: Gradients<Cpu> = loss.backward();
//! let gradients: Gradients = loss.backward();
//! opt.update(&mut model, gradients);
//! ```
Expand Down
Loading

0 comments on commit 513470e

Please sign in to comment.