Skip to content

Commit

Permalink
refactor(linear): pass device explicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
Konstantin Matsiushonak committed Feb 4, 2023
1 parent 72f2255 commit a94946a
Show file tree
Hide file tree
Showing 20 changed files with 86 additions and 73 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ Check [examples/](examples/) for more details.
1. 👌 Simple Neural Networks API, completely shape checked at compile time.

```rust
use dfdx::prelude::*;

type Mlp = (
(Linear<10, 32>, ReLU),
(Linear<32, 32>, ReLU),
(Linear<32, 2>, Tanh),
(Linear<10, 32, Cpu>, ReLU),
(Linear<32, 32, Cpu>, ReLU),
(Linear<32, 2, Cpu>, Tanh),
);

fn main() {
Expand Down Expand Up @@ -131,7 +133,7 @@ let model = Model::build_on_device(&dev);
```

```rust
type Model = (Linear<10, 5>, Tanh)
type Model = (Linear<10, 5, Cpu>, Tanh)
let model = Model::build_on_device(&dev);
```

Expand Down
4 changes: 2 additions & 2 deletions examples/03-nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn main() {

// nn exposes many different neural network types, like the Linear layer!
// you can use Build::build to construct an initialized model
let mut m = Linear::<4, 2>::build_on_device(&dev);
let mut m = Linear::<4, 2, Cpu>::build_on_device(&dev);

// Build::reset_params also allows you to re-randomize the weights
m.reset_params();
Expand All @@ -34,7 +34,7 @@ fn main() {
let _: Tensor<(usize, Const<2>), f32, _> = m.forward(dev.zeros_like(&(batch_size, Const)));

// you can also combine multiple modules with tuples
type Mlp = (Linear<4, 2>, ReLU, Linear<2, 1>);
type Mlp = (Linear<4, 2, Cpu>, ReLU, Linear<2, 1, Cpu>);
let mlp = Mlp::build_on_device(&dev);

// and of course forward passes the input through each module sequentially:
Expand Down
6 changes: 3 additions & 3 deletions examples/05-optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ use dfdx::{

// first let's declare our neural network to optimze
type Mlp = (
(Linear<5, 32>, ReLU),
(Linear<32, 32>, ReLU),
(Linear<32, 2>, Tanh),
(Linear<5, 32, Cpu>, ReLU),
(Linear<32, 32, Cpu>, ReLU),
(Linear<32, 2, Cpu>, Tanh),
);

fn main() {
Expand Down
8 changes: 4 additions & 4 deletions examples/06-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ impl MnistDataset {

// our network structure
type Mlp = (
(Linear<784, 512>, ReLU),
(Linear<512, 128>, ReLU),
(Linear<128, 32>, ReLU),
Linear<32, 10>,
(Linear<784, 512, Cpu>, ReLU),
(Linear<512, 128, Cpu>, ReLU),
(Linear<128, 32, Cpu>, ReLU),
Linear<32, 10, Cpu>,
);

// training batch size
Expand Down
4 changes: 2 additions & 2 deletions examples/07-custom-module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use dfdx::{
/// This case is trivial and should be done with a tuple of linears and relus,
/// but it demonstrates how to build models with custom behavior
struct Mlp<const IN: usize, const INNER: usize, const OUT: usize> {
l1: nn::Linear<IN, INNER>,
l2: nn::Linear<INNER, OUT>,
l1: nn::Linear<IN, INNER, Cpu>,
l2: nn::Linear<INNER, OUT, Cpu>,
relu: nn::ReLU,
}

Expand Down
2 changes: 1 addition & 1 deletion examples/11-multi-headed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
// SplitInto accepts a tuple of modules. Each one of the items in the
// tuple must accept the same type of input.
// Note that here, both of the linears have the same size input (1)
type Model = SplitInto<(Linear<1, 3>, Linear<1, 5>)>;
type Model = SplitInto<(Linear<1, 3, Cpu>, Linear<1, 5, Cpu>)>;
let m = Model::build_on_device(&dev);

// when we forward data through, we get a tuple back!
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-conv-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn main() {
(Conv2D<4, 8, 3>, ReLU),
(Conv2D<8, 16, 3>, ReLU),
Flatten2D,
Linear<7744, 10>,
Linear<7744, 10, Cpu>,
);

let dev: Cpu = Default::default();
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-resnet18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn main() {
(Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
(Downsample<128, 256>, ReLU, BasicBlock<256>, ReLU),
(Downsample<256, 512>, ReLU, BasicBlock<512>, ReLU),
(AvgPoolGlobal, Linear<512, NUM_CLASSES>),
(AvgPoolGlobal, Linear<512, NUM_CLASSES, Cpu>),
);

let dev: Cpu = Default::default();
Expand Down
6 changes: 3 additions & 3 deletions examples/rl-dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ const ACTION: usize = 2;

// our simple 2 layer feedforward network with ReLU activations
type QNetwork = (
(Linear<STATE, 32>, ReLU),
(Linear<32, 32>, ReLU),
Linear<32, ACTION>,
(Linear<STATE, 32, Cpu>, ReLU),
(Linear<32, 32, Cpu>, ReLU),
Linear<32, ACTION, Cpu>,
);

fn main() {
Expand Down
6 changes: 3 additions & 3 deletions examples/rl-ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ const STATE: usize = 4;
const ACTION: usize = 2;

type PolicyNetwork = (
(Linear<STATE, 32>, ReLU),
(Linear<32, 32>, ReLU),
Linear<32, ACTION>,
(Linear<STATE, 32, Cpu>, ReLU),
(Linear<32, 32, Cpu>, ReLU),
Linear<32, ACTION, Cpu>,
);

fn main() {
Expand Down
14 changes: 7 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,25 @@
//! ```rust
//! # use dfdx::prelude::*;
//! type Mlp = (
//! Linear<5, 3>,
//! Linear<5, 3, Cpu>,
//! ReLU,
//! Linear<3, 2>,
//! Linear<3, 2, Cpu>,
//! );
//! ```
//!
//! 3. Instantiate models with [crate::nn::BuildOnDevice]
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! type Model = (Linear<5, 2>, ReLU);
//! type Model = (Linear<5, 2, Cpu>, ReLU);
//! let mlp = Model::build_on_device(&dev);
//! ```
//!
//! 4. Pass data through networks with [crate::nn::Module]
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp: Linear<5, 2> = BuildModule::build(&dev);
//! # let mlp: Linear<5, 2, Cpu> = BuildModule::build(&dev);
//! let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
//! let y = mlp.forward(x); // compiler infers that `y` must be `Tensor<Rank1<2>>`
//! ```
Expand All @@ -51,7 +51,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let model: Linear<10, 5> = BuildModule::build(&dev);
//! # let model: Linear<10, 5, Cpu> = BuildModule::build(&dev);
//! # let y_true: Tensor<Rank1<5>, f32, _> = dev.sample_normal().softmax();
//! // tensors default to not having a tape
//! let x: Tensor<Rank1<10>, f32, Cpu, NoneTape> = dev.zeros();
Expand All @@ -68,7 +68,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients};
//! # let dev: Cpu = Default::default();
//! # let model: Linear<10, 5> = BuildModule::build(&dev);
//! # let model: Linear<10, 5, Cpu> = BuildModule::build(&dev);
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! // compute cross entropy loss
Expand All @@ -81,7 +81,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients, optim::*};
//! # let dev: Cpu = Default::default();
//! # let mut model: Linear<10, 5> = BuildModule::build(&dev);
//! # let mut model: Linear<10, 5, Cpu> = BuildModule::build(&dev);
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
Expand Down
27 changes: 18 additions & 9 deletions src/nn/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::{BuildModule, Module, ModuleMut, ResetParams, ToDevice};
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = AddInto<(Linear<2, 5>, Linear<3, 5>)>;
/// type Model = AddInto<(Linear<2, 5, Cpu>, Linear<3, 5, Cpu>)>;
/// let model = Model::build_on_device(&dev);
/// let a: Tensor<Rank1<2>, f32, _> = dev.zeros();
/// let b: Tensor<Rank1<3>, f32, _> = dev.zeros();
Expand Down Expand Up @@ -105,7 +105,7 @@ mod tests {
unique_id::HasUniqueId,
};

type TestAddIntoCpu = AddInto<(Linear<2, 5>, Linear<3, 5>)>;
type TestAddIntoCpu = AddInto<(Linear<2, 5, Cpu>, Linear<3, 5, Cpu>)>;
#[allow(unused)]
type TestAddInto<D> = OnDevice<TestAddIntoCpu, D>;

Expand Down Expand Up @@ -143,7 +143,12 @@ mod tests {
#[test]
fn test_add_into_4() {
let dev: TestDevice = Default::default();
type Model = AddInto<(Linear<2, 5>, Linear<3, 5>, Linear<4, 5>, Linear<5, 5>)>;
type Model = AddInto<(
Linear<2, 5, Cpu>,
Linear<3, 5, Cpu>,
Linear<4, 5, Cpu>,
Linear<5, 5, Cpu>,
)>;
let m = Model::build_on_device(&dev);
let _: Tensor<Rank1<5>, _, _, OwnedTape<_>> = m.forward((
dev.zeros::<Rank1<2>>().traced(),
Expand All @@ -163,11 +168,11 @@ mod tests {
fn test_add_into_5() {
let dev: TestDevice = Default::default();
type Model = AddInto<(
Linear<2, 5>,
Linear<3, 5>,
Linear<4, 5>,
Linear<5, 5>,
Linear<6, 5>,
Linear<2, 5, Cpu>,
Linear<3, 5, Cpu>,
Linear<4, 5, Cpu>,
Linear<5, 5, Cpu>,
Linear<6, 5, Cpu>,
)>;
let m = Model::build_on_device(&dev);
let _: Tensor<Rank1<5>, _, _, OwnedTape<_>> = m.forward((
Expand Down Expand Up @@ -249,7 +254,11 @@ mod tests {
fn longer_network() {
let dev: TestDevice = Default::default();
// check if it works in a longer neural net
type Model = (AddInto<(Linear<5, 3>, Linear<5, 3>)>, ReLU, Linear<3, 1>);
type Model = (
AddInto<(Linear<5, 3, Cpu>, Linear<5, 3, Cpu>)>,
ReLU,
Linear<3, 1, Cpu>,
);
let mut model = Model::build_on_device(&dev);
let _: Tensor<Rank1<1>, _, _, OwnedTape<_>> = model.forward((
dev.zeros::<Rank1<5>>().traced(),
Expand Down
2 changes: 1 addition & 1 deletion src/nn/impl_module_for_tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ mod tests {
#[test]
fn test_tuple_missing_gradients() {
let dev: TestDevice = Default::default();
type Model = (Linear<5, 3>, Linear<5, 3>, Linear<5, 3>);
type Model = (Linear<5, 3, Cpu>, Linear<5, 3, Cpu>, Linear<5, 3, Cpu>);
let mut model = Model::build_on_device(&dev);
let mut g: SimpleUpdater = Default::default();

Expand Down
18 changes: 9 additions & 9 deletions src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ use super::module::{BuildModule, Module, ModuleMut, ResetParams, ToDevice};
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// type Model = Linear<5, 2>;
/// type Model = Linear<5, 2, Cpu>;
/// let model = Model::build_on_device(&dev);
/// // single item forward
/// let _: Tensor<Rank1<2>, f32, _> = model.forward(dev.zeros::<Rank1<5>>());
/// // batched forward
/// let _: Tensor<Rank2<10, 2>, f32, _> = model.forward(dev.zeros::<Rank2<10, 5>>());
/// ```
#[derive(Debug, Clone)]
pub struct Linear<const I: usize, const O: usize, D: Device<f32> = Cpu> {
pub struct Linear<const I: usize, const O: usize, D: Device<f32>> {
/// Transposed weight matrix, shape (I, O)
pub weight: Tensor<Rank2<O, I>, f32, D>,

Expand Down Expand Up @@ -151,19 +151,19 @@ mod tests {
use super::super::module::OnDevice;

let cuda: Cuda = Default::default();
let _: Linear<1, 1, _> = BuildModule::build(&cuda);
let _: OnDevice<Linear<1, 1>, Cuda> = BuildModule::build(&cuda);
let _: OnDevice<(Linear<1, 2>, Linear<2, 1>), Cuda> = BuildModule::build(&cuda);
let _: Linear<1, 1, Cuda> = BuildModule::build(&cuda);
let _: OnDevice<Linear<1, 1, Cuda>, Cuda> = BuildModule::build(&cuda);
let _: OnDevice<(Linear<1, 2, Cuda>, Linear<2, 1, Cuda>), Cuda> = BuildModule::build(&cuda);

let _: Linear<1, 1, Cuda> = Linear::<1, 1>::build_on_device(&cuda);
let _: Linear<1, 1, _> = Linear::<1, 1>::build_on_device(&cuda);
let _ = Linear::<1, 1>::build_on_device(&cuda);
let _: Linear<1, 1, Cuda> = Linear::<1, 1, Cuda>::build_on_device(&cuda);
let _: Linear<1, 1, Cuda> = Linear::<1, 1, Cuda>::build_on_device(&cuda);
let _ = Linear::<1, 1, Cuda>::build_on_device(&cuda);
}

#[test]
fn test_linear_initialize() {
let dev: TestDevice = Default::default();
let m = Linear::<2000, 1>::build_on_device(&dev);
let m = Linear::<2000, 1, Cpu>::build_on_device(&dev);
let bound = 1.0 / 2000.0f32.sqrt();
for v in m.weight.as_vec() {
assert!(-bound <= v && v <= bound && v != 0.0);
Expand Down
24 changes: 12 additions & 12 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! type Model = Linear<5, 2>;
//! type Model = Linear<5, 2, Cpu>;
//! let model = Model::build_on_device(&dev);
//! ```
//!
Expand All @@ -56,7 +56,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! let mut model: Linear<5, 2> = BuildModule::build(&dev);
//! let mut model: Linear<5, 2, Cpu> = BuildModule::build(&dev);
//! model.reset_params();
//! ```
//!
Expand All @@ -67,22 +67,22 @@
//! Here's a single layer MLP:
//! ```rust
//! # use dfdx::prelude::*;
//! type Mlp = (Linear<5, 3>, ReLU, Linear<3, 2>);
//! type Mlp = (Linear<5, 3, Cpu>, ReLU, Linear<3, 2, Cpu>);
//! ```
//!
//! Here's a more complex feedforward network that takes vectors of 5 elements and maps them to 2 elements.
//! ```rust
//! # use dfdx::prelude::*;
//! type ComplexNetwork = (
//! DropoutOneIn<2>, // 1. dropout 50% of input
//! Linear<5, 3>, // 2. pass into a linear layer
//! LayerNorm1D<3>, // 3. normalize elements
//! ReLU, // 4. activate with relu
//! Residual<( // 5. residual connection that adds input to the result of it's sub layers
//! Linear<3, 3>,// 5.a. Apply linear layer
//! ReLU, // 5.b. Apply Relu
//! )>, // 5.c. the input to the residual is added back in after the sub layers
//! Linear<3, 2>, // 6. Apply another linear layer
//! DropoutOneIn<2>, // 1. dropout 50% of input
//! Linear<5, 3, Cpu>, // 2. pass into a linear layer
//! LayerNorm1D<3>, // 3. normalize elements
//! ReLU, // 4. activate with relu
//! Residual<( // 5. residual connection that adds input to the result of it's sub layers
//! Linear<3, 3, Cpu>, // 5.a. Apply linear layer
//! ReLU, // 5.b. Apply Relu
//! )>, // 5.c. the input to the residual is added back in after the sub layers
//! Linear<3, 2, Cpu>, // 6. Apply another linear layer
//! );
//! ```
//!
Expand Down
Loading

0 comments on commit a94946a

Please sign in to comment.