diff --git a/src/kernels/gpu_reference_kernel/naive_conv.cpp b/src/kernels/gpu_reference_kernel/naive_conv.cpp index b04dc8720a..e7ee07fe40 100644 --- a/src/kernels/gpu_reference_kernel/naive_conv.cpp +++ b/src/kernels/gpu_reference_kernel/naive_conv.cpp @@ -1039,9 +1039,9 @@ template inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, dst_data_t* __restrict__ p_out, - Strides5D in_strides, + Strides6D in_strides, Strides6D wei_strides, - Strides5D out_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -1080,18 +1080,22 @@ inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restri int in = (bid / do_) % n; int ig = bid / (n * do_); - // TODO: what to do with ig * c_per-group and similarly ig * k_per_group // p_in += static_cast(in) * di * hi * wi * c + static_cast(ig) * c_per_group; + // // p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group; + // + // p_out += static_cast(in) * do_ * ho * wo * k + static_cast(ido) * ho * wo * k + // + static_cast(ig) * k_per_group; - p_in += static_cast(in) * in_strides[4] + static_cast(ig) * c_per_group; - // assumes that group G is the highest dimension in the layout + // dim order NDHWGC + // replace C and K with G * C_per_G and G * K_per_G + p_in += static_cast(in) * in_strides[5] + static_cast(ig) * in_strides[1]; + + // Assumes that group G is the highest dimension in the layout p_wei += static_cast(ig) * wei_strides[5]; - // p_out += static_cast(in) * do_ * ho * wo * k + static_cast(ido) * ho * wo * k - // + static_cast(ig) * k_per_group; - p_out += - static_cast(in) * out_strides[4] + static_cast(ido) * out_strides[3] + ? ? ; + p_out += static_cast(in) * out_strides[5] + static_cast(ido) * out_strides[4] + + static_cast(ig) * out_strides[1]; for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { @@ -1123,19 +1127,19 @@ inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restri { if(valid_d & valid_w & valid_h) { - /* - size_t i_idx = static_cast(cur_d) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + static_cast(ic); - size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + - static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); - */ - size_t i_idx = static_cast(cur_d) * in_strides[3] + - static_cast(cur_h) * in_strides[2] + - static_cast(cur_w) * in_strides[1] + + // size_t i_idx = static_cast(cur_d) * hi * wi * c + + // static_cast(cur_h) * wi * c + + // static_cast(cur_w) * c + static_cast(ic); + // + // size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + + // static_cast(iz) * fy * fx * c_per_group + + // static_cast(iy) * fx * c_per_group + + // static_cast(ix) * c_per_group + + // static_cast(ic); + + size_t i_idx = static_cast(cur_d) * in_strides[4] + + static_cast(cur_h) * in_strides[3] + + static_cast(cur_w) * in_strides[2] + static_cast(ic) * in_strides[0]; size_t f_idx = static_cast(ik) * wei_strides[4] + @@ -1153,8 +1157,8 @@ inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restri } // size_t o_idx = static_cast(iho) * wo * k + static_cast(iwo) * k + // static_cast(ik); - size_t o_idx = static_cast(iho) * out_strides[2] + - static_cast(iwo) * out_strides[1] + + size_t o_idx = static_cast(iho) * out_strides[3] + + static_cast(iwo) * out_strides[2] + static_cast(ik) * out_strides[0]; p_out[o_idx] = cast_to(value); } @@ -1272,9 +1276,9 @@ template inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, - Strides5D in_strides, + Strides6D in_strides, Strides6D wei_strides, - Strides5D out_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -1315,14 +1319,17 @@ inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p // p_in += static_cast(in) * di * hi * wi * c + static_cast(idi) * hi * wi * c + // static_cast(ig) * c_per_group; + // // p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group; + // // p_out += static_cast(in) * do_ * ho * wo * k + static_cast(ig) * k_per_group; - // TODO(Amber): figure out the last term in p_in related to c_per_group - p_in += static_cast(in) * in_strides[4] + static_cast(idi) * in_strides[3] + - static_cast(ig) * c_per_group; + p_in += static_cast(in) * in_strides[5] + static_cast(idi) * in_strides[4] + + static_cast(ig) * in_strides[1]; + p_wei += static_cast(ig) * in_strides[5]; - p_out += static_cast(in) * out_strides[4] + static_cast(ig) * k_per_group; + + p_out += static_cast(in) * out_strides[5] + static_cast(ig) * out_strides[1]; for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { @@ -1367,20 +1374,23 @@ inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p // static_cast(cur_ho) * wo * k + // static_cast(cur_wo) * k + // static_cast(ik); + // // size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + // static_cast(iz) * fy * fx * c_per_group + // static_cast(iy) * fx * c_per_group + // static_cast(ix) * c_per_group + // static_cast(ic); - size_t o_idx = static_cast(cur_do) * out_strides[3] + - static_cast(cur_ho) * out_strides[2] + - static_cast(cur_wo) * out_strides[1] + + size_t o_idx = static_cast(cur_do) * out_strides[4] + + static_cast(cur_ho) * out_strides[3] + + static_cast(cur_wo) * out_strides[2] + static_cast(ik) * out_strides[0]; + size_t f_idx = static_cast(ik) * wei_strides[4] + static_cast(iz) * wei_strides[3] + static_cast(iy) * wei_strides[2] + static_cast(ix) * wei_strides[1] + static_cast(ic) * wei_strides[0]; + value += cast_to(p_out[o_idx]) * cast_to(p_wei[f_idx]); } @@ -1388,8 +1398,10 @@ inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p } } } - size_t i_idx = static_cast(ihi) * in_strides[2] + - static_cast(iwi) * in_strides[1] + + // size_t i_idx = static_cast(ihi) * wi * c + static_cast(iwi) * c + + // static_cast(ic); + size_t i_idx = static_cast(ihi) * in_strides[3] + + static_cast(iwi) * in_strides[2] + static_cast(ic) * in_strides[0]; p_in[i_idx] = cast_to(value); } @@ -1499,9 +1511,9 @@ template inline __device__ void naive_conv_wrw_ndhwc_nonpacked(const src_data_t* __restrict__ p_in, dst_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, - Strides5D in_strides, + Strides6D in_strides, Strides6D wei_strides, - Strides5D out_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -1540,15 +1552,17 @@ inline __device__ void naive_conv_wrw_ndhwc_nonpacked(const src_data_t* __restri int ig = bid / k_per_group; // p_in += static_cast(ig) * c_per_group; + // // p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group + // static_cast(ik) * fz * fy * fx * c_per_group; + // // p_out += static_cast(ig) * k_per_group + static_cast(ik); - // TODO(amber): c_per_group issue - p_in += static_cast(ig) * c_per_group; + p_in += static_cast(ig) * in_strides[1]; + p_wei += static_cast(ig) * wei_strides[5] + static_cast(ik) * wei_strides[4]; - // TODO(amber): k_per_group issue same as c_per_group - p_out += static_cast(ig) * k_per_group + static_cast(ik) * out_strides[0]; + + p_out += static_cast(ig) * out_strides[1] + static_cast(ik) * out_strides[0]; for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { @@ -1587,21 +1601,22 @@ inline __device__ void naive_conv_wrw_ndhwc_nonpacked(const src_data_t* __restri // static_cast(cur_d) * hi * wi * c + // static_cast(cur_h) * wi * c + // static_cast(cur_w) * c + static_cast(ic); + // // size_t o_idx = static_cast(in) * do_ * ho * wo * k + // static_cast(ido) * ho * wo * k + // static_cast(iho) * wo * k + // static_cast(iwo) * k; - size_t i_idx = static_cast(in) * in_strides[4] + - static_cast(cur_d) * in_strides[3] + - static_cast(cur_h) * in_strides[2] + - static_cast(cur_w) * in_strides[1] + + size_t i_idx = static_cast(in) * in_strides[5] + + static_cast(cur_d) * in_strides[4] + + static_cast(cur_h) * in_strides[3] + + static_cast(cur_w) * in_strides[2] + static_cast(ic) * in_strides[0]; - size_t o_idx = static_cast(in) * out_strides[4] + - static_cast(ido) * out_strides[3] + - static_cast(iho) * out_strides[2] + - static_cast(iwo) * out_strides[1]; + size_t o_idx = static_cast(in) * out_strides[5] + + static_cast(ido) * out_strides[4] + + static_cast(iho) * out_strides[3] + + static_cast(iwo) * out_strides[2]; value += cast_to(p_in[i_idx]) * cast_to(p_out[o_idx]);