Skip to content

Commit

Permalink
f64 kernels (#421)
Browse files Browse the repository at this point in the history
* impl Cpu kernels for num_traits::Float

* Update examples

* Fixing cuda issues

* Add dtype skeleton to cuda kernels

* Working commit of TestDtype & cuda kernels

* Adding test-f64 feature

* Temp commit

* Moving more things to generic dtype

* Possibly fix max_to and min_to for doubles (#431)

* Temp commit

* Fixing optim tests

* Compiling

* Tests passing for Cpu

* Cuda compiling

* f64 cuda tests passing

* Fixing warnings

* Requiring minimum compute_60

* f64 nightly tests passing for cpu

* pool2d f64 kernels

* Conv & pool passing on cuda

* Making nn/dropout.rs generic over dtype

* Styling & using DeviceBuildExt

* Cleanup optimizers

* Minifying cuda kernels

* Adding f64 cmp kernels

* Removing macros in optim

* Reduce macro usage in cuda kernels

* minify cuda kernel impls

* Fixing docs

---------

Co-authored-by: Viliam Vadocz <viliam.vadocz@gmail.com>
  • Loading branch information
coreylowman and ViliamVadocz authored Feb 13, 2023
1 parent 4dc8a2d commit 0a5a016
Show file tree
Hide file tree
Showing 204 changed files with 4,004 additions and 2,861 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ zip = { version = "0.6.2", default-features = false, optional = true }
cblas-sys = { version = "0.1.4", default-features = false, optional = true }
libc = { version = "0.2", default-features = false, optional = true }
cudarc = { version = "0.7.0", default-features = false, optional = true }
num-traits = { version = "0.2.15", default-features = false }

[features]
default = ["std", "numpy"]
Expand All @@ -42,6 +43,7 @@ cblas = ["dep:cblas-sys", "dep:libc"]
intel-mkl = ["cblas"]
cuda = ["dep:cudarc"]
test-cuda = ["cuda"]
test-f64 = []

[dev-dependencies]
rand = "0.8.5"
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type Mlp = (
fn main() {
let dev: Cpu = Default::default();
// OR `let dev: Cuda = Default::default();`
let mlp = Mlp::build_on_device(&dev);
let mlp = dev.build_module::<Mlp, f32>();
let x: Tensor<Rank1<10>, f32, Cpu> = dev.zeros();
let y /*: Tensor<Rank1<2>, f32, Cpu>*/ = mlp.forward(x);
println!("{:?}", y);
Expand Down Expand Up @@ -127,12 +127,12 @@ for sequentially executing modules.
```rust
// no idea why you would do this, but you could!
type Model = (ReLU, Sigmoid, Tanh);
let model = Model::build_on_device(&dev);
let model = dev.build_module::<Model, f32>();
```

```rust
type Model = (Linear<10, 5>, Tanh)
let model = Model::build_on_device(&dev);
let model = dev.build_module::<Model, f32>();
```

How implementing Module for a 2-tuple looks:
Expand Down
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ mod cuda {
for kernel_path in kernel_paths {
println!("cargo:rerun-if-changed={}", kernel_path.display());
let output = std::process::Command::new("nvcc")
.args(["--gpu-architecture", "compute_60"])
.arg("--ptx")
.args(["--output-directory", &out_dir])
.args(&include_options)
Expand Down
34 changes: 27 additions & 7 deletions examples/03-nn.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
//! Intro to dfdx::nn
use dfdx::{
nn::{builders::*, modules, BuildOnDevice, Module, ModuleMut, ResetParams},
nn::{builders::*, BuildOnDevice, DeviceBuildExt, Module, ModuleMut, ResetParams},
shapes::{Const, Rank1, Rank2},
tensor::{AsArray, Cpu, SampleTensor, Tensor, ZerosTensor},
};

fn main() {
let dev: Cpu = Default::default();

// nn exposes many different neural network types, like the Linear layer!
// you can use BuildModule::build to construct an initialized model
// `nn::builders` exposes many different neural network types, like the Linear layer!
type Model = Linear<4, 2>;
let mut m: modules::Linear<4, 2, f32, Cpu> = Model::build_on_device(&dev);

// ResetParams::reset_params also allows you to re-randomize the weights
// you can use DeviceBuildExt::build_module to construct an initialized model.
// the type of this is actually one of the structures under `nn::modules`.
let mut m = dev.build_module::<Model, f32>();

// `ResetParams::reset_params` also allows you to re-randomize the weights
m.reset_params();

// Modules act on tensors using either:
Expand All @@ -26,8 +28,7 @@ fn main() {

// most of them can also act on many different shapes of tensors
// here we see that Linear can also accept batches of inputs
// Note: the Rank2 with a batch size of 10 in the input
// AND the output
// Note: the Rank2 with a batch size of 10 in the input AND the output
let _: Tensor<Rank2<10, 2>, f32, _> = m.forward(dev.zeros::<Rank2<10, 4>>());

// Even dynamic size is supported;
Expand All @@ -43,4 +44,23 @@ fn main() {
let a = mlp.forward(x.clone());
let b = mlp.2.forward(mlp.1.forward(mlp.0.forward(x)));
assert_eq!(a.array(), b.array());

// There are actually two ways to specify nn types, depending on whether
// you want device/dtype agnostic types, or not.

// For device agnostic structure, you can use the `nn::builders` api
{
use dfdx::nn::builders::*;
type Model = (Linear<5, 2>, ReLU, Tanh, Linear<2, 3>);
let _ = dev.build_module::<Model, f32>();
}

// The other way is to specify modules with device & dtype generics
// using `nn::modules`. The structures in modules are actually
// what the nn::builders turn into!
{
use dfdx::nn::{modules::*, BuildModule};
type Model<E, D> = (Linear<5, 2, E, D>, ReLU, Tanh, Linear<2, 3, E, D>);
let _: Model<f32, _> = BuildModule::build(&dev);
}
}
4 changes: 2 additions & 2 deletions examples/05-optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use dfdx::{
losses::mse_loss,
nn::{builders::*, BuildOnDevice, ModuleMut},
nn::{builders::*, DeviceBuildExt, ModuleMut},
optim::{Momentum, Optimizer, Sgd, SgdConfig},
shapes::Rank2,
tensor::{AsArray, Cpu, SampleTensor, Tensor},
Expand All @@ -20,7 +20,7 @@ fn main() {
let dev: Cpu = Default::default();

// First randomly initialize our model
let mut mlp = Mlp::build_on_device(&dev);
let mut mlp = dev.build_module::<Mlp, f32>();

// Here we construct a stochastic gradient descent optimizer
// for our Mlp.
Expand Down
2 changes: 1 addition & 1 deletion examples/06-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn main() {
let mut rng = StdRng::seed_from_u64(0);

// initialize model and optimizer
let mut model = Mlp::build_on_device(&dev);
let mut model = dev.build_module::<Mlp, f32>();
let mut opt = Adam::new(&model, Default::default());

// initialize dataset
Expand Down
4 changes: 2 additions & 2 deletions examples/11-multi-headed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use dfdx::{
nn::builders::{Linear, SplitInto},
nn::{BuildOnDevice, Module},
nn::{DeviceBuildExt, Module},
shapes::Rank1,
tensor::{Cpu, Tensor, TensorFrom},
};
Expand All @@ -15,7 +15,7 @@ fn main() {
// tuple must accept the same type of input.
// Note that here, both of the linears have the same size input (1)
type Model = SplitInto<(Linear<1, 3>, Linear<1, 5>)>;
let m = Model::build_on_device(&dev);
let m = dev.build_module::<Model, f32>();

// when we forward data through, we get a tuple back!
let _: (Tensor<Rank1<3>, f32, _>, Tensor<Rank1<5>, f32, _>) = m.forward(dev.tensor([1.0]));
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-conv-net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {
);

let dev: Cpu = Default::default();
let m = Model::build_on_device(&dev);
let m = dev.build_module::<Model, f32>();

// single image forward
let x: Tensor<Rank3<3, 28, 28>, f32, _> = dev.sample_normal();
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-resnet18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn main() {

let dev: Cpu = Default::default();
let x: Tensor<Rank3<3, 224, 224>, f32, _> = dev.sample_normal();
let m = Resnet18::<1000>::build_on_device(&dev);
let m = dev.build_module::<Resnet18<1000>, f32>();
for _ in 0.. {
let start = Instant::now();
let _ = m.forward(x.clone());
Expand Down
2 changes: 1 addition & 1 deletion examples/nightly-transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() {

let dev: Cpu = Default::default();
type Model = Transformer<16, 4, 3, 3, 8>;
let t = Model::build_on_device(&dev);
let t = dev.build_module::<Model, f32>();

let src: Tensor<Rank3<4, 12, 16>, f32, _> = dev.sample_normal();
let tgt: Tensor<Rank3<4, 6, 16>, f32, _> = dev.sample_normal();
Expand Down
2 changes: 1 addition & 1 deletion examples/rl-dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn main() {
let next_state = dev.sample_normal::<Rank2<BATCH, STATE>>();

// initiliaze model
let mut q_net = QNetwork::build_on_device(&dev);
let mut q_net = dev.build_module::<QNetwork, f32>();
let target_q_net = q_net.clone();

let mut sgd = Sgd::new(
Expand Down
2 changes: 1 addition & 1 deletion examples/rl-ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn main() {
let advantage = dev.sample_normal::<Rank1<BATCH>>();

// initiliaze model - all weights are 0s
let mut pi_net = PolicyNetwork::build_on_device(&dev);
let mut pi_net = dev.build_module::<PolicyNetwork, f32>();
let target_pi_net = pi_net.clone();

let mut sgd = Sgd::new(
Expand Down
55 changes: 43 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@
//! );
//! ```
//!
//! 3. Instantiate models with [crate::nn::BuildOnDevice]
//! 3. Instantiate models with [crate::nn::DeviceBuildExt]
//! ```rust
//! # use dfdx::prelude::*;
//! let dev: Cpu = Default::default();
//! type Model = (Linear<5, 2>, ReLU);
//! let mlp = Model::build_on_device(&dev);
//! let mlp = dev.build_module::<Model, f32>();
//! ```
//!
//! 4. Pass data through networks with [crate::nn::Module]
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let mlp = <Linear<5, 2>>::build_on_device(&dev);
//! # let mlp = dev.build_module::<Linear<5, 2>, f32>();
//! let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
//! let y = mlp.forward(x); // compiler infers that `y` must be `Tensor<Rank1<2>>`
//! ```
Expand All @@ -51,7 +51,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! # let model = <Linear<10, 5>>::build_on_device(&dev);
//! # let model = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true: Tensor<Rank1<5>, f32, _> = dev.sample_normal().softmax();
//! // tensors default to not having a tape
//! let x: Tensor<Rank1<10>, f32, Cpu, NoneTape> = dev.zeros();
Expand All @@ -68,7 +68,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients};
//! # let dev: Cpu = Default::default();
//! # let model = <Linear<10, 5>>::build_on_device(&dev);
//! # let model = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! // compute cross entropy loss
Expand All @@ -81,7 +81,7 @@
//! ```rust
//! # use dfdx::{prelude::*, gradients::Gradients, optim::*};
//! # let dev: Cpu = Default::default();
//! # let mut model = <Linear<10, 5>>::build_on_device(&dev);
//! # let mut model = dev.build_module::<Linear<10, 5>, f32>();
//! # let y_true = dev.sample_normal::<Rank1<5>>().softmax();
//! # let y = model.forward(dev.zeros::<Rank1<10>>().trace());
//! # let loss = cross_entropy_with_logits_loss(y, y_true);
Expand Down Expand Up @@ -166,17 +166,28 @@ pub fn keep_denormals() {

#[cfg(test)]
pub(crate) mod tests {
const TOLERANCE: f32 = 1e-6;

#[cfg(not(feature = "test-cuda"))]
pub type TestDevice = crate::tensor::Cpu;

#[cfg(feature = "test-cuda")]
pub type TestDevice = crate::tensor::Cuda;

#[cfg(not(feature = "test-f64"))]
pub type TestDtype = f32;

#[cfg(feature = "test-f64")]
pub type TestDtype = f64;

pub trait AssertClose {
fn get_far_pair(&self, rhs: &Self, tolerance: f32) -> Option<(f32, f32)>;
fn assert_close(&self, rhs: &Self, tolerance: f32)
type Elem: std::fmt::Display + std::fmt::Debug + Copy;
const DEFAULT_TOLERANCE: Self::Elem;
fn get_far_pair(
&self,
rhs: &Self,
tolerance: Self::Elem,
) -> Option<(Self::Elem, Self::Elem)>;
fn assert_close(&self, rhs: &Self, tolerance: Self::Elem)
where
Self: std::fmt::Debug,
{
Expand All @@ -187,6 +198,8 @@ pub(crate) mod tests {
}

impl AssertClose for f32 {
type Elem = f32;
const DEFAULT_TOLERANCE: Self::Elem = 1e-6;
fn get_far_pair(&self, rhs: &Self, tolerance: f32) -> Option<(f32, f32)> {
if (self - rhs).abs() > tolerance {
Some((*self, *rhs))
Expand All @@ -196,8 +209,26 @@ pub(crate) mod tests {
}
}

impl AssertClose for f64 {
type Elem = f64;
const DEFAULT_TOLERANCE: Self::Elem = 1e-6;
fn get_far_pair(&self, rhs: &Self, tolerance: f64) -> Option<(f64, f64)> {
if (self - rhs).abs() > tolerance {
Some((*self, *rhs))
} else {
None
}
}
}

impl<T: AssertClose, const M: usize> AssertClose for [T; M] {
fn get_far_pair(&self, rhs: &Self, tolerance: f32) -> Option<(f32, f32)> {
type Elem = T::Elem;
const DEFAULT_TOLERANCE: Self::Elem = T::DEFAULT_TOLERANCE;
fn get_far_pair(
&self,
rhs: &Self,
tolerance: Self::Elem,
) -> Option<(Self::Elem, Self::Elem)> {
for (l, r) in self.iter().zip(rhs.iter()) {
if let Some(pair) = l.get_far_pair(r, tolerance) {
return Some(pair);
Expand All @@ -208,13 +239,13 @@ pub(crate) mod tests {
}

pub fn assert_close<T: AssertClose + std::fmt::Debug>(a: &T, b: &T) {
a.assert_close(b, TOLERANCE);
a.assert_close(b, T::DEFAULT_TOLERANCE);
}

pub fn assert_close_with_tolerance<T: AssertClose + std::fmt::Debug>(
a: &T,
b: &T,
tolerance: f32,
tolerance: T::Elem,
) {
a.assert_close(b, tolerance);
}
Expand Down
Loading

0 comments on commit 0a5a016

Please sign in to comment.