Skip to content

Commit

Permalink
Removing usages of tensor aliases (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Dec 20, 2022
1 parent 62e415f commit 78c9d27
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 84 deletions.
11 changes: 5 additions & 6 deletions src/optim/adam/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,14 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::*;
use crate::tensor_ops::*;
use crate::tests::{assert_close, TestDevice};
use crate::{shapes::*, tensor::*, tensor_ops::*};

#[test]
fn test_default_adam_params() {
let dev: TestDevice = Default::default();
let mut opt = Adam::default();
let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([1e-6, 1e-5, 1e-4, 1e-3, 1e-2]);
let expected = [
[0.99999994, 0.999996, 0.9997143, 0.9990244, 0.99900025],
Expand Down Expand Up @@ -201,7 +200,7 @@ mod tests {
eps: 1e-8,
weight_decay: None,
});
let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([1e-4, 1e-3, 1e-2, 1e-1, 1e-0]);
let expected = [
[0.9997143, 0.9990244, 0.99900025, 0.999, 0.999],
Expand Down Expand Up @@ -231,7 +230,7 @@ mod tests {
weight_decay: Some(WeightDecay::L2(1.0)),
..Default::default()
});
let mut t: Tensor1D<5, _> = dev.tensor([-0.5, -0.25, 0.1, 0.6, 1.0]);
let mut t: Tensor<Rank1<5>, f32, _> = dev.tensor([-0.5, -0.25, 0.1, 0.6, 1.0]);
#[rustfmt::skip]
let expected = [
[-0.499, -0.249, 0.099, 0.59900004, 0.999],
Expand Down Expand Up @@ -261,7 +260,7 @@ mod tests {
weight_decay: Some(WeightDecay::Decoupled(1.0)),
..Default::default()
});
let mut t: Tensor1D<5, _> = dev.tensor([-0.5, -0.25, 0.1, 0.6, 1.0]);
let mut t: Tensor<Rank1<5>, f32, _> = dev.tensor([-0.5, -0.25, 0.1, 0.6, 1.0]);
#[rustfmt::skip]
let expected = [
[-0.5005, -0.25075, 0.098900005, 0.5984, 0.998],
Expand Down
5 changes: 2 additions & 3 deletions src/optim/rmsprop/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,13 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::*;
use crate::tensor_ops::*;
use crate::tests::TestDevice;
use crate::{shapes::*, tensor::*, tensor_ops::*};

fn test_matches_expected(cfg: RMSpropConfig<f32>, expected: [[f32; 5]; 5]) {
let dev: TestDevice = Default::default();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let mut opt = RMSprop::new(cfg);
for e in expected.iter() {
let gradients = (t.trace() * rate.clone()).square().sum().backward();
Expand Down
19 changes: 9 additions & 10 deletions src/optim/sgd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,8 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::*;
use crate::tensor_ops::*;
use crate::tests::{assert_close, TestDevice};
use crate::{shapes::*, tensor::*, tensor_ops::*};

#[test]
fn test_perfect_sgd() {
Expand All @@ -205,8 +204,8 @@ mod tests {
weight_decay: None,
});

let mut pred: Tensor1D<5, _> = dev.zeros();
let targ: Tensor1D<5, _> = dev.ones();
let mut pred: Tensor<Rank1<5>, f32, _> = dev.zeros();
let targ: Tensor<Rank1<5>, f32, _> = dev.ones();
for _ in 0..5 {
let loss = (pred.trace() - targ.clone()).abs().mean();
let gradients = loss.backward();
Expand All @@ -221,7 +220,7 @@ mod tests {
let dev: TestDevice = Default::default();
let mut sgd = Sgd::new(Default::default());

let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let expected = [
[0.9998, 0.998, 0.996, 0.98, 0.8],
Expand All @@ -248,7 +247,7 @@ mod tests {
weight_decay: None,
});

let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let expected = [
[0.9998, 0.998, 0.996, 0.98, 0.8],
Expand All @@ -275,7 +274,7 @@ mod tests {
weight_decay: None,
});

let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let expected = [
[0.9997, 0.997, 0.994, 0.97, 0.70000005],
Expand Down Expand Up @@ -308,7 +307,7 @@ mod tests {
weight_decay: Some(WeightDecay::Decoupled(1e-1)),
});

let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let expected = [
[0.9988, 0.997, 0.995, 0.979, 0.799],
Expand Down Expand Up @@ -340,7 +339,7 @@ mod tests {
weight_decay: Some(WeightDecay::Decoupled(1e-1)),
});

let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let expected = [
[0.9988, 0.997, 0.995, 0.979, 0.799],
Expand Down Expand Up @@ -373,7 +372,7 @@ mod tests {
weight_decay: None,
});

let mut t: Tensor1D<5, _> = dev.ones();
let mut t: Tensor<Rank1<5>, f32, _> = dev.ones();
let rate = dev.tensor([0.1, 1.0, 2.0, 10.0, 100.0]);
let expected = [
[0.9988, 0.997, 0.995, 0.979, 0.799],
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/clamp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ mod tests {
#[test]
fn test_clamp() {
let dev: TestDevice = Default::default();
let t: Tensor2D<2, 3, _> = dev.tensor([[-1.0, 0.0, 1.0], [-2.0, 2.0, 1.1]]);
let t = dev.tensor([[-1.0, 0.0, 1.0], [-2.0, 2.0, 1.1]]);
let r = t.trace().clamp(-1.0, 1.0);
assert_eq!(r.array(), [[-1.0, 0.0, 1.0], [-1.0, 1.0, 1.0]]);
let g = r.exp().mean().backward();
Expand Down
4 changes: 2 additions & 2 deletions src/tensor_ops/dropout/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ mod tests {
#[test]
fn test_dropout_all_0d() {
let dev: TestDevice = Default::default();
let t: Tensor0D<_> = dev.tensor(3.0);
let t = dev.tensor(3.0);
let r = t.trace().dropout(1.0);
assert_eq!(r.array(), 0.0);
let g = r.backward();
Expand All @@ -66,7 +66,7 @@ mod tests {
#[test]
fn test_dropout_none_0d() {
let dev: TestDevice = Default::default();
let t: Tensor0D<_> = dev.tensor(3.0);
let t = dev.tensor(3.0);
let r = t.trace().dropout(0.0);
assert_eq!(r.array(), 3.0);
let g = r.backward();
Expand Down
46 changes: 23 additions & 23 deletions src/tensor_ops/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,33 +291,33 @@ mod tests {
let dev: TestDevice = Default::default();

{
let a: Tensor1D<3, _> = dev.zeros();
let b: Tensor2D<3, 2, _> = dev.zeros();
let _: Tensor1D<2, _> = a.matmul(b);
let a: Tensor<Rank1<3>, f32, _> = dev.zeros();
let b: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
let _: Tensor<Rank1<2>, f32, _> = a.matmul(b);
}

{
let a: Tensor2D<5, 3, _> = dev.zeros();
let b: Tensor2D<3, 2, _> = dev.zeros();
let _: Tensor2D<5, 2, _> = a.matmul(b);
let a: Tensor<Rank2<5, 3>, f32, _> = dev.zeros();
let b: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
let _: Tensor<Rank2<5, 2>, f32, _> = a.matmul(b);
}

{
let a: Tensor3D<10, 5, 3, _> = dev.zeros();
let b: Tensor2D<3, 2, _> = dev.zeros();
let _: Tensor3D<10, 5, 2, _> = a.matmul(b);
let a: Tensor<Rank3<10, 5, 3>, f32, _> = dev.zeros();
let b: Tensor<Rank2<3, 2>, f32, _> = dev.zeros();
let _: Tensor<Rank3<10, 5, 2>, f32, _> = a.matmul(b);
}

{
let a: Tensor3D<10, 5, 3, _> = dev.zeros();
let b: Tensor3D<10, 3, 2, _> = dev.zeros();
let _: Tensor3D<10, 5, 2, _> = a.matmul(b);
let a: Tensor<Rank3<10, 5, 3>, f32, _> = dev.zeros();
let b: Tensor<Rank3<10, 3, 2>, f32, _> = dev.zeros();
let _: Tensor<Rank3<10, 5, 2>, f32, _> = a.matmul(b);
}

{
let a: Tensor4D<10, 20, 5, 3, _> = dev.zeros();
let b: Tensor4D<10, 20, 3, 2, _> = dev.zeros();
let _: Tensor4D<10, 20, 5, 2, _> = a.matmul(b);
let a: Tensor<Rank4<10, 20, 5, 3>, f32, _> = dev.zeros();
let b: Tensor<Rank4<10, 20, 3, 2>, f32, _> = dev.zeros();
let _: Tensor<Rank4<10, 20, 5, 2>, f32, _> = a.matmul(b);
}
}

Expand Down Expand Up @@ -365,8 +365,8 @@ mod tests {
#[test]
fn test_matmul_transpose() {
let dev: TestDevice = Default::default();
let a: Tensor2D<4, 3, _> = dev.sample_normal();
let b: Tensor2D<3, 2, _> = dev.sample_normal();
let a: Tensor<Rank2<4, 3>, f32, _> = dev.sample_normal();
let b: Tensor<Rank2<3, 2>, f32, _> = dev.sample_normal();

let c = a.trace().matmul(b.clone());
let g1 = c.exp().mean().backward();
Expand All @@ -382,9 +382,9 @@ mod tests {
fn test_matul_broadcast() {
const N: usize = 5;
let dev: TestDevice = Default::default();
let a: Tensor3D<N, 4, 3, _> = dev.sample_normal();
let a: Tensor<Rank3<N, 4, 3>, f32, _> = dev.sample_normal();
let a_array = a.array();
let b: Tensor2D<3, 2, _> = dev.sample_normal();
let b: Tensor<Rank2<3, 2>, f32, _> = dev.sample_normal();
let r = a.trace().matmul(b.clone());
let r_array = r.array();
for i in 0..N {
Expand Down Expand Up @@ -429,9 +429,9 @@ mod tests {
fn test_matmul_batched_3d() {
let dev: TestDevice = Default::default();

let a: Tensor3D<5, 3, 2, _> = dev.sample_normal();
let a: Tensor<Rank3<5, 3, 2>, f32, _> = dev.sample_normal();
let a_array = a.array();
let b: Tensor3D<5, 2, 4, _> = dev.sample_normal();
let b: Tensor<Rank3<5, 2, 4>, f32, _> = dev.sample_normal();
let b_array = b.array();
let c = a.trace().matmul(b.clone());
let c_array = c.array();
Expand All @@ -455,9 +455,9 @@ mod tests {
fn test_matmul_batched_4d() {
let dev: TestDevice = Default::default();

let a: Tensor4D<7, 5, 3, 2, _> = dev.sample_normal();
let a: Tensor<Rank4<7, 5, 3, 2>, f32, _> = dev.sample_normal();
let a_array = a.array();
let b: Tensor4D<7, 5, 2, 4, _> = dev.sample_normal();
let b: Tensor<Rank4<7, 5, 2, 4>, f32, _> = dev.sample_normal();
let b_array = b.array();
let c = a.trace().matmul(b.clone());
let c_array = c.array();
Expand Down
6 changes: 2 additions & 4 deletions src/tensor_ops/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ impl<S: Shape, D: Device<f32>, T: Tape<D>> Tensor<S, f32, D, T> {

#[cfg(test)]
mod tests {
use crate::shapes::Axis;
use crate::tensor::*;
use crate::tensor_ops::*;
use crate::tests::{assert_close, TestDevice};
use crate::{shapes::*, tensor::*, tensor_ops::*};

#[test]
fn test_1d_normalize_axis_last() {
Expand Down Expand Up @@ -116,7 +114,7 @@ mod tests {
#[test]
fn test_3d_normalize_axis_last() {
let dev: TestDevice = Default::default();
let a: Tensor3D<4, 2, 3, _> = dev.ones();
let a: Tensor<Rank3<4, 2, 3>, f32, _> = dev.ones();
let r = a.trace().normalize::<Axis<2>>(1e-5);
assert_eq!(r.array(), [[[0.0; 3]; 2]; 4]);
let g = r.exp().mean().backward();
Expand Down
19 changes: 9 additions & 10 deletions src/tensor_ops/pool2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,8 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::*;
use crate::tensor_ops::*;
use crate::tests::{assert_close, TestDevice};
use crate::{shapes::*, tensor::*, tensor_ops::*};

#[test]
fn test_pool2d_3d_max2d_eq_grads() {
Expand All @@ -262,7 +261,7 @@ mod tests {
#[test]
fn test_pool2d_3d_max2d() {
let dev = TestDevice::seed_from_u64(234);
let x: Tensor3D<2, 3, 4, _> = dev.sample_normal();
let x: Tensor<Rank3<2, 3, 4>, f32, _> = dev.sample_normal();
let r = x.trace().max_pool2d::<2, 2, 0>();
assert_close(
&r.array(),
Expand All @@ -282,7 +281,7 @@ mod tests {
#[test]
fn test_pool2d_3d_min2d() {
let dev = TestDevice::seed_from_u64(234);
let x: Tensor3D<2, 3, 4, _> = dev.sample_normal();
let x: Tensor<Rank3<2, 3, 4>, f32, _> = dev.sample_normal();
let r = x.trace().min_pool2d::<2, 2, 0>();
assert_close(
&r.array(),
Expand All @@ -302,12 +301,12 @@ mod tests {
#[test]
fn test_pool2d_3d_avg2d() {
let dev = TestDevice::seed_from_u64(234);
let x: Tensor3D<2, 3, 4, _> = dev.sample_normal();
let x: Tensor<Rank3<2, 3, 4>, f32, _> = dev.sample_normal();
let r = x.trace().avg_pool2d::<2, 2, 0>();
// assert_close(
// &r.array(),
// &[[[0.03031558, -0.25052455]], [[0.39499030, 0.04878314]]],
// );
assert_close(
&r.array(),
&[[[0.03031558, -0.25052455]], [[0.39499030, 0.04878314]]],
);
let g = r.exp().mean().backward();
#[rustfmt::skip]
assert_close(
Expand All @@ -322,7 +321,7 @@ mod tests {
#[test]
fn test_pool2d_4d_avg2d() {
let dev = TestDevice::seed_from_u64(234);
let x: Tensor4D<2, 4, 2, 2, _> = dev.sample_normal();
let x: Tensor<Rank4<2, 4, 2, 2>, f32, _> = dev.sample_normal();
let r = x.trace().avg_pool2d::<1, 2, 0>();
assert_close(
&r.array(),
Expand Down
50 changes: 25 additions & 25 deletions src/tensor_ops/reshape_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,35 @@ mod tests {
fn test_valid_reshapes() {
let dev: TestDevice = Default::default();

let t: Tensor0D<_> = dev.zeros();
let _: Tensor1D<1, _> = t.clone().reshape();
let _: Tensor2D<1, 1, _> = t.clone().reshape();
let _: Tensor3D<1, 1, 1, _> = t.clone().reshape();
let _: Tensor4D<1, 1, 1, 1, _> = t.clone().reshape();
let t: Tensor<Rank0, f32, _> = dev.zeros();
let _: Tensor<Rank1<1>, f32, _> = t.clone().reshape();
let _: Tensor<Rank2<1, 1>, f32, _> = t.clone().reshape();
let _: Tensor<Rank3<1, 1, 1>, f32, _> = t.clone().reshape();
let _: Tensor<Rank4<1, 1, 1, 1>, f32, _> = t.clone().reshape();

let t: Tensor1D<16, _> = dev.zeros();
let _: Tensor1D<16, _> = t.clone().reshape();
let _: Tensor2D<2, 8, _> = t.clone().reshape();
let _: Tensor3D<2, 2, 4, _> = t.clone().reshape();
let _: Tensor4D<2, 2, 2, 2, _> = t.clone().reshape();
let t: Tensor<Rank1<16>, f32, _> = dev.zeros();
let _: Tensor<Rank1<16>, f32, _> = t.clone().reshape();
let _: Tensor<Rank2<2, 8>, f32, _> = t.clone().reshape();
let _: Tensor<Rank3<2, 2, 4>, f32, _> = t.clone().reshape();
let _: Tensor<Rank4<2, 2, 2, 2>, f32, _> = t.clone().reshape();

let t: Tensor2D<2, 8, _> = dev.zeros();
let _: Tensor1D<16, _> = t.clone().reshape();
let _: Tensor2D<8, 2, _> = t.clone().reshape();
let _: Tensor3D<2, 2, 4, _> = t.clone().reshape();
let _: Tensor4D<2, 2, 2, 2, _> = t.clone().reshape();
let t: Tensor<Rank2<2, 8>, f32, _> = dev.zeros();
let _: Tensor<Rank1<16>, f32, _> = t.clone().reshape();
let _: Tensor<Rank2<8, 2>, f32, _> = t.clone().reshape();
let _: Tensor<Rank3<2, 2, 4>, f32, _> = t.clone().reshape();
let _: Tensor<Rank4<2, 2, 2, 2>, f32, _> = t.clone().reshape();

let t: Tensor3D<2, 2, 4, _> = dev.zeros();
let _: Tensor1D<16, _> = t.clone().reshape();
let _: Tensor2D<2, 8, _> = t.clone().reshape();
let _: Tensor3D<4, 2, 2, _> = t.clone().reshape();
let _: Tensor4D<2, 2, 2, 2, _> = t.clone().reshape();
let t: Tensor<Rank3<2, 2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank1<16>, f32, _> = t.clone().reshape();
let _: Tensor<Rank2<2, 8>, f32, _> = t.clone().reshape();
let _: Tensor<Rank3<4, 2, 2>, f32, _> = t.clone().reshape();
let _: Tensor<Rank4<2, 2, 2, 2>, f32, _> = t.clone().reshape();

let t: Tensor4D<2, 2, 2, 2, _> = dev.zeros();
let _: Tensor1D<16, _> = t.clone().reshape();
let _: Tensor2D<2, 8, _> = t.clone().reshape();
let _: Tensor3D<2, 2, 4, _> = t.clone().reshape();
let _: Tensor4D<4, 1, 2, 2, _> = t.clone().reshape();
let t: Tensor<Rank4<2, 2, 2, 2>, f32, _> = dev.zeros();
let _: Tensor<Rank1<16>, f32, _> = t.clone().reshape();
let _: Tensor<Rank2<2, 8>, f32, _> = t.clone().reshape();
let _: Tensor<Rank3<2, 2, 4>, f32, _> = t.clone().reshape();
let _: Tensor<Rank4<4, 1, 2, 2>, f32, _> = t.clone().reshape();
}

#[test]
Expand Down

0 comments on commit 78c9d27

Please sign in to comment.