Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D forward convolution solver with non-packed input tensors #2418

Merged
merged 51 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
6e37c43
Squash commits together
amberhassaan Aug 20, 2023
0db63b7
redo nonpack solver
iq136boy Sep 20, 2023
f9d542a
add gtest for nonpacked tensor solver
iq136boy Sep 21, 2023
94f5fd3
Squash commits together
amberhassaan Aug 20, 2023
5ab82f3
fix formatting. disable strides for fp8 kernel for now
amberhassaan Sep 23, 2023
be93522
fix the lengths of weight tensor
amberhassaan Sep 23, 2023
3762672
add new kernel to test non-packed tensor
iq136boy Sep 23, 2023
603fa64
solve conflict
iq136boy Sep 25, 2023
d85785b
use 64-bit integers for stride value
amberhassaan Sep 25, 2023
ee6abb3
Squash commits together
amberhassaan Aug 20, 2023
4fbcd77
fix test for non-packed strides
amberhassaan Sep 18, 2023
82e0ccf
fix format
amberhassaan Sep 18, 2023
7e8a258
Fix assertion check.
amberhassaan Sep 18, 2023
cadfb95
suppress cppcheck warning to test CI
junliume Sep 20, 2023
cde6e22
fix build and remove a check that prevents non-strided inputs
amberhassaan Sep 25, 2023
9c371eb
merge pr2334
iq136boy Sep 25, 2023
f7b606b
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 25, 2023
0ad674b
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 25, 2023
01b26cd
all gtest passed locally
iq136boy Sep 25, 2023
f633030
minor fix gtest
iq136boy Sep 25, 2023
33e251d
clean debug info
iq136boy Sep 25, 2023
4db6cf8
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 25, 2023
8af6d47
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 25, 2023
b95806f
resolve conflict
iq136boy Sep 26, 2023
85f415c
Merge branch 'develop' into dfeng/ck_nonpack_conv3d_fwd
iq136boy Sep 26, 2023
ec56121
bug fix after merge develop
iq136boy Sep 26, 2023
e06c523
addressed comments. Moved common code into an include file
amberhassaan Sep 26, 2023
23d0066
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 26, 2023
35c9072
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 26, 2023
67d9a77
address comments
amberhassaan Sep 26, 2023
0f16c62
address review comments
amberhassaan Sep 28, 2023
c66da71
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Sep 28, 2023
2f9867d
combine 3d fwd packed and non-packed solvers
iq136boy Sep 28, 2023
8bb3a7f
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Sep 28, 2023
dec88a7
add more checks for strides
amberhassaan Sep 29, 2023
1c45049
disable 3d wrw solver on gfx900
iq136boy Sep 30, 2023
700b623
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Oct 2, 2023
1253aed
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Oct 2, 2023
bdfb4d2
Merge branch 'amber/tests-non-packed-conv' into dfeng/ck_nonpack_conv…
amberhassaan Oct 2, 2023
85e8a62
fix test now that strides are supported
amberhassaan Oct 3, 2023
e2d7f1d
Merge branch 'amber/tests-non-packed-conv' into dfeng/ck_nonpack_conv…
amberhassaan Oct 3, 2023
39eee97
use C++17 to compile HIP Kernels
amberhassaan Oct 4, 2023
f60f182
Merge remote-tracking branch 'origin/develop' into amber/non-packed-c…
amberhassaan Oct 4, 2023
e7c9d03
Merge branch 'amber/non-packed-conv-ref-kern' into amber/tests-non-pa…
amberhassaan Oct 4, 2023
abac92d
Merge branch 'amber/tests-non-packed-conv' into dfeng/ck_nonpack_conv…
amberhassaan Oct 4, 2023
b20ae12
Merge branch 'develop' into dfeng/ck_nonpack_conv3d_fwd
junliume Oct 8, 2023
d0f57da
bug fix after merge develop
iq136boy Oct 9, 2023
936282c
merge develop to resolve conflict
iq136boy Oct 11, 2023
8091d1c
address comments
iq136boy Oct 11, 2023
fbbaebc
Merge branch 'develop' into dfeng/ck_nonpack_conv3d_fwd
junliume Oct 12, 2023
8c49fc2
resolve merge issue
junliume Oct 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions driver/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ inline T gen_0_to_B(T B)
template <typename T>
inline T gen_A_to_B(T A, T B)
{
assert(B > A);
return gen_0_to_B(B - A) + A;
}

Expand Down
10 changes: 10 additions & 0 deletions src/include/miopen/solver/ck_utility_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ static inline bool is_ck_supported_hardware(const Handle& handle)
StartsWith(handle.GetDeviceName(), "gfx1102");
}

static inline bool is_conv_ck_supported_hardware(const std::string& device_name, bool is_wrw)
{
auto res_wrw = StartsWith(device_name, "gfx908") || StartsWith(device_name, "gfx90a") ||
StartsWith(device_name, "gfx940") || StartsWith(device_name, "gfx941") ||
StartsWith(device_name, "gfx942");
return is_wrw ? res_wrw
: (res_wrw || StartsWith(device_name, "gfx900") ||
StartsWith(device_name, "gfx906"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a recommendation : place StartsWith(device_name, "gfx900") || StartsWith(device_name, "gfx906") in sperate bool function. That way we don't need to pass is_wrw.

}

static inline bool is_support_amd_buffer_atomic_fadd(const std::string& device_name)
{
return StartsWith(device_name, "gfx908");
Expand Down
5 changes: 0 additions & 5 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,6 @@ void ConvolutionDescriptor::ConvolutionForward(Handle& handle,
{
MIOPEN_LOG_I("algo = " << algo << ", workspace = " << workSpaceSize);

if(!(xDesc.IsPacked() && wDesc.IsPacked() && yDesc.IsPacked()))
{
MIOPEN_THROW(miopenStatusNotImplemented, "Only fully packed tensors are supported");
}

const auto tensors = ConvFwdTensors{xDesc, x, wDesc, w, yDesc, y};
ValidateConvTensors(tensors);
ValidateAlphaBeta(alpha, beta);
Expand Down
11 changes: 5 additions & 6 deletions src/solver/conv_direct_naive_conv_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,8 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx,
kernel.l_wk.push_back(1);
kernel.l_wk.push_back(1);

const auto is_f8 = [&]() {
if(kernel.kernel_file == "fp8_naive_conv.cpp")
return true;
else
return false;
}();
const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp");

kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem);

int G_stride_idx = conv_internal::GetGroupStrideIndex(problem);
Expand All @@ -166,6 +162,9 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx,
handle.Run(kern)(tensors.out,
tensors.w,
tensors.in,
out_strides,
wei_strides,
in_strides,
hi,
wi,
n,
Expand Down
11 changes: 5 additions & 6 deletions src/solver/conv_direct_naive_conv_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,6 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx,
KernelInfo kernel;

kernel.kernel_file = ConvDirectNaiveConvKernelFile(ctx, problem);
const auto is_f8 = [&]() {
if(kernel.kernel_file == "fp8_naive_conv.cpp")
return true;
else
return false;
}();
kernel.kernel_name = ConvDirectNaiveConvKernelName(problem);
kernel.g_wk.clear();

Expand All @@ -139,6 +133,8 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx,
kernel.l_wk.push_back(1);
kernel.l_wk.push_back(1);

const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp");

kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem);

int G_stride_idx = conv_internal::GetGroupStrideIndex(problem);
Expand Down Expand Up @@ -166,6 +162,9 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx,
handle.Run(kern)(tensors.in,
tensors.w,
tensors.out,
in_strides,
wei_strides,
out_strides,
hi,
wi,
n,
Expand Down
11 changes: 5 additions & 6 deletions src/solver/conv_direct_naive_conv_wrw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,9 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx,
kernel.l_wk.push_back(1);
kernel.l_wk.push_back(1);

const auto is_f8 = (kernel.kernel_file == "fp8_naive_conv.cpp");

kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem);
const auto is_f8 = [&]() {
if(kernel.kernel_file == "fp8_naive_conv.cpp")
return true;
else
return false;
}();

int G_stride_idx = conv_internal::GetGroupStrideIndex(problem);

Expand All @@ -154,6 +150,9 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx,
handle.Run(kern)(tensors.x,
tensors.dw,
tensors.dy,
in_strides,
wei_strides,
out_strides,
hi,
wi,
n,
Expand Down
18 changes: 12 additions & 6 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <miopen/conv/data_invoke_params.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <miopen/solver/ck_utility_common.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp>
#endif
#include <miopen/solver/implicitgemm_ck_util.hpp>
Expand Down Expand Up @@ -86,10 +87,16 @@ struct CKArgs
output = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// strides from NHWGC to GNCHW laout
in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
// miopen strides to CK strides
auto miopen_in_strides = problem.GetIn().GetStrides();
auto miopen_out_strides = problem.GetOut().GetStrides();
auto miopen_wei_strides = problem.GetWeights().GetStrides();
miopen_in_strides.insert(miopen_in_strides.begin(), C);
miopen_out_strides.insert(miopen_out_strides.begin(), K);
miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]);
std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin());
std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin());
std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin());

strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
Expand Down Expand Up @@ -315,8 +322,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable(
return false;
if(!problem.IsLayoutNHWC())
return false;
const std::string& arch = ctx.GetStream().GetDeviceName();
if(miopen::StartsWith(arch, "gfx11") || miopen::StartsWith(arch, "gfx10"))
if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), false))
return false;
switch(problem.GetInDataType())
{
Expand Down
18 changes: 12 additions & 6 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <miopen/conv/data_invoke_params.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <miopen/solver/ck_utility_common.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp>
#endif
#include <miopen/solver/implicitgemm_ck_util.hpp>
Expand Down Expand Up @@ -86,10 +87,16 @@ struct CKArgs
output = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// strides from NHWGC to GNCHW laout
in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
// miopen strides to CK strides
auto miopen_in_strides = problem.GetIn().GetStrides();
auto miopen_out_strides = problem.GetOut().GetStrides();
auto miopen_wei_strides = problem.GetWeights().GetStrides();
miopen_in_strides.insert(miopen_in_strides.begin(), C);
miopen_out_strides.insert(miopen_out_strides.begin(), K);
miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]);
std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin());
std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin());
std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin());

strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
Expand Down Expand Up @@ -313,8 +320,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(
return false;
if(!problem.IsLayoutNHWC())
return false;
const std::string& arch = ctx.GetStream().GetDeviceName();
if(!(arch == "gfx908" || arch == "gfx90a"))
if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), false))
return false;
switch(problem.GetInDataType())
{
Expand Down
20 changes: 12 additions & 8 deletions src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <miopen/conv/wrw_invoke_params.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
#include <miopen/solver/ck_utility_common.hpp>
#include <ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp>
#endif
#include <miopen/solver/implicitgemm_ck_util.hpp>
Expand Down Expand Up @@ -84,10 +85,16 @@ struct CKArgs
output = {G, N, K, Do, Ho, Wo};
weight = {G, K, C, Z, Y, X};

// strides from NHWGC to GNCHW laout
in_strides = {C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C};
out_strides = {K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K};
wei_strides = {K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C};
// miopen strides to CK strides
auto miopen_in_strides = problem.GetIn().GetStrides();
auto miopen_out_strides = problem.GetOut().GetStrides();
auto miopen_wei_strides = problem.GetWeights().GetStrides();
miopen_in_strides.insert(miopen_in_strides.begin(), C);
miopen_out_strides.insert(miopen_out_strides.begin(), K);
miopen_wei_strides.insert(miopen_wei_strides.begin(), K * miopen_wei_strides[0]);
std::copy(miopen_in_strides.begin(), miopen_in_strides.end(), in_strides.begin());
std::copy(miopen_out_strides.begin(), miopen_out_strides.end(), out_strides.begin());
std::copy(miopen_wei_strides.begin(), miopen_wei_strides.end(), wei_strides.begin());

strides = {ProblemInterpreter::GetAdjustedConvolutionStrideD(problem),
ProblemInterpreter::GetAdjustedConvolutionStrideH(problem),
Expand Down Expand Up @@ -309,10 +316,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(
return false;
if(!problem.IsLayoutNHWC())
return false;
const std::string& arch = ctx.GetStream().GetDeviceName();
if(miopen::StartsWith(arch, "gfx11") || miopen::StartsWith(arch, "gfx10"))
return false;
if(arch == "gfx906" || arch == "gfx900")
if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), true))
return false;
switch(problem.GetInDataType())
{
Expand Down
Loading