Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing Device generic from Gradients & optimizers #402

Merged
merged 3 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/04-gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() {
// finally you can use .backward() to extract the gradients!
// NOTE: that this method is only available on tensors that **own**
// the tape!
let gradients: Gradients<Cpu> = e.backward();
let gradients: Gradients = e.backward();

// now you can extract gradients for specific tensors
// by querying with them
Expand Down
2 changes: 1 addition & 1 deletion examples/06-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ fn main() {

// initialize model and optimizer
let mut model: Mlp = dev.build_module();
let mut opt: Adam<Mlp, Dev> = Default::default();
let mut opt: Adam<Mlp> = Default::default();

// initialize dataset
let dataset = MnistDataset::train(&mnist_path);
Expand Down
95 changes: 42 additions & 53 deletions src/gradients.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
//! Implementations of [GradientTape] and generic Nd array containers via [Gradients].
#![allow(clippy::type_complexity)]

use core::marker::PhantomData;
use std::collections::HashMap;
use std::{boxed::Box, vec::Vec};

use crate::shapes::{HasDtype, HasShape};
use crate::tensor::storage_traits::{AllocGrad, DeviceStorage};
use crate::unique_id::{HasUniqueId, UniqueId};

Expand All @@ -24,28 +22,24 @@ use crate::unique_id::{HasUniqueId, UniqueId};
/// important part of key's implementing [HasShape], and [HasDtype] is that the associated type
/// of that trait is used to downcast the box to the expected value.
#[derive(Debug, Default)]
pub struct Gradients<D: DeviceStorage> {
pub struct Gradients {
gradient_by_id: HashMap<UniqueId, Box<dyn std::any::Any>>,
device: PhantomData<*const D>,
}

impl<D: DeviceStorage> Gradients<D> {
impl Gradients {
/// Retrieves mutable gradient for `t`, allocating one if it isn't present.
pub(crate) fn get_or_alloc_mut<T>(
&mut self,
t: &T,
) -> Result<&mut D::Storage<T::Shape, T::Dtype>, D::Err>
pub(crate) fn get_or_alloc_mut<T>(&mut self, t: &T) -> Result<&mut T::Gradient, T::Err>
where
T: HasUniqueId + AllocGrad<D>,
T: HasUniqueId + AllocGrad,
{
self.try_alloc_for(t)?;
Ok(self.get_mut(t))
}

/// Inserts a gradient for `t`
pub(crate) fn try_alloc_for<T>(&mut self, t: &T) -> Result<(), D::Err>
pub(crate) fn try_alloc_for<T>(&mut self, t: &T) -> Result<(), T::Err>
where
T: HasUniqueId + AllocGrad<D>,
T: HasUniqueId + AllocGrad,
{
if !self.gradient_by_id.contains_key(t.id()) {
let grad = t.try_alloc_grad()?;
Expand All @@ -57,10 +51,10 @@ impl<D: DeviceStorage> Gradients<D> {
/// Removes and returns the data associated with `t.id()`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn remove<T: HasUniqueId + HasShape + HasDtype>(
&mut self,
t: &T,
) -> Option<D::Storage<T::Shape, T::Dtype>> {
pub(crate) fn remove<T>(&mut self, t: &T) -> Option<T::Gradient>
where
T: HasUniqueId + AllocGrad,
{
self.gradient_by_id
.remove_entry(t.id())
.map(|e| *e.1.downcast().unwrap())
Expand All @@ -69,9 +63,9 @@ impl<D: DeviceStorage> Gradients<D> {
/// Returns a mutable reference to the data associated with `t`.
///
/// **Panics** if data associated with `t` is not found. This indicates an unrecoverable bug.
pub(crate) fn get_mut<T>(&mut self, t: &T) -> &mut D::Storage<T::Shape, T::Dtype>
pub(crate) fn get_mut<T>(&mut self, t: &T) -> &mut T::Gradient
where
T: HasUniqueId + HasDtype + HasShape,
T: HasUniqueId + AllocGrad,
{
self.gradient_by_id
.get_mut(t.id())
Expand All @@ -86,10 +80,10 @@ impl<D: DeviceStorage> Gradients<D> {
///
/// If no data is associated with `t` yet, this will panic due to an unwrap()
/// on a .get() to the underlying hashmap.
pub fn get<T: HasUniqueId + HasDtype + HasShape>(
&self,
t: &T,
) -> &D::Storage<T::Shape, T::Dtype> {
pub fn get<T>(&self, t: &T) -> &T::Gradient
where
T: HasUniqueId + AllocGrad,
{
self.gradient_by_id
.get(t.id())
.unwrap()
Expand All @@ -102,17 +96,10 @@ impl<D: DeviceStorage> Gradients<D> {
/// `l` is the gradient to update, and `r` is the gradient to backprop.
///
/// **Panics** if `l` and `r` have the same id.
pub(crate) fn mut_and_ref<L, R>(
&mut self,
l: &L,
r: &R,
) -> (
&mut D::Storage<L::Shape, L::Dtype>,
&D::Storage<R::Shape, R::Dtype>,
)
pub(crate) fn mut_and_ref<L, R>(&mut self, l: &L, r: &R) -> (&mut L::Gradient, &R::Gradient)
where
L: HasUniqueId + HasShape + HasDtype,
R: HasUniqueId + HasShape + HasDtype,
L: HasUniqueId + AllocGrad,
R: HasUniqueId + AllocGrad,
{
assert_ne!(l.id(), r.id());
let l_ptr = self.get_mut(l) as *mut _;
Expand All @@ -128,15 +115,11 @@ impl<D: DeviceStorage> Gradients<D> {
l1: &L1,
l2: &L2,
r: &R,
) -> (
&mut D::Storage<L1::Shape, L1::Dtype>,
&mut D::Storage<L2::Shape, L2::Dtype>,
&D::Storage<R::Shape, R::Dtype>,
)
) -> (&mut L1::Gradient, &mut L2::Gradient, &R::Gradient)
where
L1: HasUniqueId + HasShape + HasDtype,
L2: HasUniqueId + HasShape + HasDtype,
R: HasUniqueId + HasShape + HasDtype,
L1: HasUniqueId + AllocGrad,
L2: HasUniqueId + AllocGrad,
R: HasUniqueId + AllocGrad,
{
assert_ne!(l1.id(), l2.id());
assert_ne!(l1.id(), r.id());
Expand Down Expand Up @@ -183,8 +166,8 @@ impl<D: DeviceStorage> Gradients<D> {
/// This would not be possible if these chain rule operations were inside of GradientTape!
#[allow(clippy::type_complexity)]
pub struct GradientTape<D: DeviceStorage> {
operations: Vec<Box<dyn FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>>,
gradients: Gradients<D>,
operations: Vec<Box<dyn FnOnce(&mut Gradients) -> Result<(), D::Err>>>,
gradients: Gradients,
}

impl<D: DeviceStorage> Default for GradientTape<D> {
Expand Down Expand Up @@ -212,7 +195,7 @@ impl<D: DeviceStorage> GradientTape<D> {
/// * `operation` - A FnOnce that acts on [Gradients].
///
/// See src/tensor_ops for implementation examples.
pub(crate) fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
pub(crate) fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(
&mut self,
operation: F,
) {
Expand All @@ -222,7 +205,7 @@ impl<D: DeviceStorage> GradientTape<D> {
/// Compute the [Gradients]! This just runs all the operations on a new [Gradients] struct.
///
/// Note that this method takes ownership of self, so it can't be called twice!
pub(crate) fn execute(mut self) -> Result<Gradients<D>, D::Err> {
pub(crate) fn execute(mut self) -> Result<Gradients, D::Err> {
for operation in self.operations.drain(..).rev() {
(operation)(&mut self.gradients)?;
}
Expand Down Expand Up @@ -251,34 +234,40 @@ pub struct NoneTape;
pub trait Tape<D: DeviceStorage>: Default + Merge<Self> + Merge<NoneTape> {
/// Whether this object currently owns the [GradientTape]. This is known at compile time.
const OWNS_TAPE: bool;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(
&mut self,
operation: F,
);
fn try_alloc_grad<T: HasUniqueId + AllocGrad<D>>(&mut self, t: &T) -> Result<(), D::Err>;
fn try_alloc_grad<T: HasUniqueId + AllocGrad<Err = D::Err>>(
&mut self,
t: &T,
) -> Result<(), D::Err>;
}

impl<D: DeviceStorage> Tape<D> for OwnedTape<D> {
const OWNS_TAPE: bool = true;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(
&mut self,
operation: F,
) {
self.0.add_backward_op(operation)
}
fn try_alloc_grad<T: HasUniqueId + AllocGrad<D>>(&mut self, t: &T) -> Result<(), D::Err> {
fn try_alloc_grad<T: HasUniqueId + AllocGrad<Err = D::Err>>(
&mut self,
t: &T,
) -> Result<(), D::Err> {
self.0.gradients.try_alloc_for(t)
}
}

impl<D: DeviceStorage> Tape<D> for NoneTape {
const OWNS_TAPE: bool = false;
fn add_backward_op<F: 'static + FnOnce(&mut Gradients<D>) -> Result<(), D::Err>>(
&mut self,
_: F,
) {
fn add_backward_op<F: 'static + FnOnce(&mut Gradients) -> Result<(), D::Err>>(&mut self, _: F) {
}
fn try_alloc_grad<T: HasUniqueId + AllocGrad<D>>(&mut self, _: &T) -> Result<(), D::Err> {
fn try_alloc_grad<T: HasUniqueId + AllocGrad<Err = D::Err>>(
&mut self,
_: &T,
) -> Result<(), D::Err> {
Ok(())
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
//! let loss = cross_entropy_with_logits_loss(y, y_true);
//!
//! // call `backward()` to compute gradients. The tensor *must* have `OwnedTape`!
//! let gradients: Gradients<Cpu> = loss.backward();
//! let gradients: Gradients = loss.backward();
//! ```
//! 7. Use an optimizer from [crate::optim] to optimize your network!
//! ```rust
Expand All @@ -84,7 +84,7 @@
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
//! # let gradients: Gradients<Cpu> = loss.backward();
//! # let gradients: Gradients = loss.backward();
//! // Use stochastic gradient descent (Sgd), with a learning rate of 1e-2, and 0.9 momentum.
//! let mut opt = Sgd::new(SgdConfig {
//! lr: 1e-2,
Expand Down
2 changes: 1 addition & 1 deletion src/nn/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ mod tests {
fn test_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: AddInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/impl_module_for_tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ mod tests {
fn test_tuple_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: (Linear<5, 3, _>, Linear<5, 3, _>, Linear<5, 3, _>) = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused: UnusedTensors = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ mod tests {
let dev: TestDevice = Default::default();

let mut model: LayerNorm1D<5, _> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ mod tests {
let dev: TestDevice = Default::default();

let mut model: Linear<5, 3, _> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
4 changes: 2 additions & 2 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ mod tests {
use crate::{gradients::Gradients, optim::ParamUpdater, shapes::Dtype, tensor::DeviceStorage};

#[derive(Default)]
pub struct SimpleUpdater<D: DeviceStorage>(pub Gradients<D>);
pub struct SimpleUpdater(pub Gradients);

impl<D: DeviceStorage, E: Dtype> ParamUpdater<D, E> for SimpleUpdater<D> {
impl<D: DeviceStorage, E: Dtype> ParamUpdater<D, E> for SimpleUpdater {
fn update_param<S: crate::shapes::Shape>(
&mut self,
p: &mut crate::tensor::Tensor<S, E, D>,
Expand Down
2 changes: 1 addition & 1 deletion src/nn/repeated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ mod tests {
let dev: TestDevice = Default::default();

let mut model: Repeated<Linear<5, 5, _>, 3> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion src/nn/split_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ mod tests {
fn test_missing_gradients() {
let dev: TestDevice = Default::default();
let mut model: SplitInto<(Linear<5, 3, _>, Linear<5, 3, _>)> = dev.build_module();
let mut g: SimpleUpdater<_> = Default::default();
let mut g: SimpleUpdater = Default::default();

// no gradients present
let mut unused = Default::default();
Expand Down
20 changes: 10 additions & 10 deletions src/optim/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::marker::PhantomData;
use crate::{
gradients::Gradients,
shapes::{Dtype, Shape},
tensor::{Cpu, DeviceStorage},
tensor::DeviceStorage,
};

use super::{GradientUpdate, Optimizer, OptimizerUpdateError, ParamUpdater, WeightDecay};
Expand Down Expand Up @@ -77,19 +77,19 @@ impl Default for AdamConfig<f32> {
///
/// See module level documentation at [crate::optim] for examples of how to actually use an optimizer.
#[derive(Debug)]
pub struct Adam<M, D: DeviceStorage = Cpu, E: Dtype = f32> {
pub struct Adam<M, E: Dtype = f32> {
/// Hyperparameter configuration
pub cfg: AdamConfig<E>,

t: i32,
gradients: Gradients<D>,
moment1: Gradients<D>,
moment2: Gradients<D>,
gradients: Gradients,
moment1: Gradients,
moment2: Gradients,

marker: PhantomData<*const M>,
}

impl<M, D: DeviceStorage, E: Dtype> Default for Adam<M, D, E>
impl<M, E: Dtype> Default for Adam<M, E>
where
AdamConfig<E>: Default,
{
Expand All @@ -99,7 +99,7 @@ where
}
}

impl<M, D: DeviceStorage, E: Dtype> Adam<M, D, E> {
impl<M, E: Dtype> Adam<M, E> {
/// Constructs using hyperparameters from `cfg`.
pub fn new(cfg: AdamConfig<E>) -> Self {
Self {
Expand All @@ -125,7 +125,7 @@ pub(super) trait AdamKernel<E: Dtype>: DeviceStorage {
) -> Result<(), Self::Err>;
}

impl<M, D: DeviceStorage + AdamKernel<E>, E: Dtype> ParamUpdater<D, E> for Adam<M, D, E> {
impl<M, D: DeviceStorage + AdamKernel<E>, E: Dtype> ParamUpdater<D, E> for Adam<M, E> {
fn update_param<S: Shape>(
&mut self,
p: &mut crate::tensor::Tensor<S, E, D>,
Expand All @@ -145,14 +145,14 @@ impl<M, D: DeviceStorage + AdamKernel<E>, E: Dtype> ParamUpdater<D, E> for Adam<
}
}

impl<E: Dtype, D: DeviceStorage, M: GradientUpdate<D, E>> Optimizer<M, D, E> for Adam<M, D, E>
impl<M: GradientUpdate<D, E>, D: AdamKernel<E>, E: Dtype> Optimizer<M, D, E> for Adam<M, E>
where
Self: ParamUpdater<D, E>,
{
fn update(
&mut self,
module: &mut M,
gradients: Gradients<D>,
gradients: Gradients,
) -> Result<(), OptimizerUpdateError<D>> {
self.t = self.t.checked_add(1).unwrap();
self.gradients = gradients;
Expand Down
2 changes: 1 addition & 1 deletion src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
//! # let loss = losses::mse_loss(y, dev.zeros());
//! // -- snip loss computation --
//!
//! let gradients: Gradients<Cpu> = loss.backward();
//! let gradients: Gradients = loss.backward();
//! opt.update(&mut model, gradients);
//! ```

Expand Down
Loading