diff --git a/src/nn/upscale.rs b/src/nn/upscale.rs index 16e9968ef..8c7a9dc36 100644 --- a/src/nn/upscale.rs +++ b/src/nn/upscale.rs @@ -1,4 +1,4 @@ -use crate::prelude::{Const, NearestNeighbor, Upscale2DWithMethod, UpscaleMethod}; +use crate::prelude::{Const, GenericUpscale2D, NearestNeighbor, UpscaleMethod}; use crate::prelude::{Dim, Dtype, HasErr, Tape, Tensor, Upscale2DKernel, ZerosTensor}; #[allow(unused)] @@ -10,14 +10,14 @@ pub struct Upscale2D ZeroSizedModule for Upscale2D {} impl NonMutableModule for Upscale2D {} -impl> Module +impl> Module for Upscale2D { type Output = Img::Output, Const>; type Error = Img::Err; fn try_forward(&self, x: Img) -> Result { - x.try_upscale2d() + x.generic_upscale2d_like(M::default(), Const, Const) } } @@ -49,7 +49,7 @@ where &self, x: Tensor<(C, Const, Const), E, D, T>, ) -> Result { - x.try_upscale2d() + x.generic_upscale2d_like(M::default(), Const, Const) } } @@ -76,7 +76,7 @@ where &self, x: Tensor<(B, C, Const, Const), E, D, T>, ) -> Result { - x.try_upscale2d() + x.generic_upscale2d_like(M::default(), Const, Const) } } @@ -98,7 +98,7 @@ impl< x: Tensor<(C, usize, usize), E, D, T>, ) -> Result { let shape = x.shape; - x.try_upscale2d_like(shape.1 * H, shape.2 * W) + x.generic_upscale2d_like(M::default(), shape.1 * H, shape.2 * W) } } @@ -123,7 +123,7 @@ where x: Tensor<(B, C, usize, usize), E, D, T>, ) -> Result { let shape = x.shape; - x.try_upscale2d_like(shape.2 * H, shape.3 * W) + x.generic_upscale2d_like(M::default(), shape.2 * H, shape.3 * W) } } #[cfg(test)] diff --git a/src/tensor_ops/conv2d/conv2d.cu b/src/tensor_ops/conv2d/conv2d.cu index 70f9e2533..ed48dbc46 100644 --- a/src/tensor_ops/conv2d/conv2d.cu +++ b/src/tensor_ops/conv2d/conv2d.cu @@ -141,12 +141,12 @@ __device__ void sum_transposed_filters( auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - const T *ptr = filters_tr + i_tr; + filters_tr += i_tr; T tmp = 0.0; - for (auto b = 0; b < op.batch; b++) { - tmp += *ptr; - ptr += numel; + for (int b = 0; b < op.batch; b++) { + tmp += *filters_tr; + filters_tr += numel; } filters[i_no] += tmp; diff --git a/src/tensor_ops/conv2d/mod.rs b/src/tensor_ops/conv2d/mod.rs index 239aae93b..4b6b1dc4b 100644 --- a/src/tensor_ops/conv2d/mod.rs +++ b/src/tensor_ops/conv2d/mod.rs @@ -398,54 +398,6 @@ mod tests { assert_close(&g.get(&bias).array(), &[0.44699076, 0.408709]); } - #[test] - fn test_batched_conv2d() { - let dev: TestDevice = Default::default(); - let weight: Tensor<_, TestDtype, _> = dev.tensor([ - [[[0.05998272]], [[-0.07759511]]], - [[[0.68307382]], [[-0.56570816]]], - [[[0.31137520]], [[0.41600472]]], - ]); - let bias: Tensor<_, TestDtype, _> = dev.tensor([0.49647599, 0.15591705, -0.12342280]); - #[rustfmt::skip] - let x: Tensor<_, TestDtype, _> = dev.tensor([ - [[[-0.5396145, -2.43986344], [-0.01883135, -1.19915044]],[[-1.30589044, 2.05276346], [-0.20004864, 0.19919693]]], - [[[-0.22305037, 0.63030297], [0.65323567, -0.68972057]],[[-0.50617385, -0.87281805], [0.30253950, -1.75082350]]], - [[[1.65487242, 0.44441956], [-0.45107457, 1.41857898]],[[1.00477660, -0.16381662], [0.40009478, -0.57880658]]], - ]); - let result = x.leaky_trace().conv2d::<1, 0>(weight.clone()) - + bias.leaky_trace().broadcast::<_, Axes3<0, 2, 3>>(); - - #[rustfmt::skip] - result.array().assert_close( - &[ - [[[0.56543916, 0.19084194], [0.51086920, 0.40909100]],[[0.52607340, -2.67195487], [0.25622299, -0.77587855]],[[-0.83470196, -0.02917651], [-0.21250761, -0.41394162]]], - [[[0.52237344, 0.60200971], [0.51218325, 0.59096003]],[[0.28990385, 1.08022082], [0.43097615, 0.67524213]],[[-0.40344587, -0.29025853], [0.20583645, -1.06653547]]], - [[[0.51777399, 0.53584486], [0.43837392, 0.62647879]],[[0.71790677, 0.55216080], [-0.37853706, 1.45234680]],[[0.80985522, -0.05319006], [-0.09743492, 0.07750125]]], - ], - 1e-4, - ); - let g = result.exp().mean().backward(); - - #[rustfmt::skip] - g.get(&x).array().assert_close( - &[ - [[[0.03879637, 0.01172858], [0.03428607, 0.01695974]],[[-0.02537140, 0.00752865], [-0.01455240, -0.00283930]]], - [[[0.03394239, 0.06539788], [0.04260371, 0.04326087]],[[-0.01691348, -0.04157412], [-0.01358072, -0.03078515]]], - [[[0.06113625, 0.04400696], [0.02342399, 0.09354331]],[[-0.00986115, -0.02002173], [-0.00362044, -0.05869440]]], - ], - 1e-4, - ); - - #[rustfmt::skip] - assert_close( - &g.get(&weight).array(), - &[[[[0.01032944]], [[-0.11132676]]],[[[0.26300028]], [[-0.24666277]]],[[[0.07612189]], [[0.05598290]]]], - ); - - assert_close(&g.get(&bias).array(), &[0.55381978, 0.55677116, 0.30686682]); - } - #[test] fn test_conv2d_s4p3k2() { let dev = TestDevice::seed_from_u64(432); @@ -464,4 +416,35 @@ mod tests { [[-0.19717735, -0.19717735, -0.19717735],[-0.19717735, 1.3412137, 2.9476144],[-0.19717735, 4.247249, -2.1779637]], ]); } + + #[test] + fn test_batched_conv2d() { + let dev: TestDevice = Default::default(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + let w: Tensor, TestDtype, _> = dev.sample_normal(); + + let y: Tensor, _, _, _> = x.leaky_trace().conv2d::<3, 2>(w.clone()); + let y0 = y.array(); + let grads0 = y.square().mean().backward(); + let x0 = grads0.get(&x).array(); + let w0 = grads0.get(&w).array(); + + let x = x + .broadcast::, _>() + .reshape::>(); + + let y: Tensor, _, _, _> = x.leaky_trace().conv2d::<3, 2>(w.clone()); + for i in 0..10 { + assert_close(&y0, &y.retaped::().select(dev.tensor(i)).array()); + } + + let grads = y.square().mean().backward(); + + assert_close(&w0, &(grads.get(&w)).array()); + + let x_grad = grads.get(&x) * 10.0; + for i in 0..10 { + assert_close(&x0, &x_grad.clone().select(dev.tensor(i)).array()); + } + } } diff --git a/src/tensor_ops/convtrans2d/convtrans2d.cu b/src/tensor_ops/convtrans2d/convtrans2d.cu index ff267c416..0004a90b8 100644 --- a/src/tensor_ops/convtrans2d/convtrans2d.cu +++ b/src/tensor_ops/convtrans2d/convtrans2d.cu @@ -78,7 +78,7 @@ __device__ void unfold_output_into_patches( image_out += b * (op.chan_out * op.h_out * op.w_out) + o * (op.h_out * op.w_out); patches += y * op.w_in + x; patches += o * (op.kernel * op.kernel * op.h_in * op.w_in); - patches += b * (op.chan_in * op.kernel * op.kernel * op.h_in * op.w_in); + patches += b * (op.chan_out * op.kernel * op.kernel * op.h_in * op.w_in); for (int k1 = 0;k1 < op.kernel;k1++) { const size_t oh = y * op.stride + k1 - op.padding; @@ -142,12 +142,12 @@ __device__ void sum_transposed_filters( auto i_tr = c * (op.chan_out * op.kernel * op.kernel) + o * (op.kernel * op.kernel) + k1 * (op.kernel) + k2; auto i_no = o * strides[0] + c * strides[1] + k1 * strides[2] + k2 * strides[3]; - const T *ptr = filters_tr + i_tr; + filters_tr += i_tr; T tmp = 0.0; - for (auto b = 0; b < op.batch; b++) { - tmp += *ptr; - ptr += numel; + for (int b = 0; b < op.batch; b++) { + tmp += *filters_tr; + filters_tr += numel; } filters[i_no] += tmp; diff --git a/src/tensor_ops/convtrans2d/cpu_kernel.rs b/src/tensor_ops/convtrans2d/cpu_kernel.rs index 36cdb4256..05edcc236 100644 --- a/src/tensor_ops/convtrans2d/cpu_kernel.rs +++ b/src/tensor_ops/convtrans2d/cpu_kernel.rs @@ -9,15 +9,11 @@ use super::{ConvTrans2DKernel, ConvTrans2DOp}; impl ConvTrans2DOp { #[inline(always)] fn unfold_idx(&self, [k1, k2, y, x]: [usize; 4]) -> Option<[usize; 2]> { - let mut oh = y * self.stride; - oh += k1; - oh -= self.padding; - - let mut ow = x * self.stride; - ow += k2; - ow -= self.padding; - - Some([oh, ow]) + (y * self.stride + k1) + .checked_sub(self.padding) + .zip((x * self.stride + k2).checked_sub(self.padding)) + .filter(|&(oh, ow)| oh < self.h_out && ow < self.w_out) + .map(|(oh, ow)| [oh, ow]) } } diff --git a/src/tensor_ops/convtrans2d/cuda_kernel.rs b/src/tensor_ops/convtrans2d/cuda_kernel.rs index c69af99df..63f680d51 100644 --- a/src/tensor_ops/convtrans2d/cuda_kernel.rs +++ b/src/tensor_ops/convtrans2d/cuda_kernel.rs @@ -146,7 +146,7 @@ where self.gemm_batch( (op.batch, m, k, n), &f_b1023, - [m * k, k, 1], + [0, k, 1], &patches, [k * n, n, 1], ::ONE, diff --git a/src/tensor_ops/convtrans2d/mod.rs b/src/tensor_ops/convtrans2d/mod.rs index 3352a8f7b..ed5c386d7 100644 --- a/src/tensor_ops/convtrans2d/mod.rs +++ b/src/tensor_ops/convtrans2d/mod.rs @@ -224,6 +224,7 @@ impl< mod tests { use super::*; use crate::{tensor::*, tensor_ops::*, tests::*}; + use num_traits::FromPrimitive; #[test] /// TODO @@ -341,4 +342,39 @@ mod tests { ], ); } + + #[test] + fn test_batched_convtrans2d() { + let dev: TestDevice = Default::default(); + let x: Tensor, TestDtype, _> = dev.sample_normal(); + let w: Tensor, TestDtype, _> = dev.sample_normal(); + + let y: Tensor, _, _, _> = x.leaky_trace().convtrans2d::<3, 2>(w.clone()); + let y0 = y.array(); + let grads0 = y.square().mean().backward(); + let x0 = grads0.get(&x).array(); + let w0 = grads0.get(&w).array(); + + let x = x + .broadcast::, _>() + .reshape::>(); + + let y: Tensor, _, _, _> = + x.leaky_trace().convtrans2d::<3, 2>(w.clone()); + for i in 0..10 { + y0.assert_close( + &y.retaped::().select(dev.tensor(i)).array(), + TestDtype::from_f32(1e-5).unwrap(), + ); + } + + let grads = y.square().mean().backward(); + + assert_close(&w0, &(grads.get(&w)).array()); + + let x_grad = grads.get(&x) * 10.0; + for i in 0..10 { + assert_close(&x0, &x_grad.clone().select(dev.tensor(i)).array()); + } + } } diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 14d7fa910..68daafec3 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -269,7 +269,7 @@ pub use convtrans2d::{ConvTransAlgebra, TryConvTrans2D, TryConvTrans2DTo}; mod upscale2d; pub(crate) use upscale2d::Upscale2DKernel; -pub use upscale2d::{Bilinear, NearestNeighbor, TryUpscale2D, Upscale2DWithMethod, UpscaleMethod}; +pub use upscale2d::{Bilinear, GenericUpscale2D, NearestNeighbor, TryUpscale2D, UpscaleMethod}; #[cfg(feature = "nightly")] mod pool2d; diff --git a/src/tensor_ops/prelu.rs b/src/tensor_ops/prelu.rs index 3d9966ebc..d3c336a70 100644 --- a/src/tensor_ops/prelu.rs +++ b/src/tensor_ops/prelu.rs @@ -1,6 +1,6 @@ use crate::{shapes::*, tensor::*}; -use super::{BroadcastTo, ChooseFrom, Device}; +use super::{BroadcastTo, ChooseFrom, Device, TryMul}; /// [Parametric Rectified Linear Unit (PReLU)](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html). `max(0, lhs) + rhs*min(0, lhs)` /// @@ -65,8 +65,8 @@ where { /// See [prelu] fn try_prelu(self, rhs: Tensor) -> Result { - let scaled = self.with_empty_tape() * rhs; - Ok(self.scalar_lt(E::default()).choose(scaled, self)) + let scaled = self.with_empty_tape().try_mul(rhs)?; + self.try_scalar_lt(E::default())?.try_choose(scaled, self) } } @@ -75,8 +75,8 @@ impl, T: Tape> TryPReLU for Tensor Result { let dev = self.device.clone(); let scale = dev.tensor(rhs).retaped::().broadcast_like(self.shape()); - let scaled = self.with_empty_tape() * scale; - Ok(self.scalar_lt(E::default()).choose(scaled, self)) + let scaled = self.with_empty_tape().try_mul(scale)?; + self.try_scalar_lt(E::default())?.try_choose(scaled, self) } } diff --git a/src/tensor_ops/upscale2d/cpu_kernel.rs b/src/tensor_ops/upscale2d/cpu_kernel.rs index 4e13423d9..328d33254 100644 --- a/src/tensor_ops/upscale2d/cpu_kernel.rs +++ b/src/tensor_ops/upscale2d/cpu_kernel.rs @@ -27,19 +27,21 @@ impl let istr = make_4d::(inp.strides); let ostr = make_4d::(out.strides); - let h_scale = ((istr[1] / istr[2]) as f32) / ((ostr[1] / ostr[2]) as f32); - let w_scale = ((istr[2] / istr[3]) as f32) / ((ostr[2] / ostr[3]) as f32); + let y_ratio = (op.h_in as f32) / (op.h_out as f32); + let x_ratio = (op.w_in as f32) / (op.w_out as f32); let buf = inp.data.as_ref(); let out_buf = Arc::make_mut(&mut out.data); for b in 0..op.batch { for c in 0..op.chan { - for oh in 0..op.h_out { - for ow in 0..op.w_out { - let ih = (h_scale * oh as f32) as usize; - let iw = (w_scale * ow as f32) as usize; - out_buf[b * ostr[0] + c * ostr[1] + oh * ostr[2] + ow * ostr[3]] = - buf[b * istr[0] + c * istr[1] + ih * istr[2] + iw * istr[3]]; + for y_out in 0..op.h_out { + for x_out in 0..op.w_out { + let y_in = (y_ratio * y_out as f32).floor() as usize; + let x_in = (x_ratio * x_out as f32).floor() as usize; + let y_in = y_in.min(op.h_in - 1); + let x_in = x_in.min(op.w_in - 1); + out_buf[b * ostr[0] + c * ostr[1] + y_out * ostr[2] + x_out * ostr[3]] = + buf[b * istr[0] + c * istr[1] + y_in * istr[2] + x_in * istr[3]]; } } } @@ -58,17 +60,19 @@ impl let istr = make_4d::(inp.strides); let ostr = make_4d::(out.strides); - let h_scale = ((istr[1] / istr[2]) as f32) / ((ostr[1] / ostr[2]) as f32); - let w_scale = ((istr[2] / istr[3]) as f32) / ((ostr[2] / ostr[3]) as f32); + let y_ratio = (op.h_in as f32) / (op.h_out as f32); + let x_ratio = (op.w_in as f32) / (op.w_out as f32); for b in 0..op.batch { for c in 0..op.chan { - for oh in 0..op.h_out { - for ow in 0..op.w_out { - let ih = (h_scale * oh as f32) as usize; - let iw = (w_scale * ow as f32) as usize; - grad_inp[b * istr[0] + c * istr[1] + ih * istr[2] + iw * istr[3]] += - grad_out[b * ostr[0] + c * ostr[1] + oh * ostr[2] + ow * ostr[3]]; + for y_out in 0..op.h_out { + for x_out in 0..op.w_out { + let y_in: usize = (y_ratio * y_out as f32).floor() as usize; + let y_in = y_in.min(op.h_in - 1); + let x_in: usize = (x_ratio * x_out as f32).floor() as usize; + let x_in = x_in.min(op.w_in - 1); + grad_inp[b * istr[0] + c * istr[1] + y_in * istr[2] + x_in * istr[3]] += + grad_out[b * ostr[0] + c * ostr[1] + y_out * ostr[2] + x_out * ostr[3]]; } } } @@ -77,9 +81,7 @@ impl } } -impl - super::Upscale2DKernel for Cpu -{ +impl super::Upscale2DKernel for Cpu { fn forward( &self, op: super::Upscale2DOp, @@ -89,51 +91,39 @@ impl let istr = make_4d::(inp.strides); let ostr = make_4d::(out.strides); - let h_scale = ((istr[1] / istr[2] - 1) as f32) / ((ostr[1] / ostr[2] - 1) as f32); - let w_scale = ((istr[2] / istr[3] - 1) as f32) / ((ostr[2] / ostr[3] - 1) as f32); + let y_ratio = ((op.h_in - 1) as f32) / ((op.h_out - 1) as f32); + let x_ratio = ((op.w_in - 1) as f32) / ((op.w_out - 1) as f32); let buf = inp.data.as_ref(); let out_buf = Arc::make_mut(&mut out.data); for b in 0..op.batch { for c in 0..op.chan { - for oh in 0..op.h_out { - for ow in 0..op.w_out { - let ih = (h_scale * oh as f32) as usize; - let iw = (w_scale * ow as f32) as usize; - - let hs = E::from(h_scale * oh as f32 - (ih as f32)).unwrap(); - let ws = E::from(w_scale * ow as f32 - (iw as f32)).unwrap(); - - let one = E::from(1.0).unwrap(); - let zero = E::from(0.0).unwrap(); - - let ll = buf[b * istr[0] + c * istr[1] + ih * istr[2] + iw * istr[3]] - * (one - hs) - * (one - ws); - let lh = if ws != zero { - buf[b * istr[0] + c * istr[1] + ih * istr[2] + (iw + 1) * istr[3]] - * (one - hs) - * ws - } else { - zero - }; - let hl = if hs != zero { - buf[b * istr[0] + c * istr[1] + (ih + 1) * istr[2] + iw * istr[3]] - * hs - * (one - ws) - } else { - zero - }; - let hh = if hs != zero && ws != zero { - buf[b * istr[0] + c * istr[1] + (ih + 1) * istr[2] + (iw + 1) * istr[3]] - * hs - * ws - } else { - zero - }; - - out_buf[b * ostr[0] + c * ostr[1] + oh * ostr[2] + ow * ostr[3]] = - ll + lh + hl + hh; + for y_out in 0..op.h_out { + for x_out in 0..op.w_out { + let x_frac = x_ratio * x_out as f32; + let x0 = x_frac.floor().min((op.w_in - 1) as f32); + let x1 = x_frac.ceil().min((op.w_in - 1) as f32); + let xw = E::from_f32(x_frac - x0).unwrap(); + + let y_frac = y_ratio * y_out as f32; + let y0 = y_frac.floor().min((op.h_in - 1) as f32); + let y1 = y_frac.ceil().min((op.h_in - 1) as f32); + let yw = E::from_f32(y_frac - y0).unwrap(); + + let [x0, x1, y0, y1] = [x0, x1, y0, y1].map(|q| q as usize); + + let p_a = buf[b * istr[0] + c * istr[1] + y0 * istr[2] + x0 * istr[3]]; + let p_b = buf[b * istr[0] + c * istr[1] + y0 * istr[2] + x1 * istr[3]]; + let p_c = buf[b * istr[0] + c * istr[1] + y1 * istr[2] + x0 * istr[3]]; + let p_d = buf[b * istr[0] + c * istr[1] + y1 * istr[2] + x1 * istr[3]]; + + let p_a = p_a * (E::ONE - xw) * (E::ONE - yw); + let p_b = p_b * xw * (E::ONE - yw); + let p_c = p_c * (E::ONE - xw) * yw; + let p_d = p_d * xw * yw; + + out_buf[b * ostr[0] + c * ostr[1] + y_out * ostr[2] + x_out * ostr[3]] = + p_a + p_b + p_c + p_d; } } } @@ -152,42 +142,34 @@ impl let istr = make_4d::(inp.strides); let ostr = make_4d::(out.strides); - let h_scale = ((istr[1] / istr[2] - 1) as f32) / ((ostr[1] / ostr[2] - 1) as f32); - let w_scale = ((istr[2] / istr[3] - 1) as f32) / ((ostr[2] / ostr[3] - 1) as f32); + let y_ratio = ((op.h_in - 1) as f32) / ((op.h_out - 1) as f32); + let x_ratio = ((op.w_in - 1) as f32) / ((op.w_out - 1) as f32); for b in 0..op.batch { for c in 0..op.chan { - for oh in 0..op.h_out { - for ow in 0..op.w_out { - let ih = (h_scale * oh as f32) as usize; - let iw = (w_scale * ow as f32) as usize; - - let hs = E::from(h_scale * oh as f32 - (ih as f32)).unwrap(); - let ws = E::from(w_scale * ow as f32 - (iw as f32)).unwrap(); - - let one = E::from(1.0).unwrap(); - let zero = E::from(0.0).unwrap(); - - let g = grad_out[b * ostr[0] + c * ostr[1] + oh * ostr[2] + ow * ostr[3]]; - - grad_inp[b * istr[0] + c * istr[1] + ih * istr[2] + iw * istr[3]] += - g * (one - hs) * (one - ws); - if ws != zero { - grad_inp - [b * istr[0] + c * istr[1] + ih * istr[2] + (iw + 1) * istr[3]] += - g * (one - hs) * ws; - } - if hs != zero { - grad_inp - [b * istr[0] + c * istr[1] + (ih + 1) * istr[2] + iw * istr[3]] += - g * hs * (one - ws); - } - if ws != zero && hs != zero { - grad_inp[b * istr[0] - + c * istr[1] - + (ih + 1) * istr[2] - + (iw + 1) * istr[3]] += g * hs * ws; - } + let i_base = b * istr[0] + c * istr[1]; + for y_out in 0..op.h_out { + for x_out in 0..op.w_out { + let go = + grad_out[b * ostr[0] + c * ostr[1] + y_out * ostr[2] + x_out * ostr[3]]; + + let x_frac = x_ratio * x_out as f32; + let x0 = x_frac.floor().min((op.w_in - 1) as f32); + let x1 = x_frac.ceil().min((op.w_in - 1) as f32); + let xw = E::from_f32(x_frac - x0).unwrap(); + + let y_frac = y_ratio * y_out as f32; + let y0 = y_frac.floor().min((op.h_in - 1) as f32); + let y1 = y_frac.ceil().min((op.h_in - 1) as f32); + let yw = E::from_f32(y_frac - y0).unwrap(); + + let [x0, x1, y0, y1] = [x0, x1, y0, y1].map(|q| q as usize); + + grad_inp[i_base + y0 * istr[2] + x0 * istr[3]] += + go * (E::ONE - xw) * (E::ONE - yw); + grad_inp[i_base + y0 * istr[2] + x1 * istr[3]] += go * xw * (E::ONE - yw); + grad_inp[i_base + y1 * istr[2] + x0 * istr[3]] += go * (E::ONE - xw) * yw; + grad_inp[i_base + y1 * istr[2] + x1 * istr[3]] += go * xw * yw; } } } diff --git a/src/tensor_ops/upscale2d/cuda_kernel.rs b/src/tensor_ops/upscale2d/cuda_kernel.rs index 562fddd26..ea5ab70e2 100644 --- a/src/tensor_ops/upscale2d/cuda_kernel.rs +++ b/src/tensor_ops/upscale2d/cuda_kernel.rs @@ -7,107 +7,89 @@ use std::sync::Arc; use cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig}; -use super::{Bilinear, NearestNeighbor}; +use super::{Bilinear, NearestNeighbor, UpscaleMethod}; const PTX_SRC: &str = include_str!(concat!(env!("OUT_DIR"), "/upscale2d.ptx")); unsafe impl DeviceRepr for super::Upscale2DOp {} -fn make_4d(strides: S::Concrete) -> [usize; 4] { +fn make_4d(strides: S::Concrete, pad: usize) -> [usize; 4] { match S::NUM_DIMS { - 3 => [0, strides[0], strides[1], strides[2]], + 3 => [pad, strides[0], strides[1], strides[2]], 4 => [strides[0], strides[1], strides[2], strides[3]], _ => panic!("Only implemented for 3d & 4d arrays"), } } -fn make_4d_shape(sizes: S::Concrete) -> [usize; 4] { - match S::NUM_DIMS { - 3 => [1, sizes[0], sizes[1], sizes[2]], - 4 => [sizes[0], sizes[1], sizes[2], sizes[3]], - _ => panic!("Only implemented for 3d & 4d arrays"), - } +trait HasCudaKernel { + const FWD: &'static str; + const BWD: &'static str; } - -macro_rules! pool_impl { - ($Trait:tt<$TypeName:ty, $UpscaleType:ty>, $Fwd:tt, $Bwd:tt) => { - impl super::$Trait<$TypeName, $UpscaleType> for Cuda { - fn forward( - &self, - op: super::Upscale2DOp, - inp: &Tensor, - out: &mut Tensor, - ) -> Result<(), Self::Err> { - if !self.dev.has_func($Fwd, $Fwd) { - self.dev.load_ptx(PTX_SRC.into(), $Fwd, &[$Fwd, $Bwd])?; - } - - let inp_strides = self.dev.htod_copy(make_4d::(inp.strides).into())?; - let inp_sizes = self - .dev - .htod_copy(make_4d_shape::(inp.shape.concrete()).into())?; - let out_strides = self.dev.htod_copy(make_4d::(out.strides).into())?; - let out_sizes = self - .dev - .htod_copy(make_4d_shape::(out.shape.concrete()).into())?; - let fwd_fn = self.dev.get_func($Fwd, $Fwd).unwrap(); - let cfg = LaunchConfig::for_num_elems(out.shape().num_elements() as u32); - let params = ( - op, // const Pool2dOp op, - &inp_strides, // const size_t *inp_strides, - &inp_sizes, // const size_t *inp_sizes, - &out_strides, // const size_t *out_strides, - &out_sizes, // const size_t *out_sizes, - inp.data.as_ref(), // const float *inp, - Arc::make_mut(&mut out.data), // float *out - ); - unsafe { fwd_fn.launch(cfg, params) }?; - Ok(()) - } - fn backward( - &self, - op: super::Upscale2DOp, - inp: &Tensor, - grad_inp: &mut Self::Vec<$TypeName>, - out: &Tensor, - grad_out: &Self::Vec<$TypeName>, - ) -> Result<(), Self::Err> { - let inp_strides = self.dev.htod_copy(make_4d::(inp.strides).into())?; - let inp_sizes = self - .dev - .htod_copy(make_4d_shape::(inp.shape.concrete()).into())?; - let out_strides = self.dev.htod_copy(make_4d::(out.strides).into())?; - let out_sizes = self - .dev - .htod_copy(make_4d_shape::(out.shape.concrete()).into())?; - let bwd_fn = self.dev.get_func($Fwd, $Bwd).unwrap(); - let cfg = LaunchConfig::for_num_elems(inp.shape().num_elements() as u32); - let params = ( - op, // const Pool2dOp op, - &inp_strides, // const size_t *inp_strides, - &inp_sizes, // const size_t *inp_sizes, - &out_strides, // const size_t *out_strides, - &out_sizes, // const size_t *out_sizes, - inp.data.as_ref(), // const float *inp, - grad_inp, // float *grad_inp, - out.data.as_ref(), // const float *out, - grad_out, // const float *grad_out - ); - unsafe { bwd_fn.launch(cfg, params) }?; - Ok(()) - } - } - }; +impl HasCudaKernel for Cuda { + const FWD: &'static str = "nearest_upscale2d_fwd_f32"; + const BWD: &'static str = "nearest_upscale2d_bwd_f32"; } +impl HasCudaKernel for Cuda { + const FWD: &'static str = "bilinear_upscale2d_fwd_f32"; + const BWD: &'static str = "bilinear_upscale2d_bwd_f32"; +} +impl HasCudaKernel for Cuda { + const FWD: &'static str = "nearest_upscale2d_fwd_f64"; + const BWD: &'static str = "nearest_upscale2d_bwd_f64"; +} +impl HasCudaKernel for Cuda { + const FWD: &'static str = "bilinear_upscale2d_fwd_f64"; + const BWD: &'static str = "bilinear_upscale2d_bwd_f64"; +} +impl super::Upscale2DKernel for Cuda +where + Self: HasCudaKernel, +{ + fn forward( + &self, + op: super::Upscale2DOp, + inp: &Tensor, + out: &mut Tensor, + ) -> Result<(), Self::Err> { + if !self.dev.has_func(Self::FWD, Self::FWD) { + self.dev + .load_ptx(PTX_SRC.into(), Self::FWD, &[Self::FWD, Self::BWD])?; + } -pool_impl!( - Upscale2DKernel, - "nearest_upscale2d_fwd_f32", - "nearest_upscale2d_bwd_f32" -); - -pool_impl!( - Upscale2DKernel, - "bilinear_upscale2d_fwd_f32", - "bilinear_upscale2d_bwd_f32" -); + let inp_strides = self.dev.htod_copy(make_4d::(inp.strides, 0).into())?; + let out_strides = self.dev.htod_copy(make_4d::(out.strides, 0).into())?; + let fwd_fn = self.dev.get_func(Self::FWD, Self::FWD).unwrap(); + let cfg = LaunchConfig::for_num_elems(out.shape().num_elements() as u32); + let params = ( + op, // const Pool2dOp op, + &inp_strides, // const size_t *inp_strides, + &out_strides, // const size_t *out_strides, + inp.data.as_ref(), // const float *inp, + Arc::make_mut(&mut out.data), // float *out + ); + unsafe { fwd_fn.launch(cfg, params) }?; + Ok(()) + } + fn backward( + &self, + op: super::Upscale2DOp, + inp: &Tensor, + grad_inp: &mut Self::Vec, + out: &Tensor, + grad_out: &Self::Vec, + ) -> Result<(), Self::Err> { + let inp_strides = self.dev.htod_copy(make_4d::(inp.strides, 0).into())?; + let out_strides = self.dev.htod_copy(make_4d::(out.strides, 0).into())?; + let bwd_fn = self.dev.get_func(Self::FWD, Self::BWD).unwrap(); + let cfg = LaunchConfig::for_num_elems(out.shape().num_elements() as u32); + let params = ( + op, // const Pool2dOp op, + &inp_strides, // const size_t *inp_strides, + &out_strides, // const size_t *out_strides, + grad_inp, // float *grad_inp, + grad_out, // const float *grad_out + ); + unsafe { bwd_fn.launch(cfg, params) }?; + Ok(()) + } +} diff --git a/src/tensor_ops/upscale2d/mod.rs b/src/tensor_ops/upscale2d/mod.rs index bfd2057fc..40b6c9c86 100644 --- a/src/tensor_ops/upscale2d/mod.rs +++ b/src/tensor_ops/upscale2d/mod.rs @@ -40,16 +40,23 @@ impl Upscale2DOp { } } +/// Upscaling method to be used with [TryUpscale2D], can be either +/// [NearestNeighbor] or [Bilinear]. pub trait UpscaleMethod: Default {} +/// Upscales images using a pixel's nearest neighbor. +/// +/// **pytorch equivalent** `F.interpolate(..., mode="nearest")` #[derive(Clone, Copy, Default)] pub struct NearestNeighbor; - impl UpscaleMethod for NearestNeighbor {} +/// Upscales images using bilinear interpolation between +/// a pixels neighbors +/// +/// **pytorch equivalent**: `F.interpolate(..., mode="bilinear", align_corners=True)` #[derive(Clone, Copy, Default)] pub struct Bilinear; - impl UpscaleMethod for Bilinear {} pub trait Upscale2DKernel: DeviceStorage { @@ -70,57 +77,80 @@ pub trait Upscale2DKernel: DeviceStorage { ) -> Result<(), Self::Err>; } -pub trait Upscale2DWithMethod: HasErr { +pub trait GenericUpscale2D: HasErr { type Output; - fn try_upscale2d( - self, - ) -> Result, Const>, Self::Err> { - self.try_upscale2d_like(Const::, Const::) - } - - fn try_upscale2d_like( + fn generic_upscale2d_like( self, + method: M, height: OH, width: OW, ) -> Result, Self::Err>; } +/// Upscales an image to a new shape. Valid methods of upscaling are: +/// +/// - [NearestNeighbor] pytorch equivalent: `F.interpolate(..., mode="nearest")` +/// - [Bilinear] pytorch equivalent: `F.interpolate(..., mode="bilinear", align_corners=True)` +/// +/// Compile time upscale: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t: Tensor, f32, _> = dev.zeros(); +/// let y: Tensor, f32, _> = t.upscale2d(NearestNeighbor); +/// ``` +/// +/// Runtime upscale: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t: Tensor, f32, _> = dev.zeros(); +/// let y: Tensor<(Const<3>, usize, usize), f32, _> = t.upscale2d_like(NearestNeighbor, 64, 64); +/// ``` pub trait TryUpscale2D { - fn upscale_2d( + /// Upscale to compile time known dimensions. + fn upscale2d( self, - ) -> >::Output, Const> + method: M, + ) -> >::Output, Const> where - Self: Upscale2DWithMethod, + Self: GenericUpscale2D, { - self.try_upscale2d().unwrap() + self.generic_upscale2d_like(method, Const, Const).unwrap() } - fn try_upscale_2d( + /// Fallibly upscale to compile time known dimensions. + fn try_upscale2d( self, - ) -> Result<>::Output, Const>, Self::Err> + method: M, + ) -> Result<>::Output, Const>, Self::Err> where - Self: Upscale2DWithMethod, + Self: GenericUpscale2D, { - Upscale2DWithMethod::try_upscale2d(self) + self.generic_upscale2d_like(method, Const, Const) } - fn upscale_2d_like( + /// Upscale to runtime known dimensions. + fn upscale2d_like( self, + method: M, height: OH, width: OW, - ) -> >::Output + ) -> >::Output where - Self: Upscale2DWithMethod, + Self: GenericUpscale2D, { - self.try_upscale2d_like(height, width).unwrap() + self.generic_upscale2d_like(method, height, width).unwrap() } - fn try_upscale_2d_like( + /// Fallibly upscale to runtime known dimensions. + fn try_upscale2d_like( self, + method: M, height: OH, width: OW, - ) -> Result<>::Output, Self::Err> + ) -> Result<>::Output, Self::Err> where - Self: Upscale2DWithMethod, + Self: GenericUpscale2D, { - Upscale2DWithMethod::try_upscale2d_like(self, height, width) + GenericUpscale2D::generic_upscale2d_like(self, method, height, width) } } impl TryUpscale2D for Tensor {} @@ -133,12 +163,13 @@ impl< M: UpscaleMethod, D: Upscale2DKernel + ZerosTensor, T: 'static + Tape, - > Upscale2DWithMethod for Tensor<(C, H, W), E, D, T> + > GenericUpscale2D for Tensor<(C, H, W), E, D, T> { type Output = Tensor<(C, OH, OW), E, D, T>; - fn try_upscale2d_like( + fn generic_upscale2d_like( self, + _method: M, out_height: OH, out_width: OW, ) -> Result, Self::Err> { @@ -174,12 +205,13 @@ impl< M: UpscaleMethod, D: Upscale2DKernel + ZerosTensor, T: 'static + Tape, - > Upscale2DWithMethod for Tensor<(B, C, H, W), E, D, T> + > GenericUpscale2D for Tensor<(B, C, H, W), E, D, T> { type Output = Tensor<(B, C, OH, OW), E, D, T>; - fn try_upscale2d_like( + fn generic_upscale2d_like( self, + _method: M, out_height: OH, out_width: OW, ) -> Result, Self::Err> { @@ -215,11 +247,11 @@ mod tests { use super::{Bilinear, NearestNeighbor, TryUpscale2D}; #[test] - fn nearest_upscale2d_even() { + fn test_upscale2d_nearest_even() { let dev = TestDevice::default(); let x = dev.tensor([[[1.0, 0.0], [2.0, 3.0]]]); - let y = x.leaky_trace().upscale_2d::<4, 4, NearestNeighbor>(); + let y = x.leaky_trace().upscale2d::<4, 4, _>(NearestNeighbor); assert_close( &y.array(), &[[ @@ -238,11 +270,11 @@ mod tests { } #[test] - fn nearest_upscale2d_uneven() { + fn test_upscale2d_nearest_uneven() { let dev = TestDevice::default(); let x = dev.tensor([[[1.0, 0.0, 2.0], [2.0, 3.0, 4.0]]]); - let y = x.leaky_trace().upscale_2d::<2, 7, NearestNeighbor>(); + let y = x.leaky_trace().upscale2d::<2, 7, _>(NearestNeighbor); assert_close( &y.array(), &[[ @@ -261,20 +293,56 @@ mod tests { ); } + #[test] + fn test_upscale2d_nearest_batched() { + let dev = TestDevice::default(); + + let x: Tensor<_, TestDtype, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let x: Tensor, _, _> = [x.clone(), x.clone(), x].stack(); + let x: Tensor, _, _> = + [x.clone(), x.clone(), x.clone(), x.clone(), x].stack(); + let y = x.leaky_trace().upscale2d::<5, 6, _>(NearestNeighbor); + let y_array = y.array(); + for img in y_array { + assert_eq!( + img, + [[ + [1., 1., 2., 2., 3., 3.], + [1., 1., 2., 2., 3., 3.], + [1., 1., 2., 2., 3., 3.], + [4., 4., 5., 5., 6., 6.], + [4., 4., 5., 5., 6., 6.] + ]; 3] + ); + } + + let grads = y.exp().mean().backward(); + let x_grad = grads.get(&x).array(); + for row in x_grad.iter() { + assert_close( + row, + &[[ + [0.03624376, 0.09852076, 0.26780716], + [0.48531687, 1.319228, 3.5860338], + ]; 3], + ); + } + } + // Use align_corners when comparing these #[test] - fn bilinear_upscale2d_even() { + fn test_upscale2d_bilinear_even() { let dev = TestDevice::default(); let x = dev.tensor([[[1.0, 0.0], [2.0, 3.0]]]); - let y = x.leaky_trace().upscale_2d::<4, 4, Bilinear>(); + let y = x.leaky_trace().upscale2d::<4, 4, _>(Bilinear); assert_close( &y.array(), &[[ - [1.0000000, 0.6666666, 0.3333333, 0.0000000], - [1.3333333, 1.2222222, 1.1111112, 1.0000000], - [1.6666667, 1.7777778, 1.8888890, 2.0000000], - [2.0000000, 2.3333333, 2.6666665, 3.0000000], + [1.0, 0.66666663, 0.33333331, 0.0], + [1.33333325, 1.22222221, 1.11111116, 1.0], + [1.66666675, 1.77777779, 1.88888907, 2.0], + [2.0, 2.33333325, 2.66666651, 3.0], ]], ); @@ -286,20 +354,16 @@ mod tests { } #[test] - fn bilinear_upscale2d_uneven() { + fn test_upscale2d_bilinear_uneven() { let dev = TestDevice::default(); let x = dev.tensor([[[1.0, 0.0, 2.0], [2.0, 3.0, 4.0]]]); - let y = x.leaky_trace().upscale_2d::<2, 7, Bilinear>(); + let y = x.leaky_trace().upscale2d::<2, 7, _>(Bilinear); assert_close( &y.array(), &[[ - [ - 1.0000000, 0.6666666, 0.3333333, 0.0000000, 0.6666667, 1.3333335, 2.0000000, - ], - [ - 2.0000000, 2.3333333, 2.6666665, 3.0000000, 3.3333335, 3.6666667, 4.0000000, - ], + [1.0, 0.6666666, 0.3333333, 0.0, 0.6666667, 1.3333335, 2.0], + [2.0, 2.3333333, 2.6666665, 3.0, 3.3333335, 3.6666667, 4.0], ]], ); @@ -312,4 +376,40 @@ mod tests { ]], ); } + + #[test] + fn test_bilinear_upscale2d_batched() { + let dev = TestDevice::default(); + + let x: Tensor<_, TestDtype, _> = dev.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let x: Tensor, _, _> = [x.clone(), x.clone(), x].stack(); + let x: Tensor, _, _> = + [x.clone(), x.clone(), x.clone(), x.clone(), x].stack(); + let y = x.leaky_trace().upscale2d::<5, 6, _>(Bilinear); + let y_array = y.array(); + for img in y_array { + assert_close( + &img, + &[[ + [1.0, 1.4, 1.8, 2.2, 2.6, 3.0], + [1.75, 2.15, 2.55, 2.95, 3.35, 3.75], + [2.5, 2.9, 3.3, 3.7, 4.1, 4.5], + [3.25, 3.65, 4.05, 4.45, 4.85, 5.25], + [4.0, 4.4, 4.8, 5.2, 5.6, 6.0], + ]; 3], + ); + } + + let grads = y.exp().mean().backward(); + let x_grad = grads.get(&x).array(); + for row in x_grad.iter() { + assert_close( + row, + &[[ + [0.10178878, 0.30509925, 0.47953573], + [0.42368498, 1.2699431, 1.9960163], + ]; 3], + ); + } + } } diff --git a/src/tensor_ops/upscale2d/upscale2d.cu b/src/tensor_ops/upscale2d/upscale2d.cu index 5e571bcc3..39786f4e5 100644 --- a/src/tensor_ops/upscale2d/upscale2d.cu +++ b/src/tensor_ops/upscale2d/upscale2d.cu @@ -13,20 +13,17 @@ template __device__ void nearest_upscale2d_fwd( const Upscale2dOp op, const size_t *inp_strides, - const size_t *inp_sizes, const size_t *out_strides, - const size_t *out_sizes, const T *inp, // 4d (Batch, Channels, Height, Width) T *out // 4d (Batch, Channels, HeightOut, WidthOut) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - const size_t numel = op.batch * op.chan * op.h_out * op.w_out; - if (i >= numel) { + if (i >= op.batch * op.chan * op.h_out * op.w_out) { return; } - float h_scale = ((float)inp_sizes[2])/out_sizes[2]; - float w_scale = ((float)inp_sizes[3])/out_sizes[3]; + float h_scale = static_cast(op.h_in)/static_cast(op.h_out); + float w_scale = static_cast(op.w_in)/static_cast(op.w_out); unsigned int idx = i; const size_t ow = idx % op.w_out; @@ -36,10 +33,9 @@ __device__ void nearest_upscale2d_fwd( const size_t c = idx % op.chan; idx /= op.chan; const size_t b = idx % op.batch; - idx /= op.batch; - size_t ih = h_scale * oh; - size_t iw = w_scale * ow; + size_t ih = min(static_cast(h_scale * oh), op.h_out - 1); + size_t iw = min(static_cast(w_scale * ow), op.w_out - 1); size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]; @@ -50,77 +46,49 @@ template __device__ void nearest_upscale2d_bwd( const Upscale2dOp op, const size_t *inp_strides, - const size_t *inp_sizes, const size_t *out_strides, - const size_t *out_sizes, - const T *inp, // 4d (Batch, Channels, Height, Width) T *grad_inp, - const T *out, // 4d (Batch, Channels, HeightOut, WidthOut) - const T *grad_out + const T *grad_out // 4d (Batch, Channels, HeightOut, WidthOut) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - const size_t numel = op.batch * op.chan * op.h_in * op.w_in; - if (i >= numel) { + if (i >= op.batch * op.chan * op.h_out * op.w_out) { return; } - float h_scale = ((float)inp_sizes[2])/out_sizes[2]; - float w_scale = ((float)inp_sizes[3])/out_sizes[3]; + float h_scale = static_cast(op.h_in)/static_cast(op.h_out); + float w_scale = static_cast(op.w_in)/static_cast(op.w_out); unsigned int idx = i; - const size_t x = idx % op.w_in; - idx /= op.w_in; - const size_t y = idx % op.h_in; - idx /= op.h_in; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; const size_t c = idx % op.chan; idx /= op.chan; const size_t b = idx % op.batch; - idx /= op.batch; - // Probably isn't efficient, but it works - size_t oh_s = 0; - size_t ow_s = 0; - size_t oh_e = op.h_out-1; - size_t ow_e = op.w_out-1; - while (oh_s*h_scale < y) { - oh_s++; - } - while (ow_s*w_scale < x) { - ow_s++; - } - while (oh_e*h_scale >= y+1) { - oh_e--; - } - while (ow_e*w_scale >= x+1) { - ow_e--; - } + size_t ih = min(static_cast(h_scale * oh), op.h_out - 1); + size_t iw = min(static_cast(w_scale * ow), op.w_out - 1); - for (int oh = oh_s; oh <= oh_e; oh++) { - for (int ow = ow_s; ow <= ow_e; ow++) { - size_t out_i = b * out_strides[0] + c * out_strides[1] + oh * out_strides[2] + ow * out_strides[3]; - grad_inp[i] += grad_out[out_i]; - } - } + size_t inp_i = b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]; + atomicAdd(grad_inp + inp_i, grad_out[i]); } template __device__ void bilinear_upscale2d_fwd( const Upscale2dOp op, const size_t *inp_strides, - const size_t *inp_sizes, const size_t *out_strides, - const size_t *out_sizes, const T *inp, // 4d (Batch, Channels, Height, Width) T *out // 4d (Batch, Channels, HeightOut, WidthOut) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - const size_t numel = op.batch * op.chan * op.h_out * op.w_out; - if (i >= numel) { + if (i >= op.batch * op.chan * op.h_out * op.w_out) { return; } - float h_scale = ((float)inp_sizes[2]-1)/(out_sizes[2]-1); - float w_scale = ((float)inp_sizes[3]-1)/(out_sizes[3]-1); + float h_scale = ((float)op.h_in-1)/(op.h_out-1); + float w_scale = ((float)op.w_in-1)/(op.w_out-1); unsigned int idx = i; const size_t ow = idx % op.w_out; @@ -132,16 +100,20 @@ __device__ void bilinear_upscale2d_fwd( const size_t b = idx % op.batch; idx /= op.batch; - size_t ih = h_scale * oh; - size_t iw = w_scale * ow; + size_t y0 = min(static_cast(h_scale * oh), op.h_out - 1); + size_t y1 = min(y0 + 1, op.h_out - 1); + size_t x0 = min(static_cast(w_scale * ow), op.w_out - 1); + size_t x1 = min(x0 + 1, op.w_out - 1); + + T hs = h_scale * oh - y0; + T ws = w_scale * ow - x0; - T hs = h_scale * oh - ih; - T ws = w_scale * ow - iw; + inp += b * inp_strides[0] + c * inp_strides[1]; - T ll = inp[b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + iw * inp_strides[3]] * (1-hs) * (1-ws); - T lh = ws != 0 ? inp[b * inp_strides[0] + c * inp_strides[1] + ih * inp_strides[2] + (iw+1) * inp_strides[3]] * (1-hs) * ws : 0; - T hl = hs != 0 ? inp[b * inp_strides[0] + c * inp_strides[1] + (ih+1) * inp_strides[2] + iw * inp_strides[3]] * hs * (1-ws) : 0; - T hh = hs != 0 && ws != 0 ? inp[b * inp_strides[0] + c * inp_strides[1] + (ih+1) * inp_strides[2] + (iw+1) * inp_strides[3]] * hs * ws : 0; + T ll = inp[y0 * inp_strides[2] + x0 * inp_strides[3]] * (1-hs) * (1-ws); + T lh = inp[y0 * inp_strides[2] + x1 * inp_strides[3]] * (1-hs) * ws; + T hl = inp[y1 * inp_strides[2] + x0 * inp_strides[3]] * hs * (1-ws); + T hh = inp[y1 * inp_strides[2] + x1 * inp_strides[3]] * hs * ws; out[i] = ll + lh + hl + hh; } @@ -150,87 +122,64 @@ template __device__ void bilinear_upscale2d_bwd( const Upscale2dOp op, const size_t *inp_strides, - const size_t *inp_sizes, const size_t *out_strides, - const size_t *out_sizes, - const T *inp, // 4d (Batch, Channels, Height, Width) - T *grad_inp, - const T *out, // 4d (Batch, Channels, HeightOut, WidthOut) - const T *grad_out + T *grad_inp, // 4d (Batch, Channels, Height, Width) + const T *grad_out // 4d (Batch, Channels, HeightOut, WidthOut) ) { unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; - const size_t numel = op.batch * op.chan * op.h_in * op.w_in; - if (i >= numel) { + if (i >= op.batch * op.chan * op.h_out * op.w_out) { return; } - float h_scale = ((float)inp_sizes[2]-1)/(out_sizes[2]-1); - float w_scale = ((float)inp_sizes[3]-1)/(out_sizes[3]-1); + float h_scale = ((float)op.h_in-1)/(op.h_out-1); + float w_scale = ((float)op.w_in-1)/(op.w_out-1); unsigned int idx = i; - const size_t x = idx % op.w_in; - idx /= op.w_in; - const size_t y = idx % op.h_in; - idx /= op.h_in; + const size_t ow = idx % op.w_out; + idx /= op.w_out; + const size_t oh = idx % op.h_out; + idx /= op.h_out; const size_t c = idx % op.chan; idx /= op.chan; const size_t b = idx % op.batch; idx /= op.batch; - // Probably isn't efficient, but it works - size_t oh_s = 0; - size_t ow_s = 0; - size_t oh_e = op.h_out-1; - size_t ow_e = op.w_out-1; - while (ceil(oh_s*h_scale) < y) { - oh_s++; - } - while (ceil(ow_s*w_scale) < x) { - ow_s++; - } - while (floor(oh_e*h_scale) > y) { - oh_e--; - } - while (floor(ow_e*w_scale) > x) { - ow_e--; - } + size_t y0 = min(static_cast(h_scale * oh), op.h_out - 1); + size_t y1 = min(y0 + 1, op.h_out - 1); + size_t x0 = min(static_cast(w_scale * ow), op.w_out - 1); + size_t x1 = min(x0 + 1, op.w_out - 1); - for (int oh = oh_s; oh <= oh_e; oh++) { - for (int ow = ow_s; ow <= ow_e; ow++) { - size_t out_i = b * out_strides[0] + c * out_strides[1] + oh * out_strides[2] + ow * out_strides[3]; + T hs = h_scale * oh - y0; + T ws = w_scale * ow - x0; - T hs = abs(h_scale * oh - y); - T ws = abs(w_scale * ow - x); + T go = grad_out[i]; - grad_inp[i] += grad_out[out_i] * (1-hs)*(1-ws); - } - } + grad_inp += b * inp_strides[0] + c * inp_strides[1]; + + atomicAdd(grad_inp + y0 * inp_strides[2] + x0 * inp_strides[3], go * (1-hs) * (1-ws)); + atomicAdd(grad_inp + y0 * inp_strides[2] + x1 * inp_strides[3], go * (1-hs) * ws); + atomicAdd(grad_inp + y1 * inp_strides[2] + x0 * inp_strides[3], go * hs * (1-ws)); + atomicAdd(grad_inp + y1 * inp_strides[2] + x1 * inp_strides[3], go * hs * ws); } #define UPSCALE_OP(TYPENAME, fwd, bwd, fwd_FN, bwd_FN) \ extern "C" __global__ void fwd( \ const Upscale2dOp op, \ const size_t *inp_strides, \ - const size_t *inp_sizes, \ const size_t *out_strides, \ - const size_t *out_sizes, \ const TYPENAME *inp, \ TYPENAME *out \ ) { \ - fwd_FN(op, inp_strides, inp_sizes, out_strides, out_sizes, inp, out); \ + fwd_FN(op, inp_strides, out_strides, inp, out); \ } \ extern "C" __global__ void bwd( \ const Upscale2dOp op, \ const size_t *inp_strides, \ - const size_t *inp_sizes, \ const size_t *out_strides, \ - const size_t *out_sizes, \ - const TYPENAME *inp, \ TYPENAME *grad_inp, \ - const TYPENAME *out, \ const TYPENAME *grad_out \ ) { \ - bwd_FN(op, inp_strides, inp_sizes, out_strides, out_sizes, inp, grad_inp, out, grad_out); \ + bwd_FN(op, inp_strides, out_strides, grad_inp, grad_out); \ } UPSCALE_OP(