Skip to content

Commit

Permalink
Fixing remainder of cuda tests & fixing cblas/cublas matmul with stri…
Browse files Browse the repository at this point in the history
…des [1,1] (#393)

* Adding failing tests for cuda

* Fixing all tests for cuda

* Fixing bug with cblas sgemm
  • Loading branch information
coreylowman authored Jan 24, 2023
1 parent f5c012d commit 726338b
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 90 deletions.
11 changes: 5 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,13 @@ pub(crate) mod tests {
}
}

impl<const M: usize> AssertClose for [f32; M] {
impl AssertClose for f32 {
fn get_far_pair(&self, rhs: &Self, tolerance: f32) -> Option<(f32, f32)> {
for (l, r) in self.iter().zip(rhs.iter()) {
if (l - r).abs() > tolerance {
return Some((*l, *r));
}
if (self - rhs).abs() > tolerance {
Some((*self, *rhs))
} else {
None
}
None
}
}

Expand Down
49 changes: 24 additions & 25 deletions src/losses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ pub fn binary_cross_entropy_with_logits_loss<S: Shape, D: Device<f32>, T: Tape<D
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::*;
use crate::tests::{assert_close, TestDevice};
use crate::{tensor::*, tests::*};

#[test]
fn test_mse() {
Expand All @@ -195,7 +194,7 @@ mod tests {
let x = dev.tensor([0.87248087, -0.24252531, -1.0060949, 1.155084, 1.5545048]);
let y = dev.tensor([-0.90954804, -1.0193186, -0.39221755, 2.2524886, 1.3035554]);
let loss = mae_loss(x.trace(), y);
assert_eq!(loss.array(), 0.9042107);
assert_close(&loss.array(), &0.9042107);
let g = loss.backward();
assert_eq!(g.get(&x).array(), [0.2, 0.2, -0.2, -0.2, 0.2]);
}
Expand All @@ -212,7 +211,7 @@ mod tests {
[0.15627657, 0.29779273, 0.10897867, 0.2879545, 0.14899758],
]);
let loss = cross_entropy_with_logits_loss(x.trace(), y.clone());
assert_eq!(loss.array(), 1.9889611);
assert_close(&loss.array(), &1.9889611);
let g = loss.backward();
assert_close(
&g.get(&x).array(),
Expand Down Expand Up @@ -290,26 +289,26 @@ mod tests {
[0.7026833, 0.5563793, 0.6429267],
]);
let loss = binary_cross_entropy_with_logits_loss(logit.trace(), prob.clone());
assert_eq!(loss.array(), 0.7045728);
assert_close(&loss.array(), &0.7045728);

let g = loss.backward();

assert_eq!(
g.get(&logit).array(),
[
assert_close(
&g.get(&logit).array(),
&[
[0.003761424, -0.054871976, 0.025817735],
[-0.0009343492, 0.0051718787, 0.0074731046],
[-0.047248676, -0.03401173, 0.0071035423]
]
[-0.047248676, -0.03401173, 0.0071035423],
],
);

assert_eq!(
g.get(&prob).array(),
[
assert_close(
&g.get(&prob).array(),
&[
[0.04546672, 0.07451131, -0.10224107],
[0.18426175, -0.18865204, 0.16475087],
[0.10635218, 0.12190584, -0.097797275]
]
[0.10635218, 0.12190584, -0.097797275],
],
);
}

Expand All @@ -324,22 +323,22 @@ mod tests {

let g = loss.backward();

assert_eq!(
g.get(&logit).array(),
[
assert_close(
&g.get(&logit).array(),
&[
[0.11111111, 0.055555556, 0.0],
[0.0, -0.055555556, -0.11111111],
[0.029882379, 0.0, -0.02988238]
]
[0.029882379, 0.0, -0.02988238],
],
);

assert_eq!(
g.get(&targ).array(),
[
assert_close(
&g.get(&targ).array(),
&[
[-11.111112, -11.111112, -11.111112],
[11.111112, 11.111112, 11.111112],
[0.11111111, 0.0, -0.11111111]
]
[0.11111111, 0.0, -0.11111111],
],
);
}

Expand Down
6 changes: 3 additions & 3 deletions src/tensor_ops/log_softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ mod tests {
let dev: TestDevice = Default::default();
let a = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r = a.trace().log_softmax();
assert_eq!(
r.array(),
[-4.4519143, -3.4519143, -2.4519143, -1.4519143, -0.4519143]
assert_close(
&r.array(),
&[-4.4519143, -3.4519143, -2.4519143, -1.4519143, -0.4519143],
);
let g = r.mean().backward();
assert_close(
Expand Down
2 changes: 1 addition & 1 deletion src/tensor_ops/logsumexp_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ mod tests {
let dev: TestDevice = Default::default();
let a = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r = a.trace().logsumexp();
assert_eq!(r.array(), 2.4519143);
assert_close(&r.array(), &2.4519143);
let g = r.backward();
assert_close(
&g.get(&a).array(),
Expand Down
34 changes: 6 additions & 28 deletions src/tensor_ops/matmul/cpu_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,18 @@ pub(crate) fn matmul<M: Dim, K: Dim, N: Dim>(

#[cfg(feature = "cblas")]
unsafe {
let (lda, a_tr) = super::matrix_strides((m, k), a.strides);
let (ldb, b_tr) = super::matrix_strides((k, n), b.strides);
let (ldc, c_tr) = super::matrix_strides((m, n), c.strides);
let (m, n, k) = (m as libc::c_int, n as libc::c_int, k as libc::c_int);

let (lda, a_tr) = match a.strides {
[1, 0] => (m as i32, true),
[0, 1] => (k as i32, false),
[ld, 1] => (ld as i32, false),
[1, ld] => (ld as i32, true),
_ => panic!("At least one of a's strides must be 1 for cblas"),
};

let (ldb, b_tr) = match b.strides {
[1, 0] => (k as i32, true),
[0, 1] => (n as i32, false),
[ld, 1] => (ld as i32, false),
[1, ld] => (ld as i32, true),
_ => panic!("At least one of b's strides must be 1 for cblas"),
};

let (ldc, c_trans) = match c.strides {
[1, 0] => (m as i32, true),
[0, 1] => (n as i32, false),
[ld, 1] => (ld as i32, false),
[1, ld] => (ld as i32, true),
_ => panic!("At least one of c's strides must be 1 for cblas"),
};

let layout = if c_trans { ColMajor } else { RowMajor };
let (a_tr, b_tr) = if c_trans {
let layout = if c_tr { ColMajor } else { RowMajor };
let (a_tr, b_tr) = if c_tr {
(if a_tr { NoTr } else { Tr }, if b_tr { NoTr } else { Tr })
} else {
(if a_tr { Tr } else { NoTr }, if b_tr { Tr } else { NoTr })
};
sgemm(
layout, a_tr, b_tr, m, n, k, 1.0, ap, lda, bp, ldb, 1.0, cp, ldc,
layout, a_tr, b_tr, m, n, k, 1.0, ap, lda as i32, bp, ldb as i32, 1.0, cp, ldc as i32,
)
}
}
Expand Down
26 changes: 3 additions & 23 deletions src/tensor_ops/matmul/cuda_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,9 @@ fn sgemm_config<M: Dim, K: Dim, N: Dim>(
beta: f32,
out_strides: [usize; 2],
) -> (GemmConfig<f32>, bool) {
let (lhs_stride, lhs_trans) = match lhs_strides {
[1, 0] => (m.size(), true),
[0, 1] => (k.size(), false),
[ld, 1] => (ld, false),
[1, ld] => (ld, true),
_ => panic!("At least one of a's strides must be 1 for cublas"),
};

let (rhs_stride, rhs_trans) = match rhs_strides {
[1, 0] => (k.size(), true),
[0, 1] => (n.size(), false),
[ld, 1] => (ld, false),
[1, ld] => (ld, true),
_ => panic!("At least one of b's strides must be 1 for cublas"),
};

let (out_stride, out_trans) = match out_strides {
[1, 0] => (m.size(), true),
[0, 1] => (n.size(), false),
[ld, 1] => (ld, false),
[1, ld] => (ld, true),
_ => panic!("At least one of c's strides must be 1 for cublas"),
};
let (lhs_stride, lhs_trans) = super::matrix_strides((m.size(), k.size()), lhs_strides);
let (rhs_stride, rhs_trans) = super::matrix_strides((k.size(), n.size()), rhs_strides);
let (out_stride, out_trans) = super::matrix_strides((m.size(), n.size()), out_strides);

if !out_trans {
// out is stored in row major format
Expand Down
120 changes: 120 additions & 0 deletions src/tensor_ops/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,20 @@ where
}
}

/// Utility function returning the ld and whether the matrix is transposed
/// for cublas & cblas.
#[allow(unused)]
pub(super) fn matrix_strides((m, n): (usize, usize), strides: [usize; 2]) -> (usize, bool) {
match strides {
[1, 0] => (m, true),
[0, 1] => (n, false),
[1, 1] => (n, false),
[ld, 1] => (ld, false),
[1, ld] => (ld, true),
_ => panic!("At least a single stride must be 1 for cublas"),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -556,4 +570,110 @@ mod tests {
);
assert_close(&g.get(&b).array(), &[-0.13630435, -1.6781758, -0.75171506]);
}

#[test]
fn test_small_matmul_vv() {
let dev: TestDevice = Default::default();
let a = dev.tensor([0.5f32]);
let b = dev.tensor([2.0f32]);
let c = a.trace().matmul(b.clone());
assert_eq!(c.array(), [[1.0]]);
let g = c.exp().sum().backward();
assert_eq!(g.get(&a).array(), [5.4365635]);
assert_eq!(g.get(&b).array(), [1.3591409]);
}

#[test]
fn test_small_matmul_vm() {
let dev: TestDevice = Default::default();

// 1 * 1x1
let a = dev.tensor([0.5f32]);
let b = dev.tensor([[2.0f32]]);
let c = a.trace().matmul(b.clone());
assert_eq!(c.array(), [1.0]);
let g = c.exp().sum().backward();
assert_eq!(g.get(&a).array(), [5.4365635]);
assert_eq!(g.get(&b).array(), [[1.3591409]]);

// 1 * 1x1 (permuted)
let c = a.trace().matmul(b.trace().permute());
assert_eq!(c.array(), [1.0]);
let g = c.exp().sum().backward();
assert_eq!(g.get(&a).array(), [5.4365635]);
assert_eq!(g.get(&b).array(), [[1.3591409]]);

// 1 * 1x2
let a = dev.tensor([0.5f32]);
let b = dev.tensor([[2.0f32, 4.0]]);
let c = a.trace().matmul(b.clone());
assert_eq!(c.array(), [1.0, 2.0]);
let g = c.exp().sum().backward();
assert_eq!(g.get(&a).array(), [34.99279]);
assert_eq!(g.get(&b).array(), [[1.3591409, 3.694528]]);

// 1 * 1x2 (permuted)
let a = dev.tensor([0.5f32]);
let b = dev.tensor([[2.0f32], [4.0]]);
let c = a.trace().matmul(b.trace().permute());
assert_eq!(c.array(), [1.0, 2.0]);
let g = c.exp().sum().backward();
assert_eq!(g.get(&a).array(), [34.99279]);
assert_eq!(g.get(&b).array(), [[1.3591409], [3.694528]]);
}

#[test]
fn test_small_matmul_mm() {
let dev: TestDevice = Default::default();

{
// 1x1 * 1x1
let a = dev.tensor([[0.5f32]]);
let b = dev.tensor([[2.0f32]]);
let c = a.trace().matmul(b.clone());
assert_eq!(c.array(), [[1.0]]);
let g = c.exp().sum().backward();
assert_eq!(g.get(&a).array(), [[5.4365635]]);
assert_eq!(g.get(&b).array(), [[1.3591409]]);
}

{
// 1x2 * 2x1
let a = dev.tensor([[0.5f32, 0.1]]);
let b = dev.tensor([[2.0f32], [4.0]]);
let c = a.trace().matmul(b.clone());
assert_eq!(c.array(), [[1.4]]);
let g = c.exp().sum().backward();
g.get(&a).array().assert_close(&[[8.1104, 16.2208]], 1e-5);
g.get(&b)
.array()
.assert_close(&[[2.0276], [0.40552002]], 1e-5);
}

{
// 1x2 (permuted) * 2x1
let a = dev.tensor([[0.5f32], [0.1]]);
let b = dev.tensor([[2.0f32], [4.0]]);
let c = a.trace().permute().matmul(b.clone());
assert_eq!(c.array(), [[1.4]]);
let g = c.exp().sum().backward();
g.get(&a).array().assert_close(&[[8.1104], [16.2208]], 1e-5);
g.get(&b)
.array()
.assert_close(&[[2.0276], [0.40552002]], 1e-5);
}

{
// 1x2 * 2x1 (permuted)
let a = dev.tensor([[0.5f32, 0.1]]);
let b = dev.tensor([[2.0f32, 4.0]]);
let c = a.trace().matmul(b.trace().permute());
assert_eq!(c.array(), [[1.4]]);
let g = c.exp().sum().backward();
g.get(&a).array().assert_close(&[[8.1104, 16.2208]], 1e-5);
g.get(&b)
.array()
.assert_close(&[[2.0276, 0.40552002]], 1e-5);
}
}
}
8 changes: 4 additions & 4 deletions src/tensor_ops/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ mod tests {
let dev: TestDevice = Default::default();
let a = dev.tensor([[-2.0, 0.0], [1.0, 2.0], [4.0, 5.0]]);
let r = a.trace().normalize::<Axis<0>>(1e-5);
assert_eq!(
r.array(),
[
assert_close(
&r.array(),
&[
[-1.2247438, -1.1355485],
[0.0, -0.16222118],
[1.2247438, 1.2977698],
]
],
);
let g = r.exp().mean().backward();
assert_close(
Expand Down

0 comments on commit 726338b

Please sign in to comment.