diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index 5a8a918b4e..7a25ecd587 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -367,6 +367,17 @@ struct ProblemDescription : ProblemDescriptionBase bool IsNCHWc_NCHWc() const; bool IsNCHWc_CHWNc() const; + bool HasNonPackedTensors() const + { + return !(in.IsPacked() && weights.IsPacked() && out.IsPacked()); + } + + bool HasMixedDataTypes() const + { + return !(GetInDataType() == GetWeightsDataType() && + GetWeightsDataType() == GetOutDataType()); + } + void HeuristicUpdateLayouts(); void MakeNetworkConfig(std::string& conf_key) const; diff --git a/src/include/miopen/fusion/utils.hpp b/src/include/miopen/fusion/utils.hpp index d33d822429..5669de990f 100644 --- a/src/include/miopen/fusion/utils.hpp +++ b/src/include/miopen/fusion/utils.hpp @@ -84,6 +84,8 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes return false; if(!conv_problem.IsFp32()) return false; + if(conv_problem.HasNonPackedTensors()) + return false; if(!conv_problem.IsLayoutDefault()) return false; if(!conv_problem.direction.IsForward()) diff --git a/src/solver/conv_MP_bidirectional_winograd.cpp b/src/solver/conv_MP_bidirectional_winograd.cpp index c364c5ae00..1ecc210144 100644 --- a/src/solver/conv_MP_bidirectional_winograd.cpp +++ b/src/solver/conv_MP_bidirectional_winograd.cpp @@ -229,9 +229,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc } if(!problem.IsLayoutDefault()) - { return false; - } { unsigned int const waves_in_group = 512 / wave_size; @@ -323,11 +321,11 @@ bool ConvMPBidirectWinograd::IsA { // HIP backend required for sending ptr (buffer + offset) // ROCBLAS for GEMM step + if(problem.HasNonPackedTensors()) + return false; if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_asm_1x1u.cpp b/src/solver/conv_asm_1x1u.cpp index 0664c32eb8..be72cf1fe3 100644 --- a/src/solver/conv_asm_1x1u.cpp +++ b/src/solver/conv_asm_1x1u.cpp @@ -527,6 +527,8 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) @@ -545,13 +547,9 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip const std::string name = ctx.GetStream().GetDeviceName(); if(name.find("gfx9") == std::string::npos) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp index c935a2aff6..16063a04b5 100644 --- a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp +++ b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp @@ -256,9 +256,6 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context, if(conv_problem.GetDilationH() != 1) return false; - if(conv_problem.IsTensorsCasted()) - return false; - // Check if the conovlution part is applicable return sol.IsApplicable(conv_ctx, conv_problem); } diff --git a/src/solver/conv_asm_1x1u_stride2.cpp b/src/solver/conv_asm_1x1u_stride2.cpp index b9925ee30c..3a102d4701 100644 --- a/src/solver/conv_asm_1x1u_stride2.cpp +++ b/src/solver/conv_asm_1x1u_stride2.cpp @@ -489,6 +489,8 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -505,13 +507,9 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, const std::string name = ctx.GetStream().GetDeviceName(); if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_3x3u.cpp b/src/solver/conv_asm_3x3u.cpp index 18f07b9630..c486b4d867 100644 --- a/src/solver/conv_asm_3x3u.cpp +++ b/src/solver/conv_asm_3x3u.cpp @@ -176,6 +176,8 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) @@ -194,9 +196,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx90"))) return false; if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_5x10u2v2b1.cpp b/src/solver/conv_asm_5x10u2v2b1.cpp index 6da4863f6d..98b604ef3e 100644 --- a/src/solver/conv_asm_5x10u2v2b1.cpp +++ b/src/solver/conv_asm_5x10u2v2b1.cpp @@ -45,6 +45,8 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -63,17 +65,11 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, return false; #endif if(!device_is_gfx8_9_no_xnack) - { return false; - } if(!problem.direction.IsBackwardData()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_5x10u2v2f1.cpp b/src/solver/conv_asm_5x10u2v2f1.cpp index 74301fe5fd..8d3c1a1716 100644 --- a/src/solver/conv_asm_5x10u2v2f1.cpp +++ b/src/solver/conv_asm_5x10u2v2f1.cpp @@ -46,6 +46,8 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -64,18 +66,11 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, return false; #endif if(!device_is_gfx8_9_no_xnack) - { return false; - } if(!problem.direction.IsForward()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp index 4426a3eeca..309b178dfb 100644 --- a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp +++ b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp @@ -51,6 +51,9 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx if(!ctx.rmv.IsV2orV3()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -65,17 +68,11 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx #endif if(!(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" || name == "gfx900" || name == "gfx904" || name == "gfx906" || name == "gfx908")) - { return false; - } if(!problem.direction.IsForward()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off return problem.GetPadW() == 3 // -q diff --git a/src/solver/conv_asm_dir_BwdWrW1x1.cpp b/src/solver/conv_asm_dir_BwdWrW1x1.cpp index 0abe71326f..04e0f89b4e 100644 --- a/src/solver/conv_asm_dir_BwdWrW1x1.cpp +++ b/src/solver/conv_asm_dir_BwdWrW1x1.cpp @@ -479,11 +479,12 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) return false; - if(problem.IsTensorsCasted()) return false; @@ -493,13 +494,9 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, const std::string name = ctx.GetStream().GetDeviceName(); if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; diff --git a/src/solver/conv_asm_dir_BwdWrW3x3.cpp b/src/solver/conv_asm_dir_BwdWrW3x3.cpp index ae58cfcd9b..6ae7330c63 100644 --- a/src/solver/conv_asm_dir_BwdWrW3x3.cpp +++ b/src/solver/conv_asm_dir_BwdWrW3x3.cpp @@ -362,6 +362,8 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -375,10 +377,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx9"))) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; #if WORKAROUND_ISSUE_532 diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp index 9e41d56c82..6ff60242dc 100644 --- a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -143,6 +143,9 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx if(!problem.direction.IsBackwardData()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.Is2d()) return false; @@ -159,9 +162,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp index ee6b16d38b..d0551138de 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp @@ -992,6 +992,9 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32() && !problem.IsFp16()) return false; @@ -1006,9 +1009,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index 9cfdd8aeea..42f1e9f03e 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -938,6 +938,9 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp index 32b50167cf..5bc9da9b81 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp @@ -1518,6 +1518,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32() && !problem.IsFp16()) return false; @@ -1531,9 +1534,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext return false; if(!problem.IsLayoutDefault()) - { return false; - } #if WORKAROUND_SWDEV_306318 if((problem.GetWeightsHeight_() == 1) && (problem.GetWeightsWidth_() == 1) && diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index b16258235e..d7b692cfb3 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -560,6 +560,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsLayoutNCHWc()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 4ab9ce1c37..2e38366b74 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -876,6 +876,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index 8ac238395a..3db8021b26 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -866,6 +866,9 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index 8e1450c7a3..baceb8089f 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -292,6 +292,9 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32()) return false; @@ -305,9 +308,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index a5d056178d..aa84c8c76e 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -833,6 +833,9 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsFp32() && !problem.IsFp16()) return false; @@ -846,9 +849,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp index fb5f0caf7c..7adb28fdae 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp @@ -318,6 +318,9 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx if(!problem.IsFp32()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -328,9 +331,7 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_bin_wino3x3U.cpp b/src/solver/conv_bin_wino3x3U.cpp index c8508cf372..0fc9263413 100644 --- a/src/solver/conv_bin_wino3x3U.cpp +++ b/src/solver/conv_bin_wino3x3U.cpp @@ -64,10 +64,12 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, // and able to correctly run with given parameters. const auto device_is_gfx8 = StartsWith(name, "gfx8"); const auto grid_workgroup_count_x = ctx.GetStream().GetMaxComputeUnits(); + + if(problem.HasNonPackedTensors()) + return false; + if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_bin_winoRxS.cpp b/src/solver/conv_bin_winoRxS.cpp index eb4d7386f1..f1fd0868c3 100644 --- a/src/solver/conv_bin_winoRxS.cpp +++ b/src/solver/conv_bin_winoRxS.cpp @@ -186,9 +186,7 @@ static inline bool IsShaderContraintsMet(const miopen::ExecutionContext& ctx, } const auto grid_workgroup_count_x = ctx.GetStream().GetMaxComputeUnits(); if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off // Check implementation limits. @@ -222,7 +220,8 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.IsFp32() || problem.IsFp16())) return false; - + if(problem.HasNonPackedTensors()) + return false; if(problem.IsTensorsCasted()) return false; if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS{})) diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index d9c0410d98..75f5cd1ac3 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -413,8 +413,9 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, return false; if(conv_problem.GetConv().attribute.deterministic) return false; - if(conv_problem.GetInDataType() != conv_problem.GetWeightsDataType() || - conv_problem.GetInDataType() != conv_problem.GetOutDataType()) + if(conv_problem.HasNonPackedTensors()) + return false; + if(conv_problem.HasMixedDataTypes()) return false; if(!conv_problem.Is2d()) return false; diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index 001f3a8cb7..9c42d8b8db 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -105,7 +105,8 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.IsFp32() or problem.IsFp16())) return false; - + if(problem.HasNonPackedTensors()) + return false; if(problem.IsTensorsCasted()) return false; if(problem.GetGroupCount() != 1) diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index 1e8f006ef0..77406744b7 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -44,15 +44,14 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem)) return false; + if(!problem.direction.IsBackwardData()) + return false; if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16() || problem.IsFp8() || problem.IsBfp8())) return false; - - if(!problem.direction.IsBackwardData()) - return false; if(problem.IsTensorsCasted()) { auto test_cast = [&](const TensorDescriptor& desc) { diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index b530f0ff37..82f3411cb8 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -308,9 +308,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{})) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index efd4416113..b2a09e26d5 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -308,9 +308,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsForward()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index 29ed7266a7..d395e576f0 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -304,9 +304,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{})) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsBackwardWrW()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 025e1d2b5a..85bec04104 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -260,11 +260,10 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasNonPackedTensors()) + return false; + if(problem.HasMixedDataTypes()) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.direction.IsBackwardData()) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp index b2b591b859..5b3cb4933f 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp @@ -638,6 +638,8 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; + if(problem.HasNonPackedTensors()) + return false; if(!problem.IsLayoutDefault()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp index f657fa74fe..6db26950d5 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp @@ -784,6 +784,9 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -794,9 +797,8 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.IsLayoutDefault()) - { return false; - } + // gemm size int gemm_g = -1; int gemm_m = -1; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp index e60d6c76a3..a0cca73af9 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp @@ -754,6 +754,9 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, if(!problem.IsFp32()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -761,9 +764,7 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsLayoutDefault()) - { return false; - } if(!IsIndexRangeLargeEnough(problem)) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index 3b3dc8b4d3..2c4f36e820 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -838,9 +838,10 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!IsApplicableXdlops(ctx, problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp index 39e8c71c16..88d0f7b314 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp @@ -57,6 +57,8 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!problem.IsFp32() && !problem.IsFp16() && !problem.IsBfp16()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index 9cbe662180..9babd8fff8 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -585,6 +585,8 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsLayoutDefault()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; if(!problem.direction.IsForward()) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 9c09efe397..8a7b4b150b 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -990,6 +990,9 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -1006,9 +1009,8 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.IsLayoutDefault()) - { return false; - } + // gemm size { int gemm_g = -1; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp index d25ca1b68b..ed868b3d04 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp @@ -1065,6 +1065,9 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; @@ -1072,9 +1075,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( return false; if(!problem.IsLayoutDefault()) - { return false; - } // gemm size { int gemm_g = -1; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp index 4915c48e2e..6006cb0caa 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp @@ -1020,6 +1020,9 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -1043,9 +1046,8 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.IsLayoutDefault()) - { return false; - } + // gemm size { int gemm_g = -1; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 5e395b5657..1b1127d6e1 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -261,9 +261,9 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasNonPackedTensors()) + return false; + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsForward()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index 4288b4d287..8e7898ea1a 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -286,13 +286,13 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable( #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL if(miopen::IsDisabled(MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS{})) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsTensorsCasted()) return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsForward()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index 637486ef50..c9b0d94241 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -586,6 +586,8 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; + if(problem.HasNonPackedTensors()) + return false; if(!problem.IsLayoutDefault()) return false; if(!IsComposableKernelSupportedHardware(ctx)) @@ -596,7 +598,6 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsFp32()) return false; - if(problem.IsTensorsCasted()) return false; if(problem.GetGroupCount() != 1) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp index 5a42ba3255..0f38df9e6c 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp @@ -1057,6 +1057,9 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, if(!IsXdlopsSupport(ctx)) return false; + if(problem.HasNonPackedTensors()) + return false; + if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; @@ -1076,9 +1079,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsLayoutDefault()) - { return false; - } // this particular HeuristicInit is so comprehensive, that if it cannot predict a valid // performance config, the problem is probably not applicable diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index abd178dcca..b910d9658a 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -1132,6 +1132,9 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(problem.IsTensorsCasted()) return false; @@ -1142,9 +1145,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( return false; if(!problem.IsLayoutDefault()) - { return false; - } // this particular HeuristicInit is so comprehensive, that if it cannot predict a valid #if WORKAROUND_MI100_BF16_FATAL_COMPILER_ERRORS diff --git a/src/solver/conv_mlir_igemm_bwd.cpp b/src/solver/conv_mlir_igemm_bwd.cpp index 783c68350c..6804cccf42 100644 --- a/src/solver/conv_mlir_igemm_bwd.cpp +++ b/src/solver/conv_mlir_igemm_bwd.cpp @@ -47,6 +47,8 @@ bool ConvMlirIgemmBwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardData()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp index 41062cc32c..439b1ab394 100644 --- a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp @@ -50,6 +50,8 @@ bool ConvMlirIgemmBwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardData()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_mlir_igemm_fwd.cpp b/src/solver/conv_mlir_igemm_fwd.cpp index 2cc196ae10..034313e1d0 100644 --- a/src/solver/conv_mlir_igemm_fwd.cpp +++ b/src/solver/conv_mlir_igemm_fwd.cpp @@ -167,6 +167,8 @@ bool ConvMlirIgemmFwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsForward()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index c761abc137..fa17520f07 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -64,6 +64,8 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsForward()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_wrw.cpp b/src/solver/conv_mlir_igemm_wrw.cpp index cb9f6ae7b2..7b8286d896 100644 --- a/src/solver/conv_mlir_igemm_wrw.cpp +++ b/src/solver/conv_mlir_igemm_wrw.cpp @@ -50,6 +50,8 @@ bool ConvMlirIgemmWrW::IsApplicable(const ExecutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; // Note: ConvMlirIgemmWrW can run on a machine with xdlops support, however, it is diff --git a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp index fe11c828c8..5dd480da63 100644 --- a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp @@ -51,6 +51,8 @@ bool ConvMlirIgemmWrWXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_multipass_wino3x3WrW.cpp b/src/solver/conv_multipass_wino3x3WrW.cpp index 233489c4fc..4f7a1626cc 100644 --- a/src/solver/conv_multipass_wino3x3WrW.cpp +++ b/src/solver/conv_multipass_wino3x3WrW.cpp @@ -436,15 +436,14 @@ bool ConvWinograd3x3MultipassWrW return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) @@ -493,9 +492,7 @@ bool ConvWinograd3x3MultipassWrW return false; } if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off { diff --git a/src/solver/conv_ocl_dir2D11x11.cpp b/src/solver/conv_ocl_dir2D11x11.cpp index b76621a591..482892946e 100644 --- a/src/solver/conv_ocl_dir2D11x11.cpp +++ b/src/solver/conv_ocl_dir2D11x11.cpp @@ -47,17 +47,16 @@ bool ConvOclDirectFwd11x11::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } return problem.direction.IsForward() && problem.GetGroupCount() == 1 && problem.GetDilationH() == 1 && problem.GetDilationW() == 1 && diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp index 4e0cda8629..2c13d29f11 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp @@ -58,15 +58,14 @@ bool ConvOclBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp index 2b400909f8..c59449eab8 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp @@ -464,10 +464,7 @@ bool ConvOclBwdWrW2::IsApplicableBase(const ExecutionContext& ctx if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp index 4f00c8f55b..7f50e7f3ee 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp @@ -51,6 +51,8 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) @@ -61,9 +63,7 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardWrW()) return false; if(!problem.IsLayoutDefault()) - { return false; - } bool workaround = false; diff --git a/src/solver/conv_ocl_dir2Dfwd.cpp b/src/solver/conv_ocl_dir2Dfwd.cpp index c7bd8c00df..6b7d2f1f8e 100644 --- a/src/solver/conv_ocl_dir2Dfwd.cpp +++ b/src/solver/conv_ocl_dir2Dfwd.cpp @@ -48,17 +48,16 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off // Cases when dy has negative padding are not supported (issue 918) diff --git a/src/solver/conv_ocl_dir2Dfwd1x1.cpp b/src/solver/conv_ocl_dir2Dfwd1x1.cpp index b21effc0b3..b3e0bd0a2b 100644 --- a/src/solver/conv_ocl_dir2Dfwd1x1.cpp +++ b/src/solver/conv_ocl_dir2Dfwd1x1.cpp @@ -57,17 +57,16 @@ bool ConvOclDirectFwd1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } return problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && diff --git a/src/solver/conv_ocl_dir2Dfwdgen.cpp b/src/solver/conv_ocl_dir2Dfwdgen.cpp index f35e57b71c..659f0ddf73 100644 --- a/src/solver/conv_ocl_dir2Dfwdgen.cpp +++ b/src/solver/conv_ocl_dir2Dfwdgen.cpp @@ -45,18 +45,16 @@ bool ConvOclDirectFwdGen::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.GetGroupCount() > 1) return false; diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index d9cbeb713f..6c7d94632f 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -285,9 +285,7 @@ inline bool IsShaderConstraintsMet(const ProblemDescription& problem, } if(!problem.IsLayoutDefault()) - { return false; - } return IsWinogradV21Preferred(asic, problem) ? IsShaderConstraintsMetV21(problem, R, S, C, K, H, W, OH, OW, N) @@ -622,9 +620,10 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti { if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!(problem.IsFp32() || problem.IsFp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!ctx.use_asm_kernels) diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 89f870e35e..150deac983 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -171,6 +171,9 @@ bool ConvWinoFuryRxS::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(is2x3() && miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3{})) return false; diff --git a/src/solver/fft.cpp b/src/solver/fft.cpp index 9a3c7858cc..3ca98b4720 100644 --- a/src/solver/fft.cpp +++ b/src/solver/fft.cpp @@ -116,9 +116,7 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto is_fwd = problem.direction.IsForward(); decltype(auto) conv = problem.GetConv();