diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 954d366..0bf8153 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -on: [push, pull_request] +on: [push] name: CI jobs: @@ -25,6 +25,8 @@ jobs: test: name: Test runs-on: ubuntu-latest + env: + AF_VER: 3.8.0 steps: - name: Checkout sources uses: actions/checkout@v2 @@ -36,11 +38,35 @@ jobs: toolchain: stable override: true + - name: Cache ArrayFire + uses: actions/cache@v1 + id: arrayfire + with: + path: afbin + key: ${{ runner.os }}-af-${{ env.AF_VER }} + + - name: Download ArrayFire + # Only download and cache arrayfire if already not found + if: steps.arrayfire.outputs.cache-hit != 'true' + run: | + wget --quiet http://arrayfire.s3.amazonaws.com/${AF_VER}/ArrayFire-v${AF_VER}_Linux_x86_64.sh + chmod +x ./ArrayFire-v${AF_VER}_Linux_x86_64.sh + mkdir afbin + ./ArrayFire-v${AF_VER}_Linux_x86_64.sh --skip-license --exclude-subdir --prefix=./afbin + rm ./afbin/lib64/libcu*.so* + rm ./afbin/lib64/libafcuda*.so* + rm ./ArrayFire-v${AF_VER}_Linux_x86_64.sh + + - name: Export ArrayFire paths + run: | + echo "AF_PATH=${GITHUB_WORKSPACE}/afbin" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64" >> $GITHUB_ENV + - name: Run cargo test uses: actions-rs/cargo@v1 with: command: test - args: --all + args: --all-features lints: name: Lints @@ -67,11 +93,13 @@ jobs: uses: actions-rs/clippy-check@v1 with: token: ${{ secrets.GITHUB_TOKEN }} - args: --all --all-features -- -D warnings + args: --all-features -- -D warnings coverage: name: Coverage runs-on: ubuntu-latest + env: + AF_VER: 3.8.0 steps: - name: Checkout sources uses: actions/checkout@v2 @@ -83,6 +111,30 @@ jobs: toolchain: stable override: true + - name: Cache ArrayFire + uses: actions/cache@v1 + id: arrayfire + with: + path: afbin + key: ${{ runner.os }}-af-${{ env.AF_VER }} + + - name: Download ArrayFire + # Only download and cache arrayfire if already not found + if: steps.arrayfire.outputs.cache-hit != 'true' + run: | + wget --quiet http://arrayfire.s3.amazonaws.com/${AF_VER}/ArrayFire-v${AF_VER}_Linux_x86_64.sh + chmod +x ./ArrayFire-v${AF_VER}_Linux_x86_64.sh + mkdir afbin + ./ArrayFire-v${AF_VER}_Linux_x86_64.sh --skip-license --exclude-subdir --prefix=./afbin + rm ./afbin/lib64/libcu*.so* + rm ./afbin/lib64/libafcuda*.so* + rm ./ArrayFire-v${AF_VER}_Linux_x86_64.sh + + - name: Export ArrayFire paths + run: | + echo "AF_PATH=${GITHUB_WORKSPACE}/afbin" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64" >> $GITHUB_ENV + - name: Run cargo-tarpaulin uses: actions-rs/tarpaulin@v0.1 with: diff --git a/Cargo.toml b/Cargo.toml index 98c0dd8..40f0318 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,19 @@ -[workspace] -members = [ - "mushin", - "mushin_derive", - "example", -] \ No newline at end of file +[package] +name = "mushin" +version = "0.2.0" +authors = ["Aitor Ruano "] +edition = "2021" +description = "Computational graphs with reverse automatic differentation in the GPU" +homepage = "https://github.com/c0dearm/mushin" +repository = "https://github.com/c0dearm/mushin" +readme = "README.md" +keywords = ["machine-learning", "automatic", "differentiation", "cuda", "opencl", "compute", "gpu", "cpu"] +categories = ["algorithms", "mathematics", "science"] +license = "MIT/Apache-2.0" + +[badges] +maintenance = { status = "actively-developed" } +codecov = { repository = "c0dearm/mushin" } + +[dependencies] +arrayfire = "3.8" diff --git a/README.md b/README.md index b074aed..f5105bc 100644 --- a/README.md +++ b/README.md @@ -11,85 +11,48 @@ ## Description -Mushin allows the developer to build neural networks at compile-time, with preallocated arrays with well defined sizes. This has mainly three very important benefits: +**Mushin** is to `Rust` what `Tensorflow` is to `Python`. A library to build computational graphs and compute the gradients of the outputs with respect to a given set of variables using [reverse automatic differentatiation](https://en.wikipedia.org/wiki/Automatic_differentiation). -1. **Compile-time network consistency check**: Any defect in your neural network (i.e. mismatching layers inputs/outputs) will be raised at compile-time. You can enjoy your coffee while your network inference or training process never fails! -2. **Awesome Rust compiler optimizations**: Because the neural network is completely defined at compile-time, the compiler is able -to perform smart optimizations, like unrolling loops or injecting [SIMD](https://en.wikipedia.org/wiki/SIMD) instructions. -3. **Support for embedded**: The `std` library is not required to build neural networks so it can run on any target that Rust supports. +Internally it uses the [arrayfire](https://crates.io/crates/arrayfire) crate to provide parallel computations on specialized hardware, such as Nvidia CUDA GPUs, Intel MKL CPUs... For details on what devices are available and installation instructions for your OS, please checkout the `arrayfire` crate documentation. **The installation of the `arrayfire` binaries is required for `Mushin` to work.** + +One clear benefit of this crate versus `Tensorflow` is Rust's strong type system. All operations performed on tensors during the graph build are checked at compile time for mathematical soundness, which means no runtime error after an hour of model training. **If it compiles, it works**. If at some point while building your horribly nested computational graph you make a mistake on the shape of a tensor you'll be stopped before feeling stupid. ## Usage -Add this to your `Cargo.toml`: +First, install the arrayfire binaries as indicated by the [arrayfire](https://crates.io/crates/arrayfire) crate. + +Then, add **Mushin** as one of your dependencies: ```toml [dependencies] -mushin = "0.1" -mushin_derive = "0.1" +mushin = "0.2" ``` -And this is a very simple example to get you started: +The following is a self-explanatory example of the basic usage of **Mushin**, for more details, please check the crate [docs](https://docs.rs/mushin/latest/mushin/). ```rust -use rand::distributions::Uniform; - -use mushin::{activations::ReLu, layers::Dense, NeuralNetwork}; -use mushin_derive::NeuralNetwork; - -// Builds a neural network with 2 inputs and 1 output -// Made of 3 feed forward layers, you can have as many as you want and with any name -#[derive(NeuralNetwork, Debug)] -struct MyNetwork { - // LayerType - input: Dense, - hidden: Dense, - output: Dense, -} - -impl MyNetwork { - // Initialize layer weights with a uniform distribution and set ReLU as activation function - fn new() -> Self { - let mut rng = rand::thread_rng(); - let dist = Uniform::from(-1.0..=1.0); - - MyNetwork { - input: Dense::random(&mut rng, &dist), - hidden: Dense::random(&mut rng, &dist), - output: Dense::random(&mut rng, &dist), - } - } -} +use mushin::{Context, Values, Class, Gradients, add, matmul}; fn main() { - // Init the weights and perform a forward pass - let nn = MyNetwork::new(); - println!("{:#?}", nn); - - let input = [0.0, 1.0]; - println!("Input: {:#?}", input); - let output = nn.forward(input); - println!("Output: {:#?}", output); -} -``` + let ctx = Context::new(); -You may wonder how the `forward` method works. The `NeuralNetwork` derive macro defines it for you, and it looks like this for this particular example: + let x = ctx.tensor::<1, 1, 2, 3>(Values::Eye(3.0), Class::Constant); + let w = ctx.tensor::<1, 1, 3, 2>(Values::Normal, Class::Persistent("weights")); + let b = ctx.tensor::<1, 1, 3, 3>(Values::Fill(0.0), Class::Persistent("bias")); + let z = add(&b, &matmul(&w, &x)); -```rust -fn forward(&self, input: [f32; 2]) -> [f32; 1] { - self.output.forward(self.hidden.forward(self.input.forward[input])) + let grads = Gradients::compute(&z); + let dz_dw = grads.wrt(&w); + let dz_db = grads.wrt(&b); } ``` -Note how the forward method expects two input values because that's what the first (`input`) layer expects, and returns one single value because that's what the last layer (`output`) returns. - ## Roadmap -- [x] Compile-time neural network consistency check -- [x] Docs, CI/CD & Benchmarks -- [ ] Backward pass -- [ ] More layer types (convolution, dropout, lstm...) -- [ ] More activation functions (sigmoid, softmax...) -- [ ] Maaaybeee, CPU and/or GPU concurrency +- [ ] Add more operations +- [ ] Allow for higher-order gradients +- [ ] Add benchmarks +- [ ] Add a cargo feature for deep learning, which adds layers, losses and activation functions (like `Keras`) ## Contributing @@ -105,4 +68,4 @@ Mushin is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and -[COPYRIGHT](COPYRIGHT) for details. +[COPYRIGHT](COPYRIGHT) for details. \ No newline at end of file diff --git a/example/Cargo.toml b/example/Cargo.toml deleted file mode 100644 index 4490260..0000000 --- a/example/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "example" -version = "0.1.0" -authors = ["Aitor Ruano "] -edition = "2018" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mushin = { path = "../mushin" } -mushin_derive = { path = "../mushin_derive" } -rand = "0.8" \ No newline at end of file diff --git a/example/src/main.rs b/example/src/main.rs deleted file mode 100644 index 6a2cef9..0000000 --- a/example/src/main.rs +++ /dev/null @@ -1,34 +0,0 @@ -use rand::distributions::Uniform; - -use mushin::{activations::ReLu, layers::Dense, NeuralNetwork}; -use mushin_derive::NeuralNetwork; - -#[derive(NeuralNetwork, Debug)] -struct MyNetwork { - input: Dense, - hidden: Dense, - output: Dense, -} - -impl MyNetwork { - fn new() -> Self { - let mut rng = rand::thread_rng(); - let dist = Uniform::from(-1.0..=1.0); - - MyNetwork { - input: Dense::random(&mut rng, &dist), - hidden: Dense::random(&mut rng, &dist), - output: Dense::random(&mut rng, &dist), - } - } -} - -fn main() { - let nn = MyNetwork::new(); - println!("{:#?}", nn); - - let input = [0.0, 1.0]; - println!("Input: {:#?}", input); - let output = nn.forward(input); - println!("Output: {:#?}", output); -} diff --git a/mushin/Cargo.toml b/mushin/Cargo.toml deleted file mode 100644 index 0ef61d7..0000000 --- a/mushin/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "mushin" -version = "0.1.1" -authors = ["Aitor Ruano "] -edition = "2018" -description = "Compile-time creation of neural networks" -homepage = "https://github.com/c0dearm/mushin" -repository = "https://github.com/c0dearm/mushin" -readme = "../README.md" -keywords = ["neural", "network", "artificial", "intelligence", "AI"] -categories = ["algorithms", "no-std", "mathematics"] -license = "MIT/Apache-2.0" - -[badges] -maintenance = { status = "actively-developed" } -codecov = { repository = "c0dearm/mushin" } - -[dependencies] -rand = "0.8" - -[dev-dependencies] -mushin_derive = { path = "../mushin_derive" } -rand_chacha = "0.3" -approx = "0.4" \ No newline at end of file diff --git a/mushin/src/activations.rs b/mushin/src/activations.rs deleted file mode 100644 index d9d3fe2..0000000 --- a/mushin/src/activations.rs +++ /dev/null @@ -1,51 +0,0 @@ -pub trait Activation { - fn activation(x: f32) -> f32; -} - -#[derive(Debug)] -pub struct Nope; -#[derive(Debug)] -pub struct ReLu; -#[derive(Debug)] -pub struct Sigmoid; - -impl Activation for Nope { - fn activation(x: f32) -> f32 { - x - } -} - -impl Activation for ReLu { - fn activation(x: f32) -> f32 { - x.max(0.0) - } -} - -impl Activation for Sigmoid { - fn activation(x: f32) -> f32 { - 1.0 / (1.0 + (-x).exp()) - } -} - -#[cfg(test)] -mod tests { - use super::Activation; - use super::{Nope, ReLu, Sigmoid}; - - #[test] - fn nope_activation() { - approx::assert_relative_eq!(Nope::activation(1.0), 1.0); - } - - #[test] - fn relu_activation() { - approx::assert_relative_eq!(ReLu::activation(-1.0), 0.0); - approx::assert_relative_eq!(ReLu::activation(1.0), 1.0); - } - - #[test] - fn sigmoid_activation() { - approx::assert_relative_eq!(Sigmoid::activation(-1.0), 0.26894143); - approx::assert_relative_eq!(Sigmoid::activation(1.0), 0.7310586); - } -} diff --git a/mushin/src/layers/dense.rs b/mushin/src/layers/dense.rs deleted file mode 100644 index d4a7889..0000000 --- a/mushin/src/layers/dense.rs +++ /dev/null @@ -1,90 +0,0 @@ -use crate::activations::Activation; -use core::marker::PhantomData; -use rand::{distributions::Distribution, Rng, RngCore}; - -#[derive(Debug)] -pub struct Dense { - weights: [[f32; I]; O], - bias: [f32; O], - activation: PhantomData, -} - -impl Dense { - pub fn new(weights: [[f32; I]; O], bias: [f32; O]) -> Self { - Dense { - weights, - bias, - activation: PhantomData, - } - } - - pub fn random>(rng: &mut R, dist: &D) -> Self { - let mut weights = [[0.0; I]; O]; - let mut bias = [0.0; O]; - - weights - .iter_mut() - .flatten() - .chain(bias.iter_mut()) - .zip(rng.sample_iter(dist)) - .for_each(|(w, r)| *w = r); - - Dense::new(weights, bias) - } - - pub fn forward(&self, input: [f32; I]) -> [f32; O] { - let mut output = [0.0; O]; - - output - .iter_mut() - .zip(self.bias.iter()) - .enumerate() - .for_each(|(k, (o, &b))| { - *o = A::activation( - input - .iter() - .zip(self.weights[k].iter()) - .fold(b, |acc, (x, &w)| x.mul_add(w, acc)), - ) - }); - - output - } -} - -#[cfg(test)] -mod tests { - use super::Dense; - use crate::activations::ReLu; - - use rand::{distributions::Uniform, SeedableRng}; - use rand_chacha::ChaCha8Rng; - - #[test] - fn dense_new() { - let layer = Dense::::new([[0.0, 1.0], [2.0, 3.0]], [1.0, 1.0]); - approx::assert_relative_eq!(layer.weights[0][..], [0.0, 1.0]); - approx::assert_relative_eq!(layer.weights[1][..], [2.0, 3.0]); - approx::assert_relative_eq!(layer.bias[..], [1.0, 1.0]); - } - - #[test] - fn dense_random() { - let mut rng = ChaCha8Rng::from_seed(Default::default()); - let between = Uniform::from(-1.0..=1.0); - let layer = Dense::::random(&mut rng, &between); - approx::assert_relative_eq!(layer.weights[0][..], [-0.6255188, 0.67383957]); - approx::assert_relative_eq!(layer.weights[1][..], [0.8181262, 0.26284897]); - approx::assert_relative_eq!(layer.bias[..], [0.5238807, -0.53516835]); - } - - #[test] - fn dense_forward() { - let layer = Dense::::new( - [[1.0, 1.0], [1.0, 1.0], [-2.0, -2.0], [-2.0, -2.0]], - [-2.0, 1.0, -2.0, 1.0], - ); - let output = layer.forward([1.0, 1.0]); - approx::assert_relative_eq!(output[..], [0.0, 3.0, 0.0, 0.0]); - } -} diff --git a/mushin/src/layers/mod.rs b/mushin/src/layers/mod.rs deleted file mode 100644 index 53c6432..0000000 --- a/mushin/src/layers/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod dense; - -pub use dense::Dense; diff --git a/mushin/src/lib.rs b/mushin/src/lib.rs deleted file mode 100644 index 2e34bef..0000000 --- a/mushin/src/lib.rs +++ /dev/null @@ -1,41 +0,0 @@ -pub mod activations; -pub mod layers; -pub mod losses; - -// Trait to use with `derive` on structs containing the neural network layers -pub trait NeuralNetwork { - fn forward(&self, input: [f32; I]) -> [f32; O]; -} - -#[cfg(test)] -mod tests { - use super::NeuralNetwork; - use crate::activations::ReLu; - use crate::layers::Dense; - use mushin_derive::NeuralNetwork; - - use rand::{distributions::Uniform, SeedableRng}; - use rand_chacha::ChaCha8Rng; - - #[derive(NeuralNetwork)] - struct TestNetwork { - input: Dense, - hidden: Dense, - output: Dense, - } - - #[test] - fn network_forward() { - let mut rng = ChaCha8Rng::from_seed(Default::default()); - let dist = Uniform::from(-1.0..=1.0); - - let nn = TestNetwork { - input: Dense::random(&mut rng, &dist), - hidden: Dense::random(&mut rng, &dist), - output: Dense::random(&mut rng, &dist), - }; - - let output = nn.forward([1.0, 1.0]); - approx::assert_relative_eq!(output[..], [0.0]); - } -} diff --git a/mushin/src/losses.rs b/mushin/src/losses.rs deleted file mode 100644 index 9d920e2..0000000 --- a/mushin/src/losses.rs +++ /dev/null @@ -1,52 +0,0 @@ -pub trait Loss { - fn loss(output: [f32; O], target: [f32; O]) -> f32; -} - -pub struct MeanSquaredError; -pub struct CrossEntropy; - -impl Loss for MeanSquaredError { - fn loss(output: [f32; O], target: [f32; O]) -> f32 { - output - .iter() - .zip(target.iter()) - .fold(0.0, |acc, (o, t)| acc + (o - t).powi(2)) - / O as f32 - } -} - -impl Loss for CrossEntropy { - fn loss(output: [f32; O], target: [f32; O]) -> f32 { - let sum = output.iter().map(|o| o.exp()).sum::(); - - output - .iter() - .zip(target.iter()) - .fold(0.0, |acc, (&o, &t)| (sum.ln() - o).mul_add(t, acc)) - } -} - -#[cfg(test)] -mod tests { - use super::Loss; - use super::{CrossEntropy, MeanSquaredError}; - - #[test] - fn mean_squared_error_loss() { - approx::assert_relative_eq!( - MeanSquaredError::loss([1.0, 1.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]), - 0.5 - ); - } - - #[test] - fn cross_entropy_loss() { - approx::assert_relative_eq!( - CrossEntropy::loss( - [0.05, 0.95, 0.0, 0.1, 0.8, 0.1], - [0.0, 1.0, 0.0, 0.0, 0.0, 1.0] - ), - 3.360576 - ); - } -} diff --git a/mushin_derive/Cargo.toml b/mushin_derive/Cargo.toml deleted file mode 100644 index fa4effb..0000000 --- a/mushin_derive/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "mushin_derive" -version = "0.1.1" -authors = ["Aitor Ruano "] -edition = "2018" -description = "Compile-time creation of neural networks" -homepage = "https://github.com/c0dearm/mushin" -repository = "https://github.com/c0dearm/mushin" -readme = "../README.md" -keywords = ["neural", "network", "artificial", "intelligence", "AI"] -categories = ["algorithms", "no-std", "mathematics"] -license = "MIT/Apache-2.0" - -[badges] -maintenance = { status = "actively-developed" } -codecov = { repository = "c0dearm/mushin" } - -[lib] -proc-macro = true - -[dependencies] -proc-macro2 = "1.0" -syn = "1.0" -quote = "1.0" \ No newline at end of file diff --git a/mushin_derive/src/lib.rs b/mushin_derive/src/lib.rs deleted file mode 100644 index f92955f..0000000 --- a/mushin_derive/src/lib.rs +++ /dev/null @@ -1,70 +0,0 @@ -use proc_macro2::TokenStream; -use quote::quote; -use syn::{ - parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Expr, ExprLit, - Field, Fields, GenericArgument, Ident, Lit, Path, PathArguments::AngleBracketed, Type, - TypePath, -}; - -#[proc_macro_derive(NeuralNetwork)] -pub fn derive_neural_network(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(input as DeriveInput); - - let name = input.ident; - let fields = match input.data { - Data::Struct(data) => match data.fields { - Fields::Named(fields) => fields.named, - Fields::Unnamed(_) | Fields::Unit => unimplemented!(), - }, - Data::Enum(_) | Data::Union(_) => unimplemented!(), - }; - - proc_macro::TokenStream::from(impl_neural_network(name, fields)) -} - -fn get_field_type_args(field: &Field) -> &Punctuated { - let type_args = &match &field.ty { - Type::Path(TypePath { - qself: _, - path: Path { - leading_colon: _, - segments, - }, - }) => segments, - _ => unimplemented!(), - }[0] - .arguments; - - match type_args { - AngleBracketed(args) => &args.args, - _ => unimplemented!(), - } -} - -fn as_usize(arg: &GenericArgument) -> usize { - match arg { - GenericArgument::Const(Expr::Lit(ExprLit { - attrs: _, - lit: Lit::Int(v), - })) => v.base10_parse::().unwrap(), - _ => unimplemented!(), - } -} - -fn impl_neural_network(name: Ident, fields: Punctuated) -> TokenStream { - let forward_chain = fields.iter().fold(quote!(input), |acc, f| { - let name = &f.ident; - quote!(self.#name.forward(#acc)) - }); - - let input_size = as_usize(&get_field_type_args(fields.first().unwrap())[1]); - let output_size = as_usize(&get_field_type_args(fields.last().unwrap())[2]); - - quote! { - impl NeuralNetwork<#input_size, #output_size> for #name { - fn forward(&self, input: [f32; #input_size]) -> [f32; #output_size] { - #forward_chain - } - } - } -} diff --git a/src/context/function.rs b/src/context/function.rs new file mode 100644 index 0000000..ce7ba4f --- /dev/null +++ b/src/context/function.rs @@ -0,0 +1,191 @@ +use arrayfire::Array; + +use crate::tensor::{Origin, Tensor}; + +/// Represents an argument to a function and its value comes from a variable tensor +pub struct VariableArg { + value: Array, + function: usize, +} + +/// Represents an argument to a function but its value comes from a constant tensor +pub struct ConstantArg { + value: Array, +} + +/// Represents the arguments of a function with two arguments +pub enum DoubleArg { + /// Both arguments are variables + BothVariables(VariableArg, VariableArg), + /// First argument is a constant, second is a variable + ConstFirstArg(ConstantArg, VariableArg), + /// First argument is a variable, second is a constant + ConstSecondArg(VariableArg, ConstantArg), +} + +impl DoubleArg { + const fn values(&self) -> (Option, &Array, Option, &Array) { + match *self { + Self::BothVariables( + VariableArg { + value: ref a, + function: fa, + }, + VariableArg { + value: ref b, + function: fb, + }, + ) => (Some(fa), a, Some(fb), b), + Self::ConstFirstArg( + ConstantArg { value: ref a }, + VariableArg { + value: ref b, + function: fb, + }, + ) => (None, a, Some(fb), b), + Self::ConstSecondArg( + VariableArg { + value: ref a, + function: fa, + }, + ConstantArg { value: ref b }, + ) => (Some(fa), a, None, b), + } + } +} + +/// Type of the function performing the reverse pass on a single argument function +type OneArgBackwardFn = fn(df: &Array, arg: &Array) -> Array; +/// Type of the function performing the reverse pass on a double argument function +type TwoArgsBackwardFn = + fn(df: &Array, arg_a: &Array, arg_b: &Array) -> (Array, Array); + +/// `f(x) = cos(x)` if x is a variable +pub struct OneArg { + arg: VariableArg, + backward: OneArgBackwardFn, +} + +impl OneArg { + pub(crate) fn backward(&self, df: &Array) -> (usize, Array) { + (self.arg.function, (self.backward)(df, &self.arg.value)) + } +} + +/// `f(x, y) = x * y` if at least one of them is a variable +pub struct TwoArgs { + args: DoubleArg, + backward: TwoArgsBackwardFn, +} + +impl TwoArgs { + pub(crate) fn backward( + &self, + df: &Array, + ) -> (Option, Array, Option, Array) { + let (f_a, arg_a, f_b, arg_b) = self.args.values(); + let (partial_a, partial_b) = (self.backward)(df, arg_a, arg_b); + (f_a, partial_a, f_b, partial_b) + } +} + +/// Represents a node in the computation graph. +pub enum Function { + /// A variable declaration (constants are ignored) + Nary, + /// A function with only one arg (constants are ignored), like `cos(x)` + Unary(OneArg), + /// A function with two args if at least one of them is a variable, like `x * y` + Binary(TwoArgs), +} + +impl Function { + /// Creates a single argument function and pushes it to the tape, if the argument is a variable + /// Returns a new tensor origin with a reference to the newly created function + pub(crate) fn unary( + arg: &Tensor, + backward: OneArgBackwardFn, + ) -> Origin { + if let &Origin::Function(function) = arg.origin() { + let function = Self::Unary(OneArg { + arg: VariableArg { + value: arg.into(), + function, + }, + backward, + }); + Origin::Function(arg.context().push_function(function)) + } else { + // Single argument function applied to a constant is a constant + Origin::None + } + } + + /// Creates a double argument function and pushes it to the tape, if at least one of the arguments is a variable + /// Returns a new tensor origin with a reference to the newly created function + pub(crate) fn binary< + const XB: u64, + const XN: u64, + const XR: u64, + const XC: u64, + const YB: u64, + const YN: u64, + const YR: u64, + const YC: u64, + >( + arg_a: &Tensor, + arg_b: &Tensor, + backward: TwoArgsBackwardFn, + ) -> Origin { + match (arg_a.origin(), arg_b.origin()) { + // If both arguments are a constant, result is a constant + (&Origin::None, &Origin::None) => Origin::None, + (&Origin::Function(function), &Origin::None) => { + let function = Self::Binary(TwoArgs { + args: DoubleArg::ConstSecondArg( + VariableArg { + value: arg_a.into(), + function, + }, + ConstantArg { + value: arg_b.into(), + }, + ), + backward, + }); + Origin::Function(arg_a.context().push_function(function)) + } + (&Origin::None, &Origin::Function(function)) => { + let function = Self::Binary(TwoArgs { + args: DoubleArg::ConstFirstArg( + ConstantArg { + value: arg_a.into(), + }, + VariableArg { + value: arg_b.into(), + function, + }, + ), + backward, + }); + Origin::Function(arg_b.context().push_function(function)) + } + (&Origin::Function(function_a), &Origin::Function(function_b)) => { + let function = Self::Binary(TwoArgs { + args: DoubleArg::BothVariables( + VariableArg { + value: arg_a.into(), + function: function_a, + }, + VariableArg { + value: arg_b.into(), + function: function_b, + }, + ), + backward, + }); + Origin::Function(arg_a.context().push_function(function)) + } + } + } +} diff --git a/src/context/mod.rs b/src/context/mod.rs new file mode 100644 index 0000000..1f373c4 --- /dev/null +++ b/src/context/mod.rs @@ -0,0 +1,88 @@ +pub mod function; +mod storage; +mod tape; + +use std::cell::Ref; + +use arrayfire::{constant, dim4, identity, randn, randu}; + +use crate::tensor::{Class, Origin, Tensor, Values}; + +use function::Function; +use storage::Storage; +use tape::Tape; + +/// Stores the computation graph (tape) of functions and persistent values thorugh different tape builds +pub struct Context { + storage: Storage, + tape: Tape, +} + +impl Context { + /// Creates a new `Context` with fresh storage and computation graph + #[must_use] + #[inline] + pub fn new() -> Self { + Self { + storage: Storage::new(), + tape: Tape::new(), + } + } + + // Start the computation graph from scratch, but keep the persistent values + #[inline] + pub fn reset(&self) { + self.tape.reset(); + } + + /// Creates a new tensor in the computation graph with the given parameters + #[inline] + pub fn tensor( + &self, + values: Values, + class: Class, + ) -> Tensor { + let gen_values = |v| match v { + Values::Identity => identity(dim4!(R, C, L, B)), + Values::Uniform => randu!(R, C, L, B), + Values::Normal => randn!(R, C, L, B), + Values::Eye(x) => identity::(dim4!(R, C, L, B)) * x, + Values::Fill(x) => constant!(x; R, C, L, B), + }; + + match class { + Class::Constant => Tensor::new(gen_values(values), Origin::None, self), + Class::Variable => Tensor::new( + gen_values(values), + Origin::Function(self.tape.push_function(Function::Nary)), + self, + ), + Class::Persistent(key) => { + let function = self.tape.push_function(Function::Nary); + let value = self + .storage + .get_or_create(key, gen_values, values, function); + Tensor::new(value, Origin::Function(function), self) + } + } + } + + pub(crate) fn functions(&self) -> Ref> { + self.tape.functions() + } + + pub(crate) fn push_function(&self, function: Function) -> usize { + self.tape.push_function(function) + } + + pub(crate) fn tape_len(&self) -> usize { + self.tape.len() + } +} + +impl Default for Context { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/src/context/storage.rs b/src/context/storage.rs new file mode 100644 index 0000000..a3fb953 --- /dev/null +++ b/src/context/storage.rs @@ -0,0 +1,53 @@ +use std::cell::RefCell; +use std::collections::HashMap; + +use arrayfire::Array; + +/// Stores the tensor value and a reference to the function that originated it. +/// Used by `Storage` +pub struct PersistentValue { + pub(crate) value: Array, + pub(crate) function: usize, +} + +/// Stores the values of persistent tensors across different computation graph builds. +/// Used by the `Context` +pub struct Storage(RefCell>); + +impl Storage { + pub(crate) fn new() -> Self { + Self(RefCell::new(HashMap::new())) + } + + pub(crate) fn get_or_create Array>( + &self, + key: &'static str, + get_value_fn: F, + arg: A, + function: usize, + ) -> Array { + let mut map = self.0.borrow_mut(); + let item = map.get_mut(key); + + match item { + Some(&mut PersistentValue { + ref value, + function: ref mut f, + }) => { + *f = function; + value.clone() + } + None => { + let value = get_value_fn(arg); + map.insert( + key, + PersistentValue { + value: value.clone(), + function, + }, + ); + value + } + } + } +} diff --git a/src/context/tape.rs b/src/context/tape.rs new file mode 100644 index 0000000..f6007e4 --- /dev/null +++ b/src/context/tape.rs @@ -0,0 +1,31 @@ +use std::cell::{Ref, RefCell}; + +use crate::context::function::Function; + +/// Tape (or Wengert list) that keeps track of all the expressions evaluated since its declaration (a computation graph). +/// Used by the `Context` +pub struct Tape(RefCell>); + +impl Tape { + pub(crate) const fn new() -> Self { + Self(RefCell::new(Vec::new())) + } + + pub(crate) fn reset(&self) { + self.0.borrow_mut().clear(); + } + + pub(crate) fn functions(&self) -> Ref> { + self.0.borrow() + } + + pub(crate) fn len(&self) -> usize { + self.0.borrow().len() + } + + pub(crate) fn push_function(&self, function: Function) -> usize { + let index = self.functions().len(); + self.0.borrow_mut().push(function); + index + } +} diff --git a/src/gradient.rs b/src/gradient.rs new file mode 100644 index 0000000..e423bc0 --- /dev/null +++ b/src/gradient.rs @@ -0,0 +1,63 @@ +use arrayfire::{constant, Array}; + +use crate::context::function::Function; +use crate::tensor::{Origin, Tensor}; + +/// Stores the gradients for a given tensor +pub struct Gradients(Vec>); + +impl Gradients { + /// Given a root tensor, computes all the derivatives with respect to each of the variables it depends on + /// by performing reverse auto-differentiation on its computation graph + #[must_use] + #[inline] + pub fn compute( + z: &Tensor, + ) -> Self { + match *z.origin() { + Origin::Function(x_fid) => { + let mut gradients = vec![constant!(0.0; 1, 1, 1, 1); z.context().tape_len()]; + gradients[x_fid] = constant!(1.0; R, C, N, B); + + for (i, function) in z.context().functions().iter().enumerate().rev() { + match *function { + Function::Nary => {} + Function::Unary(ref f) => { + let (fid, partial) = f.backward(&gradients[i]); + gradients[fid] = &gradients[fid] + partial; + } + Function::Binary(ref f) => { + let (fid_a, partial_a, fid_b, partial_b) = f.backward(&gradients[i]); + if let Some(fa) = fid_a { + gradients[fa] = &gradients[fa] + partial_a; + } + if let Some(fb) = fid_b { + gradients[fb] = &gradients[fb] + partial_b; + } + } + } + } + Self(gradients) + } + Origin::None => Self(vec![]), + } + } + + /// Returns the gradient of the root tensor with respect to another tensor in the computation graph + #[must_use] + #[inline] + pub fn wrt<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + &self, + x: &'ctx Tensor<'ctx, B, L, R, C>, + ) -> Tensor<'ctx, B, L, R, C> { + let value: Array = match usize::try_from(x) { + Ok(function) => self + .0 + .get(function) + .map_or_else(|| constant!(0.0; R, C, L, B), std::clone::Clone::clone), + _ => constant!(0.0; R, C, L, B), + }; + + Tensor::new(value, Origin::None, x.context()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0806086 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,70 @@ +//! # Mushin +//! +//! Mushin is a library for computing gradients on computational graphs using +//! reverse automatic differentiation. In other words, what Tensorflow is to +//! Python is what Mushin is to Rust. +//! +//! All the operations on tensors use the excellent [arrayfire](https://arrayfire.com/) +//! library as a backend. Which means **Mushin can perform computations on any device** +//! (Nvidia CUDA GPUs, `OpenCL`, Intel MKL... ). Plus, all operations are checked at +//! compile time for mathematical correctness. I.e. You won't be able to add two tensors +//! of different shape/dimensions. The shape of the resulting tensors for all your +//! operations is tracked through the computation graph so in that regard we can offer +//! a guarantee that Tensorflow can't: **If it compiles, your computation graph is +//! guaranteed to be correct** +//! +//! ## Usage +//! +//! All computational graphs start with a new context: +//! ```rust +//! # use mushin::Context; +//! let ctx = Context::new(); +//! ``` +//! The context contains the tape recording the computational graph as well as a storage +//! that lives through resets of the computational graph, to store for example tensors +//! whose values we want to keep, like trainable parameters. +//! +//! Once we have our context, we can start declaring tensors and use them in our operations: +//! ```rust +//! # use mushin::{Context, Values, Class, add, matmul}; +//! # let ctx = Context::new(); +//! let x = ctx.tensor::<1, 1, 2, 3>(Values::Eye(3.0), Class::Constant); +//! let w = ctx.tensor::<1, 1, 3, 2>(Values::Normal, Class::Persistent("weights")); +//! let b = ctx.tensor::<1, 1, 3, 3>(Values::Fill(0.0), Class::Persistent("bias")); +//! let z = add(&b, &matmul(&w, &x)); +//! ``` +//! The code above is an example of a perceptron neural network layer, where we have an input (`x`) +//! that we treat as a constant and a set of persistent (trainable) parameters, (`w`,`b`). +//! We then compute the output (`z`) as `WX + b`. Being this a reverse automatic differentation +//! library, we are now of course interested on the gradients of the output with respect to the graph +//! variables, which are obtained as follows: +//! ```rust +//! # use mushin::{Context, Gradients, Values, Class}; +//! # let ctx = Context::new(); +//! # let z = ctx.tensor::<1, 1, 1, 1>(Values::Identity, Class::Constant); +//! # let w = ctx.tensor::<1, 1, 1, 1>(Values::Identity, Class::Constant); +//! # let b = ctx.tensor::<1, 1, 1, 1>(Values::Identity, Class::Constant); +//! let grads = Gradients::compute(&z); +//! let dz_dw = grads.wrt(&w); +//! let dz_db = grads.wrt(&b); +//! ``` + +#![deny( + clippy::all, + clippy::pedantic, + clippy::nursery, + clippy::cargo, + clippy::module_name_repetitions, + clippy::pattern_type_mismatch, + clippy::shadow_unrelated, + clippy::missing_inline_in_public_items +)] +mod context; +mod gradient; +mod ops; +mod tensor; + +pub use context::Context; +pub use gradient::Gradients; +pub use ops::{add, div, matmul, multiply, pow, sin, sub, sum}; +pub use tensor::{Class, Values}; diff --git a/src/ops.rs b/src/ops.rs new file mode 100644 index 0000000..8c4ae71 --- /dev/null +++ b/src/ops.rs @@ -0,0 +1,288 @@ +use arrayfire::{Array, MatProp}; + +use crate::context::function::Function; +use crate::tensor::Tensor; + +/// Computes the element-wise sinus function +#[must_use] +#[inline] +pub fn sin<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, +) -> Tensor<'ctx, B, L, R, C> { + let backward = |df: &Array, a: &Array| df * arrayfire::cos(a); + + Tensor::new( + arrayfire::sin(x.into()), + Function::unary(x, backward), + x.context(), + ) +} + +/// Performs the element-wise addition +#[must_use] +#[inline] +pub fn add<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, + y: &Tensor, +) -> Tensor<'ctx, B, L, R, C> { + let backward = |df: &Array, _: &Array, _: &Array| (df.clone(), df.clone()); + + Tensor::new( + arrayfire::add(&Array::from(x), &Array::from(y), false), + Function::binary(x, y, backward), + x.context(), + ) +} + +/// Performs the element-wise substraction +#[must_use] +#[inline] +pub fn sub<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, + y: &Tensor, +) -> Tensor<'ctx, B, L, R, C> { + let backward = |df: &Array, _: &Array, _: &Array| (df.clone(), -df.clone()); + + Tensor::new( + arrayfire::sub(&Array::from(x), &Array::from(y), false), + Function::binary(x, y, backward), + x.context(), + ) +} + +/// Performs the common matrix multiplication +#[must_use] +#[inline] +pub fn matmul<'ctx, const B: u64, const L: u64, const R: u64, const C: u64, const YC: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, + y: &Tensor, +) -> Tensor<'ctx, B, L, R, YC> { + let backward = |df: &Array, a: &Array, b: &Array| { + ( + arrayfire::matmul(df, b, MatProp::NONE, MatProp::TRANS), + arrayfire::matmul(a, df, MatProp::TRANS, MatProp::NONE), + ) + }; + + Tensor::new( + arrayfire::matmul( + &Array::from(x), + &Array::from(y), + MatProp::NONE, + MatProp::NONE, + ), + Function::binary(x, y, backward), + x.context(), + ) +} + +/// Performs the Hadamard product (element-wise multiplication of two tensors) +#[must_use] +#[inline] +pub fn multiply<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, + y: &Tensor, +) -> Tensor<'ctx, B, L, R, C> { + let backward = |df: &Array, a: &Array, b: &Array| (df * b, df * a); + + Tensor::new( + arrayfire::mul(&Array::from(x), &Array::from(y), false), + Function::binary(x, y, backward), + x.context(), + ) +} + +/// Computes the element-wise power of two tensors +#[must_use] +#[inline] +pub fn pow<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, + y: &Tensor, +) -> Tensor<'ctx, B, L, R, C> { + let backward = |df: &Array, a: &Array, b: &Array| { + ( + df * b * arrayfire::pow(a, &(b - 1.0f32), false), + df * arrayfire::pow(a, b, false) * arrayfire::log(a), + ) + }; + + Tensor::new( + arrayfire::pow(&Array::from(x), &Array::from(y), false), + Function::binary(x, y, backward), + x.context(), + ) +} + +/// Computes the sum of all the elements in the tensor +#[must_use] +#[inline] +pub fn sum<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, +) -> Tensor<'ctx, 1, 1, 1, 1> { + let backward = |df: &Array, _: &Array| df.clone(); + + let (value, _) = arrayfire::sum_all(x.into()); + + Tensor::new( + arrayfire::constant!(value; 1,1,1,1), + Function::unary(x, backward), + x.context(), + ) +} + +/// Computes the element-wise division of two tensors +#[must_use] +#[inline] +pub fn div<'ctx, const B: u64, const L: u64, const R: u64, const C: u64>( + x: &'ctx Tensor<'ctx, B, L, R, C>, + y: &Tensor<'ctx, B, L, R, C>, +) -> Tensor<'ctx, B, L, R, C> { + let backward = |df: &Array, a: &Array, b: &Array| (df / b, -(df * a / b / b)); + + Tensor::new( + arrayfire::div(&Array::from(x), &Array::from(y), false), + Function::binary(x, y, backward), + x.context(), + ) +} + +#[cfg(test)] +mod tests { + use arrayfire::{abs, all_true_all, constant, dim4, le, Array}; + + use crate::{Class, Context, Gradients, Values}; + + use super::*; + + // Helper function to assert that two arryfire Arrays are equal + fn assert_equal(x: &Array, y: &Array) { + assert!(all_true_all(&le(&abs(&(x - y)), &1e-15, false)).0) + } + + #[test] + fn sin_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 2, 3>(Values::Eye(0.5), Class::Variable); + let z = sin(&x); + assert_equal( + &Array::from(&z), + &Array::new( + &[0.479425538604203, 0.0, 0.0, 0.479425538604203, 0.0, 0.0], + dim4!(2, 3, 1, 1), + ), + ); + + let grads = Gradients::compute(&z); + assert_equal( + &grads.wrt(&x).into(), + &Array::new( + &[0.8775825618903728, 1.0, 1.0, 0.8775825618903728, 1.0, 1.0], + dim4!(2, 3, 1, 1), + ), + ); + } + + #[test] + fn add_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 3, 2>(Values::Eye(3.0), Class::Variable); + let y = ctx.tensor::<1, 1, 3, 2>(Values::Fill(2.0), Class::Variable); + let z = add(&x, &y); + assert_equal( + &Array::from(&z), + &Array::new(&[5.0, 2.0, 2.0, 2.0, 5.0, 2.0], dim4!(3, 2, 1, 1)), + ); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(1.0; 3,2,1,1)); + assert_equal(&grads.wrt(&y).into(), &constant!(1.0; 3,2,1,1)); + } + + #[test] + fn sub_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 3, 2>(Values::Eye(3.0), Class::Variable); + let y = ctx.tensor::<1, 1, 3, 2>(Values::Fill(2.0), Class::Variable); + let z = sub(&x, &y); + assert_equal( + &Array::from(&z), + &Array::new(&[1.0, -2.0, -2.0, -2.0, 1.0, -2.0], dim4!(3, 2, 1, 1)), + ); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(1.0; 3,2,1,1)); + assert_equal(&grads.wrt(&y).into(), &constant!(-1.0; 3,2,1,1)); + } + + #[test] + fn matmul_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 3, 2>(Values::Eye(3.0), Class::Variable); + let y = ctx.tensor::<1, 1, 2, 4>(Values::Eye(2.0), Class::Variable); + let z = matmul(&x, &y); + assert_equal( + &Array::from(&z), + &Array::new( + &[6.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + dim4!(3, 4, 1, 1), + ), + ); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(2.0f32; 3, 2, 1, 1)); + assert_equal(&grads.wrt(&y).into(), &constant!(3.0f32; 2, 4, 1, 1)); + } + + #[test] + fn multiply_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 2, 3>(Values::Fill(3.0), Class::Variable); + let y = ctx.tensor::<1, 1, 2, 3>(Values::Fill(2.0), Class::Variable); + let z = multiply(&x, &y); + assert_equal(&Array::from(&z), &constant!(6.0f32; 2, 3, 1, 1)); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(2.0f32; 2, 3, 1, 1)); + assert_equal(&grads.wrt(&y).into(), &constant!(3.0f32; 2, 3, 1, 1)); + } + + #[test] + fn pow_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 3, 2>(Values::Fill(2.0), Class::Variable); + let y = ctx.tensor::<1, 1, 3, 2>(Values::Fill(3.0), Class::Variable); + let z = pow(&x, &y); + assert_equal(&Array::from(&z), &constant!(8.0; 3,2,1,1)); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(12.0; 3,2,1,1)); + assert_equal(&grads.wrt(&y).into(), &constant!(5.5451775; 3,2,1,1)); + } + + #[test] + fn sum_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 3, 2>(Values::Fill(2.0), Class::Variable); + let z = sum(&x); + assert_equal(&Array::from(&z), &Array::new(&[12.0], dim4!(1, 1, 1, 1))); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(1.0; 3,2,1,1)); + } + + #[test] + fn div_forward_backward() { + let ctx = Context::new(); + let x = ctx.tensor::<1, 1, 3, 2>(Values::Fill(2.0), Class::Variable); + let y = ctx.tensor::<1, 1, 3, 2>(Values::Fill(4.0), Class::Variable); + let z = div(&x, &y); + assert_equal( + &Array::from(&z), + &Array::new(&[0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dim4!(3, 2, 1, 1)), + ); + + let grads = Gradients::compute(&z); + assert_equal(&grads.wrt(&x).into(), &constant!(0.25; 3,2,1,1)); + assert_equal(&grads.wrt(&y).into(), &constant!(-0.125; 3,2,1,1)); + } +} diff --git a/src/tensor.rs b/src/tensor.rs new file mode 100644 index 0000000..9af1aac --- /dev/null +++ b/src/tensor.rs @@ -0,0 +1,105 @@ +use arrayfire::Array; + +use crate::Context; + +/// Origin of a tensor. Either it came from a function (a variable) or it is a constant +#[non_exhaustive] +pub enum Origin { + /// The value is a constant and originated from a constant function (not inserted in the tape) + None, + /// The value is a variable and originated from a function with given index in the tape + Function(usize), +} +/// Possible pre-defined values to create a tensor from +#[non_exhaustive] +pub enum Values { + /// The identity tensor + Identity, + /// Values come from a uniform distribution + Uniform, + /// Values come from a normal distribution + Normal, + /// Tensor with all values zero except for the main diagonal + Eye(f32), + /// Tensor with all values set to the given value + Fill(f32), +} + +/// The class of a tensor defines if its value matters for the computation graph +#[non_exhaustive] +#[derive(Clone, Copy)] +pub enum Class { + /// Tensor is a constant so the value is not added to the computation graph (constants don't compute derivatives) + Constant, + /// Tensor is a variable and the value is added to the computation graph, the value does not persist through different builds + Variable, + /// Tensor is a variable and the value is added to the computation graph, that does persist through different builds. + /// The given string is a key to retrieve the value from the persistent storage. + Persistent(&'static str), +} + +/// A mathematical tensor with a reference to its origin in the computation graph +pub struct Tensor<'ctx, const B: u64, const L: u64, const R: u64, const C: u64> { + value: Array, + context: &'ctx Context, + origin: Origin, +} + +impl<'ctx, const B: u64, const L: u64, const R: u64, const C: u64> Tensor<'ctx, B, L, R, C> { + pub(crate) fn new(value: Array, origin: Origin, context: &'ctx Context) -> Self { + Tensor { + value, + context, + origin, + } + } + + pub(crate) const fn context(&self) -> &Context { + self.context + } + + pub(crate) const fn origin(&self) -> &Origin { + &self.origin + } +} + +impl<'tsr, const B: u64, const L: u64, const R: u64, const C: u64> + From<&'tsr Tensor<'_, B, L, R, C>> for &'tsr Array +{ + #[inline] + fn from(t: &'tsr Tensor<'_, B, L, R, C>) -> Self { + &t.value + } +} + +impl From<&Tensor<'_, B, L, R, C>> + for Array +{ + #[inline] + fn from(t: &Tensor<'_, B, L, R, C>) -> Self { + t.value.clone() + } +} + +impl From> + for Array +{ + #[inline] + fn from(t: Tensor<'_, B, L, R, C>) -> Self { + t.value + } +} + +impl TryFrom<&Tensor<'_, B, L, R, C>> + for usize +{ + type Error = (); + + #[inline] + fn try_from(t: &Tensor) -> Result { + match t.origin { + Origin::Function(function) => Ok(function), + Origin::None => Err(()), + } + } +}