diff --git a/src/optim/adam/mod.rs b/src/optim/adam.rs similarity index 84% rename from src/optim/adam/mod.rs rename to src/optim/adam.rs index 2a3ae3787..295915f97 100644 --- a/src/optim/adam/mod.rs +++ b/src/optim/adam.rs @@ -1,56 +1,13 @@ -mod cpu_kernel; - -#[cfg(feature = "cuda")] -mod cuda_kernel; - -use std::{marker::PhantomData, sync::Arc}; +use std::marker::PhantomData; use crate::{ nn::tensor_collection::*, shapes::{Dtype, Shape}, tensor::{Gradients, Storage, Tensor}, - tensor_ops::Device, + tensor_ops::{AdamConfig, Device}, }; -use super::{Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; - -/// Configuration of hyperparameters for [Adam]. -/// -/// Changing all default parameters: -/// ```rust -/// # use dfdx::{prelude::*, optim::*}; -/// AdamConfig { -/// lr: 1e-2, -/// betas: [0.1, 0.2], -/// eps: 1e-6, -/// weight_decay: Some(WeightDecay::L2(1e-1)), -/// }; -/// ``` -#[derive(Debug, Clone, Copy)] -pub struct AdamConfig { - /// Learning rate. Defaults to `1e-3`. - pub lr: f64, - - /// Betas from Adam paper. Defaults to `[0.9, 0.999]`. - pub betas: [f64; 2], - - /// Epsilon for numerical stability. Defaults to `1e-8`. - pub eps: f64, - - /// Optional weight decay. Defaults to `None`. - pub weight_decay: Option, -} - -impl Default for AdamConfig { - fn default() -> Self { - Self { - lr: 1e-3, - betas: [0.9, 0.999], - eps: 1e-8, - weight_decay: None, - } - } -} +use super::{Optimizer, OptimizerUpdateError, UnusedTensors}; /// An implementation of the Adam optimizer from /// [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980) @@ -95,18 +52,6 @@ impl> Adam { } } -pub trait AdamKernel: Storage { - fn update( - &self, - t: i32, - cfg: &AdamConfig, - param: &mut Self::Vec, - moment1: &mut Self::Vec, - moment2: &mut Self::Vec, - grad: &Self::Vec, - ) -> Result<(), Self::Err>; -} - impl, E: Dtype> TensorVisitor for (&mut Adam, &Gradients, UnusedTensors) { @@ -129,15 +74,7 @@ impl, E: Dtype> TensorVisitor Some(g) => { let m_t = self.0.moment1.get_or_alloc_mut(p)?; let v_t = self.0.moment2.get_or_alloc_mut(p)?; - AdamKernel::update( - &p.device, - self.0.t, - &self.0.cfg, - Arc::make_mut(&mut p.data), - m_t, - v_t, - g, - )?; + self.0.cfg.try_update(self.0.t, p, m_t, v_t, g)?; } } Ok(None) diff --git a/src/optim/mod.rs b/src/optim/mod.rs index 77943dd7d..25a9067dd 100644 --- a/src/optim/mod.rs +++ b/src/optim/mod.rs @@ -35,11 +35,13 @@ mod optimizer; mod rmsprop; mod sgd; -pub use adam::{Adam, AdamConfig, AdamKernel}; -pub use optimizer::{Momentum, WeightDecay}; +pub use adam::Adam; pub use optimizer::{Optimizer, OptimizerUpdateError, UnusedTensors}; -pub use rmsprop::{RMSprop, RMSpropConfig, RMSpropKernel}; -pub use sgd::{Sgd, SgdConfig, SgdKernel}; +pub use rmsprop::RMSprop; +pub use sgd::Sgd; + +// re-exports +pub use crate::tensor_ops::{AdamConfig, Momentum, RMSpropConfig, SgdConfig, WeightDecay}; pub mod prelude { pub use super::{Optimizer, OptimizerUpdateError, UnusedTensors}; diff --git a/src/optim/optimizer.rs b/src/optim/optimizer.rs index f63df4d9f..a42648194 100644 --- a/src/optim/optimizer.rs +++ b/src/optim/optimizer.rs @@ -3,65 +3,6 @@ use crate::{ tensor::{Gradients, Storage, Tensor, UniqueId}, }; -/// L2 and decoupled regularization methods -#[derive(Debug, Clone, Copy)] -pub enum WeightDecay { - /// Weight decay applied to the gradients before any momentum updates. Equivalent to L2 regularization. - L2(f64), - - /// Weight decay applied after any momentum updates, without modifying the gradients. - /// See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) - Decoupled(f64), -} - -/// Used to communicate the "WeightDecay" enum to cuda kernels -#[cfg(feature = "cuda")] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(C)] -pub(super) enum WeightDecayType { - None, - L2, - Decoupled, -} - -#[cfg(feature = "cuda")] -pub(super) fn weight_decay_to_cuda(wd: Option) -> (WeightDecayType, f64) { - match wd { - None => (WeightDecayType::None, Default::default()), - Some(WeightDecay::L2(x)) => (WeightDecayType::L2, x), - Some(WeightDecay::Decoupled(x)) => (WeightDecayType::Decoupled, x), - } -} - -/// Momentum used for [super::Sgd] and others -#[derive(Debug, Clone, Copy)] -pub enum Momentum { - /// Momentum that is applied to the velocity of a parameter directly. - Classic(f64), - - /// Momentum that is applied to both velocity and gradients. See [super::Sgd] nesterov paper for more. - Nesterov(f64), -} - -/// Used to communicate the "Momentum" enum to cuda kernels -#[cfg(feature = "cuda")] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[repr(C)] -pub(super) enum MomentumType { - None, - Classic, - Nesterov, -} - -#[cfg(feature = "cuda")] -pub(super) fn momentum_to_cuda(wd: Option) -> (MomentumType, f64) { - match wd { - None => (MomentumType::None, Default::default()), - Some(Momentum::Classic(x)) => (MomentumType::Classic, x), - Some(Momentum::Nesterov(x)) => (MomentumType::Nesterov, x), - } -} - /// All optimizers must implement the update function, which takes a `M` /// and updates all of its parameters. /// diff --git a/src/optim/rmsprop/mod.rs b/src/optim/rmsprop.rs similarity index 84% rename from src/optim/rmsprop/mod.rs rename to src/optim/rmsprop.rs index 4c6803cb4..8e39d3b33 100644 --- a/src/optim/rmsprop/mod.rs +++ b/src/optim/rmsprop.rs @@ -1,54 +1,13 @@ -mod cpu_kernel; - -#[cfg(feature = "cuda")] -mod cuda_kernel; - -use std::{marker::PhantomData, sync::Arc}; +use std::marker::PhantomData; use crate::{ nn::tensor_collection::*, shapes::{Dtype, Shape}, tensor::*, - tensor_ops::Device, + tensor_ops::{Device, RMSpropConfig}, }; -use super::{Optimizer, OptimizerUpdateError, UnusedTensors, WeightDecay}; - -/// Configuration of hyperparameters for [RMSprop]. -#[derive(Debug, Clone, Copy)] -pub struct RMSpropConfig { - /// Learning rate. Defaults to `1e-2`. - pub lr: f64, - - /// Value for exponential moving average. Defaults to `0.9`. - pub alpha: f64, - - /// Epsilon for stability. Defaults to `1e-8`. - pub eps: f64, - - /// Optional momentum. Defaults to `None`. - pub momentum: Option, - - /// Whether the avg should be centered by the grad's avg value. - /// Defaults to `false`. - pub centered: bool, - - /// Optional weight decay. Defaults to `None`. - pub weight_decay: Option, -} - -impl Default for RMSpropConfig { - fn default() -> Self { - Self { - lr: 1e-2, - alpha: 0.9, - eps: 1e-8, - momentum: None, - centered: false, - weight_decay: None, - } - } -} +use super::{Optimizer, OptimizerUpdateError, UnusedTensors}; /// RMSprop As described in [Hinton, 2012](http://www.cs.toronto.edu/%7Etijmen/csc321/slides/lecture_slides_lec6.pdf). /// @@ -104,18 +63,6 @@ impl> RMSprop { } } -pub trait RMSpropKernel: Storage { - fn update( - &self, - cfg: &RMSpropConfig, - param: &mut Self::Vec, - momentum: &mut Self::Vec, - square_avg: &mut Self::Vec, - grad_avg: &mut Self::Vec, - grad: &Self::Vec, - ) -> Result<(), Self::Err>; -} - impl> TensorVisitor for (&mut RMSprop, &Gradients, UnusedTensors) { @@ -144,15 +91,7 @@ impl> TensorVisitor p.device.try_fill_with_ones(sa)?; } - RMSpropKernel::update( - &p.device, - &self.0.cfg, - Arc::make_mut(&mut p.data), - m, - sa, - ga, - g, - )?; + self.0.cfg.try_update(p, m, sa, ga, g)?; } } Ok(None) diff --git a/src/optim/sgd/mod.rs b/src/optim/sgd.rs similarity index 85% rename from src/optim/sgd/mod.rs rename to src/optim/sgd.rs index 7161066b0..6d821c7d7 100644 --- a/src/optim/sgd/mod.rs +++ b/src/optim/sgd.rs @@ -1,92 +1,14 @@ -mod cpu_kernel; - -#[cfg(feature = "cuda")] -mod cuda_kernel; - use std::marker::PhantomData; use crate::{ nn::tensor_collection::*, shapes::{Dtype, Shape}, tensor::{Gradients, Storage, Tensor}, - tensor_ops::Device, + tensor_ops::{Device, SgdConfig}, }; use super::optimizer::*; -/// Configuration of hyperparameters for [Sgd]. -/// -/// Using different learning rate: -/// ```rust -/// # use dfdx::{prelude::*, optim::*}; -/// SgdConfig { -/// lr: 1e-1, -/// momentum: None, -/// weight_decay: None, -/// }; -/// ``` -/// -/// Using classic momentum: -/// ```rust -/// # use dfdx::{prelude::*, optim::*}; -/// SgdConfig { -/// lr: 1e-2, -/// momentum: Some(Momentum::Classic(0.5)), -/// weight_decay: None, -/// }; -/// ``` -/// -/// Using nesterov momentum: -/// ```rust -/// # use dfdx::{prelude::*, optim::*}; -/// SgdConfig { -/// lr: 1e-3, -/// momentum: Some(Momentum::Nesterov(0.25)), -/// weight_decay: None, -/// }; -/// ``` -/// -/// Using L2 weight decay: -/// ```rust -/// # use dfdx::{prelude::*, optim::*}; -/// SgdConfig { -/// lr: 1e-3, -/// momentum: None, -/// weight_decay: Some(WeightDecay::L2(1e-2)), -/// }; -/// ``` -/// -/// Using decoupled weight decay: -/// ```rust -/// # use dfdx::{prelude::*, optim::*}; -/// SgdConfig { -/// lr: 1e-3, -/// momentum: None, -/// weight_decay: Some(WeightDecay::Decoupled(1e-2)), -/// }; -/// ``` -#[derive(Debug, Clone, Copy)] -pub struct SgdConfig { - /// Learning rate. Defaults to `1e-2` - pub lr: f64, - - /// Optional momentum. Defaults to `None`. - pub momentum: Option, - - /// Optional weight decay. Defaults to `None`. - pub weight_decay: Option, -} - -impl Default for SgdConfig { - fn default() -> Self { - Self { - lr: 1e-2, - momentum: None, - weight_decay: None, - } - } -} - /// Implementation of Stochastic Gradient Descent. Based on [pytorch's implementation](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) /// /// Nesterov Momentum is implemented as described in @@ -132,16 +54,6 @@ impl> Sgd { } } -pub trait SgdKernel: Storage { - fn update( - &self, - cfg: &SgdConfig, - param: &mut Self::Vec, - velocity: &mut Self::Vec, - grad: &Self::Vec, - ) -> Result<(), Self::Err>; -} - impl, M> TensorVisitor for (&mut Sgd, &Gradients, UnusedTensors) { @@ -163,13 +75,7 @@ impl, M> TensorVisitor None => self.2.add(p), Some(g) => { let v = self.0.velocity.get_or_alloc_mut(p)?; - SgdKernel::update( - &p.device, - &self.0.cfg, - std::sync::Arc::make_mut(&mut p.data), - v, - g, - )?; + self.0.cfg.try_update(p, v, g)?; } } Ok(None) diff --git a/src/optim/adam/adam.cu b/src/tensor_ops/adam/adam.cu similarity index 100% rename from src/optim/adam/adam.cu rename to src/tensor_ops/adam/adam.cu diff --git a/src/optim/adam/cpu_kernel.rs b/src/tensor_ops/adam/cpu_kernel.rs similarity index 91% rename from src/optim/adam/cpu_kernel.rs rename to src/tensor_ops/adam/cpu_kernel.rs index d685824c9..4abed5354 100644 --- a/src/optim/adam/cpu_kernel.rs +++ b/src/tensor_ops/adam/cpu_kernel.rs @@ -1,8 +1,8 @@ -use super::{AdamConfig, AdamKernel}; -use crate::{optim::WeightDecay, shapes::Dtype, tensor::Cpu}; +use super::{AdamConfig, AdamKernel, WeightDecay}; +use crate::{shapes::Dtype, tensor::Cpu}; impl AdamKernel for Cpu { - fn update( + fn adam_kernel( &self, t: i32, cfg: &AdamConfig, diff --git a/src/optim/adam/cuda_kernel.rs b/src/tensor_ops/adam/cuda_kernel.rs similarity index 97% rename from src/optim/adam/cuda_kernel.rs rename to src/tensor_ops/adam/cuda_kernel.rs index f6760da73..919333271 100644 --- a/src/optim/adam/cuda_kernel.rs +++ b/src/tensor_ops/adam/cuda_kernel.rs @@ -1,7 +1,7 @@ use crate::{ - optim::optimizer::*, shapes::*, tensor::{launch_cfg, Cuda}, + tensor_ops::optim::*, }; use cudarc::driver::{DeviceRepr, DeviceSlice, LaunchAsync}; @@ -58,7 +58,7 @@ impl super::AdamKernel for Cuda where Self: HasCudaKernel, { - fn update( + fn adam_kernel( &self, t: i32, cfg: &super::AdamConfig, diff --git a/src/tensor_ops/adam/mod.rs b/src/tensor_ops/adam/mod.rs new file mode 100644 index 000000000..0188f8aeb --- /dev/null +++ b/src/tensor_ops/adam/mod.rs @@ -0,0 +1,81 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use crate::{ + shapes::{Dtype, Shape}, + tensor::{Storage, Tensor}, +}; + +use super::WeightDecay; + +/// Configuration of hyperparameters for [Adam]. +/// +/// Changing all default parameters: +/// ```rust +/// # use dfdx::{prelude::*, optim::*}; +/// AdamConfig { +/// lr: 1e-2, +/// betas: [0.1, 0.2], +/// eps: 1e-6, +/// weight_decay: Some(WeightDecay::L2(1e-1)), +/// }; +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct AdamConfig { + /// Learning rate. Defaults to `1e-3`. + pub lr: f64, + + /// Betas from Adam paper. Defaults to `[0.9, 0.999]`. + pub betas: [f64; 2], + + /// Epsilon for numerical stability. Defaults to `1e-8`. + pub eps: f64, + + /// Optional weight decay. Defaults to `None`. + pub weight_decay: Option, +} + +impl Default for AdamConfig { + fn default() -> Self { + Self { + lr: 1e-3, + betas: [0.9, 0.999], + eps: 1e-8, + weight_decay: None, + } + } +} + +pub trait AdamKernel: Storage { + fn adam_kernel( + &self, + t: i32, + cfg: &AdamConfig, + param: &mut Self::Vec, + moment1: &mut Self::Vec, + moment2: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Self::Err>; +} + +impl AdamConfig { + pub fn try_update>( + &self, + t: i32, + param: &mut Tensor, + moment1: &mut D::Vec, + moment2: &mut D::Vec, + grad: &D::Vec, + ) -> Result<(), D::Err> { + param.device.adam_kernel( + t, + self, + std::sync::Arc::make_mut(&mut param.data), + moment1, + moment2, + grad, + ) + } +} diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index f961252ec..6607a1ab1 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -151,6 +151,7 @@ pub use utilities::*; mod abs; mod accurate_gelu; +mod adam; mod add; mod attention_reshape; pub(crate) mod axpy; @@ -181,6 +182,7 @@ mod mul; mod nans_to; mod negate; mod normalize; +pub(super) mod optim; mod permute_to; mod pow; mod prelu; @@ -188,8 +190,10 @@ mod realize_to; mod recip; mod relu; mod reshape_to; +mod rmsprop; mod roll; mod select_and_gather; +mod sgd; mod sigmoid; mod sin; mod slice; @@ -208,6 +212,7 @@ mod var_to; pub use abs::abs; pub use accurate_gelu::accurate_gelu; +pub use adam::AdamConfig; pub use add::{add, TryAdd}; pub use attention_reshape::TryAttentionReshape; pub use axpy::axpy; @@ -241,6 +246,7 @@ pub use mul::{mul, TryMul}; pub use nans_to::nans_to; pub use negate::negate; pub use normalize::normalize; +pub use optim::*; pub use permute_to::PermuteTo; pub use pow::{powf, powi}; pub use prelu::{leakyrelu, prelu, TryPReLU}; @@ -248,8 +254,10 @@ pub use realize_to::RealizeTo; pub use recip::recip; pub use relu::relu; pub use reshape_to::ReshapeTo; +pub use rmsprop::RMSpropConfig; pub use roll::Roll; pub use select_and_gather::{GatherTo, SelectTo}; +pub use sgd::SgdConfig; pub use sigmoid::sigmoid; pub use sin::sin; pub use slice::slice; diff --git a/src/tensor_ops/optim.rs b/src/tensor_ops/optim.rs new file mode 100644 index 000000000..1c905aa35 --- /dev/null +++ b/src/tensor_ops/optim.rs @@ -0,0 +1,58 @@ +/// L2 and decoupled regularization methods +#[derive(Debug, Clone, Copy)] +pub enum WeightDecay { + /// Weight decay applied to the gradients before any momentum updates. Equivalent to L2 regularization. + L2(f64), + + /// Weight decay applied after any momentum updates, without modifying the gradients. + /// See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) + Decoupled(f64), +} + +/// Used to communicate the "WeightDecay" enum to cuda kernels +#[cfg(feature = "cuda")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub(super) enum WeightDecayType { + None, + L2, + Decoupled, +} + +#[cfg(feature = "cuda")] +pub(super) fn weight_decay_to_cuda(wd: Option) -> (WeightDecayType, f64) { + match wd { + None => (WeightDecayType::None, Default::default()), + Some(WeightDecay::L2(x)) => (WeightDecayType::L2, x), + Some(WeightDecay::Decoupled(x)) => (WeightDecayType::Decoupled, x), + } +} + +/// Momentum used for [super::Sgd] and others +#[derive(Debug, Clone, Copy)] +pub enum Momentum { + /// Momentum that is applied to the velocity of a parameter directly. + Classic(f64), + + /// Momentum that is applied to both velocity and gradients. See [super::Sgd] nesterov paper for more. + Nesterov(f64), +} + +/// Used to communicate the "Momentum" enum to cuda kernels +#[cfg(feature = "cuda")] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(C)] +pub(super) enum MomentumType { + None, + Classic, + Nesterov, +} + +#[cfg(feature = "cuda")] +pub(super) fn momentum_to_cuda(wd: Option) -> (MomentumType, f64) { + match wd { + None => (MomentumType::None, Default::default()), + Some(Momentum::Classic(x)) => (MomentumType::Classic, x), + Some(Momentum::Nesterov(x)) => (MomentumType::Nesterov, x), + } +} diff --git a/src/optim/rmsprop/cpu_kernel.rs b/src/tensor_ops/rmsprop/cpu_kernel.rs similarity index 93% rename from src/optim/rmsprop/cpu_kernel.rs rename to src/tensor_ops/rmsprop/cpu_kernel.rs index d7679ae34..34c276d98 100644 --- a/src/optim/rmsprop/cpu_kernel.rs +++ b/src/tensor_ops/rmsprop/cpu_kernel.rs @@ -1,9 +1,9 @@ -use crate::{optim::WeightDecay, shapes::Dtype, tensor::cpu::Cpu}; +use crate::{shapes::Dtype, tensor::cpu::Cpu}; -use super::{RMSpropConfig, RMSpropKernel}; +use super::{RMSpropConfig, RMSpropKernel, WeightDecay}; impl RMSpropKernel for Cpu { - fn update( + fn rmsprop_kernel( &self, cfg: &RMSpropConfig, param: &mut Self::Vec, diff --git a/src/optim/rmsprop/cuda_kernel.rs b/src/tensor_ops/rmsprop/cuda_kernel.rs similarity index 97% rename from src/optim/rmsprop/cuda_kernel.rs rename to src/tensor_ops/rmsprop/cuda_kernel.rs index 6de231911..bdb534f54 100644 --- a/src/optim/rmsprop/cuda_kernel.rs +++ b/src/tensor_ops/rmsprop/cuda_kernel.rs @@ -1,8 +1,8 @@ use super::RMSpropConfig; use crate::{ - optim::optimizer::*, shapes::*, tensor::{launch_cfg, Cuda}, + tensor_ops::optim::*, }; use cudarc::driver::{DeviceRepr, DeviceSlice, LaunchAsync}; @@ -68,7 +68,7 @@ impl super::RMSpropKernel for Cuda where Self: HasCudaKernel, { - fn update( + fn rmsprop_kernel( &self, cfg: &RMSpropConfig, param: &mut Self::Vec, diff --git a/src/tensor_ops/rmsprop/mod.rs b/src/tensor_ops/rmsprop/mod.rs new file mode 100644 index 000000000..b4095a2a7 --- /dev/null +++ b/src/tensor_ops/rmsprop/mod.rs @@ -0,0 +1,79 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use crate::{ + shapes::{Dtype, Shape}, + tensor::*, +}; + +use super::WeightDecay; + +/// Configuration of hyperparameters for [RMSprop]. +#[derive(Debug, Clone, Copy)] +pub struct RMSpropConfig { + /// Learning rate. Defaults to `1e-2`. + pub lr: f64, + + /// Value for exponential moving average. Defaults to `0.9`. + pub alpha: f64, + + /// Epsilon for stability. Defaults to `1e-8`. + pub eps: f64, + + /// Optional momentum. Defaults to `None`. + pub momentum: Option, + + /// Whether the avg should be centered by the grad's avg value. + /// Defaults to `false`. + pub centered: bool, + + /// Optional weight decay. Defaults to `None`. + pub weight_decay: Option, +} + +impl Default for RMSpropConfig { + fn default() -> Self { + Self { + lr: 1e-2, + alpha: 0.9, + eps: 1e-8, + momentum: None, + centered: false, + weight_decay: None, + } + } +} + +pub trait RMSpropKernel: Storage { + fn rmsprop_kernel( + &self, + cfg: &RMSpropConfig, + param: &mut Self::Vec, + momentum: &mut Self::Vec, + square_avg: &mut Self::Vec, + grad_avg: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Self::Err>; +} + +impl RMSpropConfig { + pub fn try_update>( + &self, + param: &mut Tensor, + momentum: &mut D::Vec, + square_avg: &mut D::Vec, + grad_avg: &mut D::Vec, + grad: &D::Vec, + ) -> Result<(), D::Err> { + param.device.rmsprop_kernel( + self, + std::sync::Arc::make_mut(&mut param.data), + momentum, + square_avg, + grad_avg, + grad, + ) + } +} diff --git a/src/optim/rmsprop/rmsprop.cu b/src/tensor_ops/rmsprop/rmsprop.cu similarity index 100% rename from src/optim/rmsprop/rmsprop.cu rename to src/tensor_ops/rmsprop/rmsprop.cu diff --git a/src/optim/sgd/cpu_kernel.rs b/src/tensor_ops/sgd/cpu_kernel.rs similarity index 89% rename from src/optim/sgd/cpu_kernel.rs rename to src/tensor_ops/sgd/cpu_kernel.rs index f6766261b..27f35c7b6 100644 --- a/src/optim/sgd/cpu_kernel.rs +++ b/src/tensor_ops/sgd/cpu_kernel.rs @@ -1,13 +1,9 @@ -use crate::{ - optim::optimizer::{Momentum, WeightDecay}, - shapes::Dtype, - tensor::cpu::*, -}; +use crate::{shapes::Dtype, tensor::cpu::*}; -use super::{SgdConfig, SgdKernel}; +use super::{Momentum, SgdConfig, SgdKernel, WeightDecay}; impl SgdKernel for Cpu { - fn update( + fn sgd_kernel( &self, cfg: &SgdConfig, param: &mut Self::Vec, diff --git a/src/optim/sgd/cuda_kernel.rs b/src/tensor_ops/sgd/cuda_kernel.rs similarity index 97% rename from src/optim/sgd/cuda_kernel.rs rename to src/tensor_ops/sgd/cuda_kernel.rs index 30e979e85..0b839a8a9 100644 --- a/src/optim/sgd/cuda_kernel.rs +++ b/src/tensor_ops/sgd/cuda_kernel.rs @@ -1,8 +1,9 @@ use super::SgdConfig; + use crate::{ - optim::optimizer::*, shapes::*, tensor::{launch_cfg, Cuda}, + tensor_ops::optim::*, }; use cudarc::driver::{DeviceRepr, DeviceSlice, LaunchAsync}; @@ -58,7 +59,7 @@ impl super::SgdKernel for Cuda where Self: HasCudaKernel, { - fn update( + fn sgd_kernel( &self, cfg: &SgdConfig, param: &mut Self::Vec, diff --git a/src/tensor_ops/sgd/mod.rs b/src/tensor_ops/sgd/mod.rs new file mode 100644 index 000000000..be27f1393 --- /dev/null +++ b/src/tensor_ops/sgd/mod.rs @@ -0,0 +1,110 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use crate::{ + shapes::{Dtype, Shape}, + tensor::{Storage, Tensor}, +}; + +use super::optim::{Momentum, WeightDecay}; + +/// Configuration of hyperparameters for [Sgd]. +/// +/// Using different learning rate: +/// ```rust +/// # use dfdx::{prelude::*, optim::*}; +/// SgdConfig { +/// lr: 1e-1, +/// momentum: None, +/// weight_decay: None, +/// }; +/// ``` +/// +/// Using classic momentum: +/// ```rust +/// # use dfdx::{prelude::*, optim::*}; +/// SgdConfig { +/// lr: 1e-2, +/// momentum: Some(Momentum::Classic(0.5)), +/// weight_decay: None, +/// }; +/// ``` +/// +/// Using nesterov momentum: +/// ```rust +/// # use dfdx::{prelude::*, optim::*}; +/// SgdConfig { +/// lr: 1e-3, +/// momentum: Some(Momentum::Nesterov(0.25)), +/// weight_decay: None, +/// }; +/// ``` +/// +/// Using L2 weight decay: +/// ```rust +/// # use dfdx::{prelude::*, optim::*}; +/// SgdConfig { +/// lr: 1e-3, +/// momentum: None, +/// weight_decay: Some(WeightDecay::L2(1e-2)), +/// }; +/// ``` +/// +/// Using decoupled weight decay: +/// ```rust +/// # use dfdx::{prelude::*, optim::*}; +/// SgdConfig { +/// lr: 1e-3, +/// momentum: None, +/// weight_decay: Some(WeightDecay::Decoupled(1e-2)), +/// }; +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct SgdConfig { + /// Learning rate. Defaults to `1e-2` + pub lr: f64, + + /// Optional momentum. Defaults to `None`. + pub momentum: Option, + + /// Optional weight decay. Defaults to `None`. + pub weight_decay: Option, +} + +impl Default for SgdConfig { + fn default() -> Self { + Self { + lr: 1e-2, + momentum: None, + weight_decay: None, + } + } +} + +pub trait SgdKernel: Storage { + fn sgd_kernel( + &self, + cfg: &SgdConfig, + param: &mut Self::Vec, + velocity: &mut Self::Vec, + grad: &Self::Vec, + ) -> Result<(), Self::Err>; +} + +impl SgdConfig { + pub fn try_update>( + &self, + param: &mut Tensor, + velocity: &mut D::Vec, + grad: &D::Vec, + ) -> Result<(), D::Err> { + param.device.sgd_kernel( + self, + std::sync::Arc::make_mut(&mut param.data), + velocity, + grad, + ) + } +} diff --git a/src/optim/sgd/sgd.cu b/src/tensor_ops/sgd/sgd.cu similarity index 100% rename from src/optim/sgd/sgd.cu rename to src/tensor_ops/sgd/sgd.cu diff --git a/src/tensor_ops/utilities/device.rs b/src/tensor_ops/utilities/device.rs index 31ee0e7c6..d226d2e29 100644 --- a/src/tensor_ops/utilities/device.rs +++ b/src/tensor_ops/utilities/device.rs @@ -19,9 +19,9 @@ pub trait Device: + super::super::concat_along::ConcatAlongKernel // optimizers - + crate::optim::AdamKernel - + crate::optim::SgdKernel - + crate::optim::RMSpropKernel + + super::super::adam::AdamKernel + + super::super::sgd::SgdKernel + + super::super::rmsprop::RMSpropKernel // allocation + crate::tensor::ZerosTensor