Skip to content

Commit

Permalink
Reworking crate level documentation (#644)
Browse files Browse the repository at this point in the history
* #600 Reworking crate level documentation

* Fixing failing doctests
  • Loading branch information
coreylowman authored Mar 30, 2023
1 parent f0a4f8c commit f53bf42
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 150 deletions.
227 changes: 149 additions & 78 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,104 +1,175 @@
//! Ergonomics & safety focused deep learning in Rust. Main features include:
//! 1. Tensor library with shapes up to 6d!
//! 2. Shapes with both compile and runtime sized dimensions. (e.g. `Tensor<(usize, Const<10>)>` and `Tensor<Rank2<5, 10>>`)
//! 3. A large library of tensor operations (including `matmul`, `conv2d`, and much more).
//! a. All tensor operations shape and type checked at compile time!!
//! 4. Ergonomic neural network building blocks (like `Linear`, `Conv2D`, and `Transformer`).
//! 5. Standard deep learning optimizers such as `Sgd`, `Adam`, `AdamW`, `RMSprop`, and more.
//! 6. Reverse mode auto differentiation implementation.
//! 7. Serialization to/from `.npy` and `.npz` for transferring models to/from python.
//!
//! # A quick tutorial
//!
//! 1. [crate::tensor::Tensor]s can be created with normal rust arrays. See [crate::tensor].
//! # dfdx
//!
//! dfdx is a cuda accelerated tensor and neural network library, writtten
//! entirely in rust!
//!
//! Additionally, it can track compile time shapes across tensor operations,
//! ensuring that all your neural networks are checked **at compile time**.
//!
//! The following sections provide some high level core concepts & exmaples, and
//! there is more detailed documentation in each of dfdx's submodules.
//!
//! See [feature_flags] for details on feature flags.
//!
//! # Shapes & Tensors
//!
//! *See [shapes] and [tensor] for more information.*
//!
//! At its core a [`tensor::Tensor`] is just a nd-array. Just like
//! rust arrays there are two parts:
//! 1. Shape
//! 2. Dtype
//!
//! dfdx represents shapes as **tuples** of dimensions ([`shapes::Dim`]),
//! where a dimension can either be known at:
//! 1. Compile time [`shapes::Const<M>`]
//! 2. Run time [`usize`]
//!
//! You can freely mix and match these dimensions together. Here are some
//! example shapes:
//! - `()` - unit shape
//! - `(usize,)` - 1d shape with a runtime known dimension
//! - `(usize, Const<5>)` - 2d shape with both types of dimensions
//! - `(Const<3>, usize, Const<5>)` - 3d shape!
//!
//! Here are some comparisons between representing nd arrays in rust vs dfdx:
//!
//! | rust array | dfdx `Tensor` |
//! | --- | --- |
//! | f32 | Tensor<(), f32, ...> |
//! | [u32; 5] | Tensor<Rank1<5>, u32, ...> |
//! | [[u8; 3]; 2] | Tensor<Rank2<2, 3>, u8, ...> |
//! | Vec<[bool; 5]> | Tensor<(usize, Const<5>), bool, ...> |
//!
//! The `Rank1`, `Rank2` shapes used above are actually type aliases for
//! when **all dimensions are compile time**:
//! - [`shapes::Rank0`] is just `()`.
//! - [`shapes::Rank1<M>`] is `(Const<M>, )`
//! - [`shapes::Rank2<M, N>`] is `(Const<M>, Const<N>)`
//!
//! # Allocating tensors with Devices
//!
//! *See [tensor] for more information.*
//!
//! Devices are used to allocate tensors (and neural networks!). They are akin
//! to [std::alloc::GlobalAlloc] in rust - they just allocate memory.
//! They are also used to execute tensor ops, which we will get to later on.
//!
//! There are two options for this currently, with more planned to be added in the future:
//!
//! 1. [tensor::Cpu] - for tensors stored on the heap
//! 2. [tensor::Cuda] - for tensors stored in GPU memory
//!
//! Both devices implement [Default], you can also create them with a certain seed
//! and ordinal.
//!
//! Here's how you might use a device:
//!
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! let x = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
//! let y: Tensor<Rank2<2, 3>, f32, Cpu> = dev.ones();
//! // Runtime shape
//! let z: Tensor<(usize, Const<3>), f32, _> = dev.ones_like(&(10, Const));
//! let t: Tensor<Rank2<2, 3>, f32, _> = dev.zeros();
//! ```
//!
//! 2. Neural networks are built with types. Tuples are sequential models. See [crate::nn].
//! ```rust
//! # use dfdx::prelude::*;
//! type Mlp = (
//! Linear<5, 3>,
//! ReLU,
//! Linear<3, 2>,
//! );
//! ```
//! # Tensor Operations (tip of the iceberg)
//!
//! *See [tensor_ops] for more information*
//!
//! Once you've instantiated tensors with a device, you can start doing operations on them!
//! There are **many many** operations, here are a few core ones and how they related
//! to things like numpy/pytorch:
//!
//! | Operation | dfdx | numpy | pytorch |
//! | --- | --- | --- | --- |
//! | Unary Operations | `a.sqrt()` | `a.sqrt()` | `a.sqrt()` |
//! | Binary Operations | `a + b` | `a + b` | `a + b` |
//! | gemm/gemv | [tensor_ops::matmul] | `a @ b` | `a @ b` |
//! | 2d Convolution | [tensor_ops::TryConv2D] | - | `torch.conv2d` |
//! | 2d Transposed Convolution | [tensor_ops::TryConvTrans2D] | - | `torch.conv_transpose2d` |
//! | Slicing | [tensor_ops::slice] | `a[...]` | `a[...]` |
//! | Select | [tensor_ops::SelectTo] | `a[...]` | `torch.select` |
//! | Gather | [tensor_ops::GatherTo] | `np.take` | `torch.gather` |
//! | Broadcasting | [tensor_ops::BroadcastTo] | implicit/`np.broadcast` | implicit/`torch.broadcast_to` |
//! | Permute | [tensor_ops::PermuteTo] | `np.transpose(...)` | `torch.permute` |
//! | Where | [tensor_ops::ChooseFrom] | `np.where` | `torch.where` |
//! | Reshape | [tensor_ops::ReshapeTo] | `np.reshape(shape)` | `a.reshape(shape)` |
//! | View | [tensor_ops::ReshapeTo] | `np.view(...)` | `a.view(...)` |
//! | Roll | [tensor_ops::Roll] | `np.rollaxis(...)` | `a.roll(...)` |
//! | Stack | [tensor_ops::TryStack] | `np.stack` | `torch.stack` |
//! | Concat | [tensor_ops::TryConcat] | `np.concatenate` | `torch.concat` |
//!
//! and **much much more!**
//!
//! # Neural networks
//!
//! *See [nn] for more information.*
//!
//! Neural networks are composed of building blocks that you can chain together. In
//! dfdx, sequential neural networks are represents by **tuples**! For example,
//! the following two networks are identical:
//!
//! | dfdx | pytorch |
//! | --- | --- |
//! | `(Linear<3, 5>, ReLU, Linear<5, 10>)` | `nn.Sequential(nn.Linear(3, 5), nn.ReLU(), nn.Linear(5, 10))` |
//! | `((Conv2D<3, 2, 1>, Tanh), Conv2D<3, 2, 1>)` | `nn.Sequential(nn.Sequential(nn.Conv2d(3, 2, 1), nn.Tanh()), nn.Conv2d(3, 2, 1))`
//!
//! To build a neural network, you of course need a device:
//!
//! 3. Instantiate models with [crate::nn::DeviceBuildExt]
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! type Model = (Linear<5, 2>, ReLU);
//! let mlp = dev.build_module::<Model, f32>();
//! type Model = (Linear<3, 5>, ReLU, Linear<5, 10>);
//! let model = dev.build_module::<Model, f32>();
//! ```
//!
//! 4. Pass data through networks with [crate::nn::Module]
//! Note two things:
//! 1. We are using [nn::DeviceBuildExt] to instantiate the model
//! 2. We **need** to pass a dtype (in this case f32) to create the model.
//!
//! You can then pass tensors into the model with [nn::Module::forward()]:
//!
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = dev.build_module::<Linear<5, 2>, f32>();
//! let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
//! let y = mlp.forward(x); // compiler infers that `y` must be `Tensor<Rank1<2>>`
//! # type Model = (Linear<3, 5>, ReLU, Linear<5, 10>);
//! # let model = dev.build_module::<Model, f32>();
//! // tensor with runtime batch dimension of 10
//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_normal_like(&(10, Const));
//! let y = model.forward(x);
//! ```
//!
//! 5. Trace gradients using [crate::tensor::Trace::trace()]
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true: Tensor<Rank1<5>, f32, _> = dev.sample_normal().softmax();
//! // allocate gradients [ZeroGrads::alloc_grads]
//! let grads = mlp.alloc_grads();
//! # Optimizers and Gradients
//!
//! // tensors default to not having a tape
//! let x: Tensor<Rank1<10>, f32, Cpu, NoneTape> = dev.zeros();
//! *See [optim] for more information*
//!
//! // `.trace()` clones `x` and inserts a gradient tape.
//! let x_traced: Tensor<Rank1<10>, f32, Cpu, OwnedTape<f32, Cpu>> = x.trace(grads);
//! dfdx supports a number of the standard optimizers:
//!
//! // The tape from the input is moved through the network during .forward().
//! let y: Tensor<Rank1<5>, f32, Cpu, NoneTape> = mlp.forward(x);
//! let y_traced: Tensor<Rank1<5>, f32, Cpu, OwnedTape<f32, Cpu>> = mlp.forward(x_traced);
//! ```
//! | Optimizer | dfdx | pytorch |
//! | --- | --- | --- |
//! | SGD | [optim::Sgd] | `torch.optim.SGD` |
//! | Adam | [optim::Adam] | torch.optim.Adam` |
//! | AdamW | [optim::Adam] with [optim::WeightDecay::Decoupled] | `torch.optim.AdamW` |
//! | RMSprop | [optim::RMSprop] | `torch.optim.RMSprop` |
//!
//! 6. Compute gradients with [crate::tensor_ops::Backward]. See [crate::tensor_ops].
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = mlp.forward(dev.zeros::<Rank1<10>>().trace(Gradients::leaky()));
//! // compute cross entropy loss
//! let loss = cross_entropy_with_logits_loss(y, y_true);
//!
//! // call `backward()` to compute gradients. The tensor *must* have `OwnedTape`!
//! let gradients: Gradients<f32, Cpu> = loss.backward();
//! ```
//! 7. Use an optimizer from [crate::optim] to optimize your network!
//! You can use optimizers to optimize neural networks (or even tensors!). Here's
//! a simple example of how to do this with [nn::ZeroGrads]:
//! ```rust
//! # use dfdx::{prelude::*, optim::*};
//! # let dev: Cpu = Default::default();
//! # let mut mlp = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = mlp.forward(dev.zeros::<Rank1<10>>().trace(Gradients::leaky()));
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
//! # let mut gradients: Gradients<f32, Cpu> = loss.backward();
//! // Use stochastic gradient descent (Sgd), with a learning rate of 1e-2, and 0.9 momentum.
//! let mut opt = Sgd::new(&mlp, SgdConfig {
//! lr: 1e-2,
//! momentum: Some(Momentum::Classic(0.9)),
//! weight_decay: None,
//! });
//!
//! // pass the gradients & the mlp into the optimizer's update method
//! opt.update(&mut mlp, &gradients);
//! mlp.zero_grads(&mut gradients);
//! type Model = (Linear<3, 5>, ReLU, Linear<5, 10>);
//! let mut model = dev.build_module::<Model, f32>();
//! // 1. allocate gradients for the model
//! let mut grads = model.alloc_grads();
//! // 2. create our optimizer
//! let mut opt = Sgd::new(&model, Default::default());
//! // 3. trace gradients through forward pass
//! let x: Tensor<Rank2<10, 3>, f32, _> = dev.sample_normal();
//! let y = model.forward_mut(x.traced(grads));
//! // 4. compute loss & run backpropagation
//! let loss = y.square().mean();
//! grads = loss.backward();
//! // 5. apply gradients
//! opt.update(&mut model, &grads);
//! ```
#![cfg_attr(all(feature = "no-std", not(feature = "std")), no_std)]
Expand Down
35 changes: 23 additions & 12 deletions src/tensor_ops/broadcast_to.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
use crate::{shapes::*, tensor::*};

/// Broadcast self into a new shape.
///
/// **pytorch equivalent** `torch.broadcast_to`.
///
/// Use shape generic or output type to dictate what shape you want:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 7>, f32, _> = dev.zeros();
/// // broadcast axis 1
/// let _: Tensor<Rank3<3, 5, 7>, _, _> = a.clone().broadcast();
/// // broadcast axis 0 and axis 2
/// let _ = a.clone().broadcast::<Rank4<1, 3, 5, 7>, _>();
/// ```
///
/// Use axes generic to dis-ambiguate:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank1<1>, f32, _> = dev.zeros();
/// // It's ambiguous what axes to broadcast here - explicitly say axes 0 and 2
/// let _: Tensor<Rank3<1, 1, 1>, _, _> = a.clone().broadcast::<_, Axes2<0, 2>>();
/// ```
pub trait BroadcastTo: HasErr + HasShape {
/// Broadcast into shape `Dst` along axes `Ax`:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<3, 7>, f32, _> = dev.zeros();
///
/// // broadcast axis 1
/// let _ = a.clone().broadcast::<Rank3<3, 5, 7>, _>();
///
/// // broadcast axis 0 and axis 2
/// let _ = a.clone().broadcast::<Rank4<1, 3, 5, 7>, _>();
/// ```
/// Broadcast into shape `Dst` along axes `Ax`.
fn broadcast<Dst: ConstShape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: BroadcastShapeTo<Dst, Ax>,
Expand Down
10 changes: 10 additions & 0 deletions src/tensor_ops/choose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ pub trait ChooseKernel<E: Dtype>: DeviceStorage {
}

/// Choose values from two tensors using a boolean mask. Equivalent to `torch.where` from pytorch.
///
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let cond: Tensor<Rank1<3>, bool, _> = dev.tensor([true, false, true]);
/// let a: Tensor<Rank1<3>, f32, _> = dev.tensor([1.0, 2.0, 3.0]);
/// let b: Tensor<Rank1<3>, f32, _> = dev.tensor([-1.0, -2.0, -3.0]);
/// let c = cond.choose(a, b);
/// assert_eq!(c.array(), [1.0, -2.0, 3.0]);
/// ```
pub trait ChooseFrom<Lhs, Rhs>: HasErr {
type Output;

Expand Down
31 changes: 22 additions & 9 deletions src/tensor_ops/permute_to.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
use crate::{shapes::*, tensor::*};

/// Changes order of dimensions/axes
/// Changes order of dimensions/axes in a tensor.
///
/// **pytorch equivalent**: `torch.permute`.
///
/// Option 1: Specifying shape generic:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
/// let b: Tensor<Rank2<3, 2>, f32, _> = a.permute::<Rank2<3, 2>, _>();
/// assert_eq!(b.array(), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
/// ```
///
/// Option 2: Specifying axes generic:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank2<2, 3>, f32, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
/// let b: Tensor<Rank2<3, 2>, f32, _> = a.permute::<_, Axes2<1, 0>>();
/// assert_eq!(b.array(), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]);
/// ```
pub trait PermuteTo: HasErr + HasShape {
/// Permutes the tensor:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let a: Tensor<Rank3<1, 2, 3>, f32, _> = dev.zeros();
/// let _ = a.clone().permute::<Rank3<3, 2, 1>, _>();
/// let _ = a.clone().permute::<_, Axes3<2, 1, 0>>();
/// ```
/// Permutes the tensor.
fn permute<Dst: Shape, Ax: Axes>(self) -> Self::WithShape<Dst>
where
Self::Shape: PermuteShapeTo<Dst, Ax>,
Expand Down
Loading

0 comments on commit f53bf42

Please sign in to comment.