Skip to content

Commit

Permalink
fix: treat input and output as 6D NDHWGC tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
amberhassaan committed Sep 4, 2023
1 parent 568db6f commit 9ce51ec
Showing 1 changed file with 64 additions and 49 deletions.
113 changes: 64 additions & 49 deletions src/kernels/gpu_reference_kernel/naive_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,9 +1039,9 @@ template <typename src_data_t, typename acc_data_t, typename dst_data_t>
inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restrict__ p_in,
const src_data_t* __restrict__ p_wei,
dst_data_t* __restrict__ p_out,
Strides5D in_strides,
Strides6D in_strides,
Strides6D wei_strides,
Strides5D out_strides,
Strides6D out_strides,
int di,
int hi,
int wi,
Expand Down Expand Up @@ -1080,18 +1080,22 @@ inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restri
int in = (bid / do_) % n;
int ig = bid / (n * do_);

// TODO: what to do with ig * c_per-group and similarly ig * k_per_group
// p_in += static_cast<size_t>(in) * di * hi * wi * c + static_cast<size_t>(ig) * c_per_group;
//
// p_wei += static_cast<size_t>(ig) * k_per_group * fz * fy * fx * c_per_group;
//
// p_out += static_cast<size_t>(in) * do_ * ho * wo * k + static_cast<size_t>(ido) * ho * wo * k
// + static_cast<size_t>(ig) * k_per_group;

p_in += static_cast<size_t>(in) * in_strides[4] + static_cast<size_t>(ig) * c_per_group;
// assumes that group G is the highest dimension in the layout
// dim order NDHWGC
// replace C and K with G * C_per_G and G * K_per_G
p_in += static_cast<size_t>(in) * in_strides[5] + static_cast<size_t>(ig) * in_strides[1];

// Assumes that group G is the highest dimension in the layout
p_wei += static_cast<size_t>(ig) * wei_strides[5];

// p_out += static_cast<size_t>(in) * do_ * ho * wo * k + static_cast<size_t>(ido) * ho * wo * k
// + static_cast<size_t>(ig) * k_per_group;
p_out +=
static_cast<size_t>(in) * out_strides[4] + static_cast<size_t>(ido) * out_strides[3] + ? ? ;
p_out += static_cast<size_t>(in) * out_strides[5] + static_cast<size_t>(ido) * out_strides[4] +
static_cast<size_t>(ig) * out_strides[1];

for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x)
{
Expand Down Expand Up @@ -1123,19 +1127,19 @@ inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restri
{
if(valid_d & valid_w & valid_h)
{
/*
size_t i_idx = static_cast<size_t>(cur_d) * hi * wi * c +
static_cast<size_t>(cur_h) * wi * c +
static_cast<size_t>(cur_w) * c + static_cast<size_t>(ic);
size_t f_idx = static_cast<size_t>(ik) * fz * fy * fx * c_per_group +
static_cast<size_t>(iz) * fy * fx * c_per_group +
static_cast<size_t>(iy) * fx * c_per_group +
static_cast<size_t>(ix) * c_per_group +
static_cast<size_t>(ic);
*/
size_t i_idx = static_cast<size_t>(cur_d) * in_strides[3] +
static_cast<size_t>(cur_h) * in_strides[2] +
static_cast<size_t>(cur_w) * in_strides[1] +
// size_t i_idx = static_cast<size_t>(cur_d) * hi * wi * c +
// static_cast<size_t>(cur_h) * wi * c +
// static_cast<size_t>(cur_w) * c + static_cast<size_t>(ic);
//
// size_t f_idx = static_cast<size_t>(ik) * fz * fy * fx * c_per_group +
// static_cast<size_t>(iz) * fy * fx * c_per_group +
// static_cast<size_t>(iy) * fx * c_per_group +
// static_cast<size_t>(ix) * c_per_group +
// static_cast<size_t>(ic);

size_t i_idx = static_cast<size_t>(cur_d) * in_strides[4] +
static_cast<size_t>(cur_h) * in_strides[3] +
static_cast<size_t>(cur_w) * in_strides[2] +
static_cast<size_t>(ic) * in_strides[0];

size_t f_idx = static_cast<size_t>(ik) * wei_strides[4] +
Expand All @@ -1153,8 +1157,8 @@ inline __device__ void naive_conv_fwd_ndhwc_nonpacked(const src_data_t* __restri
}
// size_t o_idx = static_cast<size_t>(iho) * wo * k + static_cast<size_t>(iwo) * k +
// static_cast<size_t>(ik);
size_t o_idx = static_cast<size_t>(iho) * out_strides[2] +
static_cast<size_t>(iwo) * out_strides[1] +
size_t o_idx = static_cast<size_t>(iho) * out_strides[3] +
static_cast<size_t>(iwo) * out_strides[2] +
static_cast<size_t>(ik) * out_strides[0];
p_out[o_idx] = cast_to<acc_data_t, dst_data_t>(value);
}
Expand Down Expand Up @@ -1272,9 +1276,9 @@ template <typename src_data_t, typename acc_data_t, typename dst_data_t>
inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p_in,
const src_data_t* __restrict__ p_wei,
const src_data_t* __restrict__ p_out,
Strides5D in_strides,
Strides6D in_strides,
Strides6D wei_strides,
Strides5D out_strides,
Strides6D out_strides,
int di,
int hi,
int wi,
Expand Down Expand Up @@ -1315,14 +1319,17 @@ inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p

// p_in += static_cast<size_t>(in) * di * hi * wi * c + static_cast<size_t>(idi) * hi * wi * c +
// static_cast<size_t>(ig) * c_per_group;
//
// p_wei += static_cast<size_t>(ig) * k_per_group * fz * fy * fx * c_per_group;
//
// p_out += static_cast<size_t>(in) * do_ * ho * wo * k + static_cast<size_t>(ig) * k_per_group;

// TODO(Amber): figure out the last term in p_in related to c_per_group
p_in += static_cast<size_t>(in) * in_strides[4] + static_cast<size_t>(idi) * in_strides[3] +
static_cast<size_t>(ig) * c_per_group;
p_in += static_cast<size_t>(in) * in_strides[5] + static_cast<size_t>(idi) * in_strides[4] +
static_cast<size_t>(ig) * in_strides[1];

p_wei += static_cast<size_t>(ig) * in_strides[5];
p_out += static_cast<size_t>(in) * out_strides[4] + static_cast<size_t>(ig) * k_per_group;

p_out += static_cast<size_t>(in) * out_strides[5] + static_cast<size_t>(ig) * out_strides[1];

for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x)
{
Expand Down Expand Up @@ -1367,29 +1374,34 @@ inline __device__ void naive_conv_bwd_ndhwc_nonpacked(dst_data_t* __restrict__ p
// static_cast<size_t>(cur_ho) * wo * k +
// static_cast<size_t>(cur_wo) * k +
// static_cast<size_t>(ik);
//
// size_t f_idx = static_cast<size_t>(ik) * fz * fy * fx * c_per_group +
// static_cast<size_t>(iz) * fy * fx * c_per_group +
// static_cast<size_t>(iy) * fx * c_per_group +
// static_cast<size_t>(ix) * c_per_group +
// static_cast<size_t>(ic);
size_t o_idx = static_cast<size_t>(cur_do) * out_strides[3] +
static_cast<size_t>(cur_ho) * out_strides[2] +
static_cast<size_t>(cur_wo) * out_strides[1] +
size_t o_idx = static_cast<size_t>(cur_do) * out_strides[4] +
static_cast<size_t>(cur_ho) * out_strides[3] +
static_cast<size_t>(cur_wo) * out_strides[2] +
static_cast<size_t>(ik) * out_strides[0];

size_t f_idx = static_cast<size_t>(ik) * wei_strides[4] +
static_cast<size_t>(iz) * wei_strides[3] +
static_cast<size_t>(iy) * wei_strides[2] +
static_cast<size_t>(ix) * wei_strides[1] +
static_cast<size_t>(ic) * wei_strides[0];

value += cast_to<src_data_t, acc_data_t>(p_out[o_idx]) *
cast_to<src_data_t, acc_data_t>(p_wei[f_idx]);
}
}
}
}
}
size_t i_idx = static_cast<size_t>(ihi) * in_strides[2] +
static_cast<size_t>(iwi) * in_strides[1] +
// size_t i_idx = static_cast<size_t>(ihi) * wi * c + static_cast<size_t>(iwi) * c +
// static_cast<size_t>(ic);
size_t i_idx = static_cast<size_t>(ihi) * in_strides[3] +
static_cast<size_t>(iwi) * in_strides[2] +
static_cast<size_t>(ic) * in_strides[0];
p_in[i_idx] = cast_to<acc_data_t, dst_data_t>(value);
}
Expand Down Expand Up @@ -1499,9 +1511,9 @@ template <typename src_data_t, typename acc_data_t, typename dst_data_t>
inline __device__ void naive_conv_wrw_ndhwc_nonpacked(const src_data_t* __restrict__ p_in,
dst_data_t* __restrict__ p_wei,
const src_data_t* __restrict__ p_out,
Strides5D in_strides,
Strides6D in_strides,
Strides6D wei_strides,
Strides5D out_strides,
Strides6D out_strides,
int di,
int hi,
int wi,
Expand Down Expand Up @@ -1540,15 +1552,17 @@ inline __device__ void naive_conv_wrw_ndhwc_nonpacked(const src_data_t* __restri
int ig = bid / k_per_group;

// p_in += static_cast<size_t>(ig) * c_per_group;
//
// p_wei += static_cast<size_t>(ig) * k_per_group * fz * fy * fx * c_per_group +
// static_cast<size_t>(ik) * fz * fy * fx * c_per_group;
//
// p_out += static_cast<size_t>(ig) * k_per_group + static_cast<size_t>(ik);

// TODO(amber): c_per_group issue
p_in += static_cast<size_t>(ig) * c_per_group;
p_in += static_cast<size_t>(ig) * in_strides[1];

p_wei += static_cast<size_t>(ig) * wei_strides[5] + static_cast<size_t>(ik) * wei_strides[4];
// TODO(amber): k_per_group issue same as c_per_group
p_out += static_cast<size_t>(ig) * k_per_group + static_cast<size_t>(ik) * out_strides[0];

p_out += static_cast<size_t>(ig) * out_strides[1] + static_cast<size_t>(ik) * out_strides[0];

for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x)
{
Expand Down Expand Up @@ -1587,21 +1601,22 @@ inline __device__ void naive_conv_wrw_ndhwc_nonpacked(const src_data_t* __restri
// static_cast<size_t>(cur_d) * hi * wi * c +
// static_cast<size_t>(cur_h) * wi * c +
// static_cast<size_t>(cur_w) * c + static_cast<size_t>(ic);
//
// size_t o_idx = static_cast<size_t>(in) * do_ * ho * wo * k +
// static_cast<size_t>(ido) * ho * wo * k +
// static_cast<size_t>(iho) * wo * k +
// static_cast<size_t>(iwo) * k;

size_t i_idx = static_cast<size_t>(in) * in_strides[4] +
static_cast<size_t>(cur_d) * in_strides[3] +
static_cast<size_t>(cur_h) * in_strides[2] +
static_cast<size_t>(cur_w) * in_strides[1] +
size_t i_idx = static_cast<size_t>(in) * in_strides[5] +
static_cast<size_t>(cur_d) * in_strides[4] +
static_cast<size_t>(cur_h) * in_strides[3] +
static_cast<size_t>(cur_w) * in_strides[2] +
static_cast<size_t>(ic) * in_strides[0];

size_t o_idx = static_cast<size_t>(in) * out_strides[4] +
static_cast<size_t>(ido) * out_strides[3] +
static_cast<size_t>(iho) * out_strides[2] +
static_cast<size_t>(iwo) * out_strides[1];
size_t o_idx = static_cast<size_t>(in) * out_strides[5] +
static_cast<size_t>(ido) * out_strides[4] +
static_cast<size_t>(iho) * out_strides[3] +
static_cast<size_t>(iwo) * out_strides[2];

value += cast_to<src_data_t, acc_data_t>(p_in[i_idx]) *
cast_to<src_data_t, acc_data_t>(p_out[o_idx]);
Expand Down

0 comments on commit 9ce51ec

Please sign in to comment.