Skip to content

Commit

Permalink
Adding examples of numpy and tensor usage
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman committed Jul 14, 2022
1 parent 3b23b3a commit 770e2ea
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
19 changes: 19 additions & 0 deletions examples/npy_serialize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use dfdx::numpy as np;

fn main() {
np::save("0d-rs.npy", &1.234).expect("Saving failed");
np::save("1d-rs.npy", &[1.0, 2.0, 3.0]).expect("Saving failed");
np::save("2d-rs.npy", &[[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]).expect("Saving failed");

let mut expected_0d = 0.0;
np::load("0d-rs.npy", &mut expected_0d).expect("Loading failed");
assert_eq!(expected_0d, 1.234);

let mut expected_1d = [0.0; 3];
np::load("1d-rs.npy", &mut expected_1d).expect("Loading failed");
assert_eq!(expected_1d, [1.0, 2.0, 3.0]);

let mut expected_2d = [[0.0; 3]; 2];
np::load("2d-rs.npy", &mut expected_2d).expect("Loading failed");
assert_eq!(expected_2d, [[1.0, 2.0, 3.0], [-1.0, -2.0, -3.0]]);
}
22 changes: 22 additions & 0 deletions examples/tensors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use dfdx::prelude::*;

fn main() {
let a: Tensor2D<2, 3> = TensorCreator::zeros();

// since add() expects tensors with the same size, we dont need a type for this
let b = TensorCreator::ones();
let c = add(a, &b);

// tensors just store raw rust arrays, use `.data()` to access this.
assert_eq!(c.data(), &[[1.0; 3]; 2]);

// since we pass in an array, rust will figure out that we mean Tensor1D<5> since its an [f32; 5]
let mut d = Tensor1D::new([1.0, 2.0, 3.0, 4.0, 5.0]);

// use `.mut_data()` to access underlying mutable array. type is provided for readability
let raw_data: &mut [f32; 5] = d.mut_data();
for i in 0..5 {
raw_data[i] *= 2.0;
}
assert_eq!(d.data(), &[2.0, 4.0, 6.0, 8.0, 10.0]);
}

0 comments on commit 770e2ea

Please sign in to comment.