Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D forward convolution solver with non-packed input tensors #2418

Merged
merged 51 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
6e37c43
Squash commits together
amberhassaan Aug 20, 2023
0db63b7
redo nonpack solver
iq136boy Sep 20, 2023
f9d542a
add gtest for nonpacked tensor solver
iq136boy Sep 21, 2023
94f5fd3
Squash commits together
amberhassaan Aug 20, 2023
5ab82f3
fix formatting. disable strides for fp8 kernel for now
amberhassaan Sep 23, 2023
be93522
fix the lengths of weight tensor
amberhassaan Sep 23, 2023
3762672
add new kernel to test non-packed tensor
iq136boy Sep 23, 2023
603fa64
solve conflict
iq136boy Sep 25, 2023
d85785b
use 64-bit integers for stride value
amberhassaan Sep 25, 2023
ee6abb3
Squash commits together
amberhassaan Aug 20, 2023
4fbcd77
fix test for non-packed strides
amberhassaan Sep 18, 2023
82e0ccf
fix format
amberhassaan Sep 18, 2023
7e8a258
Fix assertion check.
amberhassaan Sep 18, 2023
cadfb95
suppress cppcheck warning to test CI
junliume Sep 20, 2023
cde6e22
fix build and remove a check that prevents non-strided inputs
amberhassaan Sep 25, 2023
9c371eb
merge pr2334
iq136boy Sep 25, 2023
f7b606b
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 25, 2023
0ad674b
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 25, 2023
01b26cd
all gtest passed locally
iq136boy Sep 25, 2023
f633030
minor fix gtest
iq136boy Sep 25, 2023
33e251d
clean debug info
iq136boy Sep 25, 2023
4db6cf8
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 25, 2023
8af6d47
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 25, 2023
b95806f
resolve conflict
iq136boy Sep 26, 2023
85f415c
Merge branch 'develop' into dfeng/ck_nonpack_conv3d_fwd
iq136boy Sep 26, 2023
ec56121
bug fix after merge develop
iq136boy Sep 26, 2023
e06c523
addressed comments. Moved common code into an include file
amberhassaan Sep 26, 2023
23d0066
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 26, 2023
35c9072
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 26, 2023
67d9a77
address comments
amberhassaan Sep 26, 2023
0f16c62
address review comments
amberhassaan Sep 28, 2023
c66da71
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 28, 2023
2f9867d
combine 3d fwd packed and non-packed solvers
iq136boy Sep 28, 2023
8bb3a7f
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 28, 2023
dec88a7
add more checks for strides
amberhassaan Sep 29, 2023
1c45049
disable 3d wrw solver on gfx900
iq136boy Sep 30, 2023
700b623
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Oct 2, 2023
1253aed
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Oct 2, 2023
bdfb4d2
Merge branch 'amber/tests-non-packed-conv' into dfeng/ck_nonpack_conv…
amberhassaan Oct 2, 2023
85e8a62
fix test now that strides are supported
amberhassaan Oct 3, 2023
e2d7f1d
Merge branch 'amber/tests-non-packed-conv' into dfeng/ck_nonpack_conv…
amberhassaan Oct 3, 2023
39eee97
use C++17 to compile HIP Kernels
amberhassaan Oct 4, 2023
f60f182
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Oct 4, 2023
e7c9d03
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Oct 4, 2023
abac92d
Merge branch 'amber/tests-non-packed-conv' into dfeng/ck_nonpack_conv…
amberhassaan Oct 4, 2023
b20ae12
Merge branch 'develop' into dfeng/ck_nonpack_conv3d_fwd
junliume Oct 8, 2023
d0f57da
bug fix after merge develop
iq136boy Oct 9, 2023
936282c
merge develop to resolve conflict
iq136boy Oct 11, 2023
8091d1c
address comments
iq136boy Oct 11, 2023
fbbaebc
Merge branch 'develop' into dfeng/ck_nonpack_conv3d_fwd
junliume Oct 12, 2023
8c49fc2
resolve merge issue
junliume Oct 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions driver/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ inline T gen_0_to_B(T B)
template <typename T>
inline T gen_A_to_B(T A, T B)
{
assert(B > A);
return gen_0_to_B(B - A) + A;
}

Expand Down
4 changes: 4 additions & 0 deletions src/include/miopen/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <miopen/names.hpp>
#include <miopen/invoke_params.hpp>
#include <miopen/invoker.hpp>
#include <miopen/conv/tensors.hpp>

#include <nlohmann/json_fwd.hpp>

Expand Down Expand Up @@ -404,6 +405,9 @@ struct ConvolutionDescriptor : miopenConvolutionDescriptor

friend void to_json(nlohmann::json& json, const ConvolutionDescriptor& conv);
friend void from_json(const nlohmann::json& json, ConvolutionDescriptor& conv);

private:
void ValidateConvTensors(const ConvTensors& conv_tensors) const;
};

void ConvolutionBackwardBias(const Handle& handle,
Expand Down
108 changes: 78 additions & 30 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,31 +287,6 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle,

namespace {

void ValidateConvTensors(const ConvTensors& tensors)
{
const auto invalid_buffers =
tensors.x == nullptr || tensors.w == nullptr || tensors.y == nullptr;

const auto tensor_sizes_not_matched = tensors.xDesc.GetSize() != tensors.yDesc.GetSize() ||
tensors.xDesc.GetSize() != tensors.wDesc.GetSize();

const auto trivial_tensor_types_not_matched =
tensors.xDesc.GetType() != tensors.yDesc.GetType() &&
tensors.xDesc.GetType() != miopenInt8 && tensors.xDesc.GetType() != miopenInt8x4;

// if(xDesc.GetLengths()[1] != wDesc.GetLengths()[1]) {
// MIOPEN_THROW(miopenStatusBadParm);
//}

const auto x_tensor_invalid = tensors.xDesc.GetSize() < 3;

const auto bad_parameters = invalid_buffers || tensor_sizes_not_matched ||
trivial_tensor_types_not_matched || x_tensor_invalid;

if(bad_parameters)
MIOPEN_THROW(miopenStatusBadParm);
}

void ValidateAlphaBeta(const void* alpha, const void* beta)
{
if(!float_equal(*(static_cast<const float*>(alpha)), 1.0) ||
Expand Down Expand Up @@ -402,6 +377,84 @@ static void ConvForwardCheckNumerics(const Handle& handle,
}
}

void ConvolutionDescriptor::ValidateConvTensors(const ConvTensors& tensors) const
{

// Group stride in current TensorDescriptor is implicit. When invoking kernels,
// we need to add the group dimension G and compute its stride. We want the stride
// left of C to be a multiple of group count G. e.g. for NCHW, the stride for N
// should be a multiple of G so that we can compute the strides for NGCHW
auto bad_group_stride = [this](const TensorDescriptor& td) {
auto l = td.GetLayout_t();
int g_stride_index = -1;
if(l == miopenTensorNCHW || l == miopenTensorNCDHW)
{
g_stride_index = 0; // stride index for N;
}
else if(l == miopenTensorNHWC || l == miopenTensorNDHWC)
{
// stride index for W. Normally this would be 2nd-last stride but we store
// strides in NCHW order for some weird reason.
g_stride_index = td.GetStrides().size() - 1;
}

if(g_stride_index != 1)
{
return (td.GetStrides()[g_stride_index] % this->group_count) != 0;
}

return false;
};

// invalid_buffers
if(tensors.x == nullptr || tensors.w == nullptr || tensors.y == nullptr)
{
MIOPEN_THROW(miopenStatusBadParm, "One of the convolution tensors is null");
}

// x_tensor_invalid =
if(tensors.xDesc.GetSize() < 3)
{
MIOPEN_THROW(miopenStatusBadParm, "input tensor's number of dimensions is wrong");
}

// tensor_sizes_not_matched =
if(tensors.xDesc.GetSize() != tensors.yDesc.GetSize() ||
tensors.xDesc.GetSize() != tensors.wDesc.GetSize())
{
MIOPEN_THROW(miopenStatusBadParm,
"number of dimensions mismatch between input, output and weights tensors");
}

// trivial_tensor_types_not_matched =
if(tensors.xDesc.GetType() != tensors.yDesc.GetType() &&
tensors.xDesc.GetType() != miopenInt8 && tensors.xDesc.GetType() != miopenInt8x4)
{
MIOPEN_THROW(miopenStatusBadParm, "input/output tensor data types do not match");
}

// check for bad_group_stride. This applies for input and output only. There
// is no check for weight tensor currently.
// no need to check for group_count == 1

if((this->group_count > 1) && bad_group_stride(tensors.xDesc))
{
MIOPEN_THROW(
miopenStatusBadParm,
"Invalid input tensor strides. Channel stride must be a multiple of group count");
}
if((this->group_count > 1) && bad_group_stride(tensors.yDesc))
{
MIOPEN_THROW(
miopenStatusBadParm,
"Invalid output tensor strides. Channel stride must be a multiple of group count");
}

// if(xDesc.GetLengths()[1] != wDesc.GetLengths()[1]) {
// MIOPEN_THROW(miopenStatusBadParm);
//}
}

void ConvolutionDescriptor::ConvolutionForward(Handle& handle,
const void* alpha,
const TensorDescriptor& xDesc,
Expand All @@ -417,11 +470,6 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle,
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);

if(!(xDesc.IsPacked() && wDesc.IsPacked() && yDesc.IsPacked()))
{
MIOPEN_THROW(miopenStatusNotImplemented, "Only fully packed tensors are supported");
}

const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};
ValidateConvTensors(tensors);
ValidateAlphaBeta(alpha, beta);
Expand Down
11 changes: 5 additions & 6 deletions src/solver/conv_direct_naive_conv_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,8 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx,
kernel.l_wk.push_back(1);
kernel.l_wk.push_back(1);

const auto is_f8 = [&]() {
if(kernel.kernel_file == "fp8_naive_conv.cpp")
return true;
else
return false;
}();
const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp");

kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem);

int G_stride_idx = conv_internal::GetGroupStrideIndex(problem);
Expand All @@ -166,6 +162,9 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx,
handle.Run(kern)(tensors.out,
tensors.w,
tensors.in,
out_strides,
wei_strides,
in_strides,
hi,
wi,
n,
Expand Down
11 changes: 5 additions & 6 deletions src/solver/conv_direct_naive_conv_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx,
KernelInfo kernel;

kernel.kernel_file = ConvDirectNaiveConvKernelFile(ctx, problem);
const auto is_f8 = [&]() {
if(kernel.kernel_file == "fp8_naive_conv.cpp")
return true;
else
return false;
}();
kernel.kernel_name = ConvDirectNaiveConvKernelName(problem);
kernel.g_wk.clear();

Expand All @@ -139,6 +133,8 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx,
kernel.l_wk.push_back(1);
kernel.l_wk.push_back(1);

const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp");

kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem);

int G_stride_idx = conv_internal::GetGroupStrideIndex(problem);
Expand Down Expand Up @@ -166,6 +162,9 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx,
handle.Run(kern)(tensors.in,
tensors.w,
tensors.out,
in_strides,
wei_strides,
out_strides,
hi,
wi,
n,
Expand Down
11 changes: 5 additions & 6 deletions src/solver/conv_direct_naive_conv_wrw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,9 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx,
kernel.l_wk.push_back(1);
kernel.l_wk.push_back(1);

const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp");

kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem);
const auto is_f8 = [&]() {
if(kernel.kernel_file == "fp8_naive_conv.cpp")
return true;
else
return false;
}();

int G_stride_idx = conv_internal::GetGroupStrideIndex(problem);

Expand All @@ -154,6 +150,9 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx,
handle.Run(kern)(tensors.x,
tensors.dw,
tensors.dy,
in_strides,
wei_strides,
out_strides,
hi,
wi,
n,
Expand Down
14 changes: 10 additions & 4 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,16 @@ struct CKArgs
output = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// strides from NHWGC to GNCHW laout
in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
// miopen strides to CK strides
auto miopen_in_strides = problem.GetIn().GetStrides();
auto miopen_out_strides = problem.GetOut().GetStrides();
auto miopen_wei_strides = problem.GetWeights().GetStrides();
miopen_in_strides.insert(miopen_in_strides.begin(), C);
miopen_out_strides.insert(miopen_out_strides.begin(), K);
miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]);
std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin());
std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin());
std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin());

strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
Expand Down
16 changes: 11 additions & 5 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,16 @@ struct CKArgs
output = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// strides from NHWGC to GNCHW laout
in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
// miopen strides to CK strides
auto miopen_in_strides = problem.GetIn().GetStrides();
auto miopen_out_strides = problem.GetOut().GetStrides();
auto miopen_wei_strides = problem.GetWeights().GetStrides();
miopen_in_strides.insert(miopen_in_strides.begin(), C);
miopen_out_strides.insert(miopen_out_strides.begin(), K);
miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]);
std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin());
std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin());
std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin());

strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
Expand Down Expand Up @@ -314,7 +320,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(
if(!problem.IsLayoutNHWC())
return false;
const std::string& arch = ctx.GetStream().GetDeviceName();
if(!(arch == "gfx908" || arch == "gfx90a"))
if(miopen::StartsWith(arch, "gfx11") || miopen::StartsWith(arch, "gfx10"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we convert this to a white list, so, list the arches that CK supports.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion and should be used in all CK-based solvers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest we convert this to a white list, so, list the arches that CK supports.

Great idea. I have applied the change to all ck based solvers and currently doing the local test before push the changes.

return false;
switch(problem.GetInDataType())
{
Expand Down
14 changes: 10 additions & 4 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@ struct CKArgs
output = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// strides from NHWGC to GNCHW laout
in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
// miopen strides to CK strides
auto miopen_in_strides = problem.GetIn().GetStrides();
auto miopen_out_strides = problem.GetOut().GetStrides();
auto miopen_wei_strides = problem.GetWeights().GetStrides();
miopen_in_strides.insert(miopen_in_strides.begin(), C);
miopen_out_strides.insert(miopen_out_strides.begin(), K);
miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]);
std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin());
std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin());
std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin());

strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
Expand Down
Loading