diff --git a/src/devices/conv.rs b/src/devices/conv.rs new file mode 100644 index 000000000..f05d91bc5 --- /dev/null +++ b/src/devices/conv.rs @@ -0,0 +1,179 @@ +use super::Cpu; + +/// **Requires nightly** 2d convolution with stride and padding specified at trait level. +/// +/// This allows the rest of the parameters to be inferred by inputs. +pub trait DeviceConv2D { + /// Forward operation that modifies the `out` image. + fn conv_forward< + const C: usize, + const O: usize, + const K: usize, + const H: usize, + const W: usize, + >( + img: &[[[f32; W]; H]; C], + weight: &[[[[f32; K]; K]; C]; O], + bias: &[f32; O], + out: &mut [[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O], + ); + + /// Backward operation that modifies the gradients of img, weight, and bias. + fn conv_backward< + const C: usize, + const O: usize, + const K: usize, + const H: usize, + const W: usize, + >( + img: &[[[f32; W]; H]; C], + weight: &[[[[f32; K]; K]; C]; O], + out_g: &[[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O], + img_g: &mut [[[f32; W]; H]; C], + weight_g: &mut [[[[f32; K]; K]; C]; O], + bias_g: &mut [f32; O], + ); +} + +impl DeviceConv2D for Cpu { + fn conv_forward< + const C: usize, + const O: usize, + const K: usize, + const H: usize, + const W: usize, + >( + img: &[[[f32; W]; H]; C], + weight: &[[[[f32; K]; K]; C]; O], + bias: &[f32; O], + out: &mut [[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O], + ) { + let out_height = (H + 2 * P - K) / S + 1; + let out_width = (W + 2 * P - K) / S + 1; + for c in 0..C { + for oc in 0..O { + for oh in 0..out_height { + for ow in 0..out_width { + let o = &mut out[oc][oh][ow]; + for k1 in 0..K { + let y = (oh * S + k1).checked_sub(P); + for k2 in 0..K { + let x = (ow * S + k2).checked_sub(P); + if let Some((y, x)) = y.zip(x) { + if y < H && x < W { + *o += weight[oc][c][k1][k2] * img[c][y][x]; + } + } + } + } + } + } + } + } + for oc in 0..O { + for oh in 0..out_height { + for ow in 0..out_width { + out[oc][oh][ow] += bias[oc]; + } + } + } + } + + fn conv_backward< + const C: usize, + const O: usize, + const K: usize, + const H: usize, + const W: usize, + >( + img: &[[[f32; W]; H]; C], + weight: &[[[[f32; K]; K]; C]; O], + out_g: &[[[f32; (W + 2 * P - K) / S + 1]; (H + 2 * P - K) / S + 1]; O], + img_g: &mut [[[f32; W]; H]; C], + weight_g: &mut [[[[f32; K]; K]; C]; O], + bias_g: &mut [f32; O], + ) { + let out_height = (H + 2 * P - K) / S + 1; + let out_width = (W + 2 * P - K) / S + 1; + + for oc in 0..O { + for oh in 0..out_height { + for ow in 0..out_width { + bias_g[oc] += out_g[oc][oh][ow]; + } + } + } + + for c in 0..C { + for oh in 0..out_height { + for ow in 0..out_width { + for oc in 0..O { + let o_g = &out_g[oc][oh][ow]; + for k1 in 0..K { + let y = (oh * S + k1).wrapping_sub(P); + if y < H { + for k2 in 0..K { + let x = (ow * S + k2).wrapping_sub(P); + if x < W { + weight_g[oc][c][k1][k2] += img[c][y][x] * o_g; + img_g[c][y][x] += weight[oc][c][k1][k2] * o_g; + } + } + } + } + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::devices::{AllocateZeros, FillElements}; + use crate::tests::assert_close; + use rand::prelude::*; + use rand_distr::StandardNormal; + + #[test] + fn test_conv2d_s4p3k2() { + let mut rng = StdRng::seed_from_u64(432); + let mut randn = |x: &mut f32| *x = rng.sample(StandardNormal); + + let weight: Box<[[[[f32; 2]; 2]; 5]; 3]> = Cpu::filled(&mut randn); + let bias: Box<[f32; 3]> = Cpu::filled(&mut randn); + let x: Box<[[[f32; 6]; 7]; 5]> = Cpu::filled(&mut randn); + + let mut out = [[[0.0; 3]; 3]; 3]; + >::conv_forward( + x.as_ref(), + weight.as_ref(), + bias.as_ref(), + &mut out, + ); + + #[rustfmt::skip] + assert_close(&out, &[ + [[-0.57176435, -0.57176435, -0.57176435],[-0.57176435, 1.0759051, 1.4307989],[-0.57176435, -0.86296344, -1.8794353]], + [[0.29306656, 0.29306656, 0.29306656],[0.29306656, 0.9771965, 1.467767],[0.29306656, -6.367015, -2.3370528]], + [[-0.19717735, -0.19717735, -0.19717735],[-0.19717735, 1.3412137, 2.9476144],[-0.19717735, 4.247249, -2.1779637]], + ]); + + let mut wg: Box<[[[[f32; 2]; 2]; 5]; 3]> = Cpu::zeros(); + let mut bg: Box<[f32; 3]> = Cpu::zeros(); + let mut xg: Box<[[[f32; 6]; 7]; 5]> = Cpu::zeros(); + >::conv_backward( + &x, + &weight, + &out, + xg.as_mut(), + wg.as_mut(), + bg.as_mut(), + ); + + assert_ne!(wg.as_ref(), &[[[[0.0; 2]; 2]; 5]; 3]); + assert_ne!(bg.as_ref(), &[0.0; 3]); + assert_ne!(xg.as_ref(), &[[[0.0; 6]; 7]; 5]); + } +} diff --git a/src/devices/mod.rs b/src/devices/mod.rs index f4b12be73..01c7b3dba 100644 --- a/src/devices/mod.rs +++ b/src/devices/mod.rs @@ -16,6 +16,11 @@ pub use matmul::*; pub use permute::*; pub use select::*; +#[cfg(feature = "nightly")] +mod conv; +#[cfg(feature = "nightly")] +pub use conv::*; + use std::ops::*; /// The CPU device diff --git a/src/tensor_ops/conv.rs b/src/tensor_ops/conv.rs index 32647d2f4..93c4541d9 100644 --- a/src/tensor_ops/conv.rs +++ b/src/tensor_ops/conv.rs @@ -1,3 +1,4 @@ +use crate::devices::{Cpu, DeviceConv2D}; use crate::gradients::Tape; use crate::prelude::*; @@ -24,38 +25,21 @@ pub fn conv2d< TAPE, > { let mut result = Tensor3D::zeros(); - conv_forward::< - IN_CHAN, - OUT_CHAN, - KERNEL, - STRIDE, - PADDING, - IN_HEIGHT, - IN_WIDTH, - { (IN_HEIGHT + 2 * PADDING - KERNEL) / STRIDE + 1 }, - { (IN_WIDTH + 2 * PADDING - KERNEL) / STRIDE + 1 }, - >(x.data(), filters.data(), bias.data(), result.mut_data()); + >::conv_forward( + x.data(), + filters.data(), + bias.data(), + result.mut_data(), + ); let f = filters.clone(); - let (x, mut tape) = x.split_tape(); - let phantom_filters = filters.phantom(); - let phantom_bias = bias.phantom(); - let phantom_result = result.phantom(); + let phf = filters.phantom(); + let phb = bias.phantom(); + let phr = result.phantom(); tape.add_backward_op(move |grads| { - let (f_grad, b_grad, i_grad, r_grad) = - grads.muts_and_ref(&phantom_filters, &phantom_bias, &x, &phantom_result); - conv_backward::< - IN_CHAN, - OUT_CHAN, - KERNEL, - STRIDE, - PADDING, - IN_HEIGHT, - IN_WIDTH, - { (IN_HEIGHT + 2 * PADDING - KERNEL) / STRIDE + 1 }, - { (IN_WIDTH + 2 * PADDING - KERNEL) / STRIDE + 1 }, - >(x.data(), f.data(), r_grad, i_grad, f_grad, b_grad); + let (fg, bg, ig, rg) = grads.muts_and_ref(&phf, &phb, &x, &phr); + >::conv_backward(x.data(), f.data(), rg, ig, fg, bg); }); result.put_tape(tape) } @@ -85,153 +69,26 @@ pub fn conv2d_batched< TAPE, > { let mut result = Tensor4D::zeros(); - for i in 0..BATCH_SIZE { - conv_forward::< - IN_CHAN, - OUT_CHAN, - KERNEL, - STRIDE, - PADDING, - IN_HEIGHT, - IN_WIDTH, - { (IN_HEIGHT + 2 * PADDING - KERNEL) / STRIDE + 1 }, - { (IN_WIDTH + 2 * PADDING - KERNEL) / STRIDE + 1 }, - >( - &x.data()[i], - filters.data(), - bias.data(), - &mut result.mut_data()[i], - ); + for (x_i, r_i) in x.data().iter().zip(result.mut_data().iter_mut()) { + >::conv_forward(x_i, filters.data(), bias.data(), r_i); } let f = filters.clone(); let (x, mut tape) = x.split_tape(); - let phantom_filters = filters.phantom(); - let phantom_bias = bias.phantom(); - let phantom_result = result.phantom(); + let phf = filters.phantom(); + let phb = bias.phantom(); + let phr = result.phantom(); tape.add_backward_op(move |grads| { - let (f_grad, b_grad, i_grad, r_grad) = - grads.muts_and_ref(&phantom_filters, &phantom_bias, &x, &phantom_result); - - for i in 0..BATCH_SIZE { - conv_backward::< - IN_CHAN, - OUT_CHAN, - KERNEL, - STRIDE, - PADDING, - IN_HEIGHT, - IN_WIDTH, - { (IN_HEIGHT + 2 * PADDING - KERNEL) / STRIDE + 1 }, - { (IN_WIDTH + 2 * PADDING - KERNEL) / STRIDE + 1 }, - >( - &x.data()[i], - f.data(), - &r_grad[i], - &mut i_grad[i], - f_grad, - b_grad, - ); + let (fg, bg, ig, r_grad) = grads.muts_and_ref(&phf, &phb, &x, &phr); + let f = f.data(); + for ((x_i, rg_i), ig_i) in x.data().iter().zip(r_grad.iter()).zip(ig.iter_mut()) { + >::conv_backward(x_i, f, rg_i, ig_i, fg, bg); } }); result.put_tape(tape) } -fn conv_forward< - const C: usize, - const OC: usize, - const K: usize, - const S: usize, - const P: usize, - const H: usize, - const W: usize, - const OH: usize, - const OW: usize, ->( - img: &[[[f32; W]; H]; C], - weight: &[[[[f32; K]; K]; C]; OC], - bias: &[f32; OC], - out: &mut [[[f32; OW]; OH]; OC], -) { - for c in 0..C { - for oc in 0..OC { - for oh in 0..OH { - for ow in 0..OW { - let o = &mut out[oc][oh][ow]; - for k1 in 0..K { - let y = (oh * S + k1).checked_sub(P); - for k2 in 0..K { - let x = (ow * S + k2).checked_sub(P); - if let Some((y, x)) = y.zip(x) { - if y < H && x < W { - *o += weight[oc][c][k1][k2] * img[c][y][x]; - } - } - } - } - } - } - } - } - for oc in 0..OC { - for oh in 0..OH { - for ow in 0..OW { - out[oc][oh][ow] += bias[oc]; - } - } - } -} - -fn conv_backward< - const C: usize, - const OC: usize, - const K: usize, - const S: usize, - const P: usize, - const H: usize, - const W: usize, - const OH: usize, - const OW: usize, ->( - img: &[[[f32; W]; H]; C], - weight: &[[[[f32; K]; K]; C]; OC], - out_g: &[[[f32; OW]; OH]; OC], - img_g: &mut [[[f32; W]; H]; C], - weight_g: &mut [[[[f32; K]; K]; C]; OC], - bias_g: &mut [f32; OC], -) { - for oc in 0..OC { - for oh in 0..OH { - for ow in 0..OW { - bias_g[oc] += out_g[oc][oh][ow]; - } - } - } - - for c in 0..C { - for oh in 0..OH { - for ow in 0..OW { - for oc in 0..OC { - let o_g = &out_g[oc][oh][ow]; - for k1 in 0..K { - let y = (oh * S + k1).wrapping_sub(P); - if y < H { - for k2 in 0..K { - let x = (ow * S + k2).wrapping_sub(P); - if x < W { - weight_g[oc][c][k1][k2] += img[c][y][x] * o_g; - img_g[c][y][x] += weight[oc][c][k1][k2] * o_g; - } - } - } - } - } - } - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -289,18 +146,24 @@ mod tests { [[[0.44704646, -0.29563826], [0.29228759, -0.16575140]]], [[[-0.30488998, 0.25222939], [0.13279295, 0.38153177]]], ]); + let bias = tensor([-0.44699109, 0.38371694]); + let x = tensor([[ [0.37100124, -0.59504986, -1.19781005], [-0.31547278, 0.58071911, 0.86612970], ]]); + let result = conv2d::(x.trace(), &weight, &bias); assert_close(result.data(), &[[[-0.29368058]], [[0.30018353]]]); + let g = backward(result.exp().mean()); + assert_close( g.ref_gradient(&x), &[[[-0.03917716, 0.06006697, 0.], [0.19859464, 0.19576924, 0.]]], ); + assert_close( g.ref_gradient(&weight), &[ @@ -308,66 +171,49 @@ mod tests { [[[0.25044560, -0.40169036], [-0.21296094, 0.39201635]]], ], ); + assert_close(g.ref_gradient(&bias), &[0.37275729, 0.67505330]); } #[test] fn test_conv2d_padding_1() { + #[rustfmt::skip] let weight = tensor([ - [ - [[0.10215953, 0.06263646], [-0.04124039, -0.09729567]], - [[-0.32656857, 0.24254093], [-0.27209827, 0.15361503]], - ], - [ - [[0.03449896, 0.22931078], [-0.17652659, 0.08222872]], - [[-0.06016779, 0.29082409], [-0.19154115, 0.13483226]], - ], - [ - [[-0.14262493, 0.19654515], [0.15921101, 0.01759464]], - [[0.16749159, 0.33096817], [0.28376505, -0.05524009]], - ], + [[[0.10215953, 0.06263646], [-0.04124039, -0.09729567]], [[-0.32656857, 0.24254093], [-0.27209827, 0.15361503]]], + [[[0.03449896, 0.22931078], [-0.17652659, 0.08222872]],[[-0.06016779, 0.29082409], [-0.19154115, 0.13483226]]], + [[[-0.14262493, 0.19654515], [0.15921101, 0.01759464]],[[0.16749159, 0.33096817], [0.28376505, -0.05524009]]], ]); + let bias = tensor([-0.22854491, 0.28763595, 0.20709404]); + let x = tensor([[[-0.32224107, -0.32800716]], [[-1.13570976, 0.93713200]]]); + let result = conv2d::(x.trace(), &weight, &bias); + + #[rustfmt::skip] assert_close( result.data(), &[ - [ - [-0.37165433, 0.26964033, -0.47000977], - [-0.52418506, 0.3161699, -0.56809187], - ], - [ - [0.10800815, 0.66143924, 0.16603859], - [-0.11654915, 0.5421771, 0.21993488], - ], - [ - [0.26416105, -0.22402346, 0.420797], - [-0.23212466, 0.3085245, 0.41083777], - ], + [[-0.37165433, 0.26964033, -0.47000977],[-0.52418506, 0.3161699, -0.56809187]], + [[0.10800815, 0.66143924, 0.16603859],[-0.11654915, 0.5421771, 0.21993488]], + [[0.26416105, -0.22402346, 0.420797],[-0.23212466, 0.3085245, 0.41083777]], ], ); + let gradients = backward(result.exp().mean()); + assert_close( gradients.ref_gradient(&x), &[[[0.010052743, 0.038219165]], [[0.0013861917, 0.096129306]]], ); + #[rustfmt::skip] assert_close( gradients.ref_gradient(&weight), &[ - [ - [[-0.03488452, -0.035597768], [-0.03483199, -0.036207683]], - [[-0.05705857, 0.03406856], [-0.05008337, 0.024666183]], - ], - [ - [[-0.053492695, -0.04727108], [-0.05620105, -0.055251926]], - [[-0.04363727, 0.033381317], [-0.0607851, 0.030584559]], - ], - [ - [[-0.051853612, -0.03900232], [-0.04206547, -0.037880093]], - [[-0.0073834136, 0.0208545], [0.02886929, -0.040557314]], - ], + [[[-0.03488452, -0.035597768], [-0.03483199, -0.036207683]],[[-0.05705857, 0.03406856], [-0.05008337, 0.024666183]]], + [[[-0.053492695, -0.04727108], [-0.05620105, -0.055251926]],[[-0.04363727, 0.033381317], [-0.0607851, 0.030584559]]], + [[[-0.051853612, -0.03900232], [-0.04206547, -0.037880093]],[[-0.0073834136, 0.0208545], [0.02886929, -0.040557314]]], ], ); @@ -379,67 +225,42 @@ mod tests { #[test] fn test_conv2d_stride_3_padding_4() { + #[rustfmt::skip] let weight = tensor([ - [[ - [-0.10252278, -0.14387409, -0.14627469], - [0.28396228, -0.14590892, 0.29269591], - [0.01090384, 0.14785287, 0.29242596], - ]], - [[ - [-0.31163597, 0.13224581, -0.20954299], - [0.27902845, -0.14735751, 0.14001134], - [-0.05224654, 0.16499066, -0.13981307], - ]], + [[[-0.10252278, -0.14387409, -0.14627469],[0.28396228, -0.14590892, 0.29269591],[0.01090384, 0.14785287, 0.29242596]]], + [[[-0.31163597, 0.13224581, -0.20954299],[0.27902845, -0.14735751, 0.14001134],[-0.05224654, 0.16499066, -0.13981307]]], ]); + let bias = tensor([-0.07123789, -0.17244765]); - let x = tensor([[ - [0.69103152, 0.25624934], - [-0.38448590, 0.03110456], - [0.83753252, 0.53786588], - [1.15540242, -0.54148245], - ]]); + + #[rustfmt::skip] + let x = tensor([[[0.69103152, 0.25624934],[-0.38448590, 0.03110456],[0.83753252, 0.53786588],[1.15540242, -0.54148245]]]); + let result = conv2d::(x.trace(), &weight, &bias); + + #[rustfmt::skip] assert_close( result.data(), &[ - [ - [-0.07123789, -0.07123789, -0.07123789], - [-0.07123789, -0.14481398, -0.07123789], - [-0.07123789, -0.59748650, -0.07123789], - [-0.07123789, -0.07123789, -0.07123789], - ], - [ - [-0.17244765, -0.17244765, -0.17244765], - [-0.17244765, -0.3061839, -0.17244765], - [-0.17244765, -0.42046443, -0.17244765], - [-0.17244765, -0.17244765, -0.17244765], - ], + [[-0.07123789, -0.07123789, -0.07123789],[-0.07123789, -0.14481398, -0.07123789],[-0.07123789, -0.59748650, -0.07123789],[-0.07123789, -0.07123789, -0.07123789]], + [[-0.17244765, -0.17244765, -0.17244765],[-0.17244765, -0.3061839, -0.17244765],[-0.17244765, -0.42046443, -0.17244765],[-0.17244765, -0.17244765, -0.17244765]], ], ); + let gradients = backward(result.exp().mean()); + + #[rustfmt::skip] assert_close( gradients.ref_gradient(&x), - &[[ - [-0.009780421, 0.01484663], - [0.010391434, 0.0062526874], - [0.00032053515, -0.009087289], - [-0.0073772445, 0.0105412705], - ]], + &[[[-0.009780421, 0.01484663],[0.010391434, 0.0062526874],[0.00032053515, -0.009087289],[-0.0073772445, 0.0105412705]]], ); + #[rustfmt::skip] assert_close( gradients.ref_gradient(&weight), &[ - [[ - [0.0, 0.019200183, 0.012330416], - [0.0, 0.051398464, -0.003175714], - [0.0, -0.013860448, 0.0011212977], - ]], - [[ - [0.0, 0.02291844, 0.01471829], - [0.0, 0.05281557, -0.0069562597], - [0.0, -0.011794927, 0.00095419877], - ]], + [[[0.0, 0.019200183, 0.012330416],[0.0, 0.051398464, -0.003175714],[0.0, -0.013860448, 0.0011212977]]], + [[[0.0, 0.02291844, 0.01471829],[0.0, 0.05281557, -0.0069562597],[0.0, -0.011794927, 0.00095419877]]], ], ); @@ -453,68 +274,42 @@ mod tests { [[[0.68307382]], [[-0.56570816]]], [[[0.31137520]], [[0.41600472]]], ]); + let bias = tensor([0.49647599, 0.15591705, -0.12342280]); + + #[rustfmt::skip] let x = tensor([ - [ - [[-0.5396145, -2.43986344], [-0.01883135, -1.19915044]], - [[-1.30589044, 2.05276346], [-0.20004864, 0.19919693]], - ], - [ - [[-0.22305037, 0.63030297], [0.65323567, -0.68972057]], - [[-0.50617385, -0.87281805], [0.30253950, -1.75082350]], - ], - [ - [[1.65487242, 0.44441956], [-0.45107457, 1.41857898]], - [[1.00477660, -0.16381662], [0.40009478, -0.57880658]], - ], + [[[-0.5396145, -2.43986344], [-0.01883135, -1.19915044]],[[-1.30589044, 2.05276346], [-0.20004864, 0.19919693]]], + [[[-0.22305037, 0.63030297], [0.65323567, -0.68972057]],[[-0.50617385, -0.87281805], [0.30253950, -1.75082350]]], + [[[1.65487242, 0.44441956], [-0.45107457, 1.41857898]],[[1.00477660, -0.16381662], [0.40009478, -0.57880658]]], ]); let result = conv2d_batched::(x.trace(), &weight, &bias); + + #[rustfmt::skip] assert_close( result.data(), &[ - [ - [[0.56543916, 0.19084194], [0.51086920, 0.40909100]], - [[0.52607340, -2.67195487], [0.25622299, -0.77587855]], - [[-0.83470196, -0.02917651], [-0.21250761, -0.41394162]], - ], - [ - [[0.52237344, 0.60200971], [0.51218325, 0.59096003]], - [[0.28990385, 1.08022082], [0.43097615, 0.67524213]], - [[-0.40344587, -0.29025853], [0.20583645, -1.06653547]], - ], - [ - [[0.51777399, 0.53584486], [0.43837392, 0.62647879]], - [[0.71790677, 0.55216080], [-0.37853706, 1.45234680]], - [[0.80985522, -0.05319006], [-0.09743492, 0.07750125]], - ], + [[[0.56543916, 0.19084194], [0.51086920, 0.40909100]],[[0.52607340, -2.67195487], [0.25622299, -0.77587855]],[[-0.83470196, -0.02917651], [-0.21250761, -0.41394162]]], + [[[0.52237344, 0.60200971], [0.51218325, 0.59096003]],[[0.28990385, 1.08022082], [0.43097615, 0.67524213]],[[-0.40344587, -0.29025853], [0.20583645, -1.06653547]]], + [[[0.51777399, 0.53584486], [0.43837392, 0.62647879]],[[0.71790677, 0.55216080], [-0.37853706, 1.45234680]],[[0.80985522, -0.05319006], [-0.09743492, 0.07750125]]], ], ); let gradients = backward(result.exp().mean()); + + #[rustfmt::skip] assert_close( gradients.ref_gradient(&x), &[ - [ - [[0.03879637, 0.01172858], [0.03428607, 0.01695974]], - [[-0.02537140, 0.00752865], [-0.01455240, -0.00283930]], - ], - [ - [[0.03394239, 0.06539788], [0.04260371, 0.04326087]], - [[-0.01691348, -0.04157412], [-0.01358072, -0.03078515]], - ], - [ - [[0.06113625, 0.04400696], [0.02342399, 0.09354331]], - [[-0.00986115, -0.02002173], [-0.00362044, -0.05869440]], - ], + [[[0.03879637, 0.01172858], [0.03428607, 0.01695974]],[[-0.02537140, 0.00752865], [-0.01455240, -0.00283930]]], + [[[0.03394239, 0.06539788], [0.04260371, 0.04326087]],[[-0.01691348, -0.04157412], [-0.01358072, -0.03078515]]], + [[[0.06113625, 0.04400696], [0.02342399, 0.09354331]],[[-0.00986115, -0.02002173], [-0.00362044, -0.05869440]]], ], ); + #[rustfmt::skip] assert_close( gradients.ref_gradient(&weight), - &[ - [[[0.01032944]], [[-0.11132676]]], - [[[0.26300028]], [[-0.24666277]]], - [[[0.07612189]], [[0.05598290]]], - ], + &[[[[0.01032944]], [[-0.11132676]]],[[[0.26300028]], [[-0.24666277]]],[[[0.07612189]], [[0.05598290]]]], ); assert_close(