From dfee26c07ec3ce104c0315735b5c8da8ce0e695f Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Fri, 20 Oct 2023 16:44:00 +0000 Subject: [PATCH 1/7] Add a check for packed tensors for convolution solvers --- src/include/miopen/conv/problem_description.hpp | 11 +++++++++++ src/include/miopen/fusion/utils.hpp | 4 ++++ src/solver/conv_MP_bidirectional_winograd.cpp | 4 ++++ src/solver/conv_asm_1x1u.cpp | 4 ++++ src/solver/conv_asm_1x1u_bias_activ_fused.cpp | 4 ++++ src/solver/conv_asm_1x1u_stride2.cpp | 4 ++++ src/solver/conv_asm_3x3u.cpp | 4 ++++ src/solver/conv_asm_5x10u2v2b1.cpp | 4 ++++ src/solver/conv_asm_5x10u2v2f1.cpp | 4 ++++ src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp | 5 +++++ src/solver/conv_asm_dir_BwdWrW1x1.cpp | 4 ++++ src/solver/conv_asm_dir_BwdWrW3x3.cpp | 4 ++++ .../conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp | 5 +++++ src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp | 4 ++++ src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp | 4 ++++ src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp | 4 ++++ src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp | 5 +++++ src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp | 4 ++++ src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp | 4 ++++ src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp | 4 ++++ ...onv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp | 4 ++++ .../conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp | 4 ++++ src/solver/conv_bin_wino3x3U.cpp | 6 ++++++ src/solver/conv_bin_winoRxS.cpp | 4 ++++ src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp | 13 ++++++++----- src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp | 8 ++++++-- src/solver/conv_direct_naive_conv_bwd.cpp | 9 ++++++--- src/solver/conv_direct_naive_conv_fwd.cpp | 4 ++++ ...conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp | 4 ++++ ...conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp | 4 ++++ ...conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp | 4 ++++ .../conv_hip_implicit_gemm_bwd_data_xdlops.cpp | 8 +++++--- src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp | 4 ++++ .../conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp | 5 +++++ src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp | 4 ++++ .../conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp | 4 ++++ src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp | 4 ++++ src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp | 8 ++++++-- .../conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp | 12 ++++++++---- ...ip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp | 4 ++++ .../conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp | 4 ++++ src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp | 4 ++++ .../conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp | 8 +++++--- src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp | 12 ++++++++---- .../conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp | 5 +++++ ...ip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp | 10 +++++++--- src/solver/conv_mlir_igemm_bwd.cpp | 4 ++++ src/solver/conv_mlir_igemm_bwd_xdlops.cpp | 4 ++++ src/solver/conv_mlir_igemm_fwd.cpp | 4 ++++ src/solver/conv_mlir_igemm_fwd_xdlops.cpp | 4 ++++ src/solver/conv_mlir_igemm_wrw.cpp | 4 ++++ src/solver/conv_mlir_igemm_wrw_xdlops.cpp | 4 ++++ src/solver/conv_multipass_wino3x3WrW.cpp | 4 ++++ src/solver/conv_ocl_dir2D11x11.cpp | 4 ++++ src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp | 4 ++++ src/solver/conv_ocl_dir2D_bwdWrW_53.cpp | 4 ++++ src/solver/conv_ocl_dir2Dfwd.cpp | 4 ++++ src/solver/conv_ocl_dir2Dfwd1x1.cpp | 4 ++++ src/solver/conv_ocl_dir2Dfwdgen.cpp | 4 ++++ src/solver/conv_winoRxS.cpp | 4 ++++ src/solver/conv_wino_fury_RxS.cpp | 4 ++++ 61 files changed, 281 insertions(+), 29 deletions(-) diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index d6b735291e..929ac49c36 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -367,6 +367,17 @@ struct ProblemDescription : ProblemDescriptionBase bool IsNCHWc_NCHWc() const; bool IsNCHWc_CHWNc() const; + bool HasNonPackedTensors() const + { + return !(in.IsPacked() && weights.IsPacked() && out.IsPacked()); + } + + bool HasMixedDataTypes() const + { + return !(GetInDataType() == GetWeightsDataType() && + GetWeightsDataType() == GetOutDataType()); + } + void HeuristicUpdateLayouts(); void BuildConfKey(std::string& conf_key) const; diff --git a/src/include/miopen/fusion/utils.hpp b/src/include/miopen/fusion/utils.hpp index d33d822429..720dcce5c1 100644 --- a/src/include/miopen/fusion/utils.hpp +++ b/src/include/miopen/fusion/utils.hpp @@ -84,6 +84,10 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes return false; if(!conv_problem.IsFp32()) return false; + if(conv_problem.HasNonPackedTensors()) + { + return false; + } if(!conv_problem.IsLayoutDefault()) return false; if(!conv_problem.direction.IsForward()) diff --git a/src/solver/conv_MP_bidirectional_winograd.cpp b/src/solver/conv_MP_bidirectional_winograd.cpp index a653157f58..d0a82ec805 100644 --- a/src/solver/conv_MP_bidirectional_winograd.cpp +++ b/src/solver/conv_MP_bidirectional_winograd.cpp @@ -323,6 +323,10 @@ bool ConvMPBidirectWinograd::IsA { // HIP backend required for sending ptr (buffer + offset) // ROCBLAS for GEMM step + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsLayoutDefault()) { diff --git a/src/solver/conv_asm_1x1u.cpp b/src/solver/conv_asm_1x1u.cpp index 0664c32eb8..466e3007cf 100644 --- a/src/solver/conv_asm_1x1u.cpp +++ b/src/solver/conv_asm_1x1u.cpp @@ -527,6 +527,10 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) diff --git a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp index c935a2aff6..9da0e8780c 100644 --- a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp +++ b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp @@ -243,6 +243,10 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context, const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); + if(conv_problem.HasNonPackedTensors()) + { + return false; + } if(conv_problem.GetPadH() != conv_problem.GetPadW()) return false; if(conv_problem.GetPadH() != 0) diff --git a/src/solver/conv_asm_1x1u_stride2.cpp b/src/solver/conv_asm_1x1u_stride2.cpp index b9925ee30c..604aac95fb 100644 --- a/src/solver/conv_asm_1x1u_stride2.cpp +++ b/src/solver/conv_asm_1x1u_stride2.cpp @@ -489,6 +489,10 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) diff --git a/src/solver/conv_asm_3x3u.cpp b/src/solver/conv_asm_3x3u.cpp index 18f07b9630..5c3d98d257 100644 --- a/src/solver/conv_asm_3x3u.cpp +++ b/src/solver/conv_asm_3x3u.cpp @@ -176,6 +176,10 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) diff --git a/src/solver/conv_asm_5x10u2v2b1.cpp b/src/solver/conv_asm_5x10u2v2b1.cpp index 6da4863f6d..b271ffc755 100644 --- a/src/solver/conv_asm_5x10u2v2b1.cpp +++ b/src/solver/conv_asm_5x10u2v2b1.cpp @@ -45,6 +45,10 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) diff --git a/src/solver/conv_asm_5x10u2v2f1.cpp b/src/solver/conv_asm_5x10u2v2f1.cpp index 74301fe5fd..c1ba614f20 100644 --- a/src/solver/conv_asm_5x10u2v2f1.cpp +++ b/src/solver/conv_asm_5x10u2v2f1.cpp @@ -46,6 +46,10 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) diff --git a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp index 4426a3eeca..e28b00aca1 100644 --- a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp +++ b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp @@ -51,6 +51,11 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx if(!ctx.rmv.IsV2orV3()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_asm_dir_BwdWrW1x1.cpp b/src/solver/conv_asm_dir_BwdWrW1x1.cpp index 0abe71326f..8a7ec8cf28 100644 --- a/src/solver/conv_asm_dir_BwdWrW1x1.cpp +++ b/src/solver/conv_asm_dir_BwdWrW1x1.cpp @@ -479,6 +479,10 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) diff --git a/src/solver/conv_asm_dir_BwdWrW3x3.cpp b/src/solver/conv_asm_dir_BwdWrW3x3.cpp index ae58cfcd9b..80b7e8cec1 100644 --- a/src/solver/conv_asm_dir_BwdWrW3x3.cpp +++ b/src/solver/conv_asm_dir_BwdWrW3x3.cpp @@ -362,6 +362,10 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp index 9e41d56c82..ddda5320f5 100644 --- a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -143,6 +143,11 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx if(!problem.direction.IsBackwardData()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!problem.Is2d()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp index ee6b16d38b..4eef538ddd 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp @@ -992,6 +992,10 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index 9cfdd8aeea..fbc858e405 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -938,6 +938,10 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp index 32b50167cf..ae8814004b 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp @@ -1517,6 +1517,10 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index b16258235e..bb70d40c20 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -560,6 +560,11 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!problem.IsLayoutNCHWc()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 4ab9ce1c37..76f17997df 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -876,6 +876,10 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index 8ac238395a..f0f5179186 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -866,6 +866,10 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index 8e1450c7a3..04de997cb8 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -291,6 +291,10 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index a5d056178d..386ffd6e60 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -832,6 +832,10 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp index fb5f0caf7c..1752c3c9e3 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp @@ -317,6 +317,10 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx if(!problem.IsFp32()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_bin_wino3x3U.cpp b/src/solver/conv_bin_wino3x3U.cpp index c8508cf372..b29bf1e915 100644 --- a/src/solver/conv_bin_wino3x3U.cpp +++ b/src/solver/conv_bin_wino3x3U.cpp @@ -64,6 +64,12 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, // and able to correctly run with given parameters. const auto device_is_gfx8 = StartsWith(name, "gfx8"); const auto grid_workgroup_count_x = ctx.GetStream().GetMaxComputeUnits(); + + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!problem.IsLayoutDefault()) { return false; diff --git a/src/solver/conv_bin_winoRxS.cpp b/src/solver/conv_bin_winoRxS.cpp index eb4d7386f1..2d73354b4e 100644 --- a/src/solver/conv_bin_winoRxS.cpp +++ b/src/solver/conv_bin_winoRxS.cpp @@ -223,6 +223,10 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, if(!(problem.IsFp32() || problem.IsFp16())) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted()) return false; if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS{})) diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index e2df6f8097..f60d8f168b 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -411,15 +411,18 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, return false; const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); - if(conv_problem.IsTensorsCasted()) - return false; if(conv_problem.GetConv().attribute.deterministic) return false; - if(conv_problem.GetInDataType() != conv_problem.GetWeightsDataType() || - conv_problem.GetInDataType() != conv_problem.GetOutDataType()) - return false; if(!conv_problem.Is2d()) return false; + if(conv_problem.HasNonPackedTensors()) + { + return false; + } + if(conv_problem.HasMixedDataTypes()) + return false; + if(conv_problem.IsTensorsCasted()) + return false; const std::string arch = ctx.GetStream().GetDeviceName(); if(arch != "gfx908" && arch != "gfx90a" && arch != "gfx940" && arch != "gfx941" && arch != "gfx942") diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index 001f3a8cb7..726ae1160b 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -97,8 +97,6 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(!ck_utility::is_ck_supported_hardware(ctx.GetStream())) return false; - if(!problem.IsLayoutDefault()) - return false; if(!problem.direction.IsForward()) return false; if(!problem.Is2d()) @@ -106,6 +104,12 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, if(!(problem.IsFp32() or problem.IsFp16())) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!problem.IsLayoutDefault()) + return false; if(problem.IsTensorsCasted()) return false; if(problem.GetGroupCount() != 1) diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index 1e8f006ef0..b433370961 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -44,15 +44,18 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, if(!ConvDirectNaiveConvIsApplicableByKernelType(ctx, problem)) return false; + if(!problem.direction.IsBackwardData()) + return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16() || problem.IsFp8() || problem.IsBfp8())) return false; - - if(!problem.direction.IsBackwardData()) - return false; if(problem.IsTensorsCasted()) { auto test_cast = [&](const TensorDescriptor& desc) { diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index f1ed2f5b10..00e569114b 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -53,6 +53,10 @@ bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsForward()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted()) { auto test_cast = [&](const TensorDescriptor& desc) { diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index 0b04365fca..96c977c5fd 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -320,6 +320,10 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index 7f012c4c10..0da32e4568 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -318,6 +318,10 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index 076bfaaa07..bcab0b289d 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -314,6 +314,10 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index bf7440a1e6..5b5f802ec7 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -262,9 +262,11 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasNonPackedTensors()) + { + return false; + } + if(problem.HasMixedDataTypes()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp index b2b591b859..121c05a8cc 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp @@ -638,6 +638,10 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsLayoutDefault()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp index f657fa74fe..79456d4be2 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp @@ -784,6 +784,11 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp index e60d6c76a3..a64dcf18c5 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp @@ -754,6 +754,10 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, if(!problem.IsFp32()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index 3b3dc8b4d3..fade77c696 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -838,6 +838,10 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp index 39e8c71c16..b21886d916 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp @@ -57,6 +57,10 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!problem.IsFp32() && !problem.IsFp16() && !problem.IsBfp16()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index 9cbe662180..d819b7d20d 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -583,8 +583,6 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; - if(!problem.IsLayoutDefault()) - return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; if(!problem.direction.IsForward()) @@ -595,6 +593,12 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.GetGroupCount() != 1) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!problem.IsLayoutDefault()) + return false; if(!IsIndexRangeLargeEnough(problem)) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 9c09efe397..558f4b33f1 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -987,16 +987,20 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& if(!IsXdlopsSupport(ctx)) return false; - if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) + if(!problem.direction.IsForward()) return false; - if(problem.IsTensorsCasted()) + if(!problem.Is2d()) + return false; + if(problem.HasNonPackedTensors()) + { return false; + } - if(!problem.direction.IsForward()) + if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(!problem.Is2d()) + if(problem.IsTensorsCasted()) return false; if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp index d25ca1b68b..d46c4c1375 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp @@ -1065,6 +1065,10 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp index 4915c48e2e..b55e813f6d 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp @@ -1020,6 +1020,10 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 95752d919d..4941229914 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -271,6 +271,10 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!IsXdlopsSupport(ctx)) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index 8135b570ea..8b65c16167 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -288,13 +288,15 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable( #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL if(miopen::IsDisabled(MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS{})) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted()) return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsForward()) return false; diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index 637486ef50..18148411e0 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -586,17 +586,21 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; - if(!problem.IsLayoutDefault()) - return false; - if(!IsComposableKernelSupportedHardware(ctx)) - return false; if(!problem.direction.IsBackwardWrW()) return false; if(!problem.Is2d() && !problem.Is3d()) return false; if(!problem.IsFp32()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!problem.IsLayoutDefault()) + return false; + if(!IsComposableKernelSupportedHardware(ctx)) + return false; if(problem.IsTensorsCasted()) return false; if(problem.GetGroupCount() != 1) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp index 5a42ba3255..f2380f469f 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp @@ -1057,6 +1057,11 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, if(!IsXdlopsSupport(ctx)) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index abd178dcca..6d5da3657e 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -1123,14 +1123,18 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!ctx.use_hip_kernels) return false; - if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) - return false; - if(!problem.direction.IsBackwardWrW()) return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } + + if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) + return false; if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_mlir_igemm_bwd.cpp b/src/solver/conv_mlir_igemm_bwd.cpp index 783c68350c..4594661811 100644 --- a/src/solver/conv_mlir_igemm_bwd.cpp +++ b/src/solver/conv_mlir_igemm_bwd.cpp @@ -47,6 +47,10 @@ bool ConvMlirIgemmBwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardData()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp index 41062cc32c..aa9392c4c6 100644 --- a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp @@ -50,6 +50,10 @@ bool ConvMlirIgemmBwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardData()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_mlir_igemm_fwd.cpp b/src/solver/conv_mlir_igemm_fwd.cpp index 2cc196ae10..ddf5908b70 100644 --- a/src/solver/conv_mlir_igemm_fwd.cpp +++ b/src/solver/conv_mlir_igemm_fwd.cpp @@ -167,6 +167,10 @@ bool ConvMlirIgemmFwd::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsForward()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index c761abc137..b98b4c9f38 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -64,6 +64,10 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsForward()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_wrw.cpp b/src/solver/conv_mlir_igemm_wrw.cpp index cb9f6ae7b2..aebf8d7665 100644 --- a/src/solver/conv_mlir_igemm_wrw.cpp +++ b/src/solver/conv_mlir_igemm_wrw.cpp @@ -50,6 +50,10 @@ bool ConvMlirIgemmWrW::IsApplicable(const ExecutionContext& ctx, return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; // Note: ConvMlirIgemmWrW can run on a machine with xdlops support, however, it is diff --git a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp index fe11c828c8..40b16835ab 100644 --- a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp @@ -51,6 +51,10 @@ bool ConvMlirIgemmWrWXdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_multipass_wino3x3WrW.cpp b/src/solver/conv_multipass_wino3x3WrW.cpp index 233489c4fc..5829afdc80 100644 --- a/src/solver/conv_multipass_wino3x3WrW.cpp +++ b/src/solver/conv_multipass_wino3x3WrW.cpp @@ -436,6 +436,10 @@ bool ConvWinograd3x3MultipassWrW return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; diff --git a/src/solver/conv_ocl_dir2D11x11.cpp b/src/solver/conv_ocl_dir2D11x11.cpp index b76621a591..33fe4f8767 100644 --- a/src/solver/conv_ocl_dir2D11x11.cpp +++ b/src/solver/conv_ocl_dir2D11x11.cpp @@ -47,6 +47,10 @@ bool ConvOclDirectFwd11x11::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp index 4e0cda8629..360a9730f2 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp @@ -58,6 +58,10 @@ bool ConvOclBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.direction.IsBackwardWrW()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp index 4f00c8f55b..3dddf7136f 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp @@ -51,6 +51,10 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_ocl_dir2Dfwd.cpp b/src/solver/conv_ocl_dir2Dfwd.cpp index c7bd8c00df..86c0e4e0ba 100644 --- a/src/solver/conv_ocl_dir2Dfwd.cpp +++ b/src/solver/conv_ocl_dir2Dfwd.cpp @@ -48,6 +48,10 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_ocl_dir2Dfwd1x1.cpp b/src/solver/conv_ocl_dir2Dfwd1x1.cpp index b21effc0b3..fad70bd428 100644 --- a/src/solver/conv_ocl_dir2Dfwd1x1.cpp +++ b/src/solver/conv_ocl_dir2Dfwd1x1.cpp @@ -57,6 +57,10 @@ bool ConvOclDirectFwd1x1::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_ocl_dir2Dfwdgen.cpp b/src/solver/conv_ocl_dir2Dfwdgen.cpp index f35e57b71c..cb296c4548 100644 --- a/src/solver/conv_ocl_dir2Dfwdgen.cpp +++ b/src/solver/conv_ocl_dir2Dfwdgen.cpp @@ -45,6 +45,10 @@ bool ConvOclDirectFwdGen::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index d9cbeb713f..b8c07d8273 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -622,6 +622,10 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti { if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(!(problem.IsFp32() || problem.IsFp16())) return false; diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 89f870e35e..85a0c32fa1 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -171,6 +171,10 @@ bool ConvWinoFuryRxS::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + { + return false; + } if(is2x3() && miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3{})) return false; From c3166b9d2323aca875e93f54a5421268503d3b87 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Mon, 23 Oct 2023 15:58:43 +0000 Subject: [PATCH 2/7] remove check from naive and ck solvers that support strides --- src/solver/conv_direct_naive_conv_bwd.cpp | 4 ---- src/solver/conv_direct_naive_conv_fwd.cpp | 4 ---- .../conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp | 8 +------- .../conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp | 8 +------- .../conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp | 8 +------- src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp | 6 ++---- 6 files changed, 5 insertions(+), 33 deletions(-) diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index b433370961..77406744b7 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -46,10 +46,6 @@ bool ConvDirectNaiveConvBwd::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardData()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(!problem.IsLayoutDefault() && !problem.IsLayoutNHWC()) return false; diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index 00e569114b..f1ed2f5b10 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -53,10 +53,6 @@ bool ConvDirectNaiveConvFwd::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsForward()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(problem.IsTensorsCasted()) { auto test_cast = [&](const TensorDescriptor& desc) { diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp index 96c977c5fd..1021930e88 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_bwd_xdlops.cpp @@ -310,9 +310,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{})) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(problem.IsTensorsCasted()) return false; @@ -320,10 +318,6 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable( return false; if(!problem.Is3d()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(!problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index 0da32e4568..008d87cf6e 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -310,18 +310,12 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsForward()) return false; if(!problem.Is3d()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(!problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) diff --git a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index bcab0b289d..0823002fc9 100644 --- a/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -306,18 +306,12 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable( return false; if(miopen::IsEnabled(MIOPEN_DEBUG_CONVOLUTION_DETERMINISTIC{})) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) + if(problem.HasMixedDataTypes()) return false; if(!problem.direction.IsBackwardWrW()) return false; if(!problem.Is3d()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(!problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName())) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 4941229914..02368c45b7 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -263,10 +263,6 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(problem.GetInDataType() != problem.GetWeightsDataType() || - problem.GetWeightsDataType() != problem.GetOutDataType() || - problem.GetInDataType() != problem.GetOutDataType()) - return false; if(!problem.direction.IsForward()) return false; if(!problem.Is2d()) @@ -275,6 +271,8 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( { return false; } + if(problem.HasMixedDataTypes()) + return false; if(!IsXdlopsSupport(ctx)) return false; if(!IsComposableKernelSupportedHardware(ctx)) From 331b19d6c6194ee87c0cb1bb0e4980b3615c3587 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Mon, 23 Oct 2023 16:20:50 +0000 Subject: [PATCH 3/7] remove check from wingrad solvers --- src/solver/conv_winoRxS.cpp | 5 ----- src/solver/conv_wino_fury_RxS.cpp | 4 ---- 2 files changed, 9 deletions(-) diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index b8c07d8273..011ae0ec1e 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -622,13 +622,8 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti { if(!problem.Is2d()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(!(problem.IsFp32() || problem.IsFp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!ctx.use_asm_kernels) diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 85a0c32fa1..89f870e35e 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -171,10 +171,6 @@ bool ConvWinoFuryRxS::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; - if(problem.HasNonPackedTensors()) - { - return false; - } if(is2x3() && miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3{})) return false; From a0dcdecdfef47c81e0ad26417fbcfcfb520de2d5 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Tue, 24 Oct 2023 15:11:34 +0000 Subject: [PATCH 4/7] remove redundant/repeated checks --- src/solver/conv_asm_1x1u_bias_activ_fused.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp index 9da0e8780c..16063a04b5 100644 --- a/src/solver/conv_asm_1x1u_bias_activ_fused.cpp +++ b/src/solver/conv_asm_1x1u_bias_activ_fused.cpp @@ -243,10 +243,6 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context, const auto conv_problem = problem.GetConvProblem(0, conv::Direction::Forward); const auto conv_ctx = context.GetConvContext(conv_problem); - if(conv_problem.HasNonPackedTensors()) - { - return false; - } if(conv_problem.GetPadH() != conv_problem.GetPadW()) return false; if(conv_problem.GetPadH() != 0) @@ -260,9 +256,6 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context, if(conv_problem.GetDilationH() != 1) return false; - if(conv_problem.IsTensorsCasted()) - return false; - // Check if the conovlution part is applicable return sol.IsApplicable(conv_ctx, conv_problem); } From b9734ca2d40dbdc7de9a199883ee741db7d81bf3 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Tue, 24 Oct 2023 15:11:56 +0000 Subject: [PATCH 5/7] remove braces around if statements --- src/include/miopen/fusion/utils.hpp | 2 -- src/solver/conv_MP_bidirectional_winograd.cpp | 6 ------ src/solver/conv_asm_1x1u.cpp | 6 ------ src/solver/conv_asm_1x1u_stride2.cpp | 6 ------ src/solver/conv_asm_3x3u.cpp | 4 ---- src/solver/conv_asm_5x10u2v2b1.cpp | 8 -------- src/solver/conv_asm_5x10u2v2f1.cpp | 9 --------- src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp | 8 -------- src/solver/conv_asm_dir_BwdWrW1x1.cpp | 7 ------- src/solver/conv_asm_dir_BwdWrW3x3.cpp | 5 ----- src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp | 4 ---- src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp | 5 +---- src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp | 3 +-- src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp | 5 +---- src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp | 2 -- src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp | 3 +-- src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp | 3 +-- src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp | 5 +---- .../conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp | 5 +---- src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp | 5 +---- src/solver/conv_bin_wino3x3U.cpp | 4 ---- src/solver/conv_bin_winoRxS.cpp | 5 ----- src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp | 2 -- src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp | 3 --- src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp | 3 --- src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp | 2 -- src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp | 5 +---- src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp | 5 +---- src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp | 3 --- src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp | 2 -- src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp | 2 -- src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp | 6 ++---- ...onv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp | 5 +---- src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp | 6 ++---- src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp | 2 -- src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp | 2 -- src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp | 3 --- src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp | 4 ---- ...onv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp | 4 ---- src/solver/conv_mlir_igemm_bwd.cpp | 2 -- src/solver/conv_mlir_igemm_bwd_xdlops.cpp | 2 -- src/solver/conv_mlir_igemm_fwd.cpp | 2 -- src/solver/conv_mlir_igemm_fwd_xdlops.cpp | 2 -- src/solver/conv_mlir_igemm_wrw.cpp | 2 -- src/solver/conv_mlir_igemm_wrw_xdlops.cpp | 2 -- src/solver/conv_multipass_wino3x3WrW.cpp | 7 ------- src/solver/conv_ocl_dir2D11x11.cpp | 5 ----- src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp | 5 ----- src/solver/conv_ocl_dir2D_bwdWrW_2.cpp | 3 --- src/solver/conv_ocl_dir2D_bwdWrW_53.cpp | 4 ---- src/solver/conv_ocl_dir2Dfwd.cpp | 5 ----- src/solver/conv_ocl_dir2Dfwd1x1.cpp | 5 ----- src/solver/conv_ocl_dir2Dfwdgen.cpp | 6 ------ src/solver/conv_winoRxS.cpp | 2 -- src/solver/fft.cpp | 2 -- 55 files changed, 15 insertions(+), 210 deletions(-) diff --git a/src/include/miopen/fusion/utils.hpp b/src/include/miopen/fusion/utils.hpp index 720dcce5c1..5669de990f 100644 --- a/src/include/miopen/fusion/utils.hpp +++ b/src/include/miopen/fusion/utils.hpp @@ -85,9 +85,7 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes if(!conv_problem.IsFp32()) return false; if(conv_problem.HasNonPackedTensors()) - { return false; - } if(!conv_problem.IsLayoutDefault()) return false; if(!conv_problem.direction.IsForward()) diff --git a/src/solver/conv_MP_bidirectional_winograd.cpp b/src/solver/conv_MP_bidirectional_winograd.cpp index 7b7286517e..1ecc210144 100644 --- a/src/solver/conv_MP_bidirectional_winograd.cpp +++ b/src/solver/conv_MP_bidirectional_winograd.cpp @@ -229,9 +229,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc } if(!problem.IsLayoutDefault()) - { return false; - } { unsigned int const waves_in_group = 512 / wave_size; @@ -324,14 +322,10 @@ bool ConvMPBidirectWinograd::IsA // HIP backend required for sending ptr (buffer + offset) // ROCBLAS for GEMM step if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_asm_1x1u.cpp b/src/solver/conv_asm_1x1u.cpp index 466e3007cf..be72cf1fe3 100644 --- a/src/solver/conv_asm_1x1u.cpp +++ b/src/solver/conv_asm_1x1u.cpp @@ -528,9 +528,7 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) @@ -549,13 +547,9 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip const std::string name = ctx.GetStream().GetDeviceName(); if(name.find("gfx9") == std::string::npos) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_1x1u_stride2.cpp b/src/solver/conv_asm_1x1u_stride2.cpp index 604aac95fb..3a102d4701 100644 --- a/src/solver/conv_asm_1x1u_stride2.cpp +++ b/src/solver/conv_asm_1x1u_stride2.cpp @@ -490,9 +490,7 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -509,13 +507,9 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx, const std::string name = ctx.GetStream().GetDeviceName(); if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_3x3u.cpp b/src/solver/conv_asm_3x3u.cpp index 5c3d98d257..c486b4d867 100644 --- a/src/solver/conv_asm_3x3u.cpp +++ b/src/solver/conv_asm_3x3u.cpp @@ -177,9 +177,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) @@ -198,9 +196,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx90"))) return false; if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_5x10u2v2b1.cpp b/src/solver/conv_asm_5x10u2v2b1.cpp index b271ffc755..98b604ef3e 100644 --- a/src/solver/conv_asm_5x10u2v2b1.cpp +++ b/src/solver/conv_asm_5x10u2v2b1.cpp @@ -46,9 +46,7 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -67,17 +65,11 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx, return false; #endif if(!device_is_gfx8_9_no_xnack) - { return false; - } if(!problem.direction.IsBackwardData()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; diff --git a/src/solver/conv_asm_5x10u2v2f1.cpp b/src/solver/conv_asm_5x10u2v2f1.cpp index c1ba614f20..8d3c1a1716 100644 --- a/src/solver/conv_asm_5x10u2v2f1.cpp +++ b/src/solver/conv_asm_5x10u2v2f1.cpp @@ -47,9 +47,7 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -68,18 +66,11 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx, return false; #endif if(!device_is_gfx8_9_no_xnack) - { return false; - } if(!problem.direction.IsForward()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp index e28b00aca1..309b178dfb 100644 --- a/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp +++ b/src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp @@ -52,9 +52,7 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted()) return false; @@ -70,17 +68,11 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx #endif if(!(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" || name == "gfx900" || name == "gfx904" || name == "gfx906" || name == "gfx908")) - { return false; - } if(!problem.direction.IsForward()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off return problem.GetPadW() == 3 // -q diff --git a/src/solver/conv_asm_dir_BwdWrW1x1.cpp b/src/solver/conv_asm_dir_BwdWrW1x1.cpp index 8a7ec8cf28..04e0f89b4e 100644 --- a/src/solver/conv_asm_dir_BwdWrW1x1.cpp +++ b/src/solver/conv_asm_dir_BwdWrW1x1.cpp @@ -480,14 +480,11 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardWrW()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) return false; - if(problem.IsTensorsCasted()) return false; @@ -497,13 +494,9 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, const std::string name = ctx.GetStream().GetDeviceName(); if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(name == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; diff --git a/src/solver/conv_asm_dir_BwdWrW3x3.cpp b/src/solver/conv_asm_dir_BwdWrW3x3.cpp index 80b7e8cec1..6ae7330c63 100644 --- a/src/solver/conv_asm_dir_BwdWrW3x3.cpp +++ b/src/solver/conv_asm_dir_BwdWrW3x3.cpp @@ -363,9 +363,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardWrW()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!ctx.rmv.IsV2orV3()) @@ -379,10 +377,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx, if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx9"))) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; #if WORKAROUND_ISSUE_532 diff --git a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp index ddda5320f5..6ff60242dc 100644 --- a/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp @@ -144,9 +144,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.Is2d()) return false; @@ -164,9 +162,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp index 4eef538ddd..d0551138de 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp @@ -993,9 +993,8 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(!problem.IsFp32() && !problem.IsFp16()) return false; @@ -1010,9 +1009,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index fbc858e405..42f1e9f03e 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -939,9 +939,8 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable( return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp index ae8814004b..5bc9da9b81 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp @@ -1517,10 +1517,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsFp32() && !problem.IsFp16()) return false; @@ -1535,9 +1534,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext return false; if(!problem.IsLayoutDefault()) - { return false; - } #if WORKAROUND_SWDEV_306318 if((problem.GetWeightsHeight_() == 1) && (problem.GetWeightsWidth_() == 1) && diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp index bb70d40c20..d7b692cfb3 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp @@ -561,9 +561,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable( return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsLayoutNCHWc()) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 76f17997df..2e38366b74 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -877,9 +877,8 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable( return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index f0f5179186..3db8021b26 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -867,9 +867,8 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable( return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(!problem.IsFp32() && !problem.IsFp16() && !(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94")))) return false; diff --git a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp index 04de997cb8..baceb8089f 100644 --- a/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp @@ -291,10 +291,9 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsFp32()) return false; @@ -309,9 +308,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index 386ffd6e60..c1e7d4454d 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -832,10 +832,9 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsFp32() && !problem.IsFp16()) return false; @@ -850,9 +849,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp index 1752c3c9e3..7adb28fdae 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_v4r1_dynamic.cpp @@ -317,10 +317,9 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx if(!problem.IsFp32()) return false; + if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted()) return false; @@ -332,9 +331,7 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) diff --git a/src/solver/conv_bin_wino3x3U.cpp b/src/solver/conv_bin_wino3x3U.cpp index b29bf1e915..0fc9263413 100644 --- a/src/solver/conv_bin_wino3x3U.cpp +++ b/src/solver/conv_bin_wino3x3U.cpp @@ -66,14 +66,10 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx, const auto grid_workgroup_count_x = ctx.GetStream().GetMaxComputeUnits(); if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsLayoutDefault()) - { return false; - } if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_bin_winoRxS.cpp b/src/solver/conv_bin_winoRxS.cpp index 2d73354b4e..f1fd0868c3 100644 --- a/src/solver/conv_bin_winoRxS.cpp +++ b/src/solver/conv_bin_winoRxS.cpp @@ -186,9 +186,7 @@ static inline bool IsShaderContraintsMet(const miopen::ExecutionContext& ctx, } const auto grid_workgroup_count_x = ctx.GetStream().GetMaxComputeUnits(); if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off // Check implementation limits. @@ -222,11 +220,8 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.IsFp32() || problem.IsFp16())) return false; - if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted()) return false; if(miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_RXS{})) diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index 864a4549a5..12682b2efd 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -414,9 +414,7 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, if(!conv_problem.Is2d()) return false; if(conv_problem.HasNonPackedTensors()) - { return false; - } if(conv_problem.HasMixedDataTypes()) return false; if(conv_problem.IsTensorsCasted()) diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index 726ae1160b..b1ee39bf7f 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -103,11 +103,8 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(!(problem.IsFp32() or problem.IsFp16())) return false; - if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsLayoutDefault()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 10a5081dc2..85bec04104 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -261,12 +261,9 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable( if(problem.GetConv().attribute.deterministic) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.HasMixedDataTypes()) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.direction.IsBackwardData()) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp index 121c05a8cc..5b3cb4933f 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1.cpp @@ -639,9 +639,7 @@ bool ConvHipImplicitGemmBwdDataV1R1::IsApplicable(const ExecutionContext& ctx, if(!ctx.use_hip_kernels) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsLayoutDefault()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp index 79456d4be2..6db26950d5 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v1r1_xdlops.cpp @@ -785,9 +785,7 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted()) return false; @@ -799,9 +797,8 @@ bool ConvHipImplicitGemmBwdDataV1R1Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.IsLayoutDefault()) - { return false; - } + // gemm size int gemm_g = -1; int gemm_m = -1; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp index a64dcf18c5..a0cca73af9 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1.cpp @@ -755,9 +755,8 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(problem.IsTensorsCasted()) return false; @@ -765,9 +764,7 @@ bool ConvHipImplicitGemmBwdDataV4R1::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsLayoutDefault()) - { return false; - } if(!IsIndexRangeLargeEnough(problem)) return false; diff --git a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp index fade77c696..2c4f36e820 100644 --- a/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_bwd_v4r1_xdlops.cpp @@ -839,12 +839,9 @@ bool ConvHipImplicitGemmBwdDataV4R1Xdlops::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!IsApplicableXdlops(ctx, problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp index b21886d916..88d0f7b314 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r1.cpp @@ -58,9 +58,7 @@ bool ConvHipImplicitGemmV4R1Fwd::IsApplicable(const ExecutionContext& ctx, if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsFp32() && !problem.IsFp16() && !problem.IsBfp16()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index d819b7d20d..f6ff58a7dc 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -594,9 +594,7 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, if(problem.GetGroupCount() != 1) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!problem.IsLayoutDefault()) return false; if(!IsIndexRangeLargeEnough(problem)) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 558f4b33f1..4edd2a7328 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -992,10 +992,9 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) - { return false; - } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; @@ -1010,9 +1009,8 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.IsLayoutDefault()) - { return false; - } + // gemm size { int gemm_g = -1; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp index d46c4c1375..ed868b3d04 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops_padded_gemm.cpp @@ -1066,9 +1066,8 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) return false; @@ -1076,9 +1075,7 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops_Padded_Gemm::IsApplicable( return false; if(!problem.IsLayoutDefault()) - { return false; - } // gemm size { int gemm_g = -1; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp index b55e813f6d..6006cb0caa 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r5_xdlops.cpp @@ -1021,9 +1021,8 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& return false; if(problem.HasNonPackedTensors()) - { return false; - } + if(problem.IsTensorsCasted()) return false; @@ -1047,9 +1046,8 @@ bool ConvHipImplicitGemmForwardV4R5Xdlops::IsApplicable(const ExecutionContext& return false; if(!problem.IsLayoutDefault()) - { return false; - } + // gemm size { int gemm_g = -1; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 98939d684c..5290754060 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -266,9 +266,7 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.HasMixedDataTypes()) return false; if(!IsXdlopsSupport(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp index 7e2823d0f3..8e7898ea1a 100644 --- a/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_grouped_fwd_xdlops.cpp @@ -287,9 +287,7 @@ bool ConvHipImplicitGemmGroupFwdXdlops::IsApplicable( if(miopen::IsDisabled(MIOPEN_DEBUG_GROUP_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS{})) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted()) return false; if(problem.GetConv().attribute.deterministic) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index 18148411e0..8c75a623e3 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -593,10 +593,7 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, if(!problem.IsFp32()) return false; if(problem.HasNonPackedTensors()) - { return false; - } - if(!problem.IsLayoutDefault()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp index f2380f469f..0f38df9e6c 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops.cpp @@ -1058,9 +1058,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; @@ -1081,9 +1079,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops::IsApplicable(const ExecutionContext& ctx, return false; if(!problem.IsLayoutDefault()) - { return false; - } // this particular HeuristicInit is so comprehensive, that if it cannot predict a valid // performance config, the problem is probably not applicable diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index 6d5da3657e..581559e655 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -1129,9 +1129,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; @@ -1146,9 +1144,7 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( return false; if(!problem.IsLayoutDefault()) - { return false; - } // this particular HeuristicInit is so comprehensive, that if it cannot predict a valid #if WORKAROUND_MI100_BF16_FATAL_COMPILER_ERRORS diff --git a/src/solver/conv_mlir_igemm_bwd.cpp b/src/solver/conv_mlir_igemm_bwd.cpp index 4594661811..6804cccf42 100644 --- a/src/solver/conv_mlir_igemm_bwd.cpp +++ b/src/solver/conv_mlir_igemm_bwd.cpp @@ -48,9 +48,7 @@ bool ConvMlirIgemmBwd::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardData()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp index aa9392c4c6..439b1ab394 100644 --- a/src/solver/conv_mlir_igemm_bwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_bwd_xdlops.cpp @@ -51,9 +51,7 @@ bool ConvMlirIgemmBwdXdlops::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardData()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_mlir_igemm_fwd.cpp b/src/solver/conv_mlir_igemm_fwd.cpp index ddf5908b70..034313e1d0 100644 --- a/src/solver/conv_mlir_igemm_fwd.cpp +++ b/src/solver/conv_mlir_igemm_fwd.cpp @@ -168,9 +168,7 @@ bool ConvMlirIgemmFwd::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsForward()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp index b98b4c9f38..fa17520f07 100644 --- a/src/solver/conv_mlir_igemm_fwd_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_fwd_xdlops.cpp @@ -65,9 +65,7 @@ bool ConvMlirIgemmFwdXdlops::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsForward()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) diff --git a/src/solver/conv_mlir_igemm_wrw.cpp b/src/solver/conv_mlir_igemm_wrw.cpp index aebf8d7665..7b8286d896 100644 --- a/src/solver/conv_mlir_igemm_wrw.cpp +++ b/src/solver/conv_mlir_igemm_wrw.cpp @@ -51,9 +51,7 @@ bool ConvMlirIgemmWrW::IsApplicable(const ExecutionContext& ctx, if(!IsComposableKernelSupportedHardware(ctx)) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; // Note: ConvMlirIgemmWrW can run on a machine with xdlops support, however, it is diff --git a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp index 40b16835ab..5dd480da63 100644 --- a/src/solver/conv_mlir_igemm_wrw_xdlops.cpp +++ b/src/solver/conv_mlir_igemm_wrw_xdlops.cpp @@ -52,9 +52,7 @@ bool ConvMlirIgemmWrWXdlops::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardWrW()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8()) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_multipass_wino3x3WrW.cpp b/src/solver/conv_multipass_wino3x3WrW.cpp index 5829afdc80..4f7a1626cc 100644 --- a/src/solver/conv_multipass_wino3x3WrW.cpp +++ b/src/solver/conv_multipass_wino3x3WrW.cpp @@ -437,18 +437,13 @@ bool ConvWinograd3x3MultipassWrW if(!problem.direction.IsBackwardWrW()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) @@ -497,9 +492,7 @@ bool ConvWinograd3x3MultipassWrW return false; } if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off { diff --git a/src/solver/conv_ocl_dir2D11x11.cpp b/src/solver/conv_ocl_dir2D11x11.cpp index 33fe4f8767..482892946e 100644 --- a/src/solver/conv_ocl_dir2D11x11.cpp +++ b/src/solver/conv_ocl_dir2D11x11.cpp @@ -48,20 +48,15 @@ bool ConvOclDirectFwd11x11::IsApplicable(const ExecutionContext& ctx, if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } return problem.direction.IsForward() && problem.GetGroupCount() == 1 && problem.GetDilationH() == 1 && problem.GetDilationW() == 1 && diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp index 360a9730f2..2c13d29f11 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_1x1.cpp @@ -59,18 +59,13 @@ bool ConvOclBwdWrW1x1::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardWrW()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp index 2b400909f8..c59449eab8 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_2.cpp @@ -464,10 +464,7 @@ bool ConvOclBwdWrW2::IsApplicableBase(const ExecutionContext& ctx if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.IsTensorsCasted()) return false; diff --git a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp index 3dddf7136f..7f50e7f3ee 100644 --- a/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp +++ b/src/solver/conv_ocl_dir2D_bwdWrW_53.cpp @@ -52,9 +52,7 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) @@ -65,9 +63,7 @@ bool ConvOclBwdWrW53::IsApplicable(const ExecutionContext& ctx, if(!problem.direction.IsBackwardWrW()) return false; if(!problem.IsLayoutDefault()) - { return false; - } bool workaround = false; diff --git a/src/solver/conv_ocl_dir2Dfwd.cpp b/src/solver/conv_ocl_dir2Dfwd.cpp index 86c0e4e0ba..6b7d2f1f8e 100644 --- a/src/solver/conv_ocl_dir2Dfwd.cpp +++ b/src/solver/conv_ocl_dir2Dfwd.cpp @@ -49,20 +49,15 @@ bool ConvOclDirectFwd::IsApplicable(const ExecutionContext& ctx, if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } // clang-format off // Cases when dy has negative padding are not supported (issue 918) diff --git a/src/solver/conv_ocl_dir2Dfwd1x1.cpp b/src/solver/conv_ocl_dir2Dfwd1x1.cpp index fad70bd428..b3e0bd0a2b 100644 --- a/src/solver/conv_ocl_dir2Dfwd1x1.cpp +++ b/src/solver/conv_ocl_dir2Dfwd1x1.cpp @@ -58,20 +58,15 @@ bool ConvOclDirectFwd1x1::IsApplicable(const ExecutionContext& ctx, if(!(problem.direction.IsForward() || problem.direction.IsBackwardData())) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } return problem.GetDilationW() == 1 && problem.GetDilationH() == 1 && problem.GetWeightsWidth_() == 1 && problem.GetWeightsHeight_() == 1 && diff --git a/src/solver/conv_ocl_dir2Dfwdgen.cpp b/src/solver/conv_ocl_dir2Dfwdgen.cpp index cb296c4548..659f0ddf73 100644 --- a/src/solver/conv_ocl_dir2Dfwdgen.cpp +++ b/src/solver/conv_ocl_dir2Dfwdgen.cpp @@ -46,21 +46,15 @@ bool ConvOclDirectFwdGen::IsApplicable(const ExecutionContext& ctx, if(!problem.Is2d()) return false; if(problem.HasNonPackedTensors()) - { return false; - } if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW()) return false; if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(problem.IsTensorsCasted()) return false; if(!problem.IsLayoutDefault()) - { return false; - } - if(problem.GetGroupCount() > 1) return false; diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index 011ae0ec1e..ef4b60a18b 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -285,9 +285,7 @@ inline bool IsShaderConstraintsMet(const ProblemDescription& problem, } if(!problem.IsLayoutDefault()) - { return false; - } return IsWinogradV21Preferred(asic, problem) ? IsShaderConstraintsMetV21(problem, R, S, C, K, H, W, OH, OW, N) diff --git a/src/solver/fft.cpp b/src/solver/fft.cpp index 9a3c7858cc..3ca98b4720 100644 --- a/src/solver/fft.cpp +++ b/src/solver/fft.cpp @@ -116,9 +116,7 @@ bool fft::IsApplicable(const ExecutionContext& ctx, const ProblemDescription& pr return false; if(!problem.IsLayoutDefault()) - { return false; - } const auto is_fwd = problem.direction.IsForward(); decltype(auto) conv = problem.GetConv(); From 82cf6d78cfbba831291d1b345f74d6e83fc0b699 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Tue, 24 Oct 2023 15:12:47 +0000 Subject: [PATCH 6/7] fix formatting --- src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp index c1e7d4454d..aa84c8c76e 100644 --- a/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp +++ b/src/solver/conv_asm_implicit_gemm_wrw_gtc_dynamic_xdlops.cpp @@ -832,7 +832,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext if(!problem.Is2d()) return false; - + if(problem.HasNonPackedTensors()) return false; From d702070b94bdddba7eb2e5074ed7eadb02dd8724 Mon Sep 17 00:00:00 2001 From: "M. Amber Hassaan" Date: Wed, 25 Oct 2023 02:44:47 +0000 Subject: [PATCH 7/7] address comments. Revert reordered checks. Re-add check to WinoGrad solvers --- src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp | 6 +++--- src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp | 4 ++-- src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp | 8 ++++---- .../conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp | 10 +++++----- src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp | 8 ++++---- src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp | 12 ++++++------ ...hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp | 7 ++++--- src/solver/conv_winoRxS.cpp | 2 ++ src/solver/conv_wino_fury_RxS.cpp | 3 +++ 9 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index 12682b2efd..75f5cd1ac3 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -409,15 +409,15 @@ bool ConvCKIgemmFwdBiasActivFused::IsApplicable(const FusionContext& ctx, return false; const auto conv_problem = fdesc_problem.GetConvProblem(0, conv::Direction::Forward); - if(conv_problem.GetConv().attribute.deterministic) + if(conv_problem.IsTensorsCasted()) return false; - if(!conv_problem.Is2d()) + if(conv_problem.GetConv().attribute.deterministic) return false; if(conv_problem.HasNonPackedTensors()) return false; if(conv_problem.HasMixedDataTypes()) return false; - if(conv_problem.IsTensorsCasted()) + if(!conv_problem.Is2d()) return false; const std::string arch = ctx.GetStream().GetDeviceName(); if(arch != "gfx908" && arch != "gfx90a" && arch != "gfx940" && arch != "gfx941" && diff --git a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp index b1ee39bf7f..9c42d8b8db 100644 --- a/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp +++ b/src/solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp @@ -97,6 +97,8 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(!ck_utility::is_ck_supported_hardware(ctx.GetStream())) return false; + if(!problem.IsLayoutDefault()) + return false; if(!problem.direction.IsForward()) return false; if(!problem.Is2d()) @@ -105,8 +107,6 @@ bool ConvCkIgemmFwdV6r1DlopsNchw::IsApplicable(const ExecutionContext& ctx, return false; if(problem.HasNonPackedTensors()) return false; - if(!problem.IsLayoutDefault()) - return false; if(problem.IsTensorsCasted()) return false; if(problem.GetGroupCount() != 1) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp index f6ff58a7dc..9babd8fff8 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4.cpp @@ -583,6 +583,10 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; + if(!problem.IsLayoutDefault()) + return false; + if(problem.HasNonPackedTensors()) + return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; if(!problem.direction.IsForward()) @@ -593,10 +597,6 @@ bool ConvHipImplicitGemmV4R4Fwd::IsApplicable(const ExecutionContext& ctx, return false; if(problem.GetGroupCount() != 1) return false; - if(problem.HasNonPackedTensors()) - return false; - if(!problem.IsLayoutDefault()) - return false; if(!IsIndexRangeLargeEnough(problem)) return false; diff --git a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp index 4edd2a7328..8a7b4b150b 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_v4r4_xdlops.cpp @@ -987,19 +987,19 @@ bool ConvHipImplicitGemmForwardV4R4Xdlops::IsApplicable(const ExecutionContext& if(!IsXdlopsSupport(ctx)) return false; - if(!problem.direction.IsForward()) + if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) return false; - if(!problem.Is2d()) + if(problem.HasNonPackedTensors()) return false; - if(problem.HasNonPackedTensors()) + if(problem.IsTensorsCasted()) return false; - if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) + if(!problem.direction.IsForward()) return false; - if(problem.IsTensorsCasted()) + if(!problem.Is2d()) return false; if(ctx.GetStream().GetDeviceName() == "gfx90a" && problem.IsGfx90aFp16altRequired()) diff --git a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp index 5290754060..1b1127d6e1 100644 --- a/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -261,14 +261,14 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( return false; if(problem.GetConv().attribute.deterministic) return false; - if(!problem.direction.IsForward()) - return false; - if(!problem.Is2d()) - return false; if(problem.HasNonPackedTensors()) return false; if(problem.HasMixedDataTypes()) return false; + if(!problem.direction.IsForward()) + return false; + if(!problem.Is2d()) + return false; if(!IsXdlopsSupport(ctx)) return false; if(!IsComposableKernelSupportedHardware(ctx)) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp index 8c75a623e3..c9b0d94241 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4.cpp @@ -586,18 +586,18 @@ bool ConvHipImplicitGemmV4R4WrW::IsApplicable(const ExecutionContext& ctx, return false; if(!ctx.use_hip_kernels) return false; - if(!problem.direction.IsBackwardWrW()) - return false; - if(!problem.Is2d() && !problem.Is3d()) - return false; - if(!problem.IsFp32()) - return false; if(problem.HasNonPackedTensors()) return false; if(!problem.IsLayoutDefault()) return false; if(!IsComposableKernelSupportedHardware(ctx)) return false; + if(!problem.direction.IsBackwardWrW()) + return false; + if(!problem.Is2d() && !problem.Is3d()) + return false; + if(!problem.IsFp32()) + return false; if(problem.IsTensorsCasted()) return false; if(problem.GetGroupCount() != 1) diff --git a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp index 581559e655..b910d9658a 100644 --- a/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp +++ b/src/solver/conv_hip_implicit_gemm_wrw_v4r4_xdlops_padded_gemm.cpp @@ -1123,15 +1123,16 @@ bool ConvHipImplicitGemmWrwV4R4Xdlops_Padded_Gemm::IsApplicable( if(!ctx.use_hip_kernels) return false; + if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) + return false; + if(!problem.direction.IsBackwardWrW()) return false; if(!problem.Is2d()) return false; - if(problem.HasNonPackedTensors()) - return false; - if(!(problem.IsFp32() || problem.IsFp16() || problem.IsBfp16())) + if(problem.HasNonPackedTensors()) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_winoRxS.cpp b/src/solver/conv_winoRxS.cpp index ef4b60a18b..6c7d94632f 100644 --- a/src/solver/conv_winoRxS.cpp +++ b/src/solver/conv_winoRxS.cpp @@ -620,6 +620,8 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti { if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; if(!(problem.IsFp32() || problem.IsFp16())) return false; if(problem.IsTensorsCasted()) diff --git a/src/solver/conv_wino_fury_RxS.cpp b/src/solver/conv_wino_fury_RxS.cpp index 89f870e35e..150deac983 100644 --- a/src/solver/conv_wino_fury_RxS.cpp +++ b/src/solver/conv_wino_fury_RxS.cpp @@ -171,6 +171,9 @@ bool ConvWinoFuryRxS::IsApplicable(const ExecutionContext& if(!problem.Is2d()) return false; + if(problem.HasNonPackedTensors()) + return false; + if(is2x3() && miopen::IsDisabled(MIOPEN_DEBUG_AMD_WINOGRAD_FURY_RXS_F2X3{})) return false;