Skip to content

Commit

Permalink
Moving optim kernels to tensor ops (#828)
Browse files Browse the repository at this point in the history
* Moving optim kernels to tensor ops

* Re-exporting stuff in optim

* Using make_mut instead of get_mut
  • Loading branch information
coreylowman authored Jul 24, 2023
1 parent bdeec8a commit d48d6f8
Show file tree
Hide file tree
Showing 20 changed files with 371 additions and 313 deletions.
71 changes: 4 additions & 67 deletions src/optim/adam/mod.rs → src/optim/adam.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,13 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use std::{marker::PhantomData, sync::Arc};
use std::marker::PhantomData;

use crate::{
nn::tensor_collection::*,
shapes::{Dtype, Shape},
tensor::{Gradients, Storage, Tensor},
tensor_ops::Device,
tensor_ops::{AdamConfig, Device},
};

use super::{Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay};

/// Configuration of hyperparameters for [Adam].
///
/// Changing all default parameters:
/// ```rust
/// # use dfdx::{prelude::*, optim::*};
/// AdamConfig {
/// lr: 1e-2,
/// betas: [0.1, 0.2],
/// eps: 1e-6,
/// weight_decay: Some(WeightDecay::L2(1e-1)),
/// };
/// ```
#[derive(Debug, Clone, Copy)]
pub struct AdamConfig {
/// Learning rate. Defaults to `1e-3`.
pub lr: f64,

/// Betas from Adam paper. Defaults to `[0.9, 0.999]`.
pub betas: [f64; 2],

/// Epsilon for numerical stability. Defaults to `1e-8`.
pub eps: f64,

/// Optional weight decay. Defaults to `None`.
pub weight_decay: Option<WeightDecay>,
}

impl Default for AdamConfig {
fn default() -> Self {
Self {
lr: 1e-3,
betas: [0.9, 0.999],
eps: 1e-8,
weight_decay: None,
}
}
}
use super::{Optimizer, OptimizerUpdateError, UnusedTensors};

/// An implementation of the Adam optimizer from
/// [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
Expand Down Expand Up @@ -95,18 +52,6 @@ impl<M, E: Dtype, D: Storage<E>> Adam<M, E, D> {
}
}

pub trait AdamKernel<E: Dtype>: Storage<E> {
fn update(
&self,
t: i32,
cfg: &AdamConfig,
param: &mut Self::Vec,
moment1: &mut Self::Vec,
moment2: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Self::Err>;
}

impl<M, D: Device<E>, E: Dtype> TensorVisitor<E, D>
for (&mut Adam<M, E, D>, &Gradients<E, D>, UnusedTensors)
{
Expand All @@ -129,15 +74,7 @@ impl<M, D: Device<E>, E: Dtype> TensorVisitor<E, D>
Some(g) => {
let m_t = self.0.moment1.get_or_alloc_mut(p)?;
let v_t = self.0.moment2.get_or_alloc_mut(p)?;
AdamKernel::update(
&p.device,
self.0.t,
&self.0.cfg,
Arc::make_mut(&mut p.data),
m_t,
v_t,
g,
)?;
self.0.cfg.try_update(self.0.t, p, m_t, v_t, g)?;
}
}
Ok(None)
Expand Down
10 changes: 6 additions & 4 deletions src/optim/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ mod optimizer;
mod rmsprop;
mod sgd;

pub use adam::{Adam, AdamConfig, AdamKernel};
pub use optimizer::{Momentum, WeightDecay};
pub use adam::Adam;
pub use optimizer::{Optimizer, OptimizerUpdateError, UnusedTensors};
pub use rmsprop::{RMSprop, RMSpropConfig, RMSpropKernel};
pub use sgd::{Sgd, SgdConfig, SgdKernel};
pub use rmsprop::RMSprop;
pub use sgd::Sgd;

// re-exports
pub use crate::tensor_ops::{AdamConfig, Momentum, RMSpropConfig, SgdConfig, WeightDecay};

pub mod prelude {
pub use super::{Optimizer, OptimizerUpdateError, UnusedTensors};
Expand Down
59 changes: 0 additions & 59 deletions src/optim/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,6 @@ use crate::{
tensor::{Gradients, Storage, Tensor, UniqueId},
};

/// L2 and decoupled regularization methods
#[derive(Debug, Clone, Copy)]
pub enum WeightDecay {
/// Weight decay applied to the gradients before any momentum updates. Equivalent to L2 regularization.
L2(f64),

/// Weight decay applied after any momentum updates, without modifying the gradients.
/// See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
Decoupled(f64),
}

/// Used to communicate the "WeightDecay" enum to cuda kernels
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub(super) enum WeightDecayType {
None,
L2,
Decoupled,
}

#[cfg(feature = "cuda")]
pub(super) fn weight_decay_to_cuda(wd: Option<WeightDecay>) -> (WeightDecayType, f64) {
match wd {
None => (WeightDecayType::None, Default::default()),
Some(WeightDecay::L2(x)) => (WeightDecayType::L2, x),
Some(WeightDecay::Decoupled(x)) => (WeightDecayType::Decoupled, x),
}
}

/// Momentum used for [super::Sgd] and others
#[derive(Debug, Clone, Copy)]
pub enum Momentum {
/// Momentum that is applied to the velocity of a parameter directly.
Classic(f64),

/// Momentum that is applied to both velocity and gradients. See [super::Sgd] nesterov paper for more.
Nesterov(f64),
}

/// Used to communicate the "Momentum" enum to cuda kernels
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub(super) enum MomentumType {
None,
Classic,
Nesterov,
}

#[cfg(feature = "cuda")]
pub(super) fn momentum_to_cuda(wd: Option<Momentum>) -> (MomentumType, f64) {
match wd {
None => (MomentumType::None, Default::default()),
Some(Momentum::Classic(x)) => (MomentumType::Classic, x),
Some(Momentum::Nesterov(x)) => (MomentumType::Nesterov, x),
}
}

/// All optimizers must implement the update function, which takes a `M`
/// and updates all of its parameters.
///
Expand Down
69 changes: 4 additions & 65 deletions src/optim/rmsprop/mod.rs → src/optim/rmsprop.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,13 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use std::{marker::PhantomData, sync::Arc};
use std::marker::PhantomData;

use crate::{
nn::tensor_collection::*,
shapes::{Dtype, Shape},
tensor::*,
tensor_ops::Device,
tensor_ops::{Device, RMSpropConfig},
};

use super::{Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay};

/// Configuration of hyperparameters for [RMSprop].
#[derive(Debug, Clone, Copy)]
pub struct RMSpropConfig {
/// Learning rate. Defaults to `1e-2`.
pub lr: f64,

/// Value for exponential moving average. Defaults to `0.9`.
pub alpha: f64,

/// Epsilon for stability. Defaults to `1e-8`.
pub eps: f64,

/// Optional momentum. Defaults to `None`.
pub momentum: Option<f64>,

/// Whether the avg should be centered by the grad's avg value.
/// Defaults to `false`.
pub centered: bool,

/// Optional weight decay. Defaults to `None`.
pub weight_decay: Option<WeightDecay>,
}

impl Default for RMSpropConfig {
fn default() -> Self {
Self {
lr: 1e-2,
alpha: 0.9,
eps: 1e-8,
momentum: None,
centered: false,
weight_decay: None,
}
}
}
use super::{Optimizer, OptimizerUpdateError, UnusedTensors};

/// RMSprop As described in [Hinton, 2012](http://www.cs.toronto.edu/%7Etijmen/csc321/slides/lecture_slides_lec6.pdf).
///
Expand Down Expand Up @@ -104,18 +63,6 @@ impl<M, E: Dtype, D: Storage<E>> RMSprop<M, E, D> {
}
}

pub trait RMSpropKernel<E: Dtype>: Storage<E> {
fn update(
&self,
cfg: &RMSpropConfig,
param: &mut Self::Vec,
momentum: &mut Self::Vec,
square_avg: &mut Self::Vec,
grad_avg: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Self::Err>;
}

impl<M, E: Dtype, D: Device<E>> TensorVisitor<E, D>
for (&mut RMSprop<M, E, D>, &Gradients<E, D>, UnusedTensors)
{
Expand Down Expand Up @@ -144,15 +91,7 @@ impl<M, E: Dtype, D: Device<E>> TensorVisitor<E, D>
p.device.try_fill_with_ones(sa)?;
}

RMSpropKernel::update(
&p.device,
&self.0.cfg,
Arc::make_mut(&mut p.data),
m,
sa,
ga,
g,
)?;
self.0.cfg.try_update(p, m, sa, ga, g)?;
}
}
Ok(None)
Expand Down
98 changes: 2 additions & 96 deletions src/optim/sgd/mod.rs → src/optim/sgd.rs
Original file line number Diff line number Diff line change
@@ -1,92 +1,14 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use std::marker::PhantomData;

use crate::{
nn::tensor_collection::*,
shapes::{Dtype, Shape},
tensor::{Gradients, Storage, Tensor},
tensor_ops::Device,
tensor_ops::{Device, SgdConfig},
};

use super::optimizer::*;

/// Configuration of hyperparameters for [Sgd].
///
/// Using different learning rate:
/// ```rust
/// # use dfdx::{prelude::*, optim::*};
/// SgdConfig {
/// lr: 1e-1,
/// momentum: None,
/// weight_decay: None,
/// };
/// ```
///
/// Using classic momentum:
/// ```rust
/// # use dfdx::{prelude::*, optim::*};
/// SgdConfig {
/// lr: 1e-2,
/// momentum: Some(Momentum::Classic(0.5)),
/// weight_decay: None,
/// };
/// ```
///
/// Using nesterov momentum:
/// ```rust
/// # use dfdx::{prelude::*, optim::*};
/// SgdConfig {
/// lr: 1e-3,
/// momentum: Some(Momentum::Nesterov(0.25)),
/// weight_decay: None,
/// };
/// ```
///
/// Using L2 weight decay:
/// ```rust
/// # use dfdx::{prelude::*, optim::*};
/// SgdConfig {
/// lr: 1e-3,
/// momentum: None,
/// weight_decay: Some(WeightDecay::L2(1e-2)),
/// };
/// ```
///
/// Using decoupled weight decay:
/// ```rust
/// # use dfdx::{prelude::*, optim::*};
/// SgdConfig {
/// lr: 1e-3,
/// momentum: None,
/// weight_decay: Some(WeightDecay::Decoupled(1e-2)),
/// };
/// ```
#[derive(Debug, Clone, Copy)]
pub struct SgdConfig {
/// Learning rate. Defaults to `1e-2`
pub lr: f64,

/// Optional momentum. Defaults to `None`.
pub momentum: Option<Momentum>,

/// Optional weight decay. Defaults to `None`.
pub weight_decay: Option<WeightDecay>,
}

impl Default for SgdConfig {
fn default() -> Self {
Self {
lr: 1e-2,
momentum: None,
weight_decay: None,
}
}
}

/// Implementation of Stochastic Gradient Descent. Based on [pytorch's implementation](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html)
///
/// Nesterov Momentum is implemented as described in
Expand Down Expand Up @@ -132,16 +54,6 @@ impl<M, E: Dtype, D: Storage<E>> Sgd<M, E, D> {
}
}

pub trait SgdKernel<E: Dtype>: Storage<E> {
fn update(
&self,
cfg: &SgdConfig,
param: &mut Self::Vec,
velocity: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Self::Err>;
}

impl<E: Dtype, D: Device<E>, M> TensorVisitor<E, D>
for (&mut Sgd<M, E, D>, &Gradients<E, D>, UnusedTensors)
{
Expand All @@ -163,13 +75,7 @@ impl<E: Dtype, D: Device<E>, M> TensorVisitor<E, D>
None => self.2.add(p),
Some(g) => {
let v = self.0.velocity.get_or_alloc_mut(p)?;
SgdKernel::update(
&p.device,
&self.0.cfg,
std::sync::Arc::make_mut(&mut p.data),
v,
g,
)?;
self.0.cfg.try_update(p, v, g)?;
}
}
Ok(None)
Expand Down
File renamed without changes.
Loading

0 comments on commit d48d6f8

Please sign in to comment.