Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a check for packed tensors for convolution solvers #2471

Merged
merged 9 commits into from
Oct 25, 2023
11 changes: 11 additions & 0 deletions src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@ struct ProblemDescription : ProblemDescriptionBase
bool IsNCHWc_NCHWc() const;
bool IsNCHWc_CHWNc() const;

bool HasNonPackedTensors() const

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamazov : favoring the current name as it is more descriptive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[resolved]

{
return !(in.IsPacked() && weights.IsPacked() && out.IsPacked());
}

bool HasMixedDataTypes() const
{
return !(GetInDataType() == GetWeightsDataType() &&
GetWeightsDataType() == GetOutDataType());
}

amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
void HeuristicUpdateLayouts();

void BuildConfKey(std::string& conf_key) const;
Expand Down
4 changes: 4 additions & 0 deletions src/include/miopen/fusion/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes
return false;
if(!conv_problem.IsFp32())
return false;
if(conv_problem.HasNonPackedTensors())
{
return false;
}
amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
if(!conv_problem.IsLayoutDefault())
return false;
if(!conv_problem.direction.IsForward())
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_MP_bidirectional_winograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ bool ConvMPBidirectWinograd<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsA
{
// HIP backend required for sending ptr (buffer + offset)
// ROCBLAS for GEMM step
if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.IsLayoutDefault())
{
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_1x1u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_1x1u_bias_activ_fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,10 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context,
const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward);
const auto conv_ctx = context.GetConvContext(conv_problem);

if(conv_problem.HasNonPackedTensors())
{
return false;
}
amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
if(conv_problem.GetPadH() != conv_problem.GetPadW())
return false;
if(conv_problem.GetPadH() != 0)
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_1x1u_stride2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,10 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,
return false;
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_3x3u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_5x10u2v2b1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_5x10u2v2f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx
if(!ctx.rmv.IsV2orV3())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}

if(problem.IsTensorsCasted())
return false;

Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_dir_BwdWrW1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.direction.IsBackwardWrW())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_dir_BwdWrW3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.direction.IsBackwardWrW())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx
if(!problem.direction.IsBackwardData())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.Is2d())
return false;

Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,10 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}
if(!problem.IsFp32() && !problem.IsFp16())
return false;

Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,10 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}
if(!problem.IsFp32() && !problem.IsFp16() &&
!(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94"))))
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,10 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext

if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.IsFp32() && !problem.IsFp16())
return false;
Expand Down
5 changes: 5 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,11 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.IsLayoutNCHWc())
return false;

Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,10 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}
if(!problem.IsFp32() && !problem.IsFp16() &&
!(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94"))))
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,10 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
{
return false;
}
if(!problem.IsFp32() && !problem.IsFp16() &&
!(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94"))))
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx

if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.IsFp32())
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,10 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext

if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.IsFp32() && !problem.IsFp16())
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx

if(!problem.IsFp32())
return false;
if(problem.HasNonPackedTensors())
{
return false;
}

if(problem.IsTensorsCasted())
return false;
Expand Down
6 changes: 6 additions & 0 deletions src/solver/conv_bin_wino3x3U.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx,
// and able to correctly run with given parameters.
const auto device_is_gfx8 = StartsWith(name, "gfx8");
const auto grid_workgroup_count_x = ctx.GetStream().GetMaxComputeUnits();

if(problem.HasNonPackedTensors())
{
return false;
}

if(!problem.IsLayoutDefault())
{
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_bin_winoRxS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx,
if(!(problem.IsFp32() || problem.IsFp16()))
return false;

if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.IsTensorsCasted())
return false;
if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS{}))
Expand Down
13 changes: 8 additions & 5 deletions src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,15 +411,18 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx,
return false;
const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward);

if(conv_problem.IsTensorsCasted())
amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
return false;
amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
if(conv_problem.GetConv().attribute.deterministic)
return false;
if(conv_problem.GetInDataType() != conv_problem.GetWeightsDataType() ||
conv_problem.GetInDataType() != conv_problem.GetOutDataType())
return false;
if(!conv_problem.Is2d())
return false;
if(conv_problem.HasNonPackedTensors())
{
return false;
}
if(conv_problem.HasMixedDataTypes())
return false;
if(conv_problem.IsTensorsCasted())
return false;
const std::string arch = ctx.GetStream().GetDeviceName();
if(arch != "gfx908" && arch != "gfx90a" && arch != "gfx940" && arch != "gfx941" &&
arch != "gfx942")
Expand Down
8 changes: 6 additions & 2 deletions src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,19 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx,
return false;
if(!ck_utility::is_ck_supported_hardware(ctx.GetStream()))
return false;
if(!problem.IsLayoutDefault())
return false;
if(!problem.direction.IsForward())
return false;
if(!problem.Is2d())
return false;
if(!(problem.IsFp32() or problem.IsFp16()))
return false;

if(problem.HasNonPackedTensors())
{
return false;
}
if(!problem.IsLayoutDefault())
return false;
if(problem.IsTensorsCasted())
return false;
if(problem.GetGroupCount() != 1)
Expand Down
5 changes: 2 additions & 3 deletions src/solver/conv_direct_naive_conv_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,14 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx,
if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem))
return false;

if(!problem.direction.IsBackwardData())
return false;
Comment on lines +47 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing code has been moved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will revert it but the existing order of checks itself is something to discuss as it seems arbitrary and based on what the programmer wrote at the time or copy-pasted from somewhere else.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the current order of checking conditions may not be optimal, but here we need to do a little research and evaluate the expected benefit from this. It's better to do this in a separate pull request

if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC())
return false;

if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16() || problem.IsFp8() ||
problem.IsBfp8()))
return false;

if(!problem.direction.IsBackwardData())
return false;
if(problem.IsTensorsCasted())
{
auto test_cast = [&](const TensorDescriptor& desc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable(
return false;
if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{}))
return false;
if(problem.GetInDataType() != problem.GetWeightsDataType() ||
problem.GetWeightsDataType() != problem.GetOutDataType() ||
problem.GetInDataType() != problem.GetOutDataType())
if(problem.HasMixedDataTypes())
return false;
if(problem.IsTensorsCasted())
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(
return false;
if(problem.GetConv().attribute.deterministic)
return false;
if(problem.GetInDataType() != problem.GetWeightsDataType() ||
problem.GetWeightsDataType() != problem.GetOutDataType() ||
problem.GetInDataType() != problem.GetOutDataType())
if(problem.HasMixedDataTypes())
return false;
if(!problem.direction.IsForward())
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(
return false;
if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{}))
return false;
if(problem.GetInDataType() != problem.GetWeightsDataType() ||
problem.GetWeightsDataType() != problem.GetOutDataType() ||
problem.GetInDataType() != problem.GetOutDataType())
if(problem.HasMixedDataTypes())
return false;
if(!problem.direction.IsBackwardWrW())
return false;
Expand Down
8 changes: 5 additions & 3 deletions src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,11 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable(
return false;
if(problem.GetConv().attribute.deterministic)
return false;
if(problem.GetInDataType() != problem.GetWeightsDataType() ||
problem.GetWeightsDataType() != problem.GetOutDataType() ||
problem.GetInDataType() != problem.GetOutDataType())
if(problem.HasNonPackedTensors())
{
return false;
}
if(problem.HasMixedDataTypes())
return false;

if(problem.IsTensorsCasted())
Expand Down
4 changes: 4 additions & 0 deletions src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,10 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!ctx.use_hip_kernels)
return false;
if(problem.HasNonPackedTensors())
{
return false;
}
if(!problem.IsLayoutDefault())
return false;
if(!IsComposableKernelSupportedHardware(ctx))
Expand Down
Loading